Initial commit

This commit is contained in:
Ke Sun
2025-11-30 18:22:17 +08:00
commit aea2fe391e
449 changed files with 83030 additions and 0 deletions

View File

View File

@@ -0,0 +1,16 @@
"""
LangGraph Graph package for memory agent.
This package provides the LangGraph workflow orchestrator with modular
node implementations, routing logic, and state management.
Package structure:
- read_graph: Main graph factory for read operations
- write_graph: Main graph factory for write operations
- nodes: LangGraph node implementations
- routing: State routing logic
- state: State management utilities
"""
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
__all__ = ['make_read_graph']

View File

@@ -0,0 +1,10 @@
"""
LangGraph node implementations.
This module contains custom node implementations for the LangGraph workflow.
"""
from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
__all__ = ["ToolExecutionNode", "create_input_message"]

View File

@@ -0,0 +1,144 @@
"""
Input node for LangGraph workflow entry point.
This module provides the create_input_message function which processes initial
user input with multimodal support and creates the first tool call message.
"""
import logging
import re
import uuid
from datetime import datetime
from typing import Dict, Any
from langchain_core.messages import AIMessage
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
logger = logging.getLogger(__name__)
async def create_input_message(
state: Dict[str, Any],
tool_name: str,
session_id: str,
search_switch: str,
apply_id: str,
group_id: str,
multimodal_processor: MultimodalProcessor
) -> Dict[str, Any]:
"""
Create initial tool call message from user input.
This function:
1. Extracts the last message content from state
2. Processes multimodal inputs (images/audio) using the multimodal processor
3. Generates a unique message ID
4. Extracts namespace from session_id
5. Handles verified_data extraction for backward compatibility
6. Returns AIMessage with complete tool_calls structure
Args:
state: LangGraph state dictionary containing messages
tool_name: Name of the tool to invoke (typically "Split_The_Problem")
session_id: Session identifier (format: "call_id_{namespace}")
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
multimodal_processor: Processor for handling image/audio inputs
Returns:
State update with AIMessage containing tool_call
Examples:
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
>>> result = await create_input_message(
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor
... )
>>> result["messages"][0].tool_calls[0]["name"]
'Split_The_Problem'
"""
messages = state.get("messages", [])
# Extract last message content
if messages:
last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1])
else:
logger.warning("[create_input_message] No messages in state, using empty string")
last_message = ""
logger.debug(f"[create_input_message] Original input: {last_message[:100]}...")
# Process multimodal input (images/audio)
try:
processed_content = await multimodal_processor.process_input(last_message)
if processed_content != last_message:
logger.info(
f"[create_input_message] Multimodal processing converted input "
f"from {len(last_message)} to {len(processed_content)} chars"
)
last_message = processed_content
except Exception as e:
logger.error(
f"[create_input_message] Multimodal processing failed: {e}",
exc_info=True
)
# Continue with original content
# Generate unique message ID
uuid_str = uuid.uuid4()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Extract namespace from session_id
# Expected format: "call_id_{namespace}" or similar
try:
namespace = str(session_id).split('_id_')[1]
except (IndexError, AttributeError):
logger.warning(
f"[create_input_message] Could not extract namespace from session_id: {session_id}"
)
namespace = "unknown"
# Handle verified_data extraction (backward compatibility)
# This regex-based extraction is kept for compatibility with existing data formats
if 'verified_data' in str(last_message):
try:
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
query_match = re.findall(r'"query": "(.*?)",', messages_last)
if query_match:
last_message = query_match[0]
logger.debug(
f"[create_input_message] Extracted query from verified_data: {last_message}"
)
except Exception as e:
logger.warning(
f"[create_input_message] Failed to extract query from verified_data: {e}"
)
# Construct tool call message
tool_call_id = f"{session_id}_{uuid_str}"
logger.info(
f"[create_input_message] Creating tool call for '{tool_name}' "
f"with ID: {tool_call_id}"
)
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": tool_name,
"args": {
"sentence": last_message,
"sessionid": session_id,
"messages_id": str(uuid_str),
"search_switch": search_switch,
"apply_id": apply_id,
"group_id": group_id
},
"id": tool_call_id
}]
)
]
}

View File

