Initial commit
This commit is contained in:
0
app/core/memory/agent/__init__.py
Normal file
0
app/core/memory/agent/__init__.py
Normal file
16
app/core/memory/agent/langgraph_graph/__init__.py
Normal file
16
app/core/memory/agent/langgraph_graph/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
LangGraph Graph package for memory agent.
|
||||
|
||||
This package provides the LangGraph workflow orchestrator with modular
|
||||
node implementations, routing logic, and state management.
|
||||
|
||||
Package structure:
|
||||
- read_graph: Main graph factory for read operations
|
||||
- write_graph: Main graph factory for write operations
|
||||
- nodes: LangGraph node implementations
|
||||
- routing: State routing logic
|
||||
- state: State management utilities
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
|
||||
__all__ = ['make_read_graph']
|
||||
10
app/core/memory/agent/langgraph_graph/nodes/__init__.py
Normal file
10
app/core/memory/agent/langgraph_graph/nodes/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
LangGraph node implementations.
|
||||
|
||||
This module contains custom node implementations for the LangGraph workflow.
|
||||
"""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
|
||||
from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
|
||||
|
||||
__all__ = ["ToolExecutionNode", "create_input_message"]
|
||||
144
app/core/memory/agent/langgraph_graph/nodes/input_node.py
Normal file
144
app/core/memory/agent/langgraph_graph/nodes/input_node.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Input node for LangGraph workflow entry point.
|
||||
|
||||
This module provides the create_input_message function which processes initial
|
||||
user input with multimodal support and creates the first tool call message.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_input_message(
|
||||
state: Dict[str, Any],
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
multimodal_processor: MultimodalProcessor
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create initial tool call message from user input.
|
||||
|
||||
This function:
|
||||
1. Extracts the last message content from state
|
||||
2. Processes multimodal inputs (images/audio) using the multimodal processor
|
||||
3. Generates a unique message ID
|
||||
4. Extracts namespace from session_id
|
||||
5. Handles verified_data extraction for backward compatibility
|
||||
6. Returns AIMessage with complete tool_calls structure
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary containing messages
|
||||
tool_name: Name of the tool to invoke (typically "Split_The_Problem")
|
||||
session_id: Session identifier (format: "call_id_{namespace}")
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
multimodal_processor: Processor for handling image/audio inputs
|
||||
|
||||
Returns:
|
||||
State update with AIMessage containing tool_call
|
||||
|
||||
Examples:
|
||||
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
|
||||
>>> result = await create_input_message(
|
||||
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor
|
||||
... )
|
||||
>>> result["messages"][0].tool_calls[0]["name"]
|
||||
'Split_The_Problem'
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# Extract last message content
|
||||
if messages:
|
||||
last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1])
|
||||
else:
|
||||
logger.warning("[create_input_message] No messages in state, using empty string")
|
||||
last_message = ""
|
||||
|
||||
logger.debug(f"[create_input_message] Original input: {last_message[:100]}...")
|
||||
|
||||
# Process multimodal input (images/audio)
|
||||
try:
|
||||
processed_content = await multimodal_processor.process_input(last_message)
|
||||
if processed_content != last_message:
|
||||
logger.info(
|
||||
f"[create_input_message] Multimodal processing converted input "
|
||||
f"from {len(last_message)} to {len(processed_content)} chars"
|
||||
)
|
||||
last_message = processed_content
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[create_input_message] Multimodal processing failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Continue with original content
|
||||
|
||||
# Generate unique message ID
|
||||
uuid_str = uuid.uuid4()
|
||||
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Extract namespace from session_id
|
||||
# Expected format: "call_id_{namespace}" or similar
|
||||
try:
|
||||
namespace = str(session_id).split('_id_')[1]
|
||||
except (IndexError, AttributeError):
|
||||
logger.warning(
|
||||
f"[create_input_message] Could not extract namespace from session_id: {session_id}"
|
||||
)
|
||||
namespace = "unknown"
|
||||
|
||||
# Handle verified_data extraction (backward compatibility)
|
||||
# This regex-based extraction is kept for compatibility with existing data formats
|
||||
if 'verified_data' in str(last_message):
|
||||
try:
|
||||
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
|
||||
query_match = re.findall(r'"query": "(.*?)",', messages_last)
|
||||
if query_match:
|
||||
last_message = query_match[0]
|
||||
logger.debug(
|
||||
f"[create_input_message] Extracted query from verified_data: {last_message}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[create_input_message] Failed to extract query from verified_data: {e}"
|
||||
)
|
||||
|
||||
# Construct tool call message
|
||||
tool_call_id = f"{session_id}_{uuid_str}"
|
||||
|
||||
logger.info(
|
||||
f"[create_input_message] Creating tool call for '{tool_name}' "
|
||||
f"with ID: {tool_call_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": tool_name,
|
||||
"args": {
|
||||
"sentence": last_message,
|
||||
"sessionid": session_id,
|
||||
"messages_id": str(uuid_str),
|
||||
"search_switch": search_switch,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
},
|
||||
"id": tool_call_id
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
199
app/core/memory/agent/langgraph_graph/nodes/tool_node.py
Normal file
199
app/core/memory/agent/langgraph_graph/nodes/tool_node.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Tool execution node for LangGraph workflow.
|
||||
|
||||
This module provides the ToolExecutionNode class which wraps tool execution
|
||||
with parameter transformation logic using the ParameterBuilder service.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||
extract_tool_call_id,
|
||||
extract_content_payload
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolExecutionNode:
|
||||
"""
|
||||
Custom LangGraph node that wraps tool execution with parameter transformation.
|
||||
|
||||
This node extracts content from previous tool results, transforms parameters
|
||||
based on tool type using ParameterBuilder, and invokes the tool with the
|
||||
correct argument structure.
|
||||
|
||||
Attributes:
|
||||
tool_node: LangGraph ToolNode wrapping the actual tool
|
||||
id: Node identifier for message IDs
|
||||
tool_name: Name of the tool being executed
|
||||
namespace: Namespace for session management
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool: Callable,
|
||||
node_id: str,
|
||||
namespace: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
parameter_builder: ParameterBuilder,
|
||||
storage_type:str,
|
||||
user_rag_memory_id:str
|
||||
):
|
||||
"""
|
||||
Initialize the tool execution node.
|
||||
|
||||
Args:
|
||||
tool: The tool function to execute
|
||||
node_id: Identifier for this node (used in message IDs)
|
||||
namespace: Namespace for session management
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
"""
|
||||
self.tool_node = ToolNode([tool])
|
||||
self.id = node_id
|
||||
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
self.namespace = namespace
|
||||
self.search_switch = search_switch
|
||||
self.apply_id = apply_id
|
||||
self.group_id = group_id
|
||||
self.parameter_builder = parameter_builder
|
||||
self.storage_type=storage_type
|
||||
self.user_rag_memory_id=user_rag_memory_id
|
||||
|
||||
logger.info(
|
||||
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
|
||||
)
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the tool with transformed parameters.
|
||||
|
||||
This method:
|
||||
1. Extracts the last message from state
|
||||
2. Extracts tool call ID using state extractors
|
||||
3. Extracts content payload using state extractors
|
||||
4. Builds tool arguments using parameter builder
|
||||
5. Constructs AIMessage with tool_calls
|
||||
6. Invokes the tool and returns the result
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Updated state with tool result in messages
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
logger.debug( self.tool_name)
|
||||
|
||||
if not messages:
|
||||
logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state")
|
||||
return {"messages": [AIMessage(content="Error: No messages in state")]}
|
||||
|
||||
last_message = messages[-1]
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract tool call ID using state extractors
|
||||
tool_call_id = extract_tool_call_id(last_message)
|
||||
logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}"
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
|
||||
|
||||
try:
|
||||
# Extract content payload using state extractors
|
||||
content = extract_content_payload(last_message)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
content = {}
|
||||
|
||||
try:
|
||||
# Build tool arguments using parameter builder
|
||||
tool_args = self.parameter_builder.build_tool_args(
|
||||
tool_name=self.tool_name,
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
search_switch=self.search_switch,
|
||||
apply_id=self.apply_id,
|
||||
group_id=self.group_id,
|
||||
storage_type=self.storage_type,
|
||||
user_rag_memory_id=self.user_rag_memory_id
|
||||
)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]}
|
||||
|
||||
# Construct tool input message
|
||||
tool_input = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": self.tool_name,
|
||||
"args": tool_args,
|
||||
"id": f"{self.id}_{tool_call_id}",
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
try:
|
||||
# Invoke the tool
|
||||
result = await self.tool_node.ainvoke(tool_input)
|
||||
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution completed"
|
||||
)
|
||||
|
||||
# Return the result directly - it already contains the messages list
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return error as ToolMessage to maintain message chain consistency
|
||||
from langchain_core.messages import ToolMessage
|
||||
return {
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Error executing tool: {str(e)}",
|
||||
tool_call_id=f"{self.id}_{tool_call_id}"
|
||||
)
|
||||
]
|
||||
}
|
||||
508
app/core/memory/agent/langgraph_graph/read_graph.py
Normal file
508
app/core/memory/agent/langgraph_graph/read_graph.py
Normal file
@@ -0,0 +1,508 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from functools import partial
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
# Import new modular components
|
||||
from app.core.memory.agent.langgraph_graph.nodes import ToolExecutionNode, create_input_message
|
||||
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
Verify_continue,
|
||||
Retrieve_continue,
|
||||
Split_continue
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
load_dotenv()
|
||||
redishost=os.getenv("REDISHOST")
|
||||
redisport=os.getenv('REDISPORT')
|
||||
redisdb=os.getenv('REDISDB')
|
||||
redispassword=os.getenv('REDISPASSWORD')
|
||||
counter = COUNTState(limit=3)
|
||||
|
||||
# 在工作流中添加循环计数更新
|
||||
async def update_loop_count(state):
|
||||
"""更新循环计数器"""
|
||||
current_count = state.get("loop_count", 0)
|
||||
return {"loop_count": current_count + 1}
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
messages = state["messages"]
|
||||
|
||||
# 添加边界检查
|
||||
if not messages:
|
||||
return END
|
||||
counter.add(1) # 累加 1
|
||||
|
||||
loop_count = counter.get_total()
|
||||
logger.debug(f"[should_continue] 当前循环次数: {loop_count}")
|
||||
|
||||
last_message = messages[-1]
|
||||
last_message_str = str(last_message).replace('\\', '')
|
||||
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
|
||||
logger.debug(f"Status tools: {status_tools}")
|
||||
|
||||
if "success" in status_tools:
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
elif "failed" in status_tools:
|
||||
if loop_count < 2: # 最大循环次数 3
|
||||
return "content_input"
|
||||
else:
|
||||
counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# 添加默认返回值,避免返回 None
|
||||
counter.reset()
|
||||
return "Summary" # 或根据业务需求选择合适的默认值
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
# Direct dictionary access instead of regex parsing
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
# Handle case where search_switch might be in messages
|
||||
if search_switch is None and "messages" in state:
|
||||
messages = state.get("messages", [])
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
# Try to extract from tool_calls args
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
break
|
||||
|
||||
# Convert to string for comparison if needed
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '0':
|
||||
return 'Verify'
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
# 添加默认返回值,避免返回 None
|
||||
return 'Retrieve_Summary' # 或根据业务逻辑选择合适的默认值
|
||||
|
||||
|
||||
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
logger.debug(f"Split_continue state: {state}")
|
||||
|
||||
# Direct dictionary access instead of regex parsing
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
# Handle case where search_switch might be in messages
|
||||
if search_switch is None and "messages" in state:
|
||||
messages = state.get("messages", [])
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
# Try to extract from tool_calls args
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
break
|
||||
|
||||
# Convert to string for comparison if needed
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
# 在 input_sentence 函数中修改参数名称
|
||||
async def input_sentence(state, name, id, search_switch,apply_id,group_id):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1].content if messages else ""
|
||||
|
||||
if last_message.endswith('.jpg') or last_message.endswith('.png'):
|
||||
last_message=await picture_model_requests(last_message)
|
||||
if any(last_message.endswith(ext) for ext in audio_extensions):
|
||||
last_message=await Vico_recognition([last_message]).run()
|
||||
logger.debug(f"Audio recognition result: {last_message}")
|
||||
|
||||
|
||||
uuid_str = uuid.uuid4()
|
||||
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
namespace = str(id).split('_id_')[1]
|
||||
if 'verified_data' in str(last_message):
|
||||
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
|
||||
last_message = re.findall(r'"query": "(.*?)",', str(messages_last))[0]
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": name,
|
||||
"args": {
|
||||
"sentence": last_message,
|
||||
'sessionid': id,
|
||||
'messages_id': str(uuid_str),
|
||||
"search_switch": search_switch, # 正确地将 search_switch 放入 args 中
|
||||
"apply_id":apply_id,
|
||||
"group_id":group_id
|
||||
},
|
||||
"id": id + f'_{uuid_str}'
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class ProblemExtensionNode:
|
||||
def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""):
|
||||
self.tool_node = ToolNode([tool])
|
||||
self.id = id
|
||||
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
self.namespace = namespace
|
||||
self.search_switch = search_switch
|
||||
self.apply_id = apply_id
|
||||
self.group_id = group_id
|
||||
self.storage_type = storage_type
|
||||
self.user_rag_memory_id = user_rag_memory_id
|
||||
|
||||
async def __call__(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1] if messages else ""
|
||||
logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}")
|
||||
if self.tool_name=='Input_Summary':
|
||||
tool_call =re.findall(f"'id': '(.*?)'",str(last_message))[0]
|
||||
else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
|
||||
# try:
|
||||
# content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message
|
||||
# except:
|
||||
# content = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||
# 尝试从上一工具的结果中提取实际的内容载荷(而不是整个对象的字符串表示)
|
||||
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||
extracted_payload = None
|
||||
# 捕获 ToolMessage 的 content 字段(支持单/双引号),并避免贪婪匹配
|
||||
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
|
||||
if m:
|
||||
extracted_payload = m.group(1)
|
||||
else:
|
||||
# 回退:直接尝试使用原始字符串
|
||||
extracted_payload = raw_msg
|
||||
|
||||
# 优先尝试将内容解析为 JSON
|
||||
try:
|
||||
content = json.loads(extracted_payload)
|
||||
except Exception:
|
||||
# 尝试从文本中提取 JSON 片段再解析
|
||||
parsed = None
|
||||
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
|
||||
for cand in candidates:
|
||||
try:
|
||||
parsed = json.loads(cand)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
# 如果仍然失败,则以原始字符串作为内容
|
||||
content = parsed if parsed is not None else extracted_payload
|
||||
|
||||
# 根据工具名称构建正确的参数
|
||||
tool_args = {}
|
||||
|
||||
if self.tool_name == "Verify":
|
||||
# Verify工具需要context和usermessages参数
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Retrieve":
|
||||
# Retrieve工具需要context和usermessages参数
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["search_switch"] = str(self.search_switch)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary":
|
||||
# Summary工具需要字符串类型的context参数
|
||||
if isinstance(content, dict):
|
||||
# 将字典转换为JSON字符串
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary_fails":
|
||||
# Summary工具需要字符串类型的context参数
|
||||
if isinstance(content, dict):
|
||||
# 将字典转换为JSON字符串
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name=='Input_Summary':
|
||||
tool_args["context"] =str(last_message)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["search_switch"] = str(self.search_switch)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
tool_args["storage_type"] = getattr(self, 'storage_type', "")
|
||||
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
|
||||
elif self.tool_name=='Retrieve_Summary' :
|
||||
# Retrieve_Summary expects dict directly, not JSON string
|
||||
# content might be a JSON string, try to parse it
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
# Check if it has a "context" key
|
||||
if isinstance(parsed_content, dict) and "context" in parsed_content:
|
||||
tool_args["context"] = parsed_content["context"]
|
||||
else:
|
||||
tool_args["context"] = parsed_content
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, wrap the string
|
||||
tool_args["context"] = {"content": content}
|
||||
elif isinstance(content, dict):
|
||||
# Check if content has a "context" key that needs unwrapping
|
||||
if "context" in content:
|
||||
tool_args["context"] = content["context"]
|
||||
else:
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": str(content)}
|
||||
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
else:
|
||||
# 其他工具使用context参数
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
|
||||
|
||||
tool_input = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": self.tool_name,
|
||||
"args": tool_args,
|
||||
"id": self.id + f"{tool_call}",
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
result = await self.tool_node.ainvoke(tool_input)
|
||||
result_text = str(result)
|
||||
|
||||
return {"messages": [AIMessage(content=result_text)]}
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config_id=None,storage_type=None,user_rag_memory_id=None):
|
||||
memory = InMemorySaver()
|
||||
tool=[i.name for i in tools ]
|
||||
logger.info(f"Initializing read graph with tools: {tool}")
|
||||
if config_id:
|
||||
logger.info(f"使用配置 ID: {config_id}")
|
||||
|
||||
# Extract tool functions
|
||||
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
|
||||
Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None)
|
||||
Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None)
|
||||
Verify_ = next((t for t in tools if t.name == "Verify"), None)
|
||||
Summary_ = next((t for t in tools if t.name == "Summary"), None)
|
||||
Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None)
|
||||
Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None)
|
||||
Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None)
|
||||
|
||||
# Instantiate services
|
||||
parameter_builder = ParameterBuilder()
|
||||
multimodal_processor = MultimodalProcessor()
|
||||
|
||||
# Create nodes using new modular components
|
||||
Split_The_Problem_node = ToolNode([Split_The_Problem_])
|
||||
|
||||
Problem_Extension_node = ToolExecutionNode(
|
||||
tool=Problem_Extension_,
|
||||
node_id="Problem_Extension_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Retrieve_node = ToolExecutionNode(
|
||||
tool=Retrieve_,
|
||||
node_id="Retrieve_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Verify_node = ToolExecutionNode(
|
||||
tool=Verify_,
|
||||
node_id="Verify_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Summary_node = ToolExecutionNode(
|
||||
tool=Summary_,
|
||||
node_id="Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Summary_fails_node = ToolExecutionNode(
|
||||
tool=Summary_fails_,
|
||||
node_id="Summary_fails_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Retrieve_Summary_node = ToolExecutionNode(
|
||||
tool=Retrieve_Summary_,
|
||||
node_id="Retrieve_Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Input_Summary_node = ToolExecutionNode(
|
||||
tool=Input_Summary_,
|
||||
node_id="Input_Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
|
||||
async def content_input_node(state):
|
||||
state_search_switch = state.get("search_switch", search_switch)
|
||||
|
||||
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
|
||||
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
|
||||
|
||||
return await create_input_message(
|
||||
state=state,
|
||||
tool_name=tool_name,
|
||||
session_id=f"{session_prefix}_{namespace}",
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
multimodal_processor=multimodal_processor
|
||||
)
|
||||
|
||||
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow.add_node("content_input", content_input_node)
|
||||
workflow.add_node("Split_The_Problem", Split_The_Problem_node)
|
||||
workflow.add_node("Problem_Extension", Problem_Extension_node)
|
||||
workflow.add_node("Retrieve", Retrieve_node)
|
||||
workflow.add_node("Verify", Verify_node)
|
||||
workflow.add_node("Summary", Summary_node)
|
||||
workflow.add_node("Summary_fails", Summary_fails_node)
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary_node)
|
||||
workflow.add_node("Input_Summary", Input_Summary_node)
|
||||
|
||||
# Add edges using imported routers
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
graph = workflow.compile(checkpointer=memory)
|
||||
yield graph
|
||||
|
||||
|
||||
# 添加到文件末尾或创建新的执行脚本
|
||||
# 在 memory_agent_service.py 文件中添加以下函数
|
||||
|
||||
13
app/core/memory/agent/langgraph_graph/routing/__init__.py
Normal file
13
app/core/memory/agent/langgraph_graph/routing/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""LangGraph routing logic."""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
Verify_continue,
|
||||
Retrieve_continue,
|
||||
Split_continue,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Verify_continue",
|
||||
"Retrieve_continue",
|
||||
"Split_continue",
|
||||
]
|
||||
123
app/core/memory/agent/langgraph_graph/routing/routers.py
Normal file
123
app/core/memory/agent/langgraph_graph/routing/routers.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Routing functions for LangGraph conditional edges.
|
||||
|
||||
This module provides routing functions that determine the next node to execute
|
||||
based on state values. All functions return Literal types for type safety.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global counter for Verify routing
|
||||
counter = COUNTState(limit=3)
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
"""
|
||||
Determine routing after Verify node based on verification result.
|
||||
|
||||
This function checks the verification result in the last message and routes to:
|
||||
- Summary: if verification succeeded
|
||||
- content_input: if verification failed and retry limit not reached
|
||||
- Summary_fails: if verification failed and retry limit reached
|
||||
|
||||
Args:
|
||||
state: LangGraph state containing messages
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# Boundary check
|
||||
if not messages:
|
||||
logger.warning("[Verify_continue] No messages in state, defaulting to Summary")
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
|
||||
# Increment counter
|
||||
counter.add(1)
|
||||
loop_count = counter.get_total()
|
||||
logger.debug(f"[Verify_continue] Current loop count: {loop_count}")
|
||||
|
||||
# Extract verification result from last message
|
||||
last_message = messages[-1]
|
||||
last_message_str = str(last_message).replace('\\', '')
|
||||
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
|
||||
logger.debug(f"[Verify_continue] Status tools: {status_tools}")
|
||||
|
||||
# Route based on verification result
|
||||
if "success" in status_tools:
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
elif "failed" in status_tools:
|
||||
if loop_count < 2: # Max retry count is 2
|
||||
return "content_input"
|
||||
else:
|
||||
counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Default to Summary if status is unclear
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
|
||||
|
||||
def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing after Retrieve node based on search_switch value.
|
||||
|
||||
This function routes based on the search_switch parameter:
|
||||
- search_switch == '0': Route to Verify (verification needed)
|
||||
- search_switch == '1': Route to Retrieve_Summary (direct summary)
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
search_switch = extract_search_switch(state)
|
||||
|
||||
logger.debug(f"[Retrieve_continue] search_switch: {search_switch}")
|
||||
|
||||
if search_switch == '0':
|
||||
return 'Verify'
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
# Default to Retrieve_Summary
|
||||
logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary")
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
|
||||
def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing after content_input node based on search_switch value.
|
||||
|
||||
This function routes based on the search_switch parameter:
|
||||
- search_switch == '2': Route to Input_Summary (direct input summary)
|
||||
- Otherwise: Route to Split_The_Problem (problem decomposition)
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
logger.debug(f"[Split_continue] state keys: {state.keys()}")
|
||||
|
||||
search_switch = extract_search_switch(state)
|
||||
|
||||
logger.debug(f"[Split_continue] search_switch: {search_switch}")
|
||||
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
|
||||
# Default to Split_The_Problem
|
||||
return 'Split_The_Problem'
|
||||
13
app/core/memory/agent/langgraph_graph/state/__init__.py
Normal file
13
app/core/memory/agent/langgraph_graph/state/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""LangGraph state management utilities."""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||
extract_search_switch,
|
||||
extract_tool_call_id,
|
||||
extract_content_payload,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"extract_search_switch",
|
||||
"extract_tool_call_id",
|
||||
"extract_content_payload",
|
||||
]
|
||||
164
app/core/memory/agent/langgraph_graph/state/extractors.py
Normal file
164
app/core/memory/agent/langgraph_graph/state/extractors.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
State extraction utilities for type-safe access to LangGraph state values.
|
||||
|
||||
This module provides utility functions for extracting values from LangGraph state
|
||||
dictionaries with proper error handling and sensible defaults.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_search_switch(state: dict) -> Optional[str]:
|
||||
"""
|
||||
Extract search_switch from state or messages.
|
||||
"""
|
||||
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
|
||||
# Try to extract from messages
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 从最新的消息开始查找
|
||||
for message in reversed(messages):
|
||||
# 尝试从 tool_calls 中提取
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
# 从 tool_call 的 args 中提取
|
||||
if "args" in tool_call and isinstance(tool_call["args"], dict):
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
# 直接从 tool_call 中提取
|
||||
search_switch = tool_call.get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
|
||||
# 尝试从 content 中提取(如果是 JSON 格式)
|
||||
if hasattr(message, "content"):
|
||||
try:
|
||||
import json
|
||||
if isinstance(message.content, str):
|
||||
content_data = json.loads(message.content)
|
||||
if isinstance(content_data, dict):
|
||||
search_switch = content_data.get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_tool_call_id(message: Any) -> str:
|
||||
"""
|
||||
Extract tool call ID from message using structured attributes.
|
||||
|
||||
This function extracts the tool call ID from a message object, handling both
|
||||
direct attribute access and tool_calls list structures.
|
||||
|
||||
Args:
|
||||
message: Message object (typically ToolMessage or AIMessage)
|
||||
|
||||
Returns:
|
||||
Tool call ID as string
|
||||
|
||||
Raises:
|
||||
ValueError: If tool call ID cannot be extracted
|
||||
|
||||
Examples:
|
||||
>>> message = ToolMessage(content="...", tool_call_id="call_123")
|
||||
>>> extract_tool_call_id(message)
|
||||
'call_123'
|
||||
"""
|
||||
# Try direct attribute access for ToolMessage
|
||||
if hasattr(message, "tool_call_id"):
|
||||
tool_call_id = message.tool_call_id
|
||||
if tool_call_id:
|
||||
return str(tool_call_id)
|
||||
|
||||
# Try extracting from tool_calls list for AIMessage
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
tool_call = message.tool_calls[0]
|
||||
if isinstance(tool_call, dict) and "id" in tool_call:
|
||||
return str(tool_call["id"])
|
||||
|
||||
# Try extracting from id attribute
|
||||
if hasattr(message, "id"):
|
||||
message_id = message.id
|
||||
if message_id:
|
||||
return str(message_id)
|
||||
|
||||
# If all else fails, raise an error
|
||||
raise ValueError(f"Could not extract tool call ID from message: {type(message)}")
|
||||
|
||||
|
||||
def extract_content_payload(message: Any) -> Any:
|
||||
"""
|
||||
Extract content payload from ToolMessage, parsing JSON if needed.
|
||||
|
||||
This function extracts the content from a message and attempts to parse it as JSON
|
||||
if it appears to be a JSON string. It handles various message formats and provides
|
||||
sensible fallbacks.
|
||||
|
||||
Args:
|
||||
message: Message object (typically ToolMessage)
|
||||
|
||||
Returns:
|
||||
Parsed content (dict, list, or str)
|
||||
|
||||
Examples:
|
||||
>>> message = ToolMessage(content='{"key": "value"}')
|
||||
>>> extract_content_payload(message)
|
||||
{'key': 'value'}
|
||||
|
||||
>>> message = ToolMessage(content='plain text')
|
||||
>>> extract_content_payload(message)
|
||||
'plain text'
|
||||
"""
|
||||
# Extract raw content
|
||||
# For ToolMessages (responses from tools), extract from content
|
||||
if hasattr(message, "content"):
|
||||
raw_content = message.content
|
||||
|
||||
# If content is empty and this is an AIMessage with tool_calls,
|
||||
# extract from args (this handles the initial tool call from content_input)
|
||||
if not raw_content and hasattr(message, "tool_calls") and message.tool_calls:
|
||||
tool_call = message.tool_calls[0]
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
return tool_call["args"]
|
||||
else:
|
||||
raw_content = str(message)
|
||||
|
||||
# If content is already a dict or list, return it directly
|
||||
if isinstance(raw_content, (dict, list)):
|
||||
return raw_content
|
||||
|
||||
# Try to parse as JSON
|
||||
if isinstance(raw_content, str):
|
||||
# First, try direct JSON parsing
|
||||
try:
|
||||
return json.loads(raw_content)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# If that fails, try to extract JSON from the string
|
||||
# This handles cases where the content is embedded in a larger string
|
||||
import re
|
||||
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
|
||||
for candidate in json_candidates:
|
||||
try:
|
||||
return json.loads(candidate)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
|
||||
# If all parsing attempts fail, return the raw content
|
||||
return raw_content
|
||||
78
app/core/memory/agent/langgraph_graph/write_graph.py
Normal file
78
app/core/memory/agent/langgraph_graph/write_graph.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import add_messages, StateGraph
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
import warnings
|
||||
import sys
|
||||
from langchain_core.messages import AIMessage
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
import asyncio
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
@asynccontextmanager
|
||||
async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None):
|
||||
logger.info("加载 MCP 工具: %s", [t.name for t in tools])
|
||||
if config_id:
|
||||
logger.info(f"使用配置 ID: {config_id}")
|
||||
|
||||
data_type_tool = next((t for t in tools if t.name == "Data_type_differentiation"), None)
|
||||
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
|
||||
|
||||
if not data_type_tool or not data_write_tool:
|
||||
logger.error('不存在数据存储工具', exc_info=True)
|
||||
raise ValueError('不存在数据存储工具')
|
||||
# ToolNode
|
||||
write_node = ToolNode([data_write_tool])
|
||||
|
||||
|
||||
async def call_model(state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
|
||||
result = await data_type_tool.ainvoke({
|
||||
"context": last_message[1] if isinstance(last_message, tuple) else last_message.content
|
||||
})
|
||||
result=json.loads( result)
|
||||
|
||||
# 调用 Data_write,传递 config_id
|
||||
write_params = {
|
||||
"content": result["context"],
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
# 如果提供了 config_id,添加到参数中
|
||||
if config_id:
|
||||
write_params["config_id"] = config_id
|
||||
logger.debug(f"传递 config_id 到 Data_write: {config_id}")
|
||||
|
||||
write_result = await data_write_tool.ainvoke(write_params)
|
||||
|
||||
if isinstance(write_result, dict):
|
||||
content = write_result.get("data", str(write_result))
|
||||
else:
|
||||
content = str(write_result)
|
||||
logger.info("写入内容: %s", content)
|
||||
return {"messages": [AIMessage(content=content)]}
|
||||
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("content_input", call_model)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_edge("content_input", "save_neo4j")
|
||||
workflow.add_edge("save_neo4j", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
|
||||
yield graph
|
||||
285
app/core/memory/agent/logger_file/log_streamer.py
Normal file
285
app/core/memory/agent/logger_file/log_streamer.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Log Streamer Module
|
||||
|
||||
Manages streaming of log file content with file watching and real-time transmission.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LogStreamer:
|
||||
"""Manages log file streaming with file watching and content transmission"""
|
||||
|
||||
def __init__(self, log_path: str, keepalive_interval: int = 300):
|
||||
"""
|
||||
Initialize LogStreamer
|
||||
|
||||
Args:
|
||||
log_path: Path to the log file to stream
|
||||
keepalive_interval: Interval in seconds for sending keepalive messages (default: 300)
|
||||
"""
|
||||
self.log_path = log_path
|
||||
self.keepalive_interval = keepalive_interval
|
||||
self.last_position = 0
|
||||
|
||||
# Pattern to match and remove timestamp and log level prefix
|
||||
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
|
||||
# This pattern is comprehensive to handle various log formats
|
||||
self.pattern = re.compile(
|
||||
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
|
||||
)
|
||||
|
||||
logger.info(f"LogStreamer initialized for {log_path}")
|
||||
|
||||
@staticmethod
|
||||
def clean_log_line(line: str) -> str:
|
||||
"""
|
||||
Static method to clean log entry by removing timestamp and log level prefix.
|
||||
This is the canonical log cleaning method used by both file mode and transmission mode.
|
||||
|
||||
Args:
|
||||
line: Raw log line
|
||||
|
||||
Returns:
|
||||
Cleaned log line without timestamp and log level prefix
|
||||
"""
|
||||
# Pattern to match and remove timestamp and log level prefix
|
||||
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
|
||||
pattern = re.compile(
|
||||
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
|
||||
)
|
||||
cleaned = re.sub(pattern, '', line)
|
||||
return cleaned
|
||||
|
||||
def clean_log_entry(self, line: str) -> str:
|
||||
"""
|
||||
Clean log entry by removing timestamp and log level prefix.
|
||||
This instance method delegates to the static method for consistency.
|
||||
|
||||
Args:
|
||||
line: Raw log line
|
||||
|
||||
Returns:
|
||||
Cleaned log line without timestamp and log level prefix
|
||||
"""
|
||||
return LogStreamer.clean_log_line(line)
|
||||
|
||||
async def send_keepalive(self) -> dict:
|
||||
"""
|
||||
Generate keepalive message
|
||||
|
||||
Returns:
|
||||
Keepalive message dict with timestamp
|
||||
"""
|
||||
return {
|
||||
"event": "keepalive",
|
||||
"data": {
|
||||
"timestamp": int(time.time())
|
||||
}
|
||||
}
|
||||
|
||||
async def read_existing_and_stream(self) -> AsyncGenerator[dict, None]:
|
||||
"""
|
||||
Read existing log content first, then watch for new content
|
||||
|
||||
This method reads all existing content in the file first,
|
||||
then continues to watch for new content as it's written.
|
||||
|
||||
Yields:
|
||||
Dict messages with event type and data:
|
||||
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
|
||||
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
|
||||
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
|
||||
- done events: {"event": "done", "data": {"message": "..."}}
|
||||
"""
|
||||
logger.info(f"Starting log stream (read existing) for {self.log_path}")
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(self.log_path):
|
||||
logger.error(f"Log file not found: {self.log_path}")
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 4006,
|
||||
"message": "日志文件不存在",
|
||||
"error": f"File not found: {self.log_path}"
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.log_path, 'r', encoding='utf-8') as f:
|
||||
# First, read all existing content
|
||||
for line in f:
|
||||
if line.strip(): # Skip empty lines
|
||||
cleaned_line = self.clean_log_entry(line)
|
||||
yield {
|
||||
"event": "log",
|
||||
"data": {
|
||||
"content": cleaned_line.rstrip('\n'),
|
||||
"timestamp": int(time.time())
|
||||
}
|
||||
}
|
||||
|
||||
# Now watch for new content
|
||||
self.last_position = f.tell()
|
||||
last_keepalive = time.time()
|
||||
|
||||
while True:
|
||||
line = f.readline()
|
||||
if line:
|
||||
cleaned_line = self.clean_log_entry(line)
|
||||
yield {
|
||||
"event": "log",
|
||||
"data": {
|
||||
"content": cleaned_line.rstrip('\n'),
|
||||
"timestamp": int(time.time())
|
||||
}
|
||||
}
|
||||
last_keepalive = time.time()
|
||||
else:
|
||||
# No new content, check if we need to send keepalive
|
||||
current_time = time.time()
|
||||
if current_time - last_keepalive >= self.keepalive_interval:
|
||||
keepalive_msg = await self.send_keepalive()
|
||||
yield keepalive_msg
|
||||
last_keepalive = current_time
|
||||
|
||||
# Sleep briefly before checking again
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Log file disappeared during streaming: {self.log_path}")
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 4006,
|
||||
"message": "日志文件在流式传输期间变得不可用",
|
||||
"error": "File not found during streaming"
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error during log streaming: {e}", exc_info=True)
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 8001,
|
||||
"message": "流式传输期间发生错误",
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
finally:
|
||||
logger.info(f"Log stream ended for {self.log_path}")
|
||||
yield {
|
||||
"event": "done",
|
||||
"data": {
|
||||
"message": "流式传输完成"
|
||||
}
|
||||
}
|
||||
|
||||
async def watch_and_stream(self) -> AsyncGenerator[dict, None]:
|
||||
"""
|
||||
Watch log file and stream only new content as it's written
|
||||
|
||||
This method starts from the end of the file and only streams
|
||||
new content that is written after the stream starts.
|
||||
|
||||
Yields:
|
||||
Dict messages with event type and data:
|
||||
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
|
||||
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
|
||||
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
|
||||
- done events: {"event": "done", "data": {"message": "..."}}
|
||||
"""
|
||||
logger.info(f"Starting log stream (new content only) for {self.log_path}")
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(self.log_path):
|
||||
logger.error(f"Log file not found: {self.log_path}")
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 4006,
|
||||
"message": "日志文件不存在",
|
||||
"error": f"File not found: {self.log_path}"
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
try:
|
||||
# Open file and seek to end to start streaming new content
|
||||
with open(self.log_path, 'r', encoding='utf-8') as f:
|
||||
# Move to end of file
|
||||
f.seek(0, os.SEEK_END)
|
||||
self.last_position = f.tell()
|
||||
|
||||
last_keepalive = time.time()
|
||||
|
||||
while True:
|
||||
# Check if file has new content
|
||||
current_position = f.tell()
|
||||
|
||||
# Read new lines if available
|
||||
line = f.readline()
|
||||
if line:
|
||||
# Clean the log entry
|
||||
cleaned_line = self.clean_log_entry(line)
|
||||
|
||||
# Yield log event
|
||||
yield {
|
||||
"event": "log",
|
||||
"data": {
|
||||
"content": cleaned_line.rstrip('\n'),
|
||||
"timestamp": int(time.time())
|
||||
}
|
||||
}
|
||||
|
||||
# Update last keepalive time since we sent data
|
||||
last_keepalive = time.time()
|
||||
else:
|
||||
# No new content, check if we need to send keepalive
|
||||
current_time = time.time()
|
||||
if current_time - last_keepalive >= self.keepalive_interval:
|
||||
keepalive_msg = await self.send_keepalive()
|
||||
yield keepalive_msg
|
||||
last_keepalive = current_time
|
||||
|
||||
# Sleep briefly before checking again
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Log file disappeared during streaming: {self.log_path}")
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 4006,
|
||||
"message": "日志文件在流式传输期间变得不可用",
|
||||
"error": "File not found during streaming"
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error during log streaming: {e}", exc_info=True)
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 8001,
|
||||
"message": "流式传输期间发生错误",
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
finally:
|
||||
logger.info(f"Log stream ended for {self.log_path}")
|
||||
yield {
|
||||
"event": "done",
|
||||
"data": {
|
||||
"message": "流式传输完成"
|
||||
}
|
||||
}
|
||||
32
app/core/memory/agent/logger_file/logger_data.py
Normal file
32
app/core/memory/agent/logger_file/logger_data.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Agent logger module for backward compatibility.
|
||||
|
||||
This module maintains the get_named_logger() function for backward compatibility
|
||||
while delegating to the centralized logging configuration.
|
||||
|
||||
All new code should import directly from app.core.logging_config instead.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__author__ = "RED_BEAR"
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
|
||||
def get_named_logger(name):
|
||||
"""Get a named logger for agent operations.
|
||||
|
||||
This function maintains backward compatibility with existing code.
|
||||
It delegates to the centralized get_agent_logger() function.
|
||||
|
||||
Args:
|
||||
name: Logger name for namespacing
|
||||
|
||||
Returns:
|
||||
Logger configured for agent operations
|
||||
|
||||
Example:
|
||||
>>> logger = get_named_logger("my_agent")
|
||||
>>> logger.info("Agent operation started")
|
||||
"""
|
||||
return get_agent_logger(name)
|
||||
28
app/core/memory/agent/mcp_server/__init__.py
Normal file
28
app/core/memory/agent/mcp_server/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
MCP Server package for memory agent.
|
||||
|
||||
This package provides the FastMCP server implementation with context-based
|
||||
dependency injection for tool functions.
|
||||
|
||||
Package structure:
|
||||
- server: FastMCP server initialization and context setup
|
||||
- tools: MCP tool implementations
|
||||
- models: Pydantic response models
|
||||
- services: Business logic services
|
||||
"""
|
||||
from app.core.memory.agent.mcp_server.server import (
|
||||
mcp,
|
||||
initialize_context,
|
||||
main,
|
||||
get_context_resource
|
||||
)
|
||||
|
||||
# Import tools to register them (but don't export them)
|
||||
from app.core.memory.agent.mcp_server import tools
|
||||
|
||||
__all__ = [
|
||||
'mcp',
|
||||
'initialize_context',
|
||||
'main',
|
||||
'get_context_resource',
|
||||
]
|
||||
11
app/core/memory/agent/mcp_server/mcp_instance.py
Normal file
11
app/core/memory/agent/mcp_server/mcp_instance.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
MCP Server Instance
|
||||
|
||||
This module contains the FastMCP server instance that is shared across all modules.
|
||||
It's in a separate file to avoid circular import issues.
|
||||
"""
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Initialize FastMCP server instance
|
||||
# This instance is shared across all tool modules
|
||||
mcp = FastMCP('data_flow')
|
||||
30
app/core/memory/agent/mcp_server/models/__init__.py
Normal file
30
app/core/memory/agent/mcp_server/models/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Pydantic models for MCP server responses."""
|
||||
|
||||
from .problem_models import (
|
||||
ProblemBreakdownItem,
|
||||
ProblemBreakdownResponse,
|
||||
ExtendedQuestionItem,
|
||||
ProblemExtensionResponse,
|
||||
)
|
||||
from .summary_models import (
|
||||
SummaryData,
|
||||
SummaryResponse,
|
||||
RetrieveSummaryData,
|
||||
RetrieveSummaryResponse,
|
||||
)
|
||||
from .verification_models import VerificationResult
|
||||
from .retrieval_models import RetrievalResult, DistinguishTypeResponse
|
||||
|
||||
__all__ = [
|
||||
"ProblemBreakdownItem",
|
||||
"ProblemBreakdownResponse",
|
||||
"ExtendedQuestionItem",
|
||||
"ProblemExtensionResponse",
|
||||
"SummaryData",
|
||||
"SummaryResponse",
|
||||
"RetrieveSummaryData",
|
||||
"RetrieveSummaryResponse",
|
||||
"VerificationResult",
|
||||
"RetrievalResult",
|
||||
"DistinguishTypeResponse",
|
||||
]
|
||||
34
app/core/memory/agent/mcp_server/models/problem_models.py
Normal file
34
app/core/memory/agent/mcp_server/models/problem_models.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Pydantic models for problem breakdown and extension operations."""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
|
||||
class ProblemBreakdownItem(BaseModel):
|
||||
"""Individual item in problem breakdown response."""
|
||||
|
||||
id: str
|
||||
question: str
|
||||
type: str
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class ProblemBreakdownResponse(RootModel[List[ProblemBreakdownItem]]):
|
||||
"""Response model for problem breakdown containing list of breakdown items."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExtendedQuestionItem(BaseModel):
|
||||
"""Individual extended question item with reasoning."""
|
||||
|
||||
original_question: str = Field(..., description="原始初步问题")
|
||||
extended_question: str = Field(..., description="扩展后的问题")
|
||||
type: str = Field(..., description="类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)")
|
||||
reason: str = Field(..., description="生成该扩展问题的理由")
|
||||
|
||||
|
||||
class ProblemExtensionResponse(RootModel[List[ExtendedQuestionItem]]):
|
||||
"""Response model for problem extension containing list of extended questions."""
|
||||
|
||||
pass
|
||||
17
app/core/memory/agent/mcp_server/models/retrieval_models.py
Normal file
17
app/core/memory/agent/mcp_server/models/retrieval_models.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Pydantic models for retrieval operations."""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RetrievalResult(BaseModel):
|
||||
"""Result model for retrieval operation."""
|
||||
|
||||
Query: str
|
||||
Expansion_issue: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class DistinguishTypeResponse(BaseModel):
|
||||
"""Response model for data type differentiation."""
|
||||
|
||||
type: str
|
||||
31
app/core/memory/agent/mcp_server/models/summary_models.py
Normal file
31
app/core/memory/agent/mcp_server/models/summary_models.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Pydantic models for summary operations."""
|
||||
|
||||
from typing import List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SummaryData(BaseModel):
|
||||
"""Data structure for summary input."""
|
||||
|
||||
query: str
|
||||
history: List[str] = Field(default_factory=list)
|
||||
retrieve_info: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SummaryResponse(BaseModel):
|
||||
"""Response model for summary operation."""
|
||||
|
||||
data: SummaryData
|
||||
query_answer: str
|
||||
|
||||
|
||||
class RetrieveSummaryData(BaseModel):
|
||||
"""Data structure for retrieve summary response."""
|
||||
|
||||
query_answer: str = Field(default="")
|
||||
|
||||
|
||||
class RetrieveSummaryResponse(BaseModel):
|
||||
"""Response model for retrieve summary operation."""
|
||||
|
||||
data: RetrieveSummaryData
|
||||
@@ -0,0 +1,14 @@
|
||||
"""Pydantic models for verification operations."""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VerificationResult(BaseModel):
|
||||
"""Result model for verification operation."""
|
||||
|
||||
query: str
|
||||
expansion_issue: List[Dict[str, Any]]
|
||||
split_result: str
|
||||
reason: Optional[str] = None
|
||||
history: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
161
app/core/memory/agent/mcp_server/server.py
Normal file
161
app/core/memory/agent/mcp_server/server.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
MCP Server initialization with FastMCP context setup.
|
||||
|
||||
This module initializes the FastMCP server and registers shared resources
|
||||
in the context for dependency injection into tool functions.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.redis_tool import RedisSessionStore, store
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID,reload_configuration_from_database
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
|
||||
from app.core.memory.agent.mcp_server.services.search_service import SearchService
|
||||
from app.core.memory.agent.mcp_server.services.session_service import SessionService
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
def get_context_resource(ctx, resource_name: str):
|
||||
"""
|
||||
Helper function to retrieve a resource from the FastMCP context.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP Context object (passed to tool functions)
|
||||
resource_name: Name of the resource to retrieve
|
||||
|
||||
Returns:
|
||||
The requested resource
|
||||
|
||||
Raises:
|
||||
AttributeError: If the resource doesn't exist
|
||||
|
||||
Example:
|
||||
@mcp.tool()
|
||||
async def my_tool(ctx: Context):
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
"""
|
||||
if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None:
|
||||
raise RuntimeError("Context does not have fastmcp attribute")
|
||||
|
||||
if not hasattr(ctx.fastmcp, resource_name):
|
||||
raise AttributeError(
|
||||
f"Resource '{resource_name}' not found in context. "
|
||||
f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}"
|
||||
)
|
||||
|
||||
return getattr(ctx.fastmcp, resource_name)
|
||||
|
||||
|
||||
def initialize_context():
|
||||
"""
|
||||
Initialize and register shared resources in FastMCP context.
|
||||
|
||||
This function sets up all shared resources that will be available
|
||||
to tool functions via dependency injection through the context parameter.
|
||||
|
||||
Resources are stored as attributes on the FastMCP instance and can be
|
||||
accessed via ctx.fastmcp in tool functions.
|
||||
|
||||
Resources registered:
|
||||
- session_store: RedisSessionStore for session management
|
||||
- llm_client: LLM client for structured API calls
|
||||
- app_settings: Application settings (renamed to avoid conflict with FastMCP settings)
|
||||
- template_service: Service for template rendering
|
||||
- search_service: Service for hybrid search
|
||||
- session_service: Service for session operations
|
||||
"""
|
||||
try:
|
||||
# Register Redis session store
|
||||
logger.info("Registering session_store in context")
|
||||
mcp.session_store = store
|
||||
|
||||
# Register LLM client
|
||||
try:
|
||||
logger.info(f"Registering llm_client in context with model ID: {SELECTED_LLM_ID}")
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
mcp.llm_client = llm_client
|
||||
logger.info("llm_client registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register llm_client: {e}", exc_info=True)
|
||||
# 注册一个 None 值,避免工具调用时找不到资源
|
||||
mcp.llm_client = None
|
||||
logger.warning("llm_client set to None due to initialization failure")
|
||||
|
||||
# Register application settings (renamed to avoid conflict with FastMCP's settings)
|
||||
logger.info("Registering app_settings in context")
|
||||
mcp.app_settings = settings
|
||||
|
||||
# Register template service
|
||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
||||
# logger.info(f"Registering template_service in context with root: {template_root}")
|
||||
template_service = TemplateService(template_root)
|
||||
mcp.template_service = template_service
|
||||
|
||||
# Register search service
|
||||
# logger.info("Registering search_service in context")
|
||||
search_service = SearchService()
|
||||
mcp.search_service = search_service
|
||||
|
||||
# Register session service
|
||||
# logger.info("Registering session_service in context")
|
||||
session_service = SessionService(store)
|
||||
mcp.session_service = session_service
|
||||
|
||||
# logger.info("All context resources registered successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize context: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the MCP server.
|
||||
|
||||
Initializes context and starts the server with SSE transport.
|
||||
"""
|
||||
try:
|
||||
# logger.info("Starting MCP server initialization")
|
||||
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
|
||||
# Initialize context resources
|
||||
initialize_context()
|
||||
|
||||
# Import and register tools
|
||||
# logger.info("Importing MCP tools")
|
||||
from app.core.memory.agent.mcp_server.tools import (
|
||||
problem_tools,
|
||||
retrieval_tools,
|
||||
verification_tools,
|
||||
summary_tools,
|
||||
data_tools
|
||||
)
|
||||
# logger.info("All MCP tools imported and registered")
|
||||
|
||||
# Log registered tools for debugging
|
||||
import asyncio
|
||||
tools_list = asyncio.run(mcp.list_tools())
|
||||
# logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}")
|
||||
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:8081 with SSE transport")
|
||||
|
||||
# Run the server with SSE transport for HTTP connections
|
||||
# The server will be available at http://127.0.0.1:8081
|
||||
import uvicorn
|
||||
app = mcp.sse_app()
|
||||
uvicorn.run(app, host=settings.SERVER_IP, port=8081, log_level="info")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
23
app/core/memory/agent/mcp_server/services/__init__.py
Normal file
23
app/core/memory/agent/mcp_server/services/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
MCP Server Services
|
||||
|
||||
This module provides business logic services for the MCP server:
|
||||
- TemplateService: Template loading and rendering
|
||||
- SearchService: Search result processing
|
||||
- SessionService: Session and history management
|
||||
- ParameterBuilder: Tool parameter construction
|
||||
"""
|
||||
|
||||
from .template_service import TemplateService, TemplateRenderError
|
||||
from .search_service import SearchService
|
||||
from .session_service import SessionService
|
||||
from .parameter_builder import ParameterBuilder
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TemplateService",
|
||||
"TemplateRenderError",
|
||||
"SearchService",
|
||||
"SessionService",
|
||||
"ParameterBuilder",
|
||||
]
|
||||
157
app/core/memory/agent/mcp_server/services/parameter_builder.py
Normal file
157
app/core/memory/agent/mcp_server/services/parameter_builder.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Parameter Builder for constructing tool call arguments.
|
||||
|
||||
This service provides tool-specific parameter transformation logic
|
||||
to build correct arguments for each tool type.
|
||||
"""
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ParameterBuilder:
|
||||
"""Service for building tool call arguments based on tool type."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the parameter builder."""
|
||||
logger.info("ParameterBuilder initialized")
|
||||
|
||||
def build_tool_args(
|
||||
self,
|
||||
tool_name: str,
|
||||
content: Any,
|
||||
tool_call_id: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build tool arguments based on tool type.
|
||||
|
||||
Different tools expect different argument formats:
|
||||
- Verify: dict context
|
||||
- Retrieve: dict context + search_switch
|
||||
- Summary/Summary_fails: JSON string context
|
||||
- Retrieve_Summary: unwrap nested context structures
|
||||
- Input_Summary: raw message string
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool being invoked
|
||||
content: Parsed content from previous tool result
|
||||
tool_call_id: Extracted tool call identifier
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary of tool arguments ready for invocation
|
||||
"""
|
||||
# Base arguments common to most tools
|
||||
base_args = {
|
||||
"usermessages": tool_call_id,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
}
|
||||
|
||||
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||
base_args["storage_type"] = storage_type if storage_type is not None else ""
|
||||
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
|
||||
|
||||
# Tool-specific argument construction
|
||||
if tool_name == "Verify":
|
||||
# Verify expects dict context
|
||||
return {
|
||||
"context": content if isinstance(content, dict) else {},
|
||||
**base_args
|
||||
}
|
||||
|
||||
elif tool_name == "Retrieve":
|
||||
# Retrieve expects dict context + search_switch
|
||||
return {
|
||||
"context": content if isinstance(content, dict) else {},
|
||||
"search_switch": search_switch,
|
||||
**base_args
|
||||
}
|
||||
|
||||
elif tool_name in ["Summary", "Summary_fails"]:
|
||||
# Summary tools expect JSON string context
|
||||
if isinstance(content, dict):
|
||||
context_str = json.dumps(content, ensure_ascii=False)
|
||||
elif isinstance(content, str):
|
||||
context_str = content
|
||||
else:
|
||||
context_str = json.dumps({"data": content}, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
"context": context_str,
|
||||
**base_args
|
||||
}
|
||||
|
||||
elif tool_name == "Retrieve_Summary":
|
||||
# Retrieve_Summary needs to unwrap nested context structures
|
||||
# Handle both 'content' and 'context' keys
|
||||
context_dict = content
|
||||
|
||||
if isinstance(content, dict):
|
||||
# Check for nested 'content' wrapper
|
||||
if "content" in content:
|
||||
inner = content["content"]
|
||||
|
||||
# If it's a JSON string, parse it
|
||||
if isinstance(inner, str):
|
||||
try:
|
||||
parsed = json.loads(inner)
|
||||
# Check if parsed has 'context' wrapper
|
||||
if isinstance(parsed, dict) and "context" in parsed:
|
||||
context_dict = parsed["context"]
|
||||
else:
|
||||
context_dict = parsed
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse JSON content for {tool_name}: {inner[:100]}"
|
||||
)
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
elif isinstance(inner, dict):
|
||||
context_dict = inner
|
||||
|
||||
# Check for 'context' wrapper
|
||||
elif "context" in content:
|
||||
context_dict = content["context"] if isinstance(content["context"], dict) else content
|
||||
|
||||
return {
|
||||
"context": context_dict,
|
||||
**base_args
|
||||
}
|
||||
|
||||
elif tool_name == "Input_Summary":
|
||||
# Input_Summary expects raw message string + search_switch
|
||||
# Content should be the raw message string
|
||||
if isinstance(content, dict):
|
||||
# Try to extract message from dict
|
||||
message_str = content.get("sentence", str(content))
|
||||
else:
|
||||
message_str = str(content)
|
||||
|
||||
return {
|
||||
"context": message_str,
|
||||
"search_switch": search_switch,
|
||||
**base_args
|
||||
}
|
||||
|
||||
else:
|
||||
# Default: pass content as context
|
||||
logger.warning(
|
||||
f"Unknown tool name '{tool_name}', using default argument structure"
|
||||
)
|
||||
return {
|
||||
"context": content,
|
||||
**base_args
|
||||
}
|
||||
193
app/core/memory/agent/mcp_server/services/search_service.py
Normal file
193
app/core/memory/agent/mcp_server/services/search_service.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Search Service for executing hybrid search and processing results.
|
||||
|
||||
This service provides clean search result processing with content extraction
|
||||
and deduplication.
|
||||
"""
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the search service."""
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
def extract_content_from_result(self, result: dict) -> str:
|
||||
"""
|
||||
Extract only meaningful content from search results, dropping all metadata.
|
||||
|
||||
Extraction rules by node type:
|
||||
- Statements: extract 'statement' field
|
||||
- Entities: extract 'name' and 'fact_summary' fields
|
||||
- Summaries: extract 'content' field
|
||||
- Chunks: extract 'content' field
|
||||
|
||||
Args:
|
||||
result: Search result dictionary
|
||||
|
||||
Returns:
|
||||
Clean content string without metadata
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return str(result)
|
||||
|
||||
content_parts = []
|
||||
|
||||
# Statements: extract statement field
|
||||
if 'statement' in result and result['statement']:
|
||||
content_parts.append(result['statement'])
|
||||
|
||||
# Summaries/Chunks: extract content field
|
||||
if 'content' in result and result['content']:
|
||||
content_parts.append(result['content'])
|
||||
|
||||
# Entities: extract name and fact_summary (commented out in original)
|
||||
# if 'name' in result and result['name']:
|
||||
# content_parts.append(result['name'])
|
||||
# if result.get('fact_summary'):
|
||||
# content_parts.append(result['fact_summary'])
|
||||
|
||||
# Return concatenated content or empty string
|
||||
return '\n'.join(content_parts) if content_parts else ""
|
||||
|
||||
def clean_query(self, query: str) -> str:
|
||||
"""
|
||||
Clean and escape query text for Lucene.
|
||||
|
||||
- Removes wrapping quotes
|
||||
- Removes newlines and carriage returns
|
||||
- Applies Lucene escaping
|
||||
|
||||
Args:
|
||||
query: Raw query string
|
||||
|
||||
Returns:
|
||||
Cleaned and escaped query string
|
||||
"""
|
||||
q = str(query).strip()
|
||||
|
||||
# Remove wrapping quotes
|
||||
if (q.startswith("'") and q.endswith("'")) or (
|
||||
q.startswith('"') and q.endswith('"')
|
||||
):
|
||||
q = q[1:-1]
|
||||
|
||||
# Remove newlines and carriage returns
|
||||
q = q.replace('\r', ' ').replace('\n', ' ').strip()
|
||||
|
||||
# Apply Lucene escaping
|
||||
q = escape_lucene_query(q)
|
||||
|
||||
return q
|
||||
|
||||
async def execute_hybrid_search(
|
||||
self,
|
||||
group_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
|
||||
Args:
|
||||
group_id: Group identifier for filtering results
|
||||
question: Search query text
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
|
||||
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
|
||||
Returns:
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
raw_results is None if return_raw_results=False
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
try:
|
||||
# Execute search
|
||||
answer = await run_hybrid_search(
|
||||
query_text=cleaned_query,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
output_path=output_path,
|
||||
rerank_alpha=rerank_alpha
|
||||
)
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
# Prioritize summaries as they contain synthesized contextual information
|
||||
answer_list = []
|
||||
|
||||
# For hybrid search, use reranked_results
|
||||
if search_type == "hybrid":
|
||||
reranked_results = answer.get('reranked_results', {})
|
||||
|
||||
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in reranked_results:
|
||||
category_results = reranked_results[category]
|
||||
if isinstance(category_results, list):
|
||||
answer_list.extend(category_results)
|
||||
else:
|
||||
# For keyword or embedding search, results are directly in answer dict
|
||||
# Apply same priority order
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in answer:
|
||||
category_results = answer[category]
|
||||
if isinstance(category_results, list):
|
||||
answer_list.extend(category_results)
|
||||
|
||||
# Extract clean content from all results
|
||||
content_list = [
|
||||
self.extract_content_from_result(ans)
|
||||
for ans in answer_list
|
||||
]
|
||||
|
||||
|
||||
# Filter out empty strings and join with newlines
|
||||
clean_content = '\n'.join([c for c in content_list if c])
|
||||
|
||||
# Log first 200 chars
|
||||
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
|
||||
|
||||
# Return raw results if requested
|
||||
if return_raw_results:
|
||||
return clean_content, cleaned_query, answer
|
||||
else:
|
||||
return clean_content, cleaned_query, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Search failed for query '{question}' in group '{group_id}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty results on failure
|
||||
if return_raw_results:
|
||||
return "", cleaned_query, {}
|
||||
else:
|
||||
return "", cleaned_query, None
|
||||
169
app/core/memory/agent/mcp_server/services/session_service.py
Normal file
169
app/core/memory/agent/mcp_server/services/session_service.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Session Service for managing user sessions and conversation history.
|
||||
|
||||
This service provides clean Redis interactions with error handling and
|
||||
session management utilities.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.redis_tool import RedisSessionStore
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""Service for managing user sessions and conversation history."""
|
||||
|
||||
def __init__(self, store: RedisSessionStore):
|
||||
"""
|
||||
Initialize the session service.
|
||||
|
||||
Args:
|
||||
store: Redis session store instance
|
||||
"""
|
||||
self.store = store
|
||||
logger.info("SessionService initialized")
|
||||
|
||||
def resolve_user_id(self, session_string: str) -> str:
|
||||
"""
|
||||
Extract user ID from session string.
|
||||
|
||||
Handles formats like:
|
||||
- 'call_id_user123' -> 'user123'
|
||||
- 'prefix_id_user456_suffix' -> 'user456_suffix'
|
||||
|
||||
Args:
|
||||
session_string: Session identifier string
|
||||
|
||||
Returns:
|
||||
Extracted user ID
|
||||
"""
|
||||
try:
|
||||
# Split by '_id_' and take everything after it
|
||||
parts = session_string.split('_id_')
|
||||
if len(parts) > 1:
|
||||
return parts[1]
|
||||
|
||||
# Fallback: return original string
|
||||
return session_string
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse user ID from session string '{session_string}': {e}"
|
||||
)
|
||||
return session_string
|
||||
|
||||
async def get_history(
|
||||
self,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve conversation history from Redis.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
|
||||
Returns:
|
||||
List of conversation history items with Query and Answer keys
|
||||
Returns empty list if no history found or on error
|
||||
"""
|
||||
try:
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||
|
||||
# Validate history structure
|
||||
if not isinstance(history, list):
|
||||
logger.warning(
|
||||
f"Invalid history format for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||
)
|
||||
return []
|
||||
|
||||
return history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve history for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty list on error to allow execution to continue
|
||||
return []
|
||||
|
||||
async def save_session(
|
||||
self,
|
||||
user_id: str,
|
||||
query: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
ai_response: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Save conversation turn to Redis.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
query: User query/message
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
ai_response: AI response/answer
|
||||
|
||||
Returns:
|
||||
Session ID if successful, None on error
|
||||
"""
|
||||
try:
|
||||
# Validate required fields
|
||||
if not user_id:
|
||||
logger.warning("Cannot save session: user_id is empty")
|
||||
return None
|
||||
|
||||
if not query:
|
||||
logger.warning("Cannot save session: query is empty")
|
||||
return None
|
||||
|
||||
# Save session
|
||||
session_id = self.store.save_session(
|
||||
userid=user_id,
|
||||
messages=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
aimessages=ai_response
|
||||
)
|
||||
|
||||
logger.info(f"Session saved successfully: {session_id}")
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session for user {user_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
async def cleanup_duplicates(self) -> int:
|
||||
"""
|
||||
Remove duplicate session entries.
|
||||
|
||||
Duplicates are identified by matching:
|
||||
- sessionid
|
||||
- user_id (id field)
|
||||
- group_id
|
||||
- messages
|
||||
- aimessages
|
||||
|
||||
Returns:
|
||||
Number of duplicate sessions deleted
|
||||
"""
|
||||
try:
|
||||
deleted_count = self.store.delete_duplicate_sessions()
|
||||
logger.info(f"Cleaned up {deleted_count} duplicate sessions")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True)
|
||||
return 0
|
||||
116
app/core/memory/agent/mcp_server/services/template_service.py
Normal file
116
app/core/memory/agent/mcp_server/services/template_service.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Template Service for loading and rendering Jinja2 templates.
|
||||
|
||||
This service provides centralized template management with caching and error handling.
|
||||
"""
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class TemplateRenderError(Exception):
|
||||
"""Exception raised when template rendering fails."""
|
||||
|
||||
def __init__(self, template_name: str, error: Exception, variables: dict):
|
||||
self.template_name = template_name
|
||||
self.error = error
|
||||
self.variables = variables
|
||||
super().__init__(
|
||||
f"Failed to render template '{template_name}': {str(error)}"
|
||||
)
|
||||
|
||||
|
||||
class TemplateService:
|
||||
"""Service for loading and rendering Jinja2 templates with caching."""
|
||||
|
||||
def __init__(self, template_root: str):
|
||||
"""
|
||||
Initialize the template service.
|
||||
|
||||
Args:
|
||||
template_root: Root directory containing template files
|
||||
"""
|
||||
self.template_root = template_root
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(template_root),
|
||||
autoescape=False # Disable autoescape for prompt templates
|
||||
)
|
||||
logger.info(f"TemplateService initialized with root: {template_root}")
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _load_template(self, template_name: str) -> Template:
|
||||
"""
|
||||
Load a template from disk with caching.
|
||||
|
||||
Args:
|
||||
template_name: Relative path to template file
|
||||
|
||||
Returns:
|
||||
Loaded Jinja2 Template object
|
||||
|
||||
Raises:
|
||||
TemplateNotFound: If template file doesn't exist
|
||||
"""
|
||||
try:
|
||||
return self.env.get_template(template_name)
|
||||
except TemplateNotFound as e:
|
||||
expected_path = os.path.join(self.template_root, template_name)
|
||||
logger.error(
|
||||
f"Template not found: {template_name}. "
|
||||
f"Expected path: {expected_path}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def render_template(
|
||||
self,
|
||||
template_name: str,
|
||||
operation_name: str,
|
||||
**variables
|
||||
) -> str:
|
||||
"""
|
||||
Load and render a Jinja2 template.
|
||||
|
||||
Args:
|
||||
template_name: Relative path to template file
|
||||
operation_name: Name for logging (e.g., "split_the_problem")
|
||||
**variables: Template variables to render
|
||||
|
||||
Returns:
|
||||
Rendered template string
|
||||
|
||||
Raises:
|
||||
TemplateRenderError: If template loading or rendering fails
|
||||
"""
|
||||
try:
|
||||
# Load template (cached)
|
||||
template = self._load_template(template_name)
|
||||
|
||||
# Render template
|
||||
rendered = template.render(**variables)
|
||||
|
||||
# Log rendered prompt
|
||||
log_prompt_rendering(operation_name, rendered)
|
||||
|
||||
return rendered
|
||||
|
||||
except TemplateNotFound as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for {operation_name} "
|
||||
f"({template_name}): Template not found",
|
||||
exc_info=True
|
||||
)
|
||||
raise TemplateRenderError(template_name, e, variables)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for {operation_name} "
|
||||
f"({template_name}): {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise TemplateRenderError(template_name, e, variables)
|
||||
27
app/core/memory/agent/mcp_server/tools/__init__.py
Normal file
27
app/core/memory/agent/mcp_server/tools/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
MCP Tools module.
|
||||
|
||||
This module contains all MCP tool implementations organized by functionality.
|
||||
|
||||
Tools are organized into the following modules:
|
||||
- problem_tools: Question segmentation and extension
|
||||
- retrieval_tools: Database and context retrieval
|
||||
- verification_tools: Data verification
|
||||
- summary_tools: Summarization and summary retrieval
|
||||
- data_tools: Data type differentiation and writing
|
||||
"""
|
||||
|
||||
# Import all tool modules to register them with the MCP server
|
||||
from . import problem_tools
|
||||
from . import retrieval_tools
|
||||
from . import verification_tools
|
||||
from . import summary_tools
|
||||
from . import data_tools
|
||||
|
||||
__all__ = [
|
||||
'problem_tools',
|
||||
'retrieval_tools',
|
||||
'verification_tools',
|
||||
'summary_tools',
|
||||
'data_tools',
|
||||
]
|
||||
149
app/core/memory/agent/mcp_server/tools/data_tools.py
Normal file
149
app/core/memory/agent/mcp_server/tools/data_tools.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Data Tools for data type differentiation and writing.
|
||||
|
||||
This module contains MCP tools for distinguishing data types and writing data.
|
||||
"""
|
||||
import os
|
||||
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.retrieval_models import DistinguishTypeResponse
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Data_type_differentiation(
|
||||
ctx: Context,
|
||||
context: str
|
||||
) -> dict:
|
||||
"""
|
||||
Distinguish the type of data (read or write).
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Text to analyze for type differentiation
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' with the original text and 'type' field
|
||||
"""
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='distinguish_types_prompt.jinja2',
|
||||
operation_name='status_typle',
|
||||
user_query=context
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Data_type_differentiation: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=DistinguishTypeResponse
|
||||
)
|
||||
|
||||
result = structured.model_dump()
|
||||
|
||||
# Add context to result
|
||||
result["context"] = context
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Data_type_differentiation: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": context,
|
||||
"type": "error",
|
||||
"message": f"LLM call failed: {str(e)}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Data_type_differentiation failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": context,
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Data_write(
|
||||
ctx: Context,
|
||||
content: str,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
config_id: str
|
||||
) -> dict:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
content: Data content to write
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
config_id: Configuration ID for processing (optional, integer)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status', 'saved_to', and 'data' fields
|
||||
"""
|
||||
try:
|
||||
# Ensure output directory exists
|
||||
os.makedirs("data_output", exist_ok=True)
|
||||
file_path = os.path.join("data_output", "user_data.csv")
|
||||
|
||||
# Write data using utility function
|
||||
try:
|
||||
await write(content, user_id, apply_id, group_id, config_id=config_id)
|
||||
logger.info(f"写入成功!Config ID: {config_id if config_id else 'None'}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"saved_to": file_path,
|
||||
"data": content,
|
||||
"config_id": config_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"写入失败: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Data_write failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
293
app/core/memory/agent/mcp_server/tools/problem_tools.py
Normal file
293
app/core/memory/agent/mcp_server/tools/problem_tools.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
Problem Tools for question segmentation and extension.
|
||||
|
||||
This module contains MCP tools for breaking down and extending user questions.
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.problem_models import (
|
||||
ProblemBreakdownItem,
|
||||
ProblemBreakdownResponse,
|
||||
ExtendedQuestionItem,
|
||||
ProblemExtensionResponse
|
||||
)
|
||||
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Split_The_Problem(
|
||||
ctx: Context,
|
||||
sentence: str,
|
||||
sessionid: str,
|
||||
messages_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
) -> dict:
|
||||
"""
|
||||
Segment the dialogue or sentence into sub-problems.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
sentence: Original sentence to split
|
||||
sessionid: Session identifier
|
||||
messages_id: Message identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' (JSON string of split results) and 'original' sentence
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Extract user ID from session
|
||||
user_id = session_service.resolve_user_id(sessionid)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(user_id, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
history = []
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='problem_breakdown_prompt.jinja2',
|
||||
operation_name='split_the_problem',
|
||||
history=history,
|
||||
sentence=sentence
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Split_The_Problem: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": sentence,
|
||||
"error": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=ProblemBreakdownResponse
|
||||
)
|
||||
|
||||
# Handle RootModel response with .root attribute access
|
||||
if structured is None:
|
||||
# LLM returned None, use empty list as fallback
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
elif hasattr(structured, 'root') and structured.root is not None:
|
||||
split_result = json.dumps(
|
||||
[item.model_dump() for item in structured.root],
|
||||
ensure_ascii=False
|
||||
)
|
||||
elif isinstance(structured, list):
|
||||
# Fallback: treat structured itself as the list
|
||||
split_result = json.dumps(
|
||||
[item.model_dump() for item in structured],
|
||||
ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
# Last resort: use empty list
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Split_The_Problem: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
|
||||
logger.info(f"问题拆分")
|
||||
logger.info(f"问题拆分结果==>>:{split_result}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
"context": split_result,
|
||||
"original": sentence,
|
||||
"_intermediate": {
|
||||
"type": "problem_split",
|
||||
"data": json.loads(split_result) if split_result else [],
|
||||
"original_query": sentence
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Split_The_Problem failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": sentence,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('问题拆分', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Problem_Extension(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Extend the problem with additional sub-questions.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing split problem results
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' (aggregated questions) and 'original' question
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Resolve session ID from usermessages
|
||||
from app.core.memory.agent.utils.messages_tool import Resolve_username
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
history = []
|
||||
|
||||
# Process context to extract questions
|
||||
extent_quest, original = await Problem_Extension_messages_deal(context)
|
||||
|
||||
# Format questions for template rendering
|
||||
questions_formatted = []
|
||||
for msg in extent_quest:
|
||||
if msg.get("role") == "user":
|
||||
questions_formatted.append(msg.get("content", ""))
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Problem_Extension_prompt.jinja2',
|
||||
operation_name='problem_extension',
|
||||
history=history,
|
||||
questions=questions_formatted
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Problem_Extension: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {},
|
||||
"original": original,
|
||||
"error": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
response_content = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=ProblemExtensionResponse
|
||||
)
|
||||
|
||||
# Aggregate results by original question
|
||||
aggregated_dict = {}
|
||||
for item in response_content.root:
|
||||
key = getattr(item, "original_question", None) or (
|
||||
item.get("original_question") if isinstance(item, dict) else None
|
||||
)
|
||||
value = getattr(item, "extended_question", None) or (
|
||||
item.get("extended_question") if isinstance(item, dict) else None
|
||||
)
|
||||
if not key or not value:
|
||||
continue
|
||||
aggregated_dict.setdefault(key, []).append(value)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Problem_Extension: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aggregated_dict = {}
|
||||
|
||||
logger.info(f"问题扩展")
|
||||
logger.info(f"问题扩展==>>:{aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
"context": aggregated_dict,
|
||||
"original": original,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "problem_extension",
|
||||
"data": aggregated_dict,
|
||||
"original_query": original,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Problem_Extension failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {},
|
||||
"original": context.get("original", ""),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('问题扩展', duration)
|
||||
282
app/core/memory/agent/mcp_server/tools/retrieval_tools.py
Normal file
282
app/core/memory/agent/mcp_server/tools/retrieval_tools.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Retrieval Tools for database and context retrieval.
|
||||
|
||||
This module contains MCP tools for retrieving data using hybrid search.
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.llm_tools import deduplicate_entries, merge_to_key_value_pairs
|
||||
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Retrieve(
|
||||
ctx: Context,
|
||||
context,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Retrieve data from the database using hybrid search.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary or string containing query information
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' with Query and Expansion_issue results
|
||||
"""
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id": os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
start = time.time()
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
search_service = get_context_resource(ctx, 'search_service')
|
||||
|
||||
databases_anser = []
|
||||
|
||||
# Handle both dict and string context
|
||||
if isinstance(context, dict):
|
||||
# Process dict context with extended questions
|
||||
all_items = []
|
||||
content, original = await Retriev_messages_deal(context)
|
||||
|
||||
# Extract all query items from content
|
||||
# content is like {original_question: [extended_questions...], ...}
|
||||
for key, values in content.items():
|
||||
if isinstance(values, list):
|
||||
all_items.extend(values)
|
||||
elif isinstance(values, str):
|
||||
all_items.append(values)
|
||||
elif values is not None:
|
||||
# Fallback: convert non-empty non-list values to string
|
||||
all_items.append(str(values))
|
||||
|
||||
# Execute search for each question
|
||||
for idx, question in enumerate(all_items):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": question,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query=question
|
||||
raw_results=clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
clean_content = ''
|
||||
raw_results=''
|
||||
cleaned_query = question
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
|
||||
databases_anser.append({
|
||||
"Query_small": cleaned_query,
|
||||
"Result_small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": idx + 1,
|
||||
"total": len(all_items)
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for question '{question}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Continue with empty result for this question
|
||||
databases_anser.append({
|
||||
"Query_small": question,
|
||||
"Result_small": ""
|
||||
})
|
||||
|
||||
# Build initial database data structure
|
||||
databases_data = {
|
||||
"Query": original,
|
||||
"Expansion_issue": databases_anser
|
||||
}
|
||||
|
||||
# Collect intermediate outputs before deduplication
|
||||
intermediate_outputs = []
|
||||
for item in databases_anser:
|
||||
if '_intermediate' in item:
|
||||
intermediate_outputs.append(item['_intermediate'])
|
||||
|
||||
# Deduplicate and merge results
|
||||
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
|
||||
deduplicated_data_merged = merge_to_key_value_pairs(
|
||||
deduplicated_data,
|
||||
'Query_small',
|
||||
'Result_small'
|
||||
)
|
||||
|
||||
# Restructure for Verify/Retrieve_Summary compatibility
|
||||
keys, val = [], []
|
||||
for item in deduplicated_data_merged:
|
||||
for items_key, items_value in item.items():
|
||||
keys.append(items_key)
|
||||
val.append(items_value)
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val):
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
})
|
||||
|
||||
dup_databases = {
|
||||
"Query": original,
|
||||
"Expansion_issue": send_verify,
|
||||
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
|
||||
}
|
||||
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
|
||||
else:
|
||||
# Handle string context (simple query)
|
||||
query = str(context).strip()
|
||||
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": query,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query = query
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = query
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
# Keep structure for Verify/Retrieve_Summary compatibility
|
||||
dup_databases = {
|
||||
"Query": cleaned_query,
|
||||
"Expansion_issue": [{
|
||||
"Query_small": cleaned_query,
|
||||
"Answer_Small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": 1,
|
||||
"total": 1
|
||||
}
|
||||
}]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for query '{query}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty results on failure
|
||||
dup_databases = {
|
||||
"Query": query,
|
||||
"Expansion_issue": []
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"检索==>>:{storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
|
||||
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
|
||||
)
|
||||
|
||||
# Build result with intermediate outputs
|
||||
result = {
|
||||
"context": dup_databases,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
# Add intermediate outputs list if they exist
|
||||
intermediate_outputs = dup_databases.get('_intermediate_outputs', [])
|
||||
if intermediate_outputs:
|
||||
result['_intermediates'] = intermediate_outputs
|
||||
logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result")
|
||||
else:
|
||||
logger.warning("No intermediate outputs found in dup_databases")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {
|
||||
"Query": "",
|
||||
"Expansion_issue": []
|
||||
},
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
647
app/core/memory/agent/mcp_server/tools/summary_tools.py
Normal file
647
app/core/memory/agent/mcp_server/tools/summary_tools.py
Normal file
@@ -0,0 +1,647 @@
|
||||
"""
|
||||
Summary Tools for data summarization.
|
||||
|
||||
This module contains MCP tools for summarizing retrieved data and generating responses.
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.summary_models import (
|
||||
SummaryData,
|
||||
SummaryResponse,
|
||||
RetrieveSummaryData,
|
||||
RetrieveSummaryResponse
|
||||
)
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Summary_messages_deal,
|
||||
Resolve_username
|
||||
)
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Summary(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize the verified data.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: JSON string containing verified data
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'summary_result'
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Process context to extract answer and query
|
||||
answer_small, query = await Summary_messages_deal(context)
|
||||
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
# Prepare data for template
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": answer_small
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Summary: initialization failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"summary_result": "信息不足,无法回答"
|
||||
}
|
||||
|
||||
try:
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='summary_prompt.jinja2',
|
||||
operation_name='summary',
|
||||
data=data,
|
||||
query=query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# Call LLM with structured response
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=SummaryResponse
|
||||
)
|
||||
|
||||
aimessages = structured.query_answer or ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = ""
|
||||
|
||||
try:
|
||||
# Save session
|
||||
if aimessages != "":
|
||||
await session_service.save_session(
|
||||
user_id=sessionid,
|
||||
query=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
# Use fallback if empty
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"验证之后的总结==>>:{aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('总结', duration)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Retrieve_Summary(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize data directly from retrieval results.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing Query and Expansion_issue from Retrieve
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'summary_result'
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
|
||||
|
||||
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
|
||||
if isinstance(context, dict):
|
||||
if "content" in context:
|
||||
inner = context["content"]
|
||||
# If it's a JSON string, parse it
|
||||
if isinstance(inner, str):
|
||||
try:
|
||||
parsed = json.loads(inner)
|
||||
logger.info(f"Retrieve_Summary: successfully parsed JSON")
|
||||
except json.JSONDecodeError:
|
||||
# Try unescaping first
|
||||
try:
|
||||
unescaped = inner.encode('utf-8').decode('unicode_escape')
|
||||
parsed = json.loads(unescaped)
|
||||
logger.info(f"Retrieve_Summary: parsed after unescaping")
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: parsing failed even after unescape: {e}"
|
||||
)
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
parsed = None
|
||||
|
||||
if parsed:
|
||||
# Check if parsed has 'context' wrapper
|
||||
if isinstance(parsed, dict) and "context" in parsed:
|
||||
context_dict = parsed["context"]
|
||||
else:
|
||||
context_dict = parsed
|
||||
elif isinstance(inner, dict):
|
||||
context_dict = inner
|
||||
else:
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
elif "context" in context:
|
||||
context_dict = context["context"] if isinstance(context["context"], dict) else context
|
||||
else:
|
||||
context_dict = context
|
||||
else:
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
|
||||
query = context_dict.get("Query", "")
|
||||
expansion_issue = context_dict.get("Expansion_issue", [])
|
||||
|
||||
# Extract retrieve_info from expansion_issue
|
||||
retrieve_info = []
|
||||
for item in expansion_issue:
|
||||
# Check for both Answer_Small and Answer_Samll (typo) for backward compatibility
|
||||
answer = None
|
||||
if isinstance(item, dict):
|
||||
if "Answer_Small" in item:
|
||||
answer = item["Answer_Small"]
|
||||
elif "Answer_Samll" in item:
|
||||
answer = item["Answer_Samll"]
|
||||
|
||||
if answer is not None:
|
||||
# Handle both string and list formats
|
||||
if isinstance(answer, list):
|
||||
# Join list of characters/strings into a single string
|
||||
retrieve_info.append(''.join(str(x) for x in answer))
|
||||
elif isinstance(answer, str):
|
||||
retrieve_info.append(answer)
|
||||
else:
|
||||
retrieve_info.append(str(answer))
|
||||
|
||||
# Join all retrieve_info into a single string
|
||||
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: initialization failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"summary_result": "信息不足,无法回答"
|
||||
}
|
||||
|
||||
try:
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='retrieve_summary',
|
||||
query=query,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info_str
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Retrieve_Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# Call LLM with structured response
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
|
||||
# Handle case where structured response might be None or incomplete
|
||||
if structured and hasattr(structured, 'data') and structured.data:
|
||||
aimessages = structured.data.query_answer or ""
|
||||
else:
|
||||
logger.warning("Structured response is None or incomplete, using default message")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
# Check for insufficient information response
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
|
||||
# Save session
|
||||
await session_service.save_session(
|
||||
user_id=sessionid,
|
||||
query=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: LLM call failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = ""
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
# Use fallback if empty
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"检索之后的总结==>>:{aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索总结', duration)
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "retrieval_summary",
|
||||
"summary": aimessages,
|
||||
"query": query,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Input_Summary(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a quick summary for direct input without verification.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: String containing the input sentence
|
||||
usermessages: User messages identifier
|
||||
search_switch: Search switch value for routing ('2' for summaries only)
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
Returns:
|
||||
dict: Contains 'query_answer' with the summary result
|
||||
"""
|
||||
start = time.time()
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# Initialize variables to avoid UnboundLocalError
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
search_service = get_context_resource(ctx, 'search_service')
|
||||
|
||||
# Check if llm_client is None
|
||||
if llm_client is None:
|
||||
error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable."
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
sessionid = sessionid.replace('call_id_', '')
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(
|
||||
str(sessionid),
|
||||
str(apply_id),
|
||||
str(group_id)
|
||||
)
|
||||
# Override with empty list for now (as in original)
|
||||
|
||||
# Log the raw context for debugging
|
||||
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
|
||||
|
||||
# Extract sentence from context
|
||||
# Context can be a string or might contain the sentence in various formats
|
||||
try:
|
||||
# Try to parse as JSON first
|
||||
if isinstance(context, str) and (context.startswith('{') or context.startswith('[')):
|
||||
try:
|
||||
import json
|
||||
context_dict = json.loads(context)
|
||||
if isinstance(context_dict, dict):
|
||||
query = context_dict.get('sentence', context_dict.get('content', context))
|
||||
else:
|
||||
query = context
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON, try regex
|
||||
match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context)
|
||||
query = match.group(1) if match else context
|
||||
else:
|
||||
query = context
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract query from context: {e}")
|
||||
query = context
|
||||
|
||||
# Clean query
|
||||
query = str(query).strip().strip("\"'")
|
||||
|
||||
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
|
||||
|
||||
# Execute search based on search_switch and storage_type
|
||||
try:
|
||||
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
|
||||
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": query,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
|
||||
'''检索'''
|
||||
if search_switch == '2':
|
||||
search_params["include"] = ["summaries"]
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
raw_results = []
|
||||
retrieve_info = ""
|
||||
kb_config={
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id":os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
|
||||
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
retrieve_info = '\n\n'.join(retrieval_knowledge)
|
||||
raw_results=[retrieve_info]
|
||||
logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
retrieve_info=''
|
||||
raw_results=['']
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
logger.info(f"Input_Summary: 使用 summary 进行检索")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: hybrid_search failed, using empty results: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
retrieve_info, question, raw_results = "", query, []
|
||||
|
||||
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='input_summary',
|
||||
query=query,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
aimessages = structured.data.query_answer or "信息不足,无法回答"
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: response_structured failed, using default answer: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "input_summary",
|
||||
"title": "快速答案",
|
||||
"summary": aimessages,
|
||||
"query": query,
|
||||
"raw_results": raw_results,
|
||||
"search_mode": "quick_search",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "fail",
|
||||
"summary_result": "信息不足,无法回答",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Summary_fails(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Handle workflow failure when summary cannot be generated.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Failure context string
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'query_answer' with failure message
|
||||
"""
|
||||
try:
|
||||
# Extract services from context
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
|
||||
# Parse session ID from usermessages
|
||||
usermessages_parts = usermessages.split('_')[1:]
|
||||
sessionid = '_'.join(usermessages_parts[:-1])
|
||||
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
logger.info(f"没有相关数据")
|
||||
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": "没有相关数据",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Summary_fails failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "fail",
|
||||
"summary_result": "没有相关数据",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
169
app/core/memory/agent/mcp_server/tools/verification_tools.py
Normal file
169
app/core/memory/agent/mcp_server/tools/verification_tools.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Verification Tools for data verification.
|
||||
|
||||
This module contains MCP tools for verifying retrieved data.
|
||||
"""
|
||||
import time
|
||||
|
||||
from jinja2 import Template
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.verify_tool import VerifyTool
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Verify_messages_deal,
|
||||
Retrieve_verify_tool_messages_deal,
|
||||
Resolve_username
|
||||
)
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Verify(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Verify the retrieved data.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing query and expansion issues
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'verified_data' with verification results
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
|
||||
# Load verification prompt template
|
||||
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2'
|
||||
|
||||
# Read template file directly (VerifyTool expects raw template content)
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
system_prompt = await read_template_file(file_path)
|
||||
|
||||
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
|
||||
template = Template(system_prompt)
|
||||
system_prompt = template.render(history=history, sentence=context)
|
||||
|
||||
# Process context to extract query and results
|
||||
Query_small, Result_small, query = await Verify_messages_deal(context)
|
||||
|
||||
# Build query list for verification
|
||||
query_list = []
|
||||
for query_small, anser in zip(Query_small, Result_small):
|
||||
query_list.append({
|
||||
'Query_small': query_small,
|
||||
'Answer_Small': anser
|
||||
})
|
||||
|
||||
messages = {
|
||||
"Query": query,
|
||||
"Expansion_issue": query_list
|
||||
}
|
||||
|
||||
|
||||
|
||||
# Call verification workflow
|
||||
verify_tool = VerifyTool(system_prompt, messages)
|
||||
verify_result = await verify_tool.verify()
|
||||
|
||||
# Parse LLM verification result with error handling
|
||||
try:
|
||||
messages_deal = await Retrieve_verify_tool_messages_deal(
|
||||
verify_result,
|
||||
history,
|
||||
query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_verify_tool_messages_deal parsing failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Fallback to avoid 500 errors
|
||||
messages_deal = {
|
||||
"data": {
|
||||
"query": query,
|
||||
"expansion_issue": []
|
||||
},
|
||||
"split_result": "failed",
|
||||
"reason": str(e),
|
||||
"history": history,
|
||||
}
|
||||
|
||||
logger.info(f"验证==>>:{messages_deal}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"verified_data": messages_deal,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "数据验证",
|
||||
"result": messages_deal.get("split_result", "unknown"),
|
||||
"reason": messages_deal.get("reason", ""),
|
||||
"query": query,
|
||||
"verified_count": len(query_list),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Verify failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"verified_data": {
|
||||
"data": {
|
||||
"query": "",
|
||||
"expansion_issue": []
|
||||
},
|
||||
"split_result": "failed",
|
||||
"reason": str(e),
|
||||
"history": [],
|
||||
}
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('验证', duration)
|
||||
7
app/core/memory/agent/utils/__init__.py
Normal file
7
app/core/memory/agent/utils/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Agent utilities."""
|
||||
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
|
||||
__all__ = [
|
||||
"MultimodalProcessor",
|
||||
]
|
||||
70
app/core/memory/agent/utils/get_dialogs.py
Normal file
70
app/core/memory/agent/utils/get_dialogs.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
import json
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||
|
||||
|
||||
async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "group_1",
|
||||
user_id: str = "user1",
|
||||
apply_id: str = "applyid",
|
||||
content: str = "这是用户的输入",
|
||||
ref_id: str = "wyl_20251027",
|
||||
config_id: str = None
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from all test data entries using the specified chunker strategy.
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
group_id: Group identifier
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
content: Dialog content
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
|
||||
Returns:
|
||||
List of DialogData objects with generated chunks for each test entry
|
||||
"""
|
||||
dialog_data_list = []
|
||||
messages = []
|
||||
|
||||
messages.append(ConversationMessage(role="用户", msg=content))
|
||||
|
||||
# Create DialogData
|
||||
conversation_context = ConversationContext(msgs=messages)
|
||||
# Create DialogData with group_id based on the entry's id for uniqueness
|
||||
dialog_data = DialogData(
|
||||
context=conversation_context,
|
||||
ref_id=ref_id,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
config_id=config_id
|
||||
)
|
||||
# Create DialogueChunker and process the dialogue
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
dialog_data.chunks = extracted_chunks
|
||||
|
||||
dialog_data_list.append(dialog_data)
|
||||
|
||||
# Convert to dict with datetime serialized
|
||||
def serialize_datetime(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
||||
|
||||
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
||||
|
||||
print(dialog_data_list)
|
||||
|
||||
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
|
||||
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
|
||||
|
||||
|
||||
return dialog_data_list
|
||||
204
app/core/memory/agent/utils/llm_tools.py
Normal file
204
app/core/memory/agent/utils/llm_tools.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import TypedDict, Annotated
|
||||
import os
|
||||
import logging
|
||||
|
||||
from jinja2 import Template
|
||||
from langchain_core.messages import AnyMessage
|
||||
from dotenv import load_dotenv
|
||||
from langgraph.graph import add_messages
|
||||
from openai import OpenAI
|
||||
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
from app.core.memory.utils.config.config_utils import get_picture_config, get_voice_config
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
|
||||
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
#TODO: Refactor entire picture/voice
|
||||
# async def LLM_model_request(context,data,query):
|
||||
# '''
|
||||
# Agent model request
|
||||
# Args:
|
||||
# context:Input request
|
||||
# data: template parameters
|
||||
# query:request content
|
||||
# Returns:
|
||||
|
||||
# '''
|
||||
# template = Template(context)
|
||||
# system_prompt = template.render(**data)
|
||||
# llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
# result = await llm_client.chat(
|
||||
# messages=[{"role": "system", "content": system_prompt}] + [{"role": "user", "content": query}]
|
||||
# )
|
||||
# return result
|
||||
|
||||
async def picture_model_requests(image_url):
|
||||
'''
|
||||
|
||||
Args:
|
||||
image_url:
|
||||
Returns:
|
||||
|
||||
'''
|
||||
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 '
|
||||
system_prompt = await read_template_file(file_path)
|
||||
result = await Picture_recognize(image_url,system_prompt)
|
||||
return (result)
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
user_id:str
|
||||
apply_id:str
|
||||
group_id:str
|
||||
|
||||
class ReadState(TypedDict):
|
||||
'''
|
||||
Langgrapg READING TypedDict
|
||||
name:
|
||||
id:user id
|
||||
loop_count:Traverse times
|
||||
search_switch:type
|
||||
config_id: configuration id for filtering results
|
||||
'''
|
||||
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
|
||||
name: str
|
||||
id: str
|
||||
loop_count:int
|
||||
search_switch: str
|
||||
user_id: str
|
||||
apply_id: str
|
||||
group_id: str
|
||||
config_id: str
|
||||
|
||||
|
||||
class COUNTState:
|
||||
'''
|
||||
The number of times the workflow dialogue retrieval content has no correct message recall traversal
|
||||
'''
|
||||
def __init__(self, limit: int = 5):
|
||||
self.total: int = 0 # 当前累加值
|
||||
self.limit: int = limit # 最大上限
|
||||
|
||||
def add(self, value: int = 1):
|
||||
"""累加数字,如果达到上限就保持最大值"""
|
||||
self.total += value
|
||||
print(f"[COUNTState] 当前值: {self.total}")
|
||||
if self.total >= self.limit:
|
||||
print(f"[COUNTState] 达到上限 {self.limit}")
|
||||
self.total = self.limit # 达到上限不再增加
|
||||
|
||||
def get_total(self) -> int:
|
||||
"""获取当前累加值"""
|
||||
return self.total
|
||||
|
||||
def reset(self):
|
||||
"""手动重置累加值"""
|
||||
self.total = 0
|
||||
print(f"[COUNTState] 已重置为 0")
|
||||
|
||||
|
||||
|
||||
# def embed(texts: list[str]) -> list[list[float]]:
|
||||
# # 这里可以换成 LangChain Embeddings
|
||||
# return [[float(len(t) % 5), float(len(t) % 3)] for t in texts]
|
||||
|
||||
|
||||
# def export_store_to_json(store, namespace):
|
||||
# """Export the entire storage content to a JSON file"""
|
||||
# # 搜索所有存储项
|
||||
# all_items = store.search(namespace)
|
||||
|
||||
# # 整理数据
|
||||
# export_data = {}
|
||||
# for item in all_items:
|
||||
# if hasattr(item, 'key') and hasattr(item, 'value'):
|
||||
# export_data[item.key] = item.value
|
||||
|
||||
# # 保存到文件
|
||||
# os.makedirs("memory_logs", exist_ok=True)
|
||||
# with open("memory_logs/full_memory_export.json", "w", encoding="utf-8") as f:
|
||||
# json.dump(export_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# print(f"{len(export_data)} 条记忆到 JSON 文件")
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
grouped[item[query_key]].append(item[result_key])
|
||||
return [{key: values} for key, values in grouped.items()]
|
||||
|
||||
def deduplicate_entries(entries):
|
||||
seen = set()
|
||||
deduped = []
|
||||
for entry in entries:
|
||||
key = (entry['Query_small'], entry['Result_small'])
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(entry)
|
||||
return deduped
|
||||
|
||||
|
||||
|
||||
async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
|
||||
try:
|
||||
model_config = get_picture_config(SELECTED_LLM_PICTURE_NAME)
|
||||
except Exception as e:
|
||||
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||
logger.error(err)
|
||||
return err
|
||||
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
|
||||
backend_model_name = model_config["llm_name"].split("/")[-1]
|
||||
api_base=model_config['api_base']
|
||||
|
||||
logger.info(f"model_name: {backend_model_name}")
|
||||
logger.info(f"api_key set: {'yes' if api_key else 'no'}")
|
||||
logger.info(f"base_url: {model_config['api_base']}")
|
||||
|
||||
client = OpenAI(
|
||||
api_key=api_key, base_url=api_base,
|
||||
)
|
||||
completion = client.chat.completions.create(
|
||||
model=backend_model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url":image_path,
|
||||
},
|
||||
{"type": "text",
|
||||
"text": PROMPT_TICKET_EXTRACTION}
|
||||
]
|
||||
}
|
||||
])
|
||||
picture_text = completion.choices[0].message.content
|
||||
picture_text = picture_text.replace('```json', '').replace('```', '')
|
||||
picture_text = json.loads(picture_text)
|
||||
return (picture_text['statement'])
|
||||
|
||||
async def Voice_recognize():
|
||||
try:
|
||||
model_config = get_voice_config(SELECTED_LLM_VOICE_NAME)
|
||||
except Exception as e:
|
||||
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||
logger.error(err)
|
||||
return err
|
||||
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
|
||||
backend_model_name = model_config["llm_name"].split("/")[-1]
|
||||
api_base = model_config['api_base']
|
||||
return api_key,backend_model_name,api_base
|
||||
|
||||
|
||||
15
app/core/memory/agent/utils/mcp_tools.py
Normal file
15
app/core/memory/agent/utils/mcp_tools.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from app.core.config import settings
|
||||
|
||||
def get_mcp_server_config():
|
||||
"""
|
||||
Get the MCP server configuration
|
||||
"""
|
||||
mcp_server_config = {
|
||||
"data_flow": {
|
||||
"url": f"http://{settings.SERVER_IP}:8081/sse", # 你前面的 FastMCP(weather) 服务端口
|
||||
"transport": "sse",
|
||||
"timeout": 15000,
|
||||
"sse_read_timeout": 15000,
|
||||
}
|
||||
}
|
||||
return mcp_server_config
|
||||
239
app/core/memory/agent/utils/messages_tool.py
Normal file
239
app/core/memory/agent/utils/messages_tool.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Any
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]:
|
||||
out = []
|
||||
for m in msgs:
|
||||
if hasattr(m, "content"):
|
||||
out.append({"role": "user", "content": getattr(m, "content", "")})
|
||||
elif isinstance(m, dict) and "role" in m and "content" in m:
|
||||
out.append(m)
|
||||
else:
|
||||
out.append({"role": "user", "content": str(m)})
|
||||
return out
|
||||
|
||||
|
||||
def _extract_content(resp: Any) -> str:
|
||||
"""Extract LLM content and sanitize to raw JSON/text.
|
||||
|
||||
- Supports both object and dict response shapes.
|
||||
- Removes leading role labels (e.g., "Assistant:").
|
||||
- Strips Markdown code fences like ```json ... ```.
|
||||
- Attempts to isolate the first valid JSON array/object block when extra text is present.
|
||||
"""
|
||||
|
||||
def _to_text(r: Any) -> str:
|
||||
try:
|
||||
# 对象形式: resp.choices[0].message.content
|
||||
if hasattr(r, "choices") and getattr(r, "choices", None):
|
||||
msg = r.choices[0].message
|
||||
if hasattr(msg, "content"):
|
||||
return msg.content
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
return msg["content"]
|
||||
# 字典形式: resp["choices"][0]["message"]["content"]
|
||||
if isinstance(r, dict):
|
||||
return r.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
except Exception:
|
||||
pass
|
||||
return str(r)
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
s = str(text).strip()
|
||||
# 移除可能的角色前缀
|
||||
s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s)
|
||||
# 提取 ```json ... ``` 代码块
|
||||
m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I)
|
||||
if m:
|
||||
s = m.group(1).strip()
|
||||
# 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段
|
||||
if not (s.startswith("{") or s.startswith("[")):
|
||||
left = s.find("[")
|
||||
right = s.rfind("]")
|
||||
if left != -1 and right != -1 and right > left:
|
||||
s = s[left:right + 1].strip()
|
||||
else:
|
||||
left = s.find("{")
|
||||
right = s.rfind("}")
|
||||
if left != -1 and right != -1 and right > left:
|
||||
s = s[left:right + 1].strip()
|
||||
return s
|
||||
|
||||
raw = _to_text(resp)
|
||||
return _clean_text(raw)
|
||||
|
||||
def Resolve_username(usermessages):
|
||||
'''
|
||||
Extract username
|
||||
Args:
|
||||
usermessages: user name
|
||||
|
||||
Returns:
|
||||
|
||||
'''
|
||||
usermessages = usermessages.split('_')[1:]
|
||||
sessionid = '_'.join(usermessages[:-1])
|
||||
return sessionid
|
||||
|
||||
|
||||
# TODO: USE app.core.memory.src.utils.render_template instead
|
||||
async def read_template_file(template_path: str) -> str:
|
||||
"""
|
||||
读取模板文件
|
||||
|
||||
Args:
|
||||
template_path: 模板文件路径
|
||||
|
||||
Returns:
|
||||
模板内容字符串
|
||||
|
||||
Note:
|
||||
建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能
|
||||
"""
|
||||
try:
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"模板文件未找到: {template_path}")
|
||||
raise
|
||||
except IOError as e:
|
||||
logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def Problem_Extension_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
extent_quest = []
|
||||
original = context.get('original', '')
|
||||
messages = context.get('context', '')
|
||||
messages = json.loads(messages)
|
||||
for message in messages:
|
||||
question = message.get('question', '')
|
||||
type = message.get('type', '')
|
||||
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
|
||||
|
||||
return extent_quest, original
|
||||
|
||||
|
||||
async def Retriev_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
if isinstance(context, dict):
|
||||
if 'context' in context or 'original' in context:
|
||||
return context.get('context', {}), context.get('original', '')
|
||||
return content, original_value
|
||||
|
||||
async def Verify_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
|
||||
query = context['context']['Query']
|
||||
Query_small_list = context['context']['Expansion_issue']
|
||||
Result_small = []
|
||||
Query_small = []
|
||||
for i in Query_small_list:
|
||||
Result_small.append(i['Answer_Small'][0])
|
||||
Query_small.append(i['Query_small'])
|
||||
return Query_small, Result_small, query
|
||||
|
||||
|
||||
async def Summary_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
query = re.findall(r'"query": (.*?),', messages)[0]
|
||||
query = query.replace('[', '').replace(']', '').strip()
|
||||
matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages)
|
||||
answer_small_texts = []
|
||||
for m in matches:
|
||||
try:
|
||||
parsed = json.loads(m)
|
||||
for item in parsed:
|
||||
answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', ''))
|
||||
except Exception:
|
||||
answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', ''))
|
||||
|
||||
return answer_small_texts, query
|
||||
|
||||
|
||||
async def VerifyTool_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
content_messages = messages.split('"context":')[1].replace('""', '"')
|
||||
messages = str(content_messages).split("name='Retrieve'")[0]
|
||||
query = re.findall(f'"Query": "(.*?)"', messages)[0]
|
||||
Query_small = re.findall(f'"Query_small": "(.*?)"', messages)
|
||||
Result_small = re.findall(f'"Result_small": "(.*?)"', messages)
|
||||
return Query_small, Result_small, query
|
||||
|
||||
|
||||
async def Retrieve_Summary_messages_deal(context):
|
||||
pass
|
||||
|
||||
|
||||
async def Retrieve_verify_tool_messages_deal(context, history, query):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
results = []
|
||||
# 统一转为字符串,避免 None 或非字符串导致正则报错
|
||||
text = str(context)
|
||||
blocks = re.findall(r'\{(.*?)\}', text, flags=re.S)
|
||||
for block in blocks:
|
||||
query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block)
|
||||
answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block)
|
||||
status = re.search(r'"status"\s*:\s*"([^"]*)"', block)
|
||||
query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block)
|
||||
|
||||
results.append({
|
||||
"query_small": query_small.group(1) if query_small else None,
|
||||
"answer_small": answer_small.group(1) if answer_small else None,
|
||||
# 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误
|
||||
"status": status.group(1) if status else "",
|
||||
"query_answer": query_answer.group(1) if query_answer else None
|
||||
})
|
||||
result = []
|
||||
for r in results:
|
||||
# 统一按字符串判定状态,兼容大小写和缺失情况
|
||||
status_str = str(r.get('status', '')).strip().lower()
|
||||
if status_str == 'false':
|
||||
continue
|
||||
else:
|
||||
result.append(r)
|
||||
split_result = 'failed' if not result else 'success'
|
||||
result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "",
|
||||
"history": history}
|
||||
return result
|
||||
38
app/core/memory/agent/utils/model_tool.py
Normal file
38
app/core/memory/agent/utils/model_tool.py
Normal file
@@ -0,0 +1,38 @@
|
||||
|
||||
|
||||
# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# sys.path.insert(0, project_root)
|
||||
|
||||
# load_dotenv()
|
||||
|
||||
# async def llm_client_chat(messages: List[dict]) -> str:
|
||||
# """使用 OpenAI 兼容接口进行对话,返回内容字符串。"""
|
||||
# try:
|
||||
# cfg = get_model_config(SELECTED_LLM_ID)
|
||||
# rb_config = RedBearModelConfig(
|
||||
# model_name=cfg["model_name"],
|
||||
# provider=cfg["provider"],
|
||||
# api_key=cfg["api_key"],
|
||||
# base_url=cfg["base_url"],
|
||||
# )
|
||||
# client = OpenAIClient(model_config=rb_config, type_="chat")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取模型配置失败:{e}")
|
||||
# err = f"获取模型配置失败:{str(e)}。请检查!!!"
|
||||
# return err
|
||||
# try:
|
||||
# response = await client.chat(messages)
|
||||
# print(f"model_tool's llm_client_chat response ======>:\n {response}")
|
||||
# return _extract_content(response)
|
||||
# # return _extract_content(result)
|
||||
# except Exception as e:
|
||||
# logger.error(f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。")
|
||||
# return f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。"
|
||||
|
||||
# async def main(image_url):
|
||||
# await llm_client_chat(image_url)
|
||||
#
|
||||
# # 运行主函数
|
||||
# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav']))
|
||||
#
|
||||
131
app/core/memory/agent/utils/multimodal.py
Normal file
131
app/core/memory/agent/utils/multimodal.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Multimodal input processor for handling image and audio content.
|
||||
|
||||
This module provides utilities for detecting and processing multimodal inputs
|
||||
(images and audio files) by converting them to text using appropriate models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from app.core.memory.agent.multimodal.speech_model import Vico_recognition
|
||||
from app.core.memory.agent.utils.llm_tools import picture_model_requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultimodalProcessor:
|
||||
"""
|
||||
Processor for handling multimodal inputs (images and audio).
|
||||
|
||||
This class detects image and audio file paths in input content and converts
|
||||
them to text using appropriate recognition models.
|
||||
"""
|
||||
|
||||
# Supported file extensions
|
||||
IMAGE_EXTENSIONS = ['.jpg', '.png']
|
||||
AUDIO_EXTENSIONS = [
|
||||
'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov',
|
||||
'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv'
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the multimodal processor."""
|
||||
pass
|
||||
|
||||
def is_image(self, content: str) -> bool:
|
||||
"""
|
||||
Check if content is an image file path.
|
||||
|
||||
Args:
|
||||
content: Input string to check
|
||||
|
||||
Returns:
|
||||
True if content ends with a supported image extension
|
||||
|
||||
Examples:
|
||||
>>> processor = MultimodalProcessor()
|
||||
>>> processor.is_image("photo.jpg")
|
||||
True
|
||||
>>> processor.is_image("document.pdf")
|
||||
False
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
|
||||
content_lower = content.lower()
|
||||
return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS)
|
||||
|
||||
def is_audio(self, content: str) -> bool:
|
||||
"""
|
||||
Check if content is an audio file path.
|
||||
|
||||
Args:
|
||||
content: Input string to check
|
||||
|
||||
Returns:
|
||||
True if content ends with a supported audio extension
|
||||
|
||||
Examples:
|
||||
>>> processor = MultimodalProcessor()
|
||||
>>> processor.is_audio("recording.mp3")
|
||||
True
|
||||
>>> processor.is_audio("video.mp4")
|
||||
True
|
||||
>>> processor.is_audio("document.txt")
|
||||
False
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
|
||||
content_lower = content.lower()
|
||||
return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS)
|
||||
|
||||
async def process_input(self, content: str) -> str:
|
||||
"""
|
||||
Process input content, converting images/audio to text if needed.
|
||||
|
||||
This method detects if the input is an image or audio file and converts
|
||||
it to text using the appropriate recognition model. If processing fails
|
||||
or the content is not multimodal, it returns the original content.
|
||||
|
||||
Args:
|
||||
content: Input string (may be file path or regular text)
|
||||
|
||||
Returns:
|
||||
Text content (original or converted from image/audio)
|
||||
|
||||
Examples:
|
||||
>>> processor = MultimodalProcessor()
|
||||
>>> await processor.process_input("photo.jpg")
|
||||
"Recognized text from image..."
|
||||
|
||||
>>> await processor.process_input("Hello world")
|
||||
"Hello world"
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}")
|
||||
return str(content)
|
||||
|
||||
try:
|
||||
# Check for image input
|
||||
if self.is_image(content):
|
||||
logger.info(f"[MultimodalProcessor] Detected image input: {content}")
|
||||
result = await picture_model_requests(content)
|
||||
logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...")
|
||||
return result
|
||||
|
||||
# Check for audio input
|
||||
if self.is_audio(content):
|
||||
logger.info(f"[MultimodalProcessor] Detected audio input: {content}")
|
||||
result = await Vico_recognition([content]).run()
|
||||
logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True)
|
||||
logger.info(f"[MultimodalProcessor] Falling back to original content")
|
||||
return content
|
||||
|
||||
# Return original content if not multimodal
|
||||
return content
|
||||
@@ -0,0 +1,81 @@
|
||||
|
||||
你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则:
|
||||
|
||||
角色:
|
||||
- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。
|
||||
- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。
|
||||
- 如果历史信息或上下文与当前问题无关,可忽略。
|
||||
|
||||
---
|
||||
|
||||
### 历史信息参考
|
||||
在生成扩展问题时,你可以参考以下历史数据(如果提供):
|
||||
- 历史对话或任务的主题;
|
||||
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
|
||||
- 历史中已解答的问题(避免重复);
|
||||
- 历史推理链(保持逻辑一致性)。
|
||||
|
||||
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
|
||||
输入历史信息内容:{{history}}
|
||||
|
||||
## User Input
|
||||
{% if questions is string %}
|
||||
{{ questions }}
|
||||
{% else %}
|
||||
{% for question in questions %}
|
||||
- {{ question }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
需求:
|
||||
- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。
|
||||
- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。
|
||||
- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。
|
||||
- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。
|
||||
- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。
|
||||
- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。
|
||||
- 子问题数量不超过4个。
|
||||
- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
|
||||
|
||||
输出要求:
|
||||
- 仅输出 JSON 数组,不要包含任何解释或代码块。
|
||||
- 每个元素包含:
|
||||
- `original_question`: 原始问题
|
||||
- `extended_question`: 扩展后的问题
|
||||
- `type`: 类型(事实检索/澄清/定义/比较/行动建议)
|
||||
- `reason`: 生成该扩展问题的简短理由
|
||||
- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。
|
||||
|
||||
示例:
|
||||
输入:
|
||||
[
|
||||
"问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳",
|
||||
]
|
||||
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
|
||||
"extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?",
|
||||
"type": "多跳",
|
||||
"reason": "输出原问题的关键要素"
|
||||
},
|
||||
{
|
||||
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
|
||||
"extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?",
|
||||
"type": "多跳",
|
||||
"reason": "输出原问题的关键要素"
|
||||
}
|
||||
]
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
|
||||
The output language should always be the same as the input language.{{ json_schema }}
|
||||
@@ -0,0 +1,37 @@
|
||||
# 角色
|
||||
你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。
|
||||
|
||||
# 任务
|
||||
根据提供的上下文信息回答用户的问题。
|
||||
|
||||
# 输入信息
|
||||
- 历史对话:{{history}}
|
||||
- 检索信息:{{retrieve_info}}
|
||||
|
||||
## User Query
|
||||
{{query}}
|
||||
|
||||
# 回答指南
|
||||
1. 仔细分析用户的问题
|
||||
2. 优先使用检索信息中的相关内容回答
|
||||
3. 结合历史对话提供连贯的回复
|
||||
4. 如果信息不足:
|
||||
- 对于简单问候或日常对话,给出自然简短的回复
|
||||
- 对于复杂问题,诚实说明信息不足
|
||||
5. 保持回答简洁、相关、自然
|
||||
6. 使用与问题相同的语言回答
|
||||
|
||||
**Output format**
|
||||
- 直接回答问题,像人类对话一样自然流畅
|
||||
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
|
||||
- 不要解释推理过程或评论信息来源
|
||||
- 如果只能部分回答问题,先回答能回答的部分,然后说明哪些方面信息不足
|
||||
- 如果完全无法回答,简洁地说明:"信息不足,无法回答。"
|
||||
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
|
||||
The output language should always be the same as the input language.{{ json_schema }}
|
||||
29
app/core/memory/agent/utils/prompt/Retrieve_prompt.jinja2
Normal file
29
app/core/memory/agent/utils/prompt/Retrieve_prompt.jinja2
Normal file
@@ -0,0 +1,29 @@
|
||||
|
||||
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
|
||||
你是一个智能问答助手,任务如下
|
||||
## 目标:
|
||||
|
||||
1. 接收一个字典,格式为 {'问题': [答案列表]}。
|
||||
2. 接收一个问题(字典中的 key)。
|
||||
3. 找到与问题匹配的答案列表。
|
||||
4. 将答案列表合并成一句自然流畅的话:
|
||||
- 如果答案有两条,使用“是”连接,例如:“A,是B”。
|
||||
- 如果答案有三条或以上,使用“,并且”“另外”等自然连词,保证句子流畅。
|
||||
5. 输出内容时只输出合并后的答案,不输出关键点或其他文字。
|
||||
6. 如果问题未在字典中找到对应答案,请输出:
|
||||
对不起,我没有找到相关信息。
|
||||
|
||||
|
||||
输出要求:
|
||||
- 文本形式
|
||||
---
|
||||
|
||||
字典示例:
|
||||
{
|
||||
'今天的天气怎么样': ['今天天气很好', '今天是晴天']
|
||||
}
|
||||
|
||||
问题示例:
|
||||
今天的天气怎么样
|
||||
输出要求:
|
||||
今天天气很好,是晴天
|
||||
@@ -0,0 +1,10 @@
|
||||
请提图像内的文本
|
||||
返回数据格式以json方式输出,
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
- 关键的JSON格式要求{"statement":识别出的文本内容}
|
||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
@@ -0,0 +1,34 @@
|
||||
你是一个输入分类助手,负责判断用户输入的意图类型。
|
||||
|
||||
## User Input
|
||||
{{ user_query }}
|
||||
|
||||
请你根据以下规则判断:
|
||||
1. 如果输入是在寻求信息、提问、请求解释、或疑问句(包括隐含的问题),则分类为 "question"。
|
||||
2. 如果输入是命令、陈述、描述、感叹、或其他类型,不在寻求答案,则分类为 "other"。
|
||||
只输出:
|
||||
{
|
||||
"type": "question"
|
||||
}
|
||||
或
|
||||
{
|
||||
"type": "other"
|
||||
}
|
||||
示例:
|
||||
输入:"Python怎么读取文件?"
|
||||
输出:{"type": "question"}
|
||||
|
||||
输入:"帮我写个读取文件的函数"
|
||||
输出:{"type": "other"}
|
||||
|
||||
输入:"今天是星期几?"
|
||||
输出:{"type": "question"}
|
||||
返回数据格式以json方式输出,
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
- 关键的JSON格式要求{"statement":识别出的文本内容}
|
||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
@@ -0,0 +1,160 @@
|
||||
|
||||
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
|
||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||
## 目标:
|
||||
你需要根据以下类型对输入数据进行分类,并生成相应的拆分策略和示例。
|
||||
---
|
||||
|
||||
### 历史信息参考
|
||||
在生成扩展问题时,你可以参考以下历史数据(如果提供):
|
||||
- 历史对话或任务的主题;
|
||||
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
|
||||
- 历史中已解答的问题(避免重复);
|
||||
- 历史推理链(保持逻辑一致性)。
|
||||
|
||||
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
|
||||
输入历史信息内容:{{history}}
|
||||
|
||||
## User Input
|
||||
{{ sentence }}
|
||||
|
||||
## 需求:
|
||||
1:首先判断类型(单跳、多跳、开放域、时间)。
|
||||
2:根据类型进行拆分。
|
||||
3:拆分后的内容需保证信息完整且可独立处理。
|
||||
4:对每个拆分条目,可附加示例或说明。
|
||||
5:拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
## 指令:
|
||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||
单跳(Single-hop)
|
||||
描述:问题或数据只需要通过一步即可得到答案或完成拆分,不依赖其他信息。
|
||||
拆分策略:直接识别核心信息或关键字段,生成可独立处理的片段。
|
||||
示例:
|
||||
输入数据:"请列出今年诺贝尔物理学奖的得主"
|
||||
拆分结果:[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": "今年诺贝尔物理学奖得主是谁",
|
||||
"type": "单跳’",
|
||||
}
|
||||
]
|
||||
注意: 当遇到上下文依赖问题时,明确指出缺失的信息类型并且,question可填写输入问题
|
||||
多跳(Multi-hop):
|
||||
描述:问题或数据需要通过多步推理或跨多个信息源才能得到答案。
|
||||
拆分策略:将问题拆解为多个子问题,每个子问题对应一个独立处理步骤,需要具备推理链条与逻辑连接数量。
|
||||
示例:
|
||||
输入数据:"今年诺贝尔物理学奖得主的研究领域及代表性成果"
|
||||
拆分结果:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 今年诺贝尔物理学奖得主是谁?",
|
||||
"type": "多跳’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "该得主的研究领域是什么?",
|
||||
"type": "多跳’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "该得主的代表性成果有哪些?",
|
||||
"type": "多跳’"
|
||||
}
|
||||
]
|
||||
开放域(Open-domain):
|
||||
描述:问题或数据不局限于特定知识库,需要从大范围信息中检索和生成答案,而不是从一个已知的小范围数据源中查找。。
|
||||
拆分策略:根据主题或关键实体拆分,同时保留上下文以便检索外部知识,问题涉及一般性、常识性、跨学科内容,可能是开放式回答(描述性、推理性、综合性)
|
||||
需要外部知识检索或推理才能确定,比如:“为什么人类需要睡眠?”、“量子计算与经典计算的主要区别是什么?”。
|
||||
示例:
|
||||
输入数据:"介绍量子计算的最新研究进展"
|
||||
拆分结果:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 量子计算的基本概念是什么?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "当前量子计算的主要研究方向有哪些?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "近期在量子计算领域有哪些重大进展?",
|
||||
"type": "开放域’",
|
||||
}
|
||||
]
|
||||
|
||||
时间(Temporal):
|
||||
描述:问题或数据涉及时间维度,需要按时间顺序或时间点拆分。
|
||||
拆分策略:根据事件时间或时间段拆分为独立条目或问题。
|
||||
示例:
|
||||
输入数据:"列出苹果公司过去五年的重大事件"
|
||||
拆分结果:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 苹果公司2019年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "苹果公司2020年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "苹果公司2021年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "苹果公司2022年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
}
|
||||
,
|
||||
{
|
||||
"id": "Q4",
|
||||
"question": "苹果公司2023年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
}
|
||||
]
|
||||
|
||||
输出要求:
|
||||
- 每个子问题包括:
|
||||
- `id`: 子问题编号(Q1, Q2...)
|
||||
- `question`: 子问题内容
|
||||
- `type`: 类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)
|
||||
- `reason`: 拆分的理由(为什么要这样拆)
|
||||
- 格式案例:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 量子计算的基本概念是什么?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "当前量子计算的主要研究方向有哪些?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "近期在量子计算领域有哪些重大进展?",
|
||||
"type": "开放域’",
|
||||
}
|
||||
]
|
||||
- 必须通过json.loads()的格式支持的形式输出
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
- 关键的JSON格式要求
|
||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
@@ -0,0 +1,60 @@
|
||||
# 角色
|
||||
你是验证专家
|
||||
你的目标是针对用户的输入Query_Samll字段的提问和Answer_Samll的回答分析,是不是回答Query_Samll这个字段的问题
|
||||
|
||||
{#以下可以采用先总括,再展开详细说明的方式,描述你希望智能体在每一个步骤如何进行工作,具体的工作步骤数量可以根据实际需求增删#}
|
||||
## 工作步骤
|
||||
1. 获取所有的Query_Samll字段和Answer_Samll字段
|
||||
2. 分析Answer_Samll的回复是不是和Query_Samll有关系
|
||||
3. 判断Answer_Samll和Query_Samll之间分析出来的关系状态
|
||||
4. 如果是True保留,否则不要相对应的问题和回答
|
||||
5. 输出,需要严格按照模版
|
||||
输入:{{history}}
|
||||
历史消息:{"history":{{sentence}}}
|
||||
### 第一步 获取用户的输入
|
||||
获取用户的输入提取对应的Query_Samll和Answer_Samll
|
||||
### 第二步 分析验证
|
||||
需要分析Query_Samll和Answer_Samll之间的关系可以参考history字段的内容,如果有关系不是答非所问
|
||||
## 核心验证标准
|
||||
在评估子问题拆分时,必须严格遵循以下标准,且验证过程中完全不依赖于子问题的相关信息(Answer_Samll):
|
||||
1. 合理性标准(必须全部满足):
|
||||
- 完整性:每个不同的子问题必须完整覆盖原问题的所有关键要素(如时间、主体、动作、目标等),无遗漏。
|
||||
- 最小化:每个不同的子问题数量应尽可能少,通常不超过原问题关键要素数量的2倍(建议2-4个),避免冗余和不必要拆分。
|
||||
- 相关性:每个不同的子问题必须直接服务于原问题的解答,不引入无关内容或扩展原问题未提及的主题。
|
||||
- 可操作性:每个不同的子问题应能在有限资源(如标准工具或合理时间)内独立解答,且难度适中。
|
||||
- 逻辑性:每个不同的子问题间应有清晰的逻辑关系(如并列、递进、因果),共同构成原问题的解答路径。
|
||||
|
||||
2. 不合理拆分的特征(出现任一特征即为不合理):
|
||||
- 不同的子问题数量超过5个或明显多于必要数量。
|
||||
- 引入原问题未提及的新主题、人物、细节或个人看法。
|
||||
- 拆分过于细碎,失去实用价值,无法高效合成原问题答案。
|
||||
|
||||
3. 特殊情况说明:
|
||||
- 每个不同的子问题与原问题相同,需进一步判断:
|
||||
- 每个不同的子问题不可进一步拆分 → success(合理,最小化拆分)
|
||||
- 每个不同的子问题能够进一步拆分为更小、更合理的问题 → failed(不合理,拆分没有最小化)
|
||||
- 每个不同的子问题数量=原问题核心要素数量 → success(理想情况)
|
||||
- 每个不同的子问题数量=核心要素数量+1 → success(通常合理)
|
||||
|
||||
### 第三步 添加状态
|
||||
如果有相关性并且比较高给一个状态TRUE,否则给一个FLASE的状态
|
||||
### 第四步 判断
|
||||
如果状态是TRUE保留这条数据,否则需不需要这条数据
|
||||
### 第五步 输出格式
|
||||
按照json的形式输出
|
||||
{"data":"Query":原来Query的字段,"history":原来的history字段,
|
||||
"expansion_issue":以为列表的形式存储验证之后的数据比如[
|
||||
{"query_small": query_small,
|
||||
"answer_small": answer_small,,
|
||||
"status": 回答的结果是否符合query_small,填写状态,
|
||||
"query_answer": answer_small},
|
||||
{
|
||||
"query_small": "张曼婷生日是什么时候?",
|
||||
"answer_small": "张曼婷喜欢绘画。",
|
||||
"status": "True",
|
||||
"query_answer": "张曼 婷喜欢绘画。"
|
||||
},{}......]
|
||||
,
|
||||
"split_result":如果expansion_issue是空的列表返回failed,不是空列表返回success,
|
||||
"reason": 为以上分析完之后的结果给一个说明
|
||||
}
|
||||
57
app/core/memory/agent/utils/prompt/summary_prompt.jinja2
Normal file
57
app/core/memory/agent/utils/prompt/summary_prompt.jinja2
Normal file
@@ -0,0 +1,57 @@
|
||||
{# 角色定义 #}
|
||||
你是专业的问题解答专家,负责根据上下文信息和检索到的所有信息准确回答用户的问题。
|
||||
|
||||
{# 输入数据展示 #}
|
||||
{% if data %}
|
||||
## 输入数据
|
||||
上下文信息:
|
||||
{% for item in data.history %}
|
||||
- {{ item }}
|
||||
{% endfor %}
|
||||
检索到的所有信息:
|
||||
{% for item in data.retrieve_info %}
|
||||
- {{ item }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## User Query
|
||||
{{ query }}
|
||||
|
||||
{# 问题回答标准 #}
|
||||
## 问题回答核心标准
|
||||
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。注意,若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
|
||||
- 若能根据已有信息回答用户的问题,应根据上下文信息和检索到的所有信息提供简明扼要的答案。
|
||||
- 若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
|
||||
|
||||
{# 重要提醒 #}
|
||||
再次提醒,给出问题的答案时,仅根据已有的信息进行回答,不能自己编造答案。
|
||||
|
||||
{# 输出格式模板 #}
|
||||
## 输出格式
|
||||
严格按照以下JSON格式输出,不添加任何其他内容:
|
||||
{
|
||||
"data": {
|
||||
"query": "{{ query }}",
|
||||
"history": [
|
||||
{% for item in data.history %}
|
||||
"{{ item | replace('"', '\\"') }}"
|
||||
{% if not loop.last %},{% endif %}
|
||||
{% endfor %}
|
||||
],
|
||||
"retrieve_info": [
|
||||
{% for item in data.retrieve_info %}
|
||||
"{{ item | replace('"', '\\"') }}"
|
||||
{% if not loop.last %},{% endif %}
|
||||
{% endfor %}
|
||||
]
|
||||
},
|
||||
"query_answer": "{% if not data.history and not data.retrieve_info %}信息不足,无法回答。{% endif %}"
|
||||
}
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
|
||||
The output language should always be the same as the input language.{{ json_schema }}
|
||||
203
app/core/memory/agent/utils/redis_tool.py
Normal file
203
app/core/memory/agent/utils/redis_tool.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import redis
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.core.config import settings
|
||||
class RedisSessionStore:
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None,session_id=''):
|
||||
self.r = redis.Redis(host=host, port=port, db=db, password=password)
|
||||
self.uudi=session_id
|
||||
|
||||
|
||||
# 修改后的 save_session 方法
|
||||
def save_session(self, userid, messages, aimessages, apply_id, group_id):
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
||||
|
||||
# 使用 Hash 存储结构化数据
|
||||
result = self.r.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": starttime
|
||||
})
|
||||
print(f"保存结果: {result}, session_id: {session_id}")
|
||||
return session_id # 返回新生成的 session_id
|
||||
except Exception as e:
|
||||
print(f"保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ---------------- 读取 ----------------
|
||||
def get_session(self, session_id):
|
||||
"""
|
||||
读取一条会话数据
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
data = self.r.hgetall(key)
|
||||
if data:
|
||||
return {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||
return None
|
||||
|
||||
def get_session_apply_group(self, sessionid, apply_id, group_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key_bytes in self.r.keys('session:*'):
|
||||
key = key_bytes.decode('utf-8')
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 解码数据
|
||||
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (decoded_data.get('sessionid') == sessionid and
|
||||
decoded_data.get('apply_id') == apply_id and
|
||||
decoded_data.get('group_id') == group_id):
|
||||
result_items.append(decoded_data)
|
||||
|
||||
return result_items
|
||||
|
||||
def get_all_sessions(self):
|
||||
"""
|
||||
获取所有会话数据
|
||||
"""
|
||||
sessions = {}
|
||||
for key in self.r.keys('session:*'):
|
||||
sid = key.decode('utf-8').split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
# ---------------- 更新 ----------------
|
||||
def update_session(self, session_id, field, value):
|
||||
"""
|
||||
更新单个字段
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
if self.r.exists(key):
|
||||
self.r.hset(key, field, value)
|
||||
return True
|
||||
return False
|
||||
|
||||
# ---------------- 删除 ----------------
|
||||
def delete_session(self, session_id):
|
||||
"""
|
||||
删除单条会话
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
return self.r.delete(key)
|
||||
|
||||
def delete_all_sessions(self):
|
||||
"""
|
||||
删除所有会话
|
||||
"""
|
||||
keys = self.r.keys('session:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
def delete_duplicate_sessions(self):
|
||||
"""
|
||||
删除重复会话数据,条件:
|
||||
"sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
"""
|
||||
seen = set() # 用来记录已出现的唯一组合
|
||||
deleted_count = 0
|
||||
|
||||
for key_bytes in self.r.keys('session:*'):
|
||||
key = key_bytes.decode('utf-8')
|
||||
data = self.r.hgetall(key)
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 获取五个字段的值并解码
|
||||
sessionid = data.get(b'sessionid', b'').decode('utf-8')
|
||||
user_id = data.get(b'id', b'').decode('utf-8') # 对应user_id
|
||||
group_id = data.get(b'group_id', b'').decode('utf-8')
|
||||
messages = data.get(b'messages', b'').decode('utf-8')
|
||||
aimessages = data.get(b'aimessages', b'').decode('utf-8')
|
||||
|
||||
# 用五元组作为唯一标识
|
||||
identifier = (sessionid, user_id, group_id, messages, aimessages)
|
||||
|
||||
if identifier in seen:
|
||||
# 重复,删除该 key
|
||||
self.r.delete(key)
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 第一次出现,加入 seen
|
||||
seen.add(identifier)
|
||||
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}")
|
||||
return deleted_count
|
||||
|
||||
def find_user_session(self,sessionid):
|
||||
user_id = sessionid
|
||||
|
||||
result_items = []
|
||||
for key, values in store.get_all_sessions().items():
|
||||
history = {}
|
||||
if user_id == str(values['sessionid']):
|
||||
history["Query"] = values['messages']
|
||||
history["Answer"] = values['aimessages']
|
||||
result_items.append(history)
|
||||
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
return (result_items)
|
||||
|
||||
def find_user_apply_group(self, sessionid, apply_id, group_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key_bytes in self.r.keys('session:*'):
|
||||
key = key_bytes.decode('utf-8')
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 解码数据
|
||||
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (decoded_data.get('sessionid') == sessionid and
|
||||
decoded_data.get('apply_id') == apply_id and
|
||||
decoded_data.get('group_id') == group_id):
|
||||
history = {
|
||||
"Query": decoded_data.get('messages'),
|
||||
"Answer": decoded_data.get('aimessages')
|
||||
}
|
||||
|
||||
|
||||
result_items.append(history)
|
||||
|
||||
# 如果结果少于等于1条,返回空列表
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
|
||||
return result_items
|
||||
|
||||
store = RedisSessionStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
59
app/core/memory/agent/utils/type_classifier.py
Normal file
59
app/core/memory/agent/utils/type_classifier.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Type classification utility for distinguishing read/write operations.
|
||||
"""
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class DistinguishTypeResponse(BaseModel):
|
||||
"""Response model for type classification"""
|
||||
type: str
|
||||
|
||||
|
||||
async def status_typle(messages: str) -> dict:
|
||||
"""
|
||||
Classify message type as read or write operation.
|
||||
|
||||
Args:
|
||||
messages: User message to classify
|
||||
|
||||
Returns:
|
||||
dict: Contains 'type' field with classification result
|
||||
"""
|
||||
try:
|
||||
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/distinguish_types_prompt.jinja2'
|
||||
template_content = await read_template_file(file_path)
|
||||
template = Template(template_content)
|
||||
system_prompt = template.render(user_query=messages)
|
||||
log_prompt_rendering("status_typle", system_prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"Template rendering failed for status_typle: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=DistinguishTypeResponse
|
||||
)
|
||||
return structured.model_dump()
|
||||
except Exception as e:
|
||||
logger.error(f"LLM call failed for status_typle: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"LLM call failed: {str(e)}"
|
||||
}
|
||||
76
app/core/memory/agent/utils/verify_tool.py
Normal file
76
app/core/memory/agent/utils/verify_tool.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from typing import TypedDict, Annotated, List, Any
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph, add_messages
|
||||
import asyncio
|
||||
import json
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
import os
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from langchain_core.messages import HumanMessage
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
load_dotenv(find_dotenv())
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
def keep_last(_, right):
|
||||
return right
|
||||
class State(TypedDict):
|
||||
user_input: Annotated[dict, keep_last]
|
||||
messages: Annotated[List[AnyMessage], add_messages]
|
||||
agent1_response: str
|
||||
agent2_response: str
|
||||
agent3_response: str
|
||||
final_response: str
|
||||
status: Annotated[str, keep_last]
|
||||
|
||||
|
||||
class VerifyTool:
|
||||
def __init__(self, system_prompt: str="", verify_data: Any=None):
|
||||
self.system_prompt = system_prompt
|
||||
if isinstance(verify_data, str):
|
||||
self.verify_data = verify_data
|
||||
else:
|
||||
try:
|
||||
self.verify_data = json.dumps(verify_data, ensure_ascii=False)
|
||||
except Exception:
|
||||
self.verify_data = str(verify_data)
|
||||
|
||||
async def model_1(self, state: State) -> State:
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
response_content = await llm_client.chat(
|
||||
messages=[{"role": "system", "content": self.system_prompt}] + _to_openai_messages(state["messages"])
|
||||
)
|
||||
return {
|
||||
"agent1_response": response_content,
|
||||
"status": "processed",
|
||||
}
|
||||
|
||||
|
||||
def get_graph(self):
|
||||
graph = StateGraph(State)
|
||||
graph.add_node("model_1", self.model_1)
|
||||
|
||||
graph.add_edge(START, "model_1")
|
||||
graph.add_edge("model_1", END)
|
||||
|
||||
compiled_graph = graph.compile()
|
||||
return compiled_graph
|
||||
|
||||
async def verify(self):
|
||||
graph = self.get_graph()
|
||||
initial_state = {
|
||||
"user_input": self.verify_data,
|
||||
"messages": [HumanMessage(content=self.verify_data)],
|
||||
"final_response": "",
|
||||
"status": ""
|
||||
}
|
||||
final_state = await graph.ainvoke(initial_state)
|
||||
# return final_state["final_response"]
|
||||
return final_state["agent1_response"]
|
||||
|
||||
49
app/core/memory/agent/utils/write_to_database.py
Normal file
49
app/core/memory/agent/utils/write_to_database.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from sqlalchemy.orm import Session
|
||||
import logging
|
||||
import json
|
||||
|
||||
from app.db import get_db
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def write_to_database(host_id: uuid.UUID, data: Any) -> str:
|
||||
"""
|
||||
将数据写入数据库
|
||||
:param host_id: 宿主 ID
|
||||
:param data: 要写入的数据
|
||||
:return: 写入数据库的结果
|
||||
"""
|
||||
# 从数据库会话中获取会话
|
||||
db: Session = next(get_db())
|
||||
try:
|
||||
if isinstance(data, (dict, list)):
|
||||
serialized = json.dumps(data, ensure_ascii=False)
|
||||
elif isinstance(data, str):
|
||||
serialized = data
|
||||
else:
|
||||
serialized = str(data)
|
||||
|
||||
new_retrieval_info = RetrievalInfo(
|
||||
# host_id=host_id,
|
||||
host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"),
|
||||
retrieve_info=serialized,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
db.add(new_retrieval_info)
|
||||
db.commit()
|
||||
logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}")
|
||||
return "success to write data to database"
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}")
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
183
app/core/memory/agent/utils/write_tools.py
Normal file
183
app/core/memory/agent/utils/write_tools.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation_all,
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# 导入配置模块(而不是直接导入变量)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import Memory_summary_generation
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id: str = "wyl20251027", config_id: str = None) -> None:
|
||||
"""
|
||||
执行完整的知识提取流水线(使用新的 ExtractionOrchestrator)
|
||||
|
||||
Args:
|
||||
content: 对话内容
|
||||
user_id: 用户ID
|
||||
apply_id: 应用ID
|
||||
group_id: 组ID
|
||||
ref_id: 参考ID,默认为 "wyl20251027"
|
||||
config_id: 配置ID,用于标记数据处理配置
|
||||
"""
|
||||
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
||||
logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}")
|
||||
logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}")
|
||||
logger.info(f"Using chunker strategy: {config_defs.SELECTED_CHUNKER_STRATEGY}")
|
||||
logger.info(f"Using group ID: {config_defs.SELECTED_GROUP_ID}")
|
||||
logger.info(f"Using embedding ID: {config_defs.SELECTED_EMBEDDING_ID}")
|
||||
logger.info(f"Config ID: {config_id if config_id else 'None'}")
|
||||
logger.info(f"LANGFUSE_ENABLED: {config_defs.LANGFUSE_ENABLED}")
|
||||
logger.info(f"AGENTA_ENABLED: {config_defs.AGENTA_ENABLED}")
|
||||
|
||||
# Initialize timing log
|
||||
log_file = "logs/time.log"
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
|
||||
# 初始化客户端
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
# 获取 embedder 配置
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
# Step 1: 加载和分块数据
|
||||
step_start = time.time()
|
||||
chunked_dialogs = await get_chunked_dialogs(
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
content=content,
|
||||
ref_id=ref_id,
|
||||
config_id=config_id,
|
||||
)
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# Step 2: 初始化并运行 ExtractionOrchestrator
|
||||
step_start = time.time()
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 运行完整的提取流水线
|
||||
# orchestrator.run returns a flat tuple of 7 values after deduplication
|
||||
(
|
||||
all_dialogue_nodes,
|
||||
all_chunk_nodes,
|
||||
all_statement_nodes,
|
||||
all_entity_nodes,
|
||||
all_statement_chunk_edges,
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
# Step 8: Save all data to Neo4j database using graph models
|
||||
step_start = time.time()
|
||||
# 运行索引创建
|
||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||
try:
|
||||
await create_fulltext_indexes()
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
finally:
|
||||
await neo4j_connector.close()
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
# Step 9: Generate Memory summaries and save to local vector DB and Neo4j
|
||||
step_start = time.time()
|
||||
try:
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
|
||||
)
|
||||
|
||||
# Save memory summaries to Neo4j as nodes
|
||||
try:
|
||||
ms_connector = Neo4jConnector()
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
# Link summaries to statements via chunks for summary→entity queries
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
try:
|
||||
await ms_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
finally:
|
||||
log_time("Memory Summary (Local Vector DB & Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
|
||||
|
||||
# Log total pipeline time
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||
|
||||
# Add completion marker to log
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Timing details saved to: {log_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
content = "你好,我是张三,是张曼婷的新朋友。请问张曼婷喜欢什么?"
|
||||
asyncio.run(write(content, ref_id="wyl20251027"))
|
||||
Reference in New Issue
Block a user