@@ -0,0 +1,199 @@
"""
Tool execution node for LangGraph workflow.
This module provides the ToolExecutionNode class which wraps tool execution
with parameter transformation logic using the ParameterBuilder service.
"""
import logging
import time
from typing import Any, Callable, Dict
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_tool_call_id,
extract_content_payload
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
logger = logging.getLogger(__name__)
class ToolExecutionNode:
"""
Custom LangGraph node that wraps tool execution with parameter transformation.
This node extracts content from previous tool results, transforms parameters
based on tool type using ParameterBuilder, and invokes the tool with the
correct argument structure.
Attributes:
tool_node: LangGraph ToolNode wrapping the actual tool
id: Node identifier for message IDs
tool_name: Name of the tool being executed
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
"""
def __init__(
self,
tool: Callable,
node_id: str,
namespace: str,
search_switch: str,
apply_id: str,
group_id: str,
parameter_builder: ParameterBuilder,
storage_type:str,
user_rag_memory_id:str
):
"""
Initialize the tool execution node.
Args:
tool: The tool function to execute
node_id: Identifier for this node (used in message IDs)
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
"""
self.tool_node = ToolNode([tool])
self.id = node_id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.parameter_builder = parameter_builder
self.storage_type=storage_type
self.user_rag_memory_id=user_rag_memory_id
logger.info(
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
)
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute the tool with transformed parameters.
This method:
1. Extracts the last message from state
2. Extracts tool call ID using state extractors
3. Extracts content payload using state extractors
4. Builds tool arguments using parameter builder
5. Constructs AIMessage with tool_calls
6. Invokes the tool and returns the result
Args:
state: LangGraph state dictionary
Returns:
Updated state with tool result in messages
"""
messages = state.get("messages", [])
logger.debug( self.tool_name)
if not messages:
logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state")
return {"messages": [AIMessage(content="Error: No messages in state")]}
last_message = messages[-1]
logger.debug(
f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}"
)
try:
# Extract tool call ID using state extractors
tool_call_id = extract_tool_call_id(last_message)
logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}")
except ValueError as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}"
)
return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
try:
# Extract content payload using state extractors
content = extract_content_payload(last_message)
logger.debug(
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}"
)
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}",
exc_info=True
)
content = {}
try:
# Build tool arguments using parameter builder
tool_args = self.parameter_builder.build_tool_args(
tool_name=self.tool_name,
content=content,
tool_call_id=tool_call_id,
search_switch=self.search_switch,
apply_id=self.apply_id,
group_id=self.group_id,
storage_type=self.storage_type,
user_rag_memory_id=self.user_rag_memory_id
)
logger.debug(
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
)
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}",
exc_info=True
)
return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]}
# Construct tool input message
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": f"{self.id}_{tool_call_id}",
}]
)
]
}
try:
# Invoke the tool
result = await self.tool_node.ainvoke(tool_input)
logger.debug(
f"[ToolExecutionNode] {self.id} - Tool execution completed"
)
# Return the result directly - it already contains the messages list
return result
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
exc_info=True
)
# Return error as ToolMessage to maintain message chain consistency
from langchain_core.messages import ToolMessage
return {
"messages": [
ToolMessage(
content=f"Error executing tool: {str(e)}",
tool_call_id=f"{self.id}_{tool_call_id}"
)
]
}

View File

@@ -0,0 +1,508 @@
import asyncio
import io
import json
import logging
import os
import re
import time
import uuid
import warnings
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Literal
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from functools import partial
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
from langgraph.checkpoint.memory import InMemorySaver
from app.core.memory.agent.utils.redis_tool import store
from app.core.logging_config import get_agent_logger
# Import new modular components
from app.core.memory.agent.langgraph_graph.nodes import ToolExecutionNode, create_input_message
from app.core.memory.agent.langgraph_graph.routing.routers import (
Verify_continue,
Retrieve_continue,
Split_continue
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
logger = get_agent_logger(__name__)
warnings.filterwarnings("ignore", category=RuntimeWarning)
load_dotenv()
redishost=os.getenv("REDISHOST")
redisport=os.getenv('REDISPORT')
redisdb=os.getenv('REDISDB')
redispassword=os.getenv('REDISPASSWORD')
counter = COUNTState(limit=3)
# 在工作流中添加循环计数更新
async def update_loop_count(state):
"""更新循环计数器"""
current_count = state.get("loop_count", 0)
return {"loop_count": current_count + 1}
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
messages = state["messages"]
# 添加边界检查
if not messages:
return END
counter.add(1) # 累加 1
loop_count = counter.get_total()
logger.debug(f"[should_continue] 当前循环次数: {loop_count}")
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"Status tools: {status_tools}")
if "success" in status_tools:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # 最大循环次数 3
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# 添加默认返回值,避免返回 None
counter.reset()
return "Summary" # 或根据业务需求选择合适的默认值
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# 添加默认返回值,避免返回 None
return 'Retrieve_Summary' # 或根据业务逻辑选择合适的默认值
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
logger.debug(f"Split_continue state: {state}")
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '2':
return 'Input_Summary'
return 'Split_The_Problem' # 默认情况
# 在 input_sentence 函数中修改参数名称
async def input_sentence(state, name, id, search_switch,apply_id,group_id):
messages = state["messages"]
last_message = messages[-1].content if messages else ""
if last_message.endswith('.jpg') or last_message.endswith('.png'):
last_message=await picture_model_requests(last_message)
if any(last_message.endswith(ext) for ext in audio_extensions):
last_message=await Vico_recognition([last_message]).run()
logger.debug(f"Audio recognition result: {last_message}")
uuid_str = uuid.uuid4()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
namespace = str(id).split('_id_')[1]
if 'verified_data' in str(last_message):
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
last_message = re.findall(r'"query": "(.*?)",', str(messages_last))[0]
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": name,
"args": {
"sentence": last_message,
'sessionid': id,
'messages_id': str(uuid_str),
"search_switch": search_switch, # 正确地将 search_switch 放入 args 中
"apply_id":apply_id,
"group_id":group_id
},
"id": id + f'_{uuid_str}'
}]
)
]
}
class ProblemExtensionNode:
def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""):
self.tool_node = ToolNode([tool])
self.id = id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.storage_type = storage_type
self.user_rag_memory_id = user_rag_memory_id
async def __call__(self, state):
messages = state["messages"]
last_message = messages[-1] if messages else ""
logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}")
if self.tool_name=='Input_Summary':
tool_call =re.findall(f"'id': '(.*?)'",str(last_message))[0]
else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
# try:
# content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message
# except:
# content = last_message.content if hasattr(last_message, 'content') else str(last_message)
# 尝试从上一工具的结果中提取实际的内容载荷(而不是整个对象的字符串表示)
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
extracted_payload = None
# 捕获 ToolMessage 的 content 字段(支持单/双引号),并避免贪婪匹配
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
if m:
extracted_payload = m.group(1)
else:
# 回退:直接尝试使用原始字符串
extracted_payload = raw_msg
# 优先尝试将内容解析为 JSON
try:
content = json.loads(extracted_payload)
except Exception:
# 尝试从文本中提取 JSON 片段再解析
parsed = None
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
for cand in candidates:
try:
parsed = json.loads(cand)
break
except Exception:
continue
# 如果仍然失败,则以原始字符串作为内容
content = parsed if parsed is not None else extracted_payload
# 根据工具名称构建正确的参数
tool_args = {}
if self.tool_name == "Verify":
# Verify工具需要context和usermessages参数
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Retrieve":
# Retrieve工具需要context和usermessages参数
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary":
# Summary工具需要字符串类型的context参数
if isinstance(content, dict):
# 将字典转换为JSON字符串
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary_fails":
# Summary工具需要字符串类型的context参数
if isinstance(content, dict):
# 将字典转换为JSON字符串
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name=='Input_Summary':
tool_args["context"] =str(last_message)
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_args["storage_type"] = getattr(self, 'storage_type', "")
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
elif self.tool_name=='Retrieve_Summary' :
# Retrieve_Summary expects dict directly, not JSON string
# content might be a JSON string, try to parse it
if isinstance(content, str):
try:
parsed_content = json.loads(content)
# Check if it has a "context" key
if isinstance(parsed_content, dict) and "context" in parsed_content:
tool_args["context"] = parsed_content["context"]
else:
tool_args["context"] = parsed_content
except json.JSONDecodeError:
# If parsing fails, wrap the string
tool_args["context"] = {"content": content}
elif isinstance(content, dict):
# Check if content has a "context" key that needs unwrapping
if "context" in content:
tool_args["context"] = content["context"]
else:
tool_args["context"] = content
else:
tool_args["context"] = {"content": str(content)}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
else:
# 其他工具使用context参数
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": self.id + f"{tool_call}",
}]
)
]
}
result = await self.tool_node.ainvoke(tool_input)
result_text = str(result)
return {"messages": [AIMessage(content=result_text)]}
@asynccontextmanager
async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config_id=None,storage_type=None,user_rag_memory_id=None):
memory = InMemorySaver()
tool=[i.name for i in tools ]
logger.info(f"Initializing read graph with tools: {tool}")
if config_id:
logger.info(f"使用配置 ID: {config_id}")
# Extract tool functions
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None)
Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None)
Verify_ = next((t for t in tools if t.name == "Verify"), None)
Summary_ = next((t for t in tools if t.name == "Summary"), None)
Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None)
Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None)
Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None)
# Instantiate services
parameter_builder = ParameterBuilder()
multimodal_processor = MultimodalProcessor()
# Create nodes using new modular components
Split_The_Problem_node = ToolNode([Split_The_Problem_])
Problem_Extension_node = ToolExecutionNode(
tool=Problem_Extension_,
node_id="Problem_Extension_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Retrieve_node = ToolExecutionNode(
tool=Retrieve_,
node_id="Retrieve_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Verify_node = ToolExecutionNode(
tool=Verify_,
node_id="Verify_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Summary_node = ToolExecutionNode(
tool=Summary_,
node_id="Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Summary_fails_node = ToolExecutionNode(
tool=Summary_fails_,
node_id="Summary_fails_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Retrieve_Summary_node = ToolExecutionNode(
tool=Retrieve_Summary_,
node_id="Retrieve_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Input_Summary_node = ToolExecutionNode(
tool=Input_Summary_,
node_id="Input_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
async def content_input_node(state):
state_search_switch = state.get("search_switch", search_switch)
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
return await create_input_message(
state=state,
tool_name=tool_name,
session_id=f"{session_prefix}_{namespace}",
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
multimodal_processor=multimodal_processor
)
# Build workflow graph
workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem_node)
workflow.add_node("Problem_Extension", Problem_Extension_node)
workflow.add_node("Retrieve", Retrieve_node)
workflow.add_node("Verify", Verify_node)
workflow.add_node("Summary", Summary_node)
workflow.add_node("Summary_fails", Summary_fails_node)
workflow.add_node("Retrieve_Summary", Retrieve_Summary_node)
workflow.add_node("Input_Summary", Input_Summary_node)
# Add edges using imported routers
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
graph = workflow.compile(checkpointer=memory)
yield graph
# 添加到文件末尾或创建新的执行脚本
# 在 memory_agent_service.py 文件中添加以下函数

View File

@@ -0,0 +1,13 @@
"""LangGraph routing logic."""
from app.core.memory.agent.langgraph_graph.routing.routers import (
Verify_continue,
Retrieve_continue,
Split_continue,
)
__all__ = [
"Verify_continue",
"Retrieve_continue",
"Split_continue",
]

View File

@@ -0,0 +1,123 @@
"""
Routing functions for LangGraph conditional edges.
This module provides routing functions that determine the next node to execute
based on state values. All functions return Literal types for type safety.
"""
import logging
import re
from typing import Literal
from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
logger = logging.getLogger(__name__)
# Global counter for Verify routing
counter = COUNTState(limit=3)
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
"""
Determine routing after Verify node based on verification result.
This function checks the verification result in the last message and routes to:
- Summary: if verification succeeded
- content_input: if verification failed and retry limit not reached
- Summary_fails: if verification failed and retry limit reached
Args:
state: LangGraph state containing messages
Returns:
Next node name as Literal type
"""
messages = state.get("messages", [])
# Boundary check
if not messages:
logger.warning("[Verify_continue] No messages in state, defaulting to Summary")
counter.reset()
return "Summary"
# Increment counter
counter.add(1)
loop_count = counter.get_total()
logger.debug(f"[Verify_continue] Current loop count: {loop_count}")
# Extract verification result from last message
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"[Verify_continue] Status tools: {status_tools}")
# Route based on verification result
if "success" in status_tools:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # Max retry count is 2
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# Default to Summary if status is unclear
counter.reset()
return "Summary"
def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing after Retrieve node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '0': Route to Verify (verification needed)
- search_switch == '1': Route to Retrieve_Summary (direct summary)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
search_switch = extract_search_switch(state)
logger.debug(f"[Retrieve_continue] search_switch: {search_switch}")
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# Default to Retrieve_Summary
logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary")
return 'Retrieve_Summary'
def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing after content_input node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '2': Route to Input_Summary (direct input summary)
- Otherwise: Route to Split_The_Problem (problem decomposition)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
logger.debug(f"[Split_continue] state keys: {state.keys()}")
search_switch = extract_search_switch(state)
logger.debug(f"[Split_continue] search_switch: {search_switch}")
if search_switch == '2':
return 'Input_Summary'
# Default to Split_The_Problem
return 'Split_The_Problem'

View File

@@ -0,0 +1,13 @@
"""LangGraph state management utilities."""
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_search_switch,
extract_tool_call_id,
extract_content_payload,
)
__all__ = [
"extract_search_switch",
"extract_tool_call_id",
"extract_content_payload",
]

View File

@@ -0,0 +1,164 @@
"""
State extraction utilities for type-safe access to LangGraph state values.
This module provides utility functions for extracting values from LangGraph state
dictionaries with proper error handling and sensible defaults.
"""
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
def extract_search_switch(state: dict) -> Optional[str]:
"""
Extract search_switch from state or messages.
"""
search_switch = state.get("search_switch")
if search_switch is not None:
return str(search_switch)
# Try to extract from messages
messages = state.get("messages", [])
if not messages:
return None
# 从最新的消息开始查找
for message in reversed(messages):
# 尝试从 tool_calls 中提取
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if isinstance(tool_call, dict):
# 从 tool_call 的 args 中提取
if "args" in tool_call and isinstance(tool_call["args"], dict):
search_switch = tool_call["args"].get("search_switch")
if search_switch is not None:
return str(search_switch)
# 直接从 tool_call 中提取
search_switch = tool_call.get("search_switch")
if search_switch is not None:
return str(search_switch)
# 尝试从 content 中提取(如果是 JSON 格式)
if hasattr(message, "content"):
try:
import json
if isinstance(message.content, str):
content_data = json.loads(message.content)
if isinstance(content_data, dict):
search_switch = content_data.get("search_switch")
if search_switch is not None:
return str(search_switch)
except (json.JSONDecodeError, ValueError):
pass
return None
def extract_tool_call_id(message: Any) -> str:
"""
Extract tool call ID from message using structured attributes.
This function extracts the tool call ID from a message object, handling both
direct attribute access and tool_calls list structures.
Args:
message: Message object (typically ToolMessage or AIMessage)
Returns:
Tool call ID as string
Raises:
ValueError: If tool call ID cannot be extracted
Examples:
>>> message = ToolMessage(content="...", tool_call_id="call_123")
>>> extract_tool_call_id(message)
'call_123'
"""
# Try direct attribute access for ToolMessage
if hasattr(message, "tool_call_id"):
tool_call_id = message.tool_call_id
if tool_call_id:
return str(tool_call_id)
# Try extracting from tool_calls list for AIMessage
if hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "id" in tool_call:
return str(tool_call["id"])
# Try extracting from id attribute
if hasattr(message, "id"):
message_id = message.id
if message_id:
return str(message_id)
# If all else fails, raise an error
raise ValueError(f"Could not extract tool call ID from message: {type(message)}")
def extract_content_payload(message: Any) -> Any:
"""
Extract content payload from ToolMessage, parsing JSON if needed.
This function extracts the content from a message and attempts to parse it as JSON
if it appears to be a JSON string. It handles various message formats and provides
sensible fallbacks.
Args:
message: Message object (typically ToolMessage)
Returns:
Parsed content (dict, list, or str)
Examples:
>>> message = ToolMessage(content='{"key": "value"}')
>>> extract_content_payload(message)
{'key': 'value'}
>>> message = ToolMessage(content='plain text')
>>> extract_content_payload(message)
'plain text'
"""
# Extract raw content
# For ToolMessages (responses from tools), extract from content
if hasattr(message, "content"):
raw_content = message.content
# If content is empty and this is an AIMessage with tool_calls,
# extract from args (this handles the initial tool call from content_input)
if not raw_content and hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "args" in tool_call:
return tool_call["args"]
else:
raw_content = str(message)
# If content is already a dict or list, return it directly
if isinstance(raw_content, (dict, list)):
return raw_content
# Try to parse as JSON
if isinstance(raw_content, str):
# First, try direct JSON parsing
try:
return json.loads(raw_content)
except (json.JSONDecodeError, ValueError):
pass
# If that fails, try to extract JSON from the string
# This handles cases where the content is embedded in a larger string
import re
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
for candidate in json_candidates:
try:
return json.loads(candidate)
except (json.JSONDecodeError, ValueError):
continue
# If all parsing attempts fail, return the raw content
return raw_content

View File

@@ -0,0 +1,78 @@
import asyncio
import json
from contextlib import asynccontextmanager
from langgraph.constants import START, END
from langgraph.graph import add_messages, StateGraph
from langgraph.prebuilt import ToolNode
from app.core.memory.agent.utils.llm_tools import WriteState
import warnings
import sys
from langchain_core.messages import AIMessage
from app.core.logging_config import get_agent_logger
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@asynccontextmanager
async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None):
logger.info("加载 MCP 工具: %s", [t.name for t in tools])
if config_id:
logger.info(f"使用配置 ID: {config_id}")
data_type_tool = next((t for t in tools if t.name == "Data_type_differentiation"), None)
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
if not data_type_tool or not data_write_tool:
logger.error('不存在数据存储工具', exc_info=True)
raise ValueError('不存在数据存储工具')
# ToolNode
write_node = ToolNode([data_write_tool])
async def call_model(state):
messages = state["messages"]
last_message = messages[-1]
result = await data_type_tool.ainvoke({
"context": last_message[1] if isinstance(last_message, tuple) else last_message.content
})
result=json.loads( result)
# 调用 Data_write传递 config_id
write_params = {
"content": result["context"],
"apply_id": apply_id,
"group_id": group_id,
"user_id": user_id
}
# 如果提供了 config_id添加到参数中
if config_id:
write_params["config_id"] = config_id
logger.debug(f"传递 config_id 到 Data_write: {config_id}")
write_result = await data_write_tool.ainvoke(write_params)
if isinstance(write_result, dict):
content = write_result.get("data", str(write_result))
else:
content = str(write_result)
logger.info("写入内容: %s", content)
return {"messages": [AIMessage(content=content)]}
workflow = StateGraph(WriteState)
workflow.add_node("content_input", call_model)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "content_input")
workflow.add_edge("content_input", "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()
yield graph

View File

@@ -0,0 +1,285 @@
"""
Log Streamer Module
Manages streaming of log file content with file watching and real-time transmission.
"""
import os
import re
import time
import asyncio
from typing import AsyncGenerator, Optional
from pathlib import Path
from app.core.logging_config import get_logger
logger = get_logger(__name__)
class LogStreamer:
"""Manages log file streaming with file watching and content transmission"""
def __init__(self, log_path: str, keepalive_interval: int = 300):
"""
Initialize LogStreamer
Args:
log_path: Path to the log file to stream
keepalive_interval: Interval in seconds for sending keepalive messages (default: 300)
"""
self.log_path = log_path
self.keepalive_interval = keepalive_interval
self.last_position = 0
# Pattern to match and remove timestamp and log level prefix
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
# This pattern is comprehensive to handle various log formats
self.pattern = re.compile(
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
)
logger.info(f"LogStreamer initialized for {log_path}")
@staticmethod
def clean_log_line(line: str) -> str:
"""
Static method to clean log entry by removing timestamp and log level prefix.
This is the canonical log cleaning method used by both file mode and transmission mode.
Args:
line: Raw log line
Returns:
Cleaned log line without timestamp and log level prefix
"""
# Pattern to match and remove timestamp and log level prefix
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
pattern = re.compile(
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
)
cleaned = re.sub(pattern, '', line)
return cleaned
def clean_log_entry(self, line: str) -> str:
"""
Clean log entry by removing timestamp and log level prefix.
This instance method delegates to the static method for consistency.
Args:
line: Raw log line
Returns:
Cleaned log line without timestamp and log level prefix
"""
return LogStreamer.clean_log_line(line)
async def send_keepalive(self) -> dict:
"""
Generate keepalive message
Returns:
Keepalive message dict with timestamp
"""
return {
"event": "keepalive",
"data": {
"timestamp": int(time.time())
}
}
async def read_existing_and_stream(self) -> AsyncGenerator[dict, None]:
"""
Read existing log content first, then watch for new content
This method reads all existing content in the file first,
then continues to watch for new content as it's written.
Yields:
Dict messages with event type and data:
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
- done events: {"event": "done", "data": {"message": "..."}}
"""
logger.info(f"Starting log stream (read existing) for {self.log_path}")
# Check if file exists
if not os.path.exists(self.log_path):
logger.error(f"Log file not found: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件不存在",
"error": f"File not found: {self.log_path}"
}
}
return
try:
with open(self.log_path, 'r', encoding='utf-8') as f:
# First, read all existing content
for line in f:
if line.strip(): # Skip empty lines
cleaned_line = self.clean_log_entry(line)
yield {
"event": "log",
"data": {
"content": cleaned_line.rstrip('\n'),
"timestamp": int(time.time())
}
}
# Now watch for new content
self.last_position = f.tell()
last_keepalive = time.time()
while True:
line = f.readline()
if line:
cleaned_line = self.clean_log_entry(line)
yield {
"event": "log",
"data": {
"content": cleaned_line.rstrip('\n'),
"timestamp": int(time.time())
}
}
last_keepalive = time.time()
else:
# No new content, check if we need to send keepalive
current_time = time.time()
if current_time - last_keepalive >= self.keepalive_interval:
keepalive_msg = await self.send_keepalive()
yield keepalive_msg
last_keepalive = current_time
# Sleep briefly before checking again
await asyncio.sleep(0.1)
except FileNotFoundError:
logger.error(f"Log file disappeared during streaming: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件在流式传输期间变得不可用",
"error": "File not found during streaming"
}
}
except Exception as e:
logger.error(f"Error during log streaming: {e}", exc_info=True)
yield {
"event": "error",
"data": {
"code": 8001,
"message": "流式传输期间发生错误",
"error": str(e)
}
}
finally:
logger.info(f"Log stream ended for {self.log_path}")
yield {
"event": "done",
"data": {
"message": "流式传输完成"
}
}
async def watch_and_stream(self) -> AsyncGenerator[dict, None]:
"""
Watch log file and stream only new content as it's written
This method starts from the end of the file and only streams
new content that is written after the stream starts.
Yields:
Dict messages with event type and data:
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
- done events: {"event": "done", "data": {"message": "..."}}
"""
logger.info(f"Starting log stream (new content only) for {self.log_path}")
# Check if file exists
if not os.path.exists(self.log_path):
logger.error(f"Log file not found: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件不存在",
"error": f"File not found: {self.log_path}"
}
}
return
try:
# Open file and seek to end to start streaming new content
with open(self.log_path, 'r', encoding='utf-8') as f:
# Move to end of file
f.seek(0, os.SEEK_END)
self.last_position = f.tell()
last_keepalive = time.time()
while True:
# Check if file has new content
current_position = f.tell()
# Read new lines if available
line = f.readline()
if line:
# Clean the log entry
cleaned_line = self.clean_log_entry(line)
# Yield log event
yield {
"event": "log",
"data": {
"content": cleaned_line.rstrip('\n'),
"timestamp": int(time.time())
}
}
# Update last keepalive time since we sent data
last_keepalive = time.time()
else:
# No new content, check if we need to send keepalive
current_time = time.time()
if current_time - last_keepalive >= self.keepalive_interval:
keepalive_msg = await self.send_keepalive()
yield keepalive_msg
last_keepalive = current_time
# Sleep briefly before checking again
await asyncio.sleep(0.1)
except FileNotFoundError:
logger.error(f"Log file disappeared during streaming: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件在流式传输期间变得不可用",
"error": "File not found during streaming"
}
}
except Exception as e:
logger.error(f"Error during log streaming: {e}", exc_info=True)
yield {
"event": "error",
"data": {
"code": 8001,
"message": "流式传输期间发生错误",
"error": str(e)
}
}
finally:
logger.info(f"Log stream ended for {self.log_path}")
yield {
"event": "done",
"data": {
"message": "流式传输完成"
}
}

View File

@@ -0,0 +1,32 @@
"""
Agent logger module for backward compatibility.
This module maintains the get_named_logger() function for backward compatibility
while delegating to the centralized logging configuration.
All new code should import directly from app.core.logging_config instead.
"""
__version__ = "0.1.0"
__author__ = "RED_BEAR"
from app.core.logging_config import get_agent_logger
def get_named_logger(name):
"""Get a named logger for agent operations.
This function maintains backward compatibility with existing code.
It delegates to the centralized get_agent_logger() function.
Args:
name: Logger name for namespacing
Returns:
Logger configured for agent operations
Example:
>>> logger = get_named_logger("my_agent")
>>> logger.info("Agent operation started")
"""
return get_agent_logger(name)

View File

@@ -0,0 +1,28 @@
"""
MCP Server package for memory agent.
This package provides the FastMCP server implementation with context-based
dependency injection for tool functions.
Package structure:
- server: FastMCP server initialization and context setup
- tools: MCP tool implementations
- models: Pydantic response models
- services: Business logic services
"""
from app.core.memory.agent.mcp_server.server import (
mcp,
initialize_context,
main,
get_context_resource
)
# Import tools to register them (but don't export them)
from app.core.memory.agent.mcp_server import tools
__all__ = [
'mcp',
'initialize_context',
'main',
'get_context_resource',
]

View File

@@ -0,0 +1,11 @@
"""
MCP Server Instance
This module contains the FastMCP server instance that is shared across all modules.
It's in a separate file to avoid circular import issues.
"""
from mcp.server.fastmcp import FastMCP
# Initialize FastMCP server instance
# This instance is shared across all tool modules
mcp = FastMCP('data_flow')

View File

@@ -0,0 +1,30 @@
"""Pydantic models for MCP server responses."""
from .problem_models import (
ProblemBreakdownItem,
ProblemBreakdownResponse,
ExtendedQuestionItem,
ProblemExtensionResponse,
)
from .summary_models import (
SummaryData,
SummaryResponse,
RetrieveSummaryData,
RetrieveSummaryResponse,
)
from .verification_models import VerificationResult
from .retrieval_models import RetrievalResult, DistinguishTypeResponse
__all__ = [
"ProblemBreakdownItem",
"ProblemBreakdownResponse",
"ExtendedQuestionItem",
"ProblemExtensionResponse",
"SummaryData",
"SummaryResponse",
"RetrieveSummaryData",
"RetrieveSummaryResponse",
"VerificationResult",
"RetrievalResult",
"DistinguishTypeResponse",
]

View File

@@ -0,0 +1,34 @@
"""Pydantic models for problem breakdown and extension operations."""
from typing import List, Optional
from pydantic import BaseModel, Field, RootModel
class ProblemBreakdownItem(BaseModel):
"""Individual item in problem breakdown response."""
id: str
question: str
type: str
reason: Optional[str] = None
class ProblemBreakdownResponse(RootModel[List[ProblemBreakdownItem]]):
"""Response model for problem breakdown containing list of breakdown items."""
pass
class ExtendedQuestionItem(BaseModel):
"""Individual extended question item with reasoning."""
original_question: str = Field(..., description="原始初步问题")
extended_question: str = Field(..., description="扩展后的问题")
type: str = Field(..., description="类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)")
reason: str = Field(..., description="生成该扩展问题的理由")
class ProblemExtensionResponse(RootModel[List[ExtendedQuestionItem]]):
"""Response model for problem extension containing list of extended questions."""
pass

View File

@@ -0,0 +1,17 @@
"""Pydantic models for retrieval operations."""
from typing import List, Dict, Any
from pydantic import BaseModel
class RetrievalResult(BaseModel):
"""Result model for retrieval operation."""
Query: str
Expansion_issue: List[Dict[str, Any]]
class DistinguishTypeResponse(BaseModel):
"""Response model for data type differentiation."""
type: str

View File

@@ -0,0 +1,31 @@
"""Pydantic models for summary operations."""
from typing import List
from pydantic import BaseModel, Field
class SummaryData(BaseModel):
"""Data structure for summary input."""
query: str
history: List[str] = Field(default_factory=list)
retrieve_info: List[str] = Field(default_factory=list)
class SummaryResponse(BaseModel):
"""Response model for summary operation."""
data: SummaryData
query_answer: str
class RetrieveSummaryData(BaseModel):
"""Data structure for retrieve summary response."""
query_answer: str = Field(default="")
class RetrieveSummaryResponse(BaseModel):
"""Response model for retrieve summary operation."""
data: RetrieveSummaryData

View File

@@ -0,0 +1,14 @@
"""Pydantic models for verification operations."""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class VerificationResult(BaseModel):
"""Result model for verification operation."""
query: str
expansion_issue: List[Dict[str, Any]]
split_result: str
reason: Optional[str] = None
history: List[Dict[str, Any]] = Field(default_factory=list)

View File

@@ -0,0 +1,161 @@
"""
MCP Server initialization with FastMCP context setup.
This module initializes the FastMCP server and registers shared resources
in the context for dependency injection into tool functions.
"""
import os
import sys
from mcp.server.fastmcp import FastMCP
from app.core.config import settings
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.redis_tool import RedisSessionStore, store
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID,reload_configuration_from_database
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
from app.core.memory.agent.mcp_server.services.search_service import SearchService
from app.core.memory.agent.mcp_server.services.session_service import SessionService
from app.core.memory.agent.mcp_server.mcp_instance import mcp
logger = get_agent_logger(__name__)
def get_context_resource(ctx, resource_name: str):
"""
Helper function to retrieve a resource from the FastMCP context.
Args:
ctx: FastMCP Context object (passed to tool functions)
resource_name: Name of the resource to retrieve
Returns:
The requested resource
Raises:
AttributeError: If the resource doesn't exist
Example:
@mcp.tool()
async def my_tool(ctx: Context):
template_service = get_context_resource(ctx, 'template_service')
llm_client = get_context_resource(ctx, 'llm_client')
"""
if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None:
raise RuntimeError("Context does not have fastmcp attribute")
if not hasattr(ctx.fastmcp, resource_name):
raise AttributeError(
f"Resource '{resource_name}' not found in context. "
f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}"
)
return getattr(ctx.fastmcp, resource_name)
def initialize_context():
"""
Initialize and register shared resources in FastMCP context.
This function sets up all shared resources that will be available
to tool functions via dependency injection through the context parameter.
Resources are stored as attributes on the FastMCP instance and can be
accessed via ctx.fastmcp in tool functions.
Resources registered:
- session_store: RedisSessionStore for session management
- llm_client: LLM client for structured API calls
- app_settings: Application settings (renamed to avoid conflict with FastMCP settings)
- template_service: Service for template rendering
- search_service: Service for hybrid search
- session_service: Service for session operations
"""
try:
# Register Redis session store
logger.info("Registering session_store in context")
mcp.session_store = store
# Register LLM client
try:
logger.info(f"Registering llm_client in context with model ID: {SELECTED_LLM_ID}")
llm_client = get_llm_client(SELECTED_LLM_ID)
mcp.llm_client = llm_client
logger.info("llm_client registered successfully")
except Exception as e:
logger.error(f"Failed to register llm_client: {e}", exc_info=True)
# 注册一个 None 值,避免工具调用时找不到资源
mcp.llm_client = None
logger.warning("llm_client set to None due to initialization failure")
# Register application settings (renamed to avoid conflict with FastMCP's settings)
logger.info("Registering app_settings in context")
mcp.app_settings = settings
# Register template service
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
# logger.info(f"Registering template_service in context with root: {template_root}")
template_service = TemplateService(template_root)
mcp.template_service = template_service
# Register search service
# logger.info("Registering search_service in context")
search_service = SearchService()
mcp.search_service = search_service
# Register session service
# logger.info("Registering session_service in context")
session_service = SessionService(store)
mcp.session_service = session_service
# logger.info("All context resources registered successfully")
except Exception as e:
logger.error(f"Failed to initialize context: {e}", exc_info=True)
raise
def main():
"""
Main entry point for the MCP server.
Initializes context and starts the server with SSE transport.
"""
try:
# logger.info("Starting MCP server initialization")
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
# Initialize context resources
initialize_context()
# Import and register tools
# logger.info("Importing MCP tools")
from app.core.memory.agent.mcp_server.tools import (
problem_tools,
retrieval_tools,
verification_tools,
summary_tools,
data_tools
)
# logger.info("All MCP tools imported and registered")
# Log registered tools for debugging
import asyncio
tools_list = asyncio.run(mcp.list_tools())
# logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}")
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:8081 with SSE transport")
# Run the server with SSE transport for HTTP connections
# The server will be available at http://127.0.0.1:8081
import uvicorn
app = mcp.sse_app()
uvicorn.run(app, host=settings.SERVER_IP, port=8081, log_level="info")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,23 @@
"""
MCP Server Services
This module provides business logic services for the MCP server:
- TemplateService: Template loading and rendering
- SearchService: Search result processing
- SessionService: Session and history management
- ParameterBuilder: Tool parameter construction
"""
from .template_service import TemplateService, TemplateRenderError
from .search_service import SearchService
from .session_service import SessionService
from .parameter_builder import ParameterBuilder
__all__ = [
"TemplateService",
"TemplateRenderError",
"SearchService",
"SessionService",
"ParameterBuilder",
]

View File

@@ -0,0 +1,157 @@
"""
Parameter Builder for constructing tool call arguments.
This service provides tool-specific parameter transformation logic
to build correct arguments for each tool type.
"""
import json
from typing import Any, Dict, Optional
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
class ParameterBuilder:
"""Service for building tool call arguments based on tool type."""
def __init__(self):
"""Initialize the parameter builder."""
logger.info("ParameterBuilder initialized")
def build_tool_args(
self,
tool_name: str,
content: Any,
tool_call_id: str,
search_switch: str,
apply_id: str,
group_id: str,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Build tool arguments based on tool type.
Different tools expect different argument formats:
- Verify: dict context
- Retrieve: dict context + search_switch
- Summary/Summary_fails: JSON string context
- Retrieve_Summary: unwrap nested context structures
- Input_Summary: raw message string
Args:
tool_name: Name of the tool being invoked
content: Parsed content from previous tool result
tool_call_id: Extracted tool call identifier
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
Returns:
Dictionary of tool arguments ready for invocation
"""
# Base arguments common to most tools
base_args = {
"usermessages": tool_call_id,
"apply_id": apply_id,
"group_id": group_id
}
# Always add storage_type and user_rag_memory_id (with defaults if None)
base_args["storage_type"] = storage_type if storage_type is not None else ""
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
# Tool-specific argument construction
if tool_name == "Verify":
# Verify expects dict context
return {
"context": content if isinstance(content, dict) else {},
**base_args
}
elif tool_name == "Retrieve":
# Retrieve expects dict context + search_switch
return {
"context": content if isinstance(content, dict) else {},
"search_switch": search_switch,
**base_args
}
elif tool_name in ["Summary", "Summary_fails"]:
# Summary tools expect JSON string context
if isinstance(content, dict):
context_str = json.dumps(content, ensure_ascii=False)
elif isinstance(content, str):
context_str = content
else:
context_str = json.dumps({"data": content}, ensure_ascii=False)
return {
"context": context_str,
**base_args
}
elif tool_name == "Retrieve_Summary":
# Retrieve_Summary needs to unwrap nested context structures
# Handle both 'content' and 'context' keys
context_dict = content
if isinstance(content, dict):
# Check for nested 'content' wrapper
if "content" in content:
inner = content["content"]
# If it's a JSON string, parse it
if isinstance(inner, str):
try:
parsed = json.loads(inner)
# Check if parsed has 'context' wrapper
if isinstance(parsed, dict) and "context" in parsed:
context_dict = parsed["context"]
else:
context_dict = parsed
except json.JSONDecodeError:
logger.warning(
f"Failed to parse JSON content for {tool_name}: {inner[:100]}"
)
context_dict = {"Query": "", "Expansion_issue": []}
elif isinstance(inner, dict):
context_dict = inner
# Check for 'context' wrapper
elif "context" in content:
context_dict = content["context"] if isinstance(content["context"], dict) else content
return {
"context": context_dict,
**base_args
}
elif tool_name == "Input_Summary":
# Input_Summary expects raw message string + search_switch
# Content should be the raw message string
if isinstance(content, dict):
# Try to extract message from dict
message_str = content.get("sentence", str(content))
else:
message_str = str(content)
return {
"context": message_str,
"search_switch": search_switch,
**base_args
}
else:
# Default: pass content as context
logger.warning(
f"Unknown tool name '{tool_name}', using default argument structure"
)
return {
"context": content,
**base_args
}

View File

@@ -0,0 +1,193 @@
"""
Search Service for executing hybrid search and processing results.
This service provides clean search result processing with content extraction
and deduplication.
"""
from typing import List, Tuple, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__)
class SearchService:
"""Service for executing hybrid search and processing results."""
def __init__(self):
"""Initialize the search service."""
logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict) -> str:
"""
Extract only meaningful content from search results, dropping all metadata.
Extraction rules by node type:
- Statements: extract 'statement' field
- Entities: extract 'name' and 'fact_summary' fields
- Summaries: extract 'content' field
- Chunks: extract 'content' field
Args:
result: Search result dictionary
Returns:
Clean content string without metadata
"""
if not isinstance(result, dict):
return str(result)
content_parts = []
# Statements: extract statement field
if 'statement' in result and result['statement']:
content_parts.append(result['statement'])
# Summaries/Chunks: extract content field
if 'content' in result and result['content']:
content_parts.append(result['content'])
# Entities: extract name and fact_summary (commented out in original)
# if 'name' in result and result['name']:
# content_parts.append(result['name'])
# if result.get('fact_summary'):
# content_parts.append(result['fact_summary'])
# Return concatenated content or empty string
return '\n'.join(content_parts) if content_parts else ""
def clean_query(self, query: str) -> str:
"""
Clean and escape query text for Lucene.
- Removes wrapping quotes
- Removes newlines and carriage returns
- Applies Lucene escaping
Args:
query: Raw query string
Returns:
Cleaned and escaped query string
"""
q = str(query).strip()
# Remove wrapping quotes
if (q.startswith("'") and q.endswith("'")) or (
q.startswith('"') and q.endswith('"')
):
q = q[1:-1]
# Remove newlines and carriage returns
q = q.replace('\r', ' ').replace('\n', ' ').strip()
# Apply Lucene escaping
q = escape_lucene_query(q)
return q
async def execute_hybrid_search(
self,
group_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search and return clean content.
Args:
group_id: Group identifier for filtering results
question: Search query text
limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False)
Returns:
Tuple of (clean_content, cleaned_query, raw_results)
raw_results is None if return_raw_results=False
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Clean query
cleaned_query = self.clean_query(question)
try:
# Execute search
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
group_id=group_id,
limit=limit,
include=include,
output_path=output_path,
rerank_alpha=rerank_alpha
)
# Extract results based on search type and include parameter
# Prioritize summaries as they contain synthesized contextual information
answer_list = []
# For hybrid search, use reranked_results
if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then statements, chunks, entities
priority_order = ['summaries', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in reranked_results:
category_results = reranked_results[category]
if isinstance(category_results, list):
answer_list.extend(category_results)
else:
# For keyword or embedding search, results are directly in answer dict
# Apply same priority order
priority_order = ['summaries', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in answer:
category_results = answer[category]
if isinstance(category_results, list):
answer_list.extend(category_results)
# Extract clean content from all results
content_list = [
self.extract_content_from_result(ans)
for ans in answer_list
]
# Filter out empty strings and join with newlines
clean_content = '\n'.join([c for c in content_list if c])
# Log first 200 chars
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
# Return raw results if requested
if return_raw_results:
return clean_content, cleaned_query, answer
else:
return clean_content, cleaned_query, None
except Exception as e:
logger.error(
f"Search failed for query '{question}' in group '{group_id}': {e}",
exc_info=True
)
# Return empty results on failure
if return_raw_results:
return "", cleaned_query, {}
else:
return "", cleaned_query, None

View File

@@ -0,0 +1,169 @@
"""
Session Service for managing user sessions and conversation history.
This service provides clean Redis interactions with error handling and
session management utilities.
"""
from typing import List, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.redis_tool import RedisSessionStore
logger = get_agent_logger(__name__)
class SessionService:
"""Service for managing user sessions and conversation history."""
def __init__(self, store: RedisSessionStore):
"""
Initialize the session service.
Args:
store: Redis session store instance
"""
self.store = store
logger.info("SessionService initialized")
def resolve_user_id(self, session_string: str) -> str:
"""
Extract user ID from session string.
Handles formats like:
- 'call_id_user123' -> 'user123'
- 'prefix_id_user456_suffix' -> 'user456_suffix'
Args:
session_string: Session identifier string
Returns:
Extracted user ID
"""
try:
# Split by '_id_' and take everything after it
parts = session_string.split('_id_')
if len(parts) > 1:
return parts[1]
# Fallback: return original string
return session_string
except Exception as e:
logger.warning(
f"Failed to parse user ID from session string '{session_string}': {e}"
)
return session_string
async def get_history(
self,
user_id: str,
apply_id: str,
group_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
)
return []
return history
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
return []
async def save_session(
self,
user_id: str,
query: str,
apply_id: str,
group_id: str,
ai_response: str
) -> Optional[str]:
"""
Save conversation turn to Redis.
Args:
user_id: User identifier
query: User query/message
apply_id: Application identifier
group_id: Group identifier
ai_response: AI response/answer
Returns:
Session ID if successful, None on error
"""
try:
# Validate required fields
if not user_id:
logger.warning("Cannot save session: user_id is empty")
return None
if not query:
logger.warning("Cannot save session: query is empty")
return None
# Save session
session_id = self.store.save_session(
userid=user_id,
messages=query,
apply_id=apply_id,
group_id=group_id,
aimessages=ai_response
)
logger.info(f"Session saved successfully: {session_id}")
return session_id
except Exception as e:
logger.error(
f"Failed to save session for user {user_id}: {e}",
exc_info=True
)
return None
async def cleanup_duplicates(self) -> int:
"""
Remove duplicate session entries.
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- group_id
- messages
- aimessages
Returns:
Number of duplicate sessions deleted
"""
try:
deleted_count = self.store.delete_duplicate_sessions()
logger.info(f"Cleaned up {deleted_count} duplicate sessions")
return deleted_count
except Exception as e:
logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True)
return 0

View File

@@ -0,0 +1,116 @@
"""
Template Service for loading and rendering Jinja2 templates.
This service provides centralized template management with caching and error handling.
"""
import os
from functools import lru_cache
from typing import Optional
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
from app.core.logging_config import get_agent_logger, log_prompt_rendering
logger = get_agent_logger(__name__)
class TemplateRenderError(Exception):
"""Exception raised when template rendering fails."""
def __init__(self, template_name: str, error: Exception, variables: dict):
self.template_name = template_name
self.error = error
self.variables = variables
super().__init__(
f"Failed to render template '{template_name}': {str(error)}"
)
class TemplateService:
"""Service for loading and rendering Jinja2 templates with caching."""
def __init__(self, template_root: str):
"""
Initialize the template service.
Args:
template_root: Root directory containing template files
"""
self.template_root = template_root
self.env = Environment(
loader=FileSystemLoader(template_root),
autoescape=False # Disable autoescape for prompt templates
)
logger.info(f"TemplateService initialized with root: {template_root}")
@lru_cache(maxsize=128)
def _load_template(self, template_name: str) -> Template:
"""
Load a template from disk with caching.
Args:
template_name: Relative path to template file
Returns:
Loaded Jinja2 Template object
Raises:
TemplateNotFound: If template file doesn't exist
"""
try:
return self.env.get_template(template_name)
except TemplateNotFound as e:
expected_path = os.path.join(self.template_root, template_name)
logger.error(
f"Template not found: {template_name}. "
f"Expected path: {expected_path}"
)
raise
async def render_template(
self,
template_name: str,
operation_name: str,
**variables
) -> str:
"""
Load and render a Jinja2 template.
Args:
template_name: Relative path to template file
operation_name: Name for logging (e.g., "split_the_problem")
**variables: Template variables to render
Returns:
Rendered template string
Raises:
TemplateRenderError: If template loading or rendering fails
"""
try:
# Load template (cached)
template = self._load_template(template_name)
# Render template
rendered = template.render(**variables)
# Log rendered prompt
log_prompt_rendering(operation_name, rendered)
return rendered
except TemplateNotFound as e:
logger.error(
f"Template rendering failed for {operation_name} "
f"({template_name}): Template not found",
exc_info=True
)
raise TemplateRenderError(template_name, e, variables)
except Exception as e:
logger.error(
f"Template rendering failed for {operation_name} "
f"({template_name}): {e}",
exc_info=True
)
raise TemplateRenderError(template_name, e, variables)

View File

@@ -0,0 +1,27 @@
"""
MCP Tools module.
This module contains all MCP tool implementations organized by functionality.
Tools are organized into the following modules:
- problem_tools: Question segmentation and extension
- retrieval_tools: Database and context retrieval
- verification_tools: Data verification
- summary_tools: Summarization and summary retrieval
- data_tools: Data type differentiation and writing
"""
# Import all tool modules to register them with the MCP server
from . import problem_tools
from . import retrieval_tools
from . import verification_tools
from . import summary_tools
from . import data_tools
__all__ = [
'problem_tools',
'retrieval_tools',
'verification_tools',
'summary_tools',
'data_tools',
]

View File

@@ -0,0 +1,149 @@
"""
Data Tools for data type differentiation and writing.
This module contains MCP tools for distinguishing data types and writing data.
"""
import os
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.retrieval_models import DistinguishTypeResponse
from app.core.memory.agent.utils.write_tools import write
logger = get_agent_logger(__name__)
@mcp.tool()
async def Data_type_differentiation(
ctx: Context,
context: str
) -> dict:
"""
Distinguish the type of data (read or write).
Args:
ctx: FastMCP context for dependency injection
context: Text to analyze for type differentiation
Returns:
dict: Contains 'context' with the original text and 'type' field
"""
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Render template
try:
system_prompt = await template_service.render_template(
template_name='distinguish_types_prompt.jinja2',
operation_name='status_typle',
user_query=context
)
except Exception as e:
logger.error(
f"Template rendering failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"type": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=DistinguishTypeResponse
)
result = structured.model_dump()
# Add context to result
result["context"] = context
return result
except Exception as e:
logger.error(
f"LLM call failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": f"LLM call failed: {str(e)}"
}
except Exception as e:
logger.error(
f"Data_type_differentiation failed: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": str(e)
}
@mcp.tool()
async def Data_write(
ctx: Context,
content: str,
user_id: str,
apply_id: str,
group_id: str,
config_id: str
) -> dict:
"""
Write data to the database/file system.
Args:
ctx: FastMCP context for dependency injection
content: Data content to write
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
config_id: Configuration ID for processing (optional, integer)
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
"""
try:
# Ensure output directory exists
os.makedirs("data_output", exist_ok=True)
file_path = os.path.join("data_output", "user_data.csv")
# Write data using utility function
try:
await write(content, user_id, apply_id, group_id, config_id=config_id)
logger.info(f"写入成功Config ID: {config_id if config_id else 'None'}")
return {
"status": "success",
"saved_to": file_path,
"data": content,
"config_id": config_id
}
except Exception as e:
logger.error(f"写入失败: {e}", exc_info=True)
return {
"status": "error",
"message": str(e)
}
except Exception as e:
logger.error(
f"Data_write failed: {e}",
exc_info=True
)
return {
"status": "error",
"message": str(e)
}

View File

@@ -0,0 +1,293 @@
"""
Problem Tools for question segmentation and extension.
This module contains MCP tools for breaking down and extending user questions.
"""
import json
import time
from typing import List
from pydantic import BaseModel, Field, RootModel
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.problem_models import (
ProblemBreakdownItem,
ProblemBreakdownResponse,
ExtendedQuestionItem,
ProblemExtensionResponse
)
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
logger = get_agent_logger(__name__)
@mcp.tool()
async def Split_The_Problem(
ctx: Context,
sentence: str,
sessionid: str,
messages_id: str,
apply_id: str,
group_id: str
) -> dict:
"""
Segment the dialogue or sentence into sub-problems.
Args:
ctx: FastMCP context for dependency injection
sentence: Original sentence to split
sessionid: Session identifier
messages_id: Message identifier
apply_id: Application identifier
group_id: Group identifier
Returns:
dict: Contains 'context' (JSON string of split results) and 'original' sentence
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Extract user ID from session
user_id = session_service.resolve_user_id(sessionid)
# Get conversation history
history = await session_service.get_history(user_id, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Render template
try:
system_prompt = await template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
history=history,
sentence=sentence
)
except Exception as e:
logger.error(
f"Template rendering failed for Split_The_Problem: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemBreakdownResponse
)
# Handle RootModel response with .root attribute access
if structured is None:
# LLM returned None, use empty list as fallback
split_result = json.dumps([], ensure_ascii=False)
elif hasattr(structured, 'root') and structured.root is not None:
split_result = json.dumps(
[item.model_dump() for item in structured.root],
ensure_ascii=False
)
elif isinstance(structured, list):
# Fallback: treat structured itself as the list
split_result = json.dumps(
[item.model_dump() for item in structured],
ensure_ascii=False
)
else:
# Last resort: use empty list
split_result = json.dumps([], ensure_ascii=False)
except Exception as e:
logger.error(
f"LLM call failed for Split_The_Problem: {e}",
exc_info=True
)
split_result = json.dumps([], ensure_ascii=False)
logger.info(f"问题拆分")
logger.info(f"问题拆分结果==>>:{split_result}")
# Emit intermediate output for frontend
result = {
"context": split_result,
"original": sentence,
"_intermediate": {
"type": "problem_split",
"data": json.loads(split_result) if split_result else [],
"original_query": sentence
}
}
return result
except Exception as e:
logger.error(
f"Split_The_Problem failed: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('问题拆分', duration)
@mcp.tool()
async def Problem_Extension(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Extend the problem with additional sub-questions.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing split problem results
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'context' (aggregated questions) and 'original' question
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Resolve session ID from usermessages
from app.core.memory.agent.utils.messages_tool import Resolve_username
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Process context to extract questions
extent_quest, original = await Problem_Extension_messages_deal(context)
# Format questions for template rendering
questions_formatted = []
for msg in extent_quest:
if msg.get("role") == "user":
questions_formatted.append(msg.get("content", ""))
# Render template
try:
system_prompt = await template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
history=history,
questions=questions_formatted
)
except Exception as e:
logger.error(
f"Template rendering failed for Problem_Extension: {e}",
exc_info=True
)
return {
"context": {},
"original": original,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
response_content = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemExtensionResponse
)
# Aggregate results by original question
aggregated_dict = {}
for item in response_content.root:
key = getattr(item, "original_question", None) or (
item.get("original_question") if isinstance(item, dict) else None
)
value = getattr(item, "extended_question", None) or (
item.get("extended_question") if isinstance(item, dict) else None
)
if not key or not value:
continue
aggregated_dict.setdefault(key, []).append(value)
except Exception as e:
logger.error(
f"LLM call failed for Problem_Extension: {e}",
exc_info=True
)
aggregated_dict = {}
logger.info(f"问题扩展")
logger.info(f"问题扩展==>>:{aggregated_dict}")
# Emit intermediate output for frontend
result = {
"context": aggregated_dict,
"original": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "problem_extension",
"data": aggregated_dict,
"original_query": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return result
except Exception as e:
logger.error(
f"Problem_Extension failed: {e}",
exc_info=True
)
return {
"context": {},
"original": context.get("original", ""),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('问题扩展', duration)

View File

@@ -0,0 +1,282 @@
"""
Retrieval Tools for database and context retrieval.
This module contains MCP tools for retrieving data using hybrid search.
"""
from dotenv import load_dotenv
import os
from app.core.rag.nlp.search import knowledge_retrieval
# 加载.env文件
load_dotenv()
import time
from typing import List
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.llm_tools import deduplicate_entries, merge_to_key_value_pairs
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
logger = get_agent_logger(__name__)
@mcp.tool()
async def Retrieve(
ctx: Context,
context,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Retrieve data from the database using hybrid search.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary or string containing query information
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'context' with Query and Expansion_issue results
"""
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
start = time.time()
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
# Extract services from context
search_service = get_context_resource(ctx, 'search_service')
databases_anser = []
# Handle both dict and string context
if isinstance(context, dict):
# Process dict context with extended questions
all_items = []
content, original = await Retriev_messages_deal(context)
# Extract all query items from content
# content is like {original_question: [extended_questions...], ...}
for key, values in content.items():
if isinstance(values, list):
all_items.extend(values)
elif isinstance(values, str):
all_items.append(values)
elif values is not None:
# Fallback: convert non-empty non-list values to string
all_items.append(str(values))
# Execute search for each question
for idx, question in enumerate(all_items):
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": question,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query=question
raw_results=clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results=''
cleaned_query = question
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
databases_anser.append({
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(all_items)
}
})
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Continue with empty result for this question
databases_anser.append({
"Query_small": question,
"Result_small": ""
})
# Build initial database data structure
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val):
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
else:
# Handle string context (simple query)
query = str(context).strip()
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = query
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results = ''
cleaned_query = query
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
# Keep structure for Verify/Retrieve_Summary compatibility
dup_databases = {
"Query": cleaned_query,
"Expansion_issue": [{
"Query_small": cleaned_query,
"Answer_Small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": 1,
"total": 1
}
}]
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for query '{query}': {e}",
exc_info=True
)
# Return empty results on failure
dup_databases = {
"Query": query,
"Expansion_issue": []
}
logger.info(
f"检索==>>:{storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
)
# Build result with intermediate outputs
result = {
"context": dup_databases,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
# Add intermediate outputs list if they exist
intermediate_outputs = dup_databases.get('_intermediate_outputs', [])
if intermediate_outputs:
result['_intermediates'] = intermediate_outputs
logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result")
else:
logger.warning("No intermediate outputs found in dup_databases")
return result
except Exception as e:
logger.error(
f"Retrieve failed: {e}",
exc_info=True
)
return {
"context": {
"Query": "",
"Expansion_issue": []
},
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)

View File

@@ -0,0 +1,647 @@
"""
Summary Tools for data summarization.
This module contains MCP tools for summarizing retrieved data and generating responses.
"""
import json
import re
import time
from typing import List
from pydantic import BaseModel, Field
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.summary_models import (
SummaryData,
SummaryResponse,
RetrieveSummaryData,
RetrieveSummaryResponse
)
from app.core.memory.agent.utils.messages_tool import (
Summary_messages_deal,
Resolve_username
)
from app.core.rag.nlp.search import knowledge_retrieval
from dotenv import load_dotenv
import os
# 加载.env文件
load_dotenv()
logger = get_agent_logger(__name__)
@mcp.tool()
async def Summary(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Summarize the verified data.
Args:
ctx: FastMCP context for dependency injection
context: JSON string containing verified data
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Process context to extract answer and query
answer_small, query = await Summary_messages_deal(context)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
# Prepare data for template
data = {
"query": query,
"history": history,
"retrieve_info": answer_small
}
except Exception as e:
logger.error(
f"Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='summary_prompt.jinja2',
operation_name='summary',
data=data,
query=query
)
except Exception as e:
logger.error(
f"Template rendering failed for Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=SummaryResponse
)
aimessages = structured.query_answer or ""
except Exception as e:
logger.error(
f"LLM call failed for Summary: {e}",
exc_info=True
)
aimessages = ""
try:
# Save session
if aimessages != "":
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}",
exc_info=True
)
return {
"status": "error",
"message": str(e)
}
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"验证之后的总结==>>:{aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('总结', duration)
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
@mcp.tool()
async def Retrieve_Summary(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Summarize data directly from retrieval results.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing Query and Expansion_issue from Retrieve
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
if isinstance(context, dict):
if "content" in context:
inner = context["content"]
# If it's a JSON string, parse it
if isinstance(inner, str):
try:
parsed = json.loads(inner)
logger.info(f"Retrieve_Summary: successfully parsed JSON")
except json.JSONDecodeError:
# Try unescaping first
try:
unescaped = inner.encode('utf-8').decode('unicode_escape')
parsed = json.loads(unescaped)
logger.info(f"Retrieve_Summary: parsed after unescaping")
except (json.JSONDecodeError, UnicodeDecodeError) as e:
logger.error(
f"Retrieve_Summary: parsing failed even after unescape: {e}"
)
context_dict = {"Query": "", "Expansion_issue": []}
parsed = None
if parsed:
# Check if parsed has 'context' wrapper
if isinstance(parsed, dict) and "context" in parsed:
context_dict = parsed["context"]
else:
context_dict = parsed
elif isinstance(inner, dict):
context_dict = inner
else:
context_dict = {"Query": "", "Expansion_issue": []}
elif "context" in context:
context_dict = context["context"] if isinstance(context["context"], dict) else context
else:
context_dict = context
else:
context_dict = {"Query": "", "Expansion_issue": []}
query = context_dict.get("Query", "")
expansion_issue = context_dict.get("Expansion_issue", [])
# Extract retrieve_info from expansion_issue
retrieve_info = []
for item in expansion_issue:
# Check for both Answer_Small and Answer_Samll (typo) for backward compatibility
answer = None
if isinstance(item, dict):
if "Answer_Small" in item:
answer = item["Answer_Small"]
elif "Answer_Samll" in item:
answer = item["Answer_Samll"]
if answer is not None:
# Handle both string and list formats
if isinstance(answer, list):
# Join list of characters/strings into a single string
retrieve_info.append(''.join(str(x) for x in answer))
elif isinstance(answer, str):
retrieve_info.append(answer)
else:
retrieve_info.append(str(answer))
# Join all retrieve_info into a single string
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
except Exception as e:
logger.error(
f"Retrieve_Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='retrieve_summary',
query=query,
history=history,
retrieve_info=retrieve_info_str
)
except Exception as e:
logger.error(
f"Template rendering failed for Retrieve_Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
# Handle case where structured response might be None or incomplete
if structured and hasattr(structured, 'data') and structured.data:
aimessages = structured.data.query_answer or ""
else:
logger.warning("Structured response is None or incomplete, using default message")
aimessages = "信息不足,无法回答"
# Check for insufficient information response
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
# Save session
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"Retrieve_Summary: LLM call failed: {e}",
exc_info=True
)
aimessages = ""
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"检索之后的总结==>>:{aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索总结', duration)
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "retrieval_summary",
"summary": aimessages,
"query": query,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
@mcp.tool()
async def Input_Summary(
ctx: Context,
context: str,
usermessages: str,
search_switch: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Generate a quick summary for direct input without verification.
Args:
ctx: FastMCP context for dependency injection
context: String containing the input sentence
usermessages: User messages identifier
search_switch: Search switch value for routing ('2' for summaries only)
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'query_answer' with the summary result
"""
start = time.time()
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# Initialize variables to avoid UnboundLocalError
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
search_service = get_context_resource(ctx, 'search_service')
# Check if llm_client is None
if llm_client is None:
error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable."
logger.error(error_msg)
return error_msg
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
sessionid = sessionid.replace('call_id_', '')
# Get conversation history
history = await session_service.get_history(
str(sessionid),
str(apply_id),
str(group_id)
)
# Override with empty list for now (as in original)
# Log the raw context for debugging
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
# Extract sentence from context
# Context can be a string or might contain the sentence in various formats
try:
# Try to parse as JSON first
if isinstance(context, str) and (context.startswith('{') or context.startswith('[')):
try:
import json
context_dict = json.loads(context)
if isinstance(context_dict, dict):
query = context_dict.get('sentence', context_dict.get('content', context))
else:
query = context
except json.JSONDecodeError:
# Not valid JSON, try regex
match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context)
query = match.group(1) if match else context
else:
query = context
except Exception as e:
logger.warning(f"Failed to extract query from context: {e}")
query = context
# Clean query
query = str(query).strip().strip("\"'")
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
# Execute search based on search_switch and storage_type
try:
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
'''检索'''
if search_switch == '2':
search_params["include"] = ["summaries"]
if storage_type == "rag" and user_rag_memory_id:
raw_results = []
retrieve_info = ""
kb_config={
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id":os.getenv('reranker_id'),
"reranker_top_k": 10
}
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
retrieve_info = '\n\n'.join(retrieval_knowledge)
raw_results=[retrieve_info]
logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}")
except:
retrieve_info=''
raw_results=['']
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
logger.info(f"Input_Summary: 使用 summary 进行检索")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
except Exception as e:
logger.error(
f"Input_Summary: hybrid_search failed, using empty results: {e}",
exc_info=True
)
retrieve_info, question, raw_results = "", query, []
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='input_summary',
query=query,
history=history,
retrieve_info=retrieve_info
)
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
aimessages = structured.data.query_answer or "信息不足,无法回答"
except Exception as e:
logger.error(
f"Input_Summary: response_structured failed, using default answer: {e}",
exc_info=True
)
aimessages = "信息不足,无法回答"
logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "input_summary",
"title": "快速答案",
"summary": aimessages,
"query": query,
"raw_results": raw_results,
"search_mode": "quick_search",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Input_Summary failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "信息不足,无法回答",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)
@mcp.tool()
async def Summary_fails(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Handle workflow failure when summary cannot be generated.
Args:
ctx: FastMCP context for dependency injection
context: Failure context string
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'query_answer' with failure message
"""
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Parse session ID from usermessages
usermessages_parts = usermessages.split('_')[1:]
sessionid = '_'.join(usermessages_parts[:-1])
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
logger.info(f"没有相关数据")
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
return {
"status": "success",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
except Exception as e:
logger.error(
f"Summary_fails failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}

View File

@@ -0,0 +1,169 @@
"""
Verification Tools for data verification.
This module contains MCP tools for verifying retrieved data.
"""
import time
from jinja2 import Template
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.verify_tool import VerifyTool
from app.core.memory.agent.utils.messages_tool import (
Verify_messages_deal,
Retrieve_verify_tool_messages_deal,
Resolve_username
)
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
logger = get_agent_logger(__name__)
@mcp.tool()
async def Verify(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Verify the retrieved data.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing query and expansion issues
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'verified_data' with verification results
"""
start = time.time()
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Load verification prompt template
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2'
# Read template file directly (VerifyTool expects raw template content)
from app.core.memory.agent.utils.messages_tool import read_template_file
system_prompt = await read_template_file(file_path)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
template = Template(system_prompt)
system_prompt = template.render(history=history, sentence=context)
# Process context to extract query and results
Query_small, Result_small, query = await Verify_messages_deal(context)
# Build query list for verification
query_list = []
for query_small, anser in zip(Query_small, Result_small):
query_list.append({
'Query_small': query_small,
'Answer_Small': anser
})
messages = {
"Query": query,
"Expansion_issue": query_list
}
# Call verification workflow
verify_tool = VerifyTool(system_prompt, messages)
verify_result = await verify_tool.verify()
# Parse LLM verification result with error handling
try:
messages_deal = await Retrieve_verify_tool_messages_deal(
verify_result,
history,
query
)
except Exception as e:
logger.error(
f"Retrieve_verify_tool_messages_deal parsing failed: {e}",
exc_info=True
)
# Fallback to avoid 500 errors
messages_deal = {
"data": {
"query": query,
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": history,
}
logger.info(f"验证==>>:{messages_deal}")
# Emit intermediate output for frontend
return {
"status": "success",
"verified_data": messages_deal,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "数据验证",
"result": messages_deal.get("split_result", "unknown"),
"reason": messages_deal.get("reason", ""),
"query": query,
"verified_count": len(query_list),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Verify failed: {e}",
exc_info=True
)
return {
"status": "error",
"message": str(e),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"verified_data": {
"data": {
"query": "",
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": [],
}
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('验证', duration)

View File

@@ -0,0 +1,7 @@
"""Agent utilities."""
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
__all__ = [
"MultimodalProcessor",
]

View File

@@ -0,0 +1,70 @@
import os
import json
from typing import List
from datetime import datetime
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1",
user_id: str = "user1",
apply_id: str = "applyid",
content: str = "这是用户的输入",
ref_id: str = "wyl_20251027",
config_id: str = None
) -> List[DialogData]:
"""Generate chunks from all test data entries using the specified chunker strategy.
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier
user_id: User identifier
apply_id: Application identifier
content: Dialog content
ref_id: Reference identifier
config_id: Configuration ID for processing
Returns:
List of DialogData objects with generated chunks for each test entry
"""
dialog_data_list = []
messages = []
messages.append(ConversationMessage(role="用户", msg=content))
# Create DialogData
conversation_context = ConversationContext(msgs=messages)
# Create DialogData with group_id based on the entry's id for uniqueness
dialog_data = DialogData(
context=conversation_context,
ref_id=ref_id,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
config_id=config_id
)
# Create DialogueChunker and process the dialogue
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks
dialog_data_list.append(dialog_data)
# Convert to dict with datetime serialized
def serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
combined_output = [dd.model_dump() for dd in dialog_data_list]
print(dialog_data_list)
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
return dialog_data_list

View File

@@ -0,0 +1,204 @@
import asyncio
import json
from collections import defaultdict
from typing import TypedDict, Annotated
import os
import logging
from jinja2 import Template
from langchain_core.messages import AnyMessage
from dotenv import load_dotenv
from langgraph.graph import add_messages
from openai import OpenAI
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.config.config_utils import get_picture_config, get_voice_config
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
from app.core.models.base import RedBearModelConfig
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logger = logging.getLogger(__name__)
load_dotenv()
#TODO: Refactor entire picture/voice
# async def LLM_model_request(context,data,query):
# '''
# Agent model request
# Args:
# context:Input request
# data: template parameters
# query:request content
# Returns:
# '''
# template = Template(context)
# system_prompt = template.render(**data)
# llm_client = get_llm_client(SELECTED_LLM_ID)
# result = await llm_client.chat(
# messages=[{"role": "system", "content": system_prompt}] + [{"role": "user", "content": query}]
# )
# return result
async def picture_model_requests(image_url):
'''
Args:
image_url:
Returns:
'''
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 '
system_prompt = await read_template_file(file_path)
result = await Picture_recognize(image_url,system_prompt)
return (result)
class WriteState(TypedDict):
'''
Langgrapg Writing TypedDict
'''
messages: Annotated[list[AnyMessage], add_messages]
user_id:str
apply_id:str
group_id:str
class ReadState(TypedDict):
'''
Langgrapg READING TypedDict
name:
id:user id
loop_count:Traverse times
search_switchtype
config_id: configuration id for filtering results
'''
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
name: str
id: str
loop_count:int
search_switch: str
user_id: str
apply_id: str
group_id: str
config_id: str
class COUNTState:
'''
The number of times the workflow dialogue retrieval content has no correct message recall traversal
'''
def __init__(self, limit: int = 5):
self.total: int = 0 # 当前累加值
self.limit: int = limit # 最大上限
def add(self, value: int = 1):
"""累加数字,如果达到上限就保持最大值"""
self.total += value
print(f"[COUNTState] 当前值: {self.total}")
if self.total >= self.limit:
print(f"[COUNTState] 达到上限 {self.limit}")
self.total = self.limit # 达到上限不再增加
def get_total(self) -> int:
"""获取当前累加值"""
return self.total
def reset(self):
"""手动重置累加值"""
self.total = 0
print(f"[COUNTState] 已重置为 0")
# def embed(texts: list[str]) -> list[list[float]]:
# # 这里可以换成 LangChain Embeddings
# return [[float(len(t) % 5), float(len(t) % 3)] for t in texts]
# def export_store_to_json(store, namespace):
# """Export the entire storage content to a JSON file"""
# # 搜索所有存储项
# all_items = store.search(namespace)
# # 整理数据
# export_data = {}
# for item in all_items:
# if hasattr(item, 'key') and hasattr(item, 'value'):
# export_data[item.key] = item.value
# # 保存到文件
# os.makedirs("memory_logs", exist_ok=True)
# with open("memory_logs/full_memory_export.json", "w", encoding="utf-8") as f:
# json.dump(export_data, f, ensure_ascii=False, indent=2)
# print(f"{len(export_data)} 条记忆到 JSON 文件")
def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list)
for item in data:
grouped[item[query_key]].append(item[result_key])
return [{key: values} for key, values in grouped.items()]
def deduplicate_entries(entries):
seen = set()
deduped = []
for entry in entries:
key = (entry['Query_small'], entry['Result_small'])
if key not in seen:
seen.add(key)
deduped.append(entry)
return deduped
async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
try:
model_config = get_picture_config(SELECTED_LLM_PICTURE_NAME)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
return err
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
backend_model_name = model_config["llm_name"].split("/")[-1]
api_base=model_config['api_base']
logger.info(f"model_name: {backend_model_name}")
logger.info(f"api_key set: {'yes' if api_key else 'no'}")
logger.info(f"base_url: {model_config['api_base']}")
client = OpenAI(
api_key=api_key, base_url=api_base,
)
completion = client.chat.completions.create(
model=backend_model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url":image_path,
},
{"type": "text",
"text": PROMPT_TICKET_EXTRACTION}
]
}
])
picture_text = completion.choices[0].message.content
picture_text = picture_text.replace('```json', '').replace('```', '')
picture_text = json.loads(picture_text)
return (picture_text['statement'])
async def Voice_recognize():
try:
model_config = get_voice_config(SELECTED_LLM_VOICE_NAME)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
return err
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
backend_model_name = model_config["llm_name"].split("/")[-1]
api_base = model_config['api_base']
return api_key,backend_model_name,api_base

View File

@@ -0,0 +1,15 @@
from app.core.config import settings
def get_mcp_server_config():
"""
Get the MCP server configuration
"""
mcp_server_config = {
"data_flow": {
"url": f"http://{settings.SERVER_IP}:8081/sse", # 你前面的 FastMCP(weather) 服务端口
"transport": "sse",
"timeout": 15000,
"sse_read_timeout": 15000,
}
}
return mcp_server_config

View File

@@ -0,0 +1,239 @@
import json
import logging
import re
from typing import List, Any
from langchain_core.messages import AnyMessage
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]:
out = []
for m in msgs:
if hasattr(m, "content"):
out.append({"role": "user", "content": getattr(m, "content", "")})
elif isinstance(m, dict) and "role" in m and "content" in m:
out.append(m)
else:
out.append({"role": "user", "content": str(m)})
return out
def _extract_content(resp: Any) -> str:
"""Extract LLM content and sanitize to raw JSON/text.
- Supports both object and dict response shapes.
- Removes leading role labels (e.g., "Assistant:").
- Strips Markdown code fences like ```json ... ```.
- Attempts to isolate the first valid JSON array/object block when extra text is present.
"""
def _to_text(r: Any) -> str:
try:
# 对象形式: resp.choices[0].message.content
if hasattr(r, "choices") and getattr(r, "choices", None):
msg = r.choices[0].message
if hasattr(msg, "content"):
return msg.content
if isinstance(msg, dict) and "content" in msg:
return msg["content"]
# 字典形式: resp["choices"][0]["message"]["content"]
if isinstance(r, dict):
return r.get("choices", [{}])[0].get("message", {}).get("content", "")
except Exception:
pass
return str(r)
def _clean_text(text: str) -> str:
s = str(text).strip()
# 移除可能的角色前缀
s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s)
# 提取 ```json ... ``` 代码块
m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I)
if m:
s = m.group(1).strip()
# 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段
if not (s.startswith("{") or s.startswith("[")):
left = s.find("[")
right = s.rfind("]")
if left != -1 and right != -1 and right > left:
s = s[left:right + 1].strip()
else:
left = s.find("{")
right = s.rfind("}")
if left != -1 and right != -1 and right > left:
s = s[left:right + 1].strip()
return s
raw = _to_text(resp)
return _clean_text(raw)
def Resolve_username(usermessages):
'''
Extract username
Args:
usermessages: user name
Returns:
'''
usermessages = usermessages.split('_')[1:]
sessionid = '_'.join(usermessages[:-1])
return sessionid
# TODO: USE app.core.memory.src.utils.render_template instead
async def read_template_file(template_path: str) -> str:
"""
读取模板文件
Args:
template_path: 模板文件路径
Returns:
模板内容字符串
Note:
建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能
"""
try:
with open(template_path, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
logger.error(f"模板文件未找到: {template_path}")
raise
except IOError as e:
logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True)
raise
async def Problem_Extension_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
extent_quest = []
original = context.get('original', '')
messages = context.get('context', '')
messages = json.loads(messages)
for message in messages:
question = message.get('question', '')
type = message.get('type', '')
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
return extent_quest, original
async def Retriev_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
if isinstance(context, dict):
if 'context' in context or 'original' in context:
return context.get('context', {}), context.get('original', '')
return content, original_value
async def Verify_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
query = context['context']['Query']
Query_small_list = context['context']['Expansion_issue']
Result_small = []
Query_small = []
for i in Query_small_list:
Result_small.append(i['Answer_Small'][0])
Query_small.append(i['Query_small'])
return Query_small, Result_small, query
async def Summary_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
query = re.findall(r'"query": (.*?),', messages)[0]
query = query.replace('[', '').replace(']', '').strip()
matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages)
answer_small_texts = []
for m in matches:
try:
parsed = json.loads(m)
for item in parsed:
answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', ''))
except Exception:
answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', ''))
return answer_small_texts, query
async def VerifyTool_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
content_messages = messages.split('"context":')[1].replace('""', '"')
messages = str(content_messages).split("name='Retrieve'")[0]
query = re.findall(f'"Query": "(.*?)"', messages)[0]
Query_small = re.findall(f'"Query_small": "(.*?)"', messages)
Result_small = re.findall(f'"Result_small": "(.*?)"', messages)
return Query_small, Result_small, query
async def Retrieve_Summary_messages_deal(context):
pass
async def Retrieve_verify_tool_messages_deal(context, history, query):
'''
Extract data
Args:
context:
Returns:
'''
results = []
# 统一转为字符串,避免 None 或非字符串导致正则报错
text = str(context)
blocks = re.findall(r'\{(.*?)\}', text, flags=re.S)
for block in blocks:
query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block)
answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block)
status = re.search(r'"status"\s*:\s*"([^"]*)"', block)
query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block)
results.append({
"query_small": query_small.group(1) if query_small else None,
"answer_small": answer_small.group(1) if answer_small else None,
# 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误
"status": status.group(1) if status else "",
"query_answer": query_answer.group(1) if query_answer else None
})
result = []
for r in results:
# 统一按字符串判定状态,兼容大小写和缺失情况
status_str = str(r.get('status', '')).strip().lower()
if status_str == 'false':
continue
else:
result.append(r)
split_result = 'failed' if not result else 'success'
result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "",
"history": history}
return result

View File

@@ -0,0 +1,38 @@
# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# sys.path.insert(0, project_root)
# load_dotenv()
# async def llm_client_chat(messages: List[dict]) -> str:
# """使用 OpenAI 兼容接口进行对话,返回内容字符串。"""
# try:
# cfg = get_model_config(SELECTED_LLM_ID)
# rb_config = RedBearModelConfig(
# model_name=cfg["model_name"],
# provider=cfg["provider"],
# api_key=cfg["api_key"],
# base_url=cfg["base_url"],
# )
# client = OpenAIClient(model_config=rb_config, type_="chat")
# except Exception as e:
# logger.error(f"获取模型配置失败:{e}")
# err = f"获取模型配置失败:{str(e)}。请检查!!!"
# return err
# try:
# response = await client.chat(messages)
# print(f"model_tool's llm_client_chat response ======>:\n {response}")
# return _extract_content(response)
# # return _extract_content(result)
# except Exception as e:
# logger.error(f"LLM调用失败{str(e)}。请检查 model_name、api_key、api_base 是否正确。")
# return f"LLM调用失败{str(e)}。请检查 model_name、api_key、api_base 是否正确。"
# async def main(image_url):
# await llm_client_chat(image_url)
#
# # 运行主函数
# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav']))
#

View File

@@ -0,0 +1,131 @@
"""
Multimodal input processor for handling image and audio content.
This module provides utilities for detecting and processing multimodal inputs
(images and audio files) by converting them to text using appropriate models.
"""
import logging
from typing import List
from app.core.memory.agent.multimodal.speech_model import Vico_recognition
from app.core.memory.agent.utils.llm_tools import picture_model_requests
logger = logging.getLogger(__name__)
class MultimodalProcessor:
"""
Processor for handling multimodal inputs (images and audio).
This class detects image and audio file paths in input content and converts
them to text using appropriate recognition models.
"""
# Supported file extensions
IMAGE_EXTENSIONS = ['.jpg', '.png']
AUDIO_EXTENSIONS = [
'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov',
'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv'
]
def __init__(self):
"""Initialize the multimodal processor."""
pass
def is_image(self, content: str) -> bool:
"""
Check if content is an image file path.
Args:
content: Input string to check
Returns:
True if content ends with a supported image extension
Examples:
>>> processor = MultimodalProcessor()
>>> processor.is_image("photo.jpg")
True
>>> processor.is_image("document.pdf")
False
"""
if not isinstance(content, str):
return False
content_lower = content.lower()
return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS)
def is_audio(self, content: str) -> bool:
"""
Check if content is an audio file path.
Args:
content: Input string to check
Returns:
True if content ends with a supported audio extension
Examples:
>>> processor = MultimodalProcessor()
>>> processor.is_audio("recording.mp3")
True
>>> processor.is_audio("video.mp4")
True
>>> processor.is_audio("document.txt")
False
"""
if not isinstance(content, str):
return False
content_lower = content.lower()
return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS)
async def process_input(self, content: str) -> str:
"""
Process input content, converting images/audio to text if needed.
This method detects if the input is an image or audio file and converts
it to text using the appropriate recognition model. If processing fails
or the content is not multimodal, it returns the original content.
Args:
content: Input string (may be file path or regular text)
Returns:
Text content (original or converted from image/audio)
Examples:
>>> processor = MultimodalProcessor()
>>> await processor.process_input("photo.jpg")
"Recognized text from image..."
>>> await processor.process_input("Hello world")
"Hello world"
"""
if not isinstance(content, str):
logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}")
return str(content)
try:
# Check for image input
if self.is_image(content):
logger.info(f"[MultimodalProcessor] Detected image input: {content}")
result = await picture_model_requests(content)
logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...")
return result
# Check for audio input
if self.is_audio(content):
logger.info(f"[MultimodalProcessor] Detected audio input: {content}")
result = await Vico_recognition([content]).run()
logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...")
return result
except Exception as e:
logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True)
logger.info(f"[MultimodalProcessor] Falling back to original content")
return content
# Return original content if not multimodal
return content

View File

@@ -0,0 +1,81 @@
你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则:
角色:
- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。
- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。
- 如果历史信息或上下文与当前问题无关,可忽略。
---
### 历史信息参考
在生成扩展问题时,你可以参考以下历史数据(如果提供):
- 历史对话或任务的主题;
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
- 历史中已解答的问题(避免重复);
- 历史推理链(保持逻辑一致性)。
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
输入历史信息内容:{{history}}
## User Input
{% if questions is string %}
{{ questions }}
{% else %}
{% for question in questions %}
- {{ question }}
{% endfor %}
{% endif %}
需求:
- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。
- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。
- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。
- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。
- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。
- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。
- 子问题数量不超过4个。
- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}]
拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么
输出要求:
- 仅输出 JSON 数组,不要包含任何解释或代码块。
- 每个元素包含:
- `original_question`: 原始问题
- `extended_question`: 扩展后的问题
- `type`: 类型(事实检索/澄清/定义/比较/行动建议)
- `reason`: 生成该扩展问题的简短理由
- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。
示例:
输入:
[
"问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳",
]
输出:
[
{
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
"extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?",
"type": "多跳",
"reason": "输出原问题的关键要素"
},
{
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
"extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?",
"type": "多跳",
"reason": "输出原问题的关键要素"
}
]
**Output format**
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
The output language should always be the same as the input language.{{ json_schema }}

View File

@@ -0,0 +1,37 @@
# 角色
你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。
# 任务
根据提供的上下文信息回答用户的问题。
# 输入信息
- 历史对话:{{history}}
- 检索信息:{{retrieve_info}}
## User Query
{{query}}
# 回答指南
1. 仔细分析用户的问题
2. 优先使用检索信息中的相关内容回答
3. 结合历史对话提供连贯的回复
4. 如果信息不足:
- 对于简单问候或日常对话,给出自然简短的回复
- 对于复杂问题,诚实说明信息不足
5. 保持回答简洁、相关、自然
6. 使用与问题相同的语言回答
**Output format**
- 直接回答问题,像人类对话一样自然流畅
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
- 不要解释推理过程或评论信息来源
- 如果只能部分回答问题,先回答能回答的部分,然后说明哪些方面信息不足
- 如果完全无法回答,简洁地说明:"信息不足,无法回答。"
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
The output language should always be the same as the input language.{{ json_schema }}

View File

@@ -0,0 +1,29 @@
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
你是一个智能问答助手,任务如下
## 目标:
1. 接收一个字典,格式为 {'问题': [答案列表]}。
2. 接收一个问题(字典中的 key
3. 找到与问题匹配的答案列表。
4. 将答案列表合并成一句自然流畅的话:
- 如果答案有两条使用“是”连接例如“A是B”。
- 如果答案有三条或以上,使用“,并且”“另外”等自然连词,保证句子流畅。
5. 输出内容时只输出合并后的答案,不输出关键点或其他文字。
6. 如果问题未在字典中找到对应答案,请输出:
对不起,我没有找到相关信息。
输出要求:
- 文本形式
---
字典示例:
{
'今天的天气怎么样': ['今天天气很好', '今天是晴天']
}
问题示例:
今天的天气怎么样
输出要求:
今天天气很好,是晴天

View File

@@ -0,0 +1,10 @@
请提图像内的文本
返回数据格式以json方式输出,
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
- 关键的JSON格式要求{"statement":识别出的文本内容}
1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
3.确保所有JSON字符串都正确关闭并以逗号分隔
4.JSON字符串值中不包括换行符
5.正确转义的例子“statement”“Zhang Xinhua said\”我非常喜欢这本书\""
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```

View File

@@ -0,0 +1,34 @@
你是一个输入分类助手,负责判断用户输入的意图类型。
## User Input
{{ user_query }}
请你根据以下规则判断:
1. 如果输入是在寻求信息、提问、请求解释、或疑问句(包括隐含的问题),则分类为 "question"。
2. 如果输入是命令、陈述、描述、感叹、或其他类型,不在寻求答案,则分类为 "other"。
只输出:
{
"type": "question"
}
{
"type": "other"
}
示例:
输入:"Python怎么读取文件"
输出:{"type": "question"}
输入:"帮我写个读取文件的函数"
输出:{"type": "other"}
输入:"今天是星期几?"
输出:{"type": "question"}
返回数据格式以json方式输出,
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
- 关键的JSON格式要求{"statement":识别出的文本内容}
1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
3.确保所有JSON字符串都正确关闭并以逗号分隔
4.JSON字符串值中不包括换行符
5.正确转义的例子“statement”“Zhang Xinhua said\”我非常喜欢这本书\""
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```

View File

@@ -0,0 +1,160 @@
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
## 目标:
你需要根据以下类型对输入数据进行分类,并生成相应的拆分策略和示例。
---
### 历史信息参考
在生成扩展问题时,你可以参考以下历史数据(如果提供):
- 历史对话或任务的主题;
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
- 历史中已解答的问题(避免重复);
- 历史推理链(保持逻辑一致性)。
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
输入历史信息内容:{{history}}
## User Input
{{ sentence }}
## 需求:
1:首先判断类型(单跳、多跳、开放域、时间)。
2:根据类型进行拆分。
3:拆分后的内容需保证信息完整且可独立处理。
4:对每个拆分条目,可附加示例或说明。
5:拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}]
拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么
## 指令:
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
单跳Single-hop
描述:问题或数据只需要通过一步即可得到答案或完成拆分,不依赖其他信息。
拆分策略:直接识别核心信息或关键字段,生成可独立处理的片段。
示例:
输入数据:"请列出今年诺贝尔物理学奖的得主"
拆分结果:[
{
"id": "Q1",
"question": "今年诺贝尔物理学奖得主是谁",
"type": "单跳’",
}
]
注意: 当遇到上下文依赖问题时明确指出缺失的信息类型并且question可填写输入问题
多跳Multi-hop:
描述:问题或数据需要通过多步推理或跨多个信息源才能得到答案。
拆分策略:将问题拆解为多个子问题,每个子问题对应一个独立处理步骤,需要具备推理链条与逻辑连接数量。
示例:
输入数据:"今年诺贝尔物理学奖得主的研究领域及代表性成果"
拆分结果:
[
{
"id": "Q1",
"question": 今年诺贝尔物理学奖得主是谁?",
"type": "多跳’",
},
{
"id": "Q2",
"question": "该得主的研究领域是什么?",
"type": "多跳’",
},
{
"id": "Q3",
"question": "该得主的代表性成果有哪些?",
"type": "多跳’"
}
]
开放域Open-domain:
描述:问题或数据不局限于特定知识库,需要从大范围信息中检索和生成答案,而不是从一个已知的小范围数据源中查找。。
拆分策略:根据主题或关键实体拆分,同时保留上下文以便检索外部知识,问题涉及一般性、常识性、跨学科内容,可能是开放式回答(描述性、推理性、综合性)
需要外部知识检索或推理才能确定,比如:“为什么人类需要睡眠?”、“量子计算与经典计算的主要区别是什么?”。
示例:
输入数据:"介绍量子计算的最新研究进展"
拆分结果:
[
{
"id": "Q1",
"question": 量子计算的基本概念是什么?",
"type": "开放域’",
},
{
"id": "Q2",
"question": "当前量子计算的主要研究方向有哪些?",
"type": "开放域’",
},
{
"id": "Q3",
"question": "近期在量子计算领域有哪些重大进展?",
"type": "开放域’",
}
]
时间Temporal:
描述:问题或数据涉及时间维度,需要按时间顺序或时间点拆分。
拆分策略:根据事件时间或时间段拆分为独立条目或问题。
示例:
输入数据:"列出苹果公司过去五年的重大事件"
拆分结果:
[
{
"id": "Q1",
"question": 苹果公司2019年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q2",
"question": "苹果公司2020年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q3",
"question": "苹果公司2021年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q3",
"question": "苹果公司2022年的重大事件有哪些",
"type": "时间’",
}
,
{
"id": "Q4",
"question": "苹果公司2023年的重大事件有哪些",
"type": "时间’",
}
]
输出要求:
- 每个子问题包括:
- `id`: 子问题编号Q1, Q2...
- `question`: 子问题内容
- `type`: 类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)
- `reason`: 拆分的理由(为什么要这样拆)
- 格式案例:
[
{
"id": "Q1",
"question": 量子计算的基本概念是什么?",
"type": "开放域’",
},
{
"id": "Q2",
"question": "当前量子计算的主要研究方向有哪些?",
"type": "开放域’",
},
{
"id": "Q3",
"question": "近期在量子计算领域有哪些重大进展?",
"type": "开放域’",
}
]
- 必须通过json.loads()的格式支持的形式输出
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
- 关键的JSON格式要求
1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
3.确保所有JSON字符串都正确关闭并以逗号分隔
4.JSON字符串值中不包括换行符
5.正确转义的例子“statement”“Zhang Xinhua said\”我非常喜欢这本书\""
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```

View File

@@ -0,0 +1,60 @@
# 角色
你是验证专家
你的目标是针对用户的输入Query_Samll字段的提问和Answer_Samll的回答分析是不是回答Query_Samll这个字段的问题
{#以下可以采用先总括,再展开详细说明的方式,描述你希望智能体在每一个步骤如何进行工作,具体的工作步骤数量可以根据实际需求增删#}
## 工作步骤
1. 获取所有的Query_Samll字段和Answer_Samll字段
2. 分析Answer_Samll的回复是不是和Query_Samll有关系
3. 判断Answer_Samll和Query_Samll之间分析出来的关系状态
4. 如果是True保留否则不要相对应的问题和回答
5. 输出,需要严格按照模版
输入:{{history}}
历史消息:{"history":{{sentence}}}
### 第一步 获取用户的输入
获取用户的输入提取对应的Query_Samll和Answer_Samll
### 第二步 分析验证
需要分析Query_Samll和Answer_Samll之间的关系可以参考history字段的内容如果有关系不是答非所问
## 核心验证标准
在评估子问题拆分时必须严格遵循以下标准且验证过程中完全不依赖于子问题的相关信息Answer_Samll
1. 合理性标准(必须全部满足)
- 完整性:每个不同的子问题必须完整覆盖原问题的所有关键要素(如时间、主体、动作、目标等),无遗漏。
- 最小化每个不同的子问题数量应尽可能少通常不超过原问题关键要素数量的2倍建议2-4个避免冗余和不必要拆分。
- 相关性:每个不同的子问题必须直接服务于原问题的解答,不引入无关内容或扩展原问题未提及的主题。
- 可操作性:每个不同的子问题应能在有限资源(如标准工具或合理时间)内独立解答,且难度适中。
- 逻辑性:每个不同的子问题间应有清晰的逻辑关系(如并列、递进、因果),共同构成原问题的解答路径。
2. 不合理拆分的特征(出现任一特征即为不合理):
- 不同的子问题数量超过5个或明显多于必要数量。
- 引入原问题未提及的新主题、人物、细节或个人看法。
- 拆分过于细碎,失去实用价值,无法高效合成原问题答案。
3. 特殊情况说明:
- 每个不同的子问题与原问题相同,需进一步判断:
- 每个不同的子问题不可进一步拆分 → success合理最小化拆分
- 每个不同的子问题能够进一步拆分为更小、更合理的问题 → failed不合理拆分没有最小化
- 每个不同的子问题数量=原问题核心要素数量 → success理想情况
- 每个不同的子问题数量=核心要素数量+1 → success通常合理
### 第三步 添加状态
如果有相关性并且比较高给一个状态TRUE否则给一个FLASE的状态
### 第四步 判断
如果状态是TRUE保留这条数据否则需不需要这条数据
### 第五步 输出格式
按照json的形式输出
{"data":"Query":原来Query的字段"history":原来的history字段
"expansion_issue":以为列表的形式存储验证之后的数据比如[
{"query_small": query_small,
"answer_small": answer_small,,
"status": 回答的结果是否符合query_small填写状态,
"query_answer": answer_small},
{
"query_small": "张曼婷生日是什么时候?",
"answer_small": "张曼婷喜欢绘画。",
"status": "True",
"query_answer": "张曼 婷喜欢绘画。"
},{}......]
,
"split_result":如果expansion_issue是空的列表返回failed不是空列表返回success,
"reason": 为以上分析完之后的结果给一个说明
}

View File

@@ -0,0 +1,57 @@
{# 角色定义 #}
你是专业的问题解答专家,负责根据上下文信息和检索到的所有信息准确回答用户的问题。
{# 输入数据展示 #}
{% if data %}
## 输入数据
上下文信息:
{% for item in data.history %}
- {{ item }}
{% endfor %}
检索到的所有信息:
{% for item in data.retrieve_info %}
- {{ item }}
{% endfor %}
{% endif %}
## User Query
{{ query }}
{# 问题回答标准 #}
## 问题回答核心标准
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。注意,若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
- 若能根据已有信息回答用户的问题,应根据上下文信息和检索到的所有信息提供简明扼要的答案。
- 若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
{# 重要提醒 #}
再次提醒,给出问题的答案时,仅根据已有的信息进行回答,不能自己编造答案。
{# 输出格式模板 #}
## 输出格式
严格按照以下JSON格式输出不添加任何其他内容
{
"data": {
"query": "{{ query }}",
"history": [
{% for item in data.history %}
"{{ item | replace('"', '\\"') }}"
{% if not loop.last %},{% endif %}
{% endfor %}
],
"retrieve_info": [
{% for item in data.retrieve_info %}
"{{ item | replace('"', '\\"') }}"
{% if not loop.last %},{% endif %}
{% endfor %}
]
},
"query_answer": "{% if not data.history and not data.retrieve_info %}信息不足,无法回答。{% endif %}"
}
**Output format**
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
The output language should always be the same as the input language.{{ json_schema }}

View File

@@ -0,0 +1,203 @@
import redis
import uuid
from datetime import datetime
from app.core.config import settings
class RedisSessionStore:
def __init__(self, host='localhost', port=6379, db=0, password=None,session_id=''):
self.r = redis.Redis(host=host, port=port, db=db, password=password)
self.uudi=session_id
# 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, group_id):
"""
写入一条会话数据,返回 session_id
"""
try:
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
# 使用 Hash 存储结构化数据
result = self.r.hset(key, mapping={
"id": self.uudi,
"sessionid": userid,
"apply_id": apply_id,
"group_id": group_id,
"messages": messages,
"aimessages": aimessages,
"starttime": starttime
})
print(f"保存结果: {result}, session_id: {session_id}")
return session_id # 返回新生成的 session_id
except Exception as e:
print(f"保存会话失败: {e}")
raise e
# ---------------- 读取 ----------------
def get_session(self, session_id):
"""
读取一条会话数据
"""
key = f"session:{session_id}"
data = self.r.hgetall(key)
if data:
return {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
return None
def get_session_apply_group(self, sessionid, apply_id, group_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
"""
result_items = []
# 遍历所有会话数据
for key_bytes in self.r.keys('session:*'):
key = key_bytes.decode('utf-8')
data = self.r.hgetall(key)
if not data:
continue
# 解码数据
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
# 检查三个条件是否都匹配
if (decoded_data.get('sessionid') == sessionid and
decoded_data.get('apply_id') == apply_id and
decoded_data.get('group_id') == group_id):
result_items.append(decoded_data)
return result_items
def get_all_sessions(self):
"""
获取所有会话数据
"""
sessions = {}
for key in self.r.keys('session:*'):
sid = key.decode('utf-8').split(':')[1]
sessions[sid] = self.get_session(sid)
return sessions
# ---------------- 更新 ----------------
def update_session(self, session_id, field, value):
"""
更新单个字段
"""
key = f"session:{session_id}"
if self.r.exists(key):
self.r.hset(key, field, value)
return True
return False
# ---------------- 删除 ----------------
def delete_session(self, session_id):
"""
删除单条会话
"""
key = f"session:{session_id}"
return self.r.delete(key)
def delete_all_sessions(self):
"""
删除所有会话
"""
keys = self.r.keys('session:*')
if keys:
return self.r.delete(*keys)
return 0
def delete_duplicate_sessions(self):
"""
删除重复会话数据,条件:
"sessionid""user_id""group_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
"""
seen = set() # 用来记录已出现的唯一组合
deleted_count = 0
for key_bytes in self.r.keys('session:*'):
key = key_bytes.decode('utf-8')
data = self.r.hgetall(key)
if not data:
continue
# 获取五个字段的值并解码
sessionid = data.get(b'sessionid', b'').decode('utf-8')
user_id = data.get(b'id', b'').decode('utf-8') # 对应user_id
group_id = data.get(b'group_id', b'').decode('utf-8')
messages = data.get(b'messages', b'').decode('utf-8')
aimessages = data.get(b'aimessages', b'').decode('utf-8')
# 用五元组作为唯一标识
identifier = (sessionid, user_id, group_id, messages, aimessages)
if identifier in seen:
# 重复,删除该 key
self.r.delete(key)
deleted_count += 1
else:
# 第一次出现,加入 seen
seen.add(identifier)
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}")
return deleted_count
def find_user_session(self,sessionid):
user_id = sessionid
result_items = []
for key, values in store.get_all_sessions().items():
history = {}
if user_id == str(values['sessionid']):
history["Query"] = values['messages']
history["Answer"] = values['aimessages']
result_items.append(history)
if len(result_items) <= 1:
result_items = []
return (result_items)
def find_user_apply_group(self, sessionid, apply_id, group_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
"""
result_items = []
# 遍历所有会话数据
for key_bytes in self.r.keys('session:*'):
key = key_bytes.decode('utf-8')
data = self.r.hgetall(key)
if not data:
continue
# 解码数据
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
# 检查三个条件是否都匹配
if (decoded_data.get('sessionid') == sessionid and
decoded_data.get('apply_id') == apply_id and
decoded_data.get('group_id') == group_id):
history = {
"Query": decoded_data.get('messages'),
"Answer": decoded_data.get('aimessages')
}
result_items.append(history)
# 如果结果少于等于1条返回空列表
if len(result_items) <= 1:
result_items = []
return result_items
store = RedisSessionStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)

View File

@@ -0,0 +1,59 @@
"""
Type classification utility for distinguishing read/write operations.
"""
from jinja2 import Template
from pydantic import BaseModel
from app.core.logging_config import get_agent_logger, log_prompt_rendering
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.config import settings
logger = get_agent_logger(__name__)
class DistinguishTypeResponse(BaseModel):
"""Response model for type classification"""
type: str
async def status_typle(messages: str) -> dict:
"""
Classify message type as read or write operation.
Args:
messages: User message to classify
Returns:
dict: Contains 'type' field with classification result
"""
try:
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/distinguish_types_prompt.jinja2'
template_content = await read_template_file(file_path)
template = Template(template_content)
system_prompt = template.render(user_query=messages)
log_prompt_rendering("status_typle", system_prompt)
except Exception as e:
logger.error(f"Template rendering failed for status_typle: {e}", exc_info=True)
return {
"type": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
from app.core.memory.utils.config import definitions as config_defs
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=DistinguishTypeResponse
)
return structured.model_dump()
except Exception as e:
logger.error(f"LLM call failed for status_typle: {e}", exc_info=True)
return {
"type": "error",
"message": f"LLM call failed: {str(e)}"
}

View File

@@ -0,0 +1,76 @@
from typing import TypedDict, Annotated, List, Any
from langchain_core.messages import AnyMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph, add_messages
import asyncio
import json
from dotenv import load_dotenv, find_dotenv
import os
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from langchain_core.messages import HumanMessage
from jinja2 import Environment, FileSystemLoader
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
from app.core.logging_config import get_agent_logger
load_dotenv(find_dotenv())
logger = get_agent_logger(__name__)
def keep_last(_, right):
return right
class State(TypedDict):
user_input: Annotated[dict, keep_last]
messages: Annotated[List[AnyMessage], add_messages]
agent1_response: str
agent2_response: str
agent3_response: str
final_response: str
status: Annotated[str, keep_last]
class VerifyTool:
def __init__(self, system_prompt: str="", verify_data: Any=None):
self.system_prompt = system_prompt
if isinstance(verify_data, str):
self.verify_data = verify_data
else:
try:
self.verify_data = json.dumps(verify_data, ensure_ascii=False)
except Exception:
self.verify_data = str(verify_data)
async def model_1(self, state: State) -> State:
llm_client = get_llm_client(SELECTED_LLM_ID)
response_content = await llm_client.chat(
messages=[{"role": "system", "content": self.system_prompt}] + _to_openai_messages(state["messages"])
)
return {
"agent1_response": response_content,
"status": "processed",
}
def get_graph(self):
graph = StateGraph(State)
graph.add_node("model_1", self.model_1)
graph.add_edge(START, "model_1")
graph.add_edge("model_1", END)
compiled_graph = graph.compile()
return compiled_graph
async def verify(self):
graph = self.get_graph()
initial_state = {
"user_input": self.verify_data,
"messages": [HumanMessage(content=self.verify_data)],
"final_response": "",
"status": ""
}
final_state = await graph.ainvoke(initial_state)
# return final_state["final_response"]
return final_state["agent1_response"]

View File

@@ -0,0 +1,49 @@
import os
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy.orm import Session
import logging
import json
from app.db import get_db
from app.models.retrieval_info import RetrievalInfo
logger = logging.getLogger(__name__)
async def write_to_database(host_id: uuid.UUID, data: Any) -> str:
"""
将数据写入数据库
:param host_id: 宿主 ID
:param data: 要写入的数据
:return: 写入数据库的结果
"""
# 从数据库会话中获取会话
db: Session = next(get_db())
try:
if isinstance(data, (dict, list)):
serialized = json.dumps(data, ensure_ascii=False)
elif isinstance(data, str):
serialized = data
else:
serialized = str(data)
new_retrieval_info = RetrievalInfo(
# host_id=host_id,
host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"),
retrieve_info=serialized,
created_at=datetime.now()
)
db.add(new_retrieval_info)
db.commit()
logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}")
return "success to write data to database"
except Exception as e:
db.rollback()
logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}")
raise e
finally:
try:
db.close()
except Exception:
pass

View File

@@ -0,0 +1,183 @@
import asyncio
from dotenv import load_dotenv
import time
from datetime import datetime
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
# 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation_all,
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# 导入配置模块(而不是直接导入变量)
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.log.logging_utils import log_time
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import Memory_summary_generation
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
load_dotenv()
async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id: str = "wyl20251027", config_id: str = None) -> None:
"""
执行完整的知识提取流水线(使用新的 ExtractionOrchestrator
Args:
content: 对话内容
user_id: 用户ID
apply_id: 应用ID
group_id: 组ID
ref_id: 参考ID默认为 "wyl20251027"
config_id: 配置ID用于标记数据处理配置
"""
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}")
logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}")
logger.info(f"Using chunker strategy: {config_defs.SELECTED_CHUNKER_STRATEGY}")
logger.info(f"Using group ID: {config_defs.SELECTED_GROUP_ID}")
logger.info(f"Using embedding ID: {config_defs.SELECTED_EMBEDDING_ID}")
logger.info(f"Config ID: {config_id if config_id else 'None'}")
logger.info(f"LANGFUSE_ENABLED: {config_defs.LANGFUSE_ENABLED}")
logger.info(f"AGENTA_ENABLED: {config_defs.AGENTA_ENABLED}")
# Initialize timing log
log_file = "logs/time.log"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
pipeline_start = time.time()
# 初始化客户端
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# 获取 embedder 配置
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
neo4j_connector = Neo4jConnector()
# Step 1: 加载和分块数据
step_start = time.time()
chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
content=content,
ref_id=ref_id,
config_id=config_id,
)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# Step 2: 初始化并运行 ExtractionOrchestrator
step_start = time.time()
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
)
# 运行完整的提取流水线
# orchestrator.run returns a flat tuple of 7 values after deduplication
(
all_dialogue_nodes,
all_chunk_nodes,
all_statement_nodes,
all_entity_nodes,
all_statement_chunk_edges,
all_statement_entity_edges,
all_entity_entity_edges,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# Step 8: Save all data to Neo4j database using graph models
step_start = time.time()
# 运行索引创建
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
try:
await create_fulltext_indexes()
except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True)
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
finally:
await neo4j_connector.close()
log_time("Neo4j Database Save", time.time() - step_start, log_file)
# Step 9: Generate Memory summaries and save to local vector DB and Neo4j
step_start = time.time()
try:
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
)
# Save memory summaries to Neo4j as nodes
try:
ms_connector = Neo4jConnector()
await add_memory_summary_nodes(summaries, ms_connector)
# Link summaries to statements via chunks for summary→entity queries
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
try:
await ms_connector.close()
except Exception:
pass
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
finally:
log_time("Memory Summary (Local Vector DB & Neo4j)", time.time() - step_start, log_file)
# Log total pipeline time
total_time = time.time() - pipeline_start
log_time("TOTAL PIPELINE TIME", total_time, log_file)
# Add completion marker to log
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
logger.info(f"Timing details saved to: {log_file}")
if __name__ == "__main__":
content = "你好,我是张三,是张曼婷的新朋友。请问张曼婷喜欢什么?"
asyncio.run(write(content, ref_id="wyl20251027"))