feat: Add base project structure with API and web components
This commit is contained in:
0
api/app/core/memory/__init__.py
Normal file
0
api/app/core/memory/__init__.py
Normal file
0
api/app/core/memory/agent/__init__.py
Normal file
0
api/app/core/memory/agent/__init__.py
Normal file
16
api/app/core/memory/agent/langgraph_graph/__init__.py
Normal file
16
api/app/core/memory/agent/langgraph_graph/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
LangGraph Graph package for memory agent.
|
||||
|
||||
This package provides the LangGraph workflow orchestrator with modular
|
||||
node implementations, routing logic, and state management.
|
||||
|
||||
Package structure:
|
||||
- read_graph: Main graph factory for read operations
|
||||
- write_graph: Main graph factory for write operations
|
||||
- nodes: LangGraph node implementations
|
||||
- routing: State routing logic
|
||||
- state: State management utilities
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
|
||||
__all__ = ['make_read_graph']
|
||||
10
api/app/core/memory/agent/langgraph_graph/nodes/__init__.py
Normal file
10
api/app/core/memory/agent/langgraph_graph/nodes/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
LangGraph node implementations.
|
||||
|
||||
This module contains custom node implementations for the LangGraph workflow.
|
||||
"""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
|
||||
from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
|
||||
|
||||
__all__ = ["ToolExecutionNode", "create_input_message"]
|
||||
144
api/app/core/memory/agent/langgraph_graph/nodes/input_node.py
Normal file
144
api/app/core/memory/agent/langgraph_graph/nodes/input_node.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Input node for LangGraph workflow entry point.
|
||||
|
||||
This module provides the create_input_message function which processes initial
|
||||
user input with multimodal support and creates the first tool call message.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_input_message(
|
||||
state: Dict[str, Any],
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
multimodal_processor: MultimodalProcessor
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create initial tool call message from user input.
|
||||
|
||||
This function:
|
||||
1. Extracts the last message content from state
|
||||
2. Processes multimodal inputs (images/audio) using the multimodal processor
|
||||
3. Generates a unique message ID
|
||||
4. Extracts namespace from session_id
|
||||
5. Handles verified_data extraction for backward compatibility
|
||||
6. Returns AIMessage with complete tool_calls structure
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary containing messages
|
||||
tool_name: Name of the tool to invoke (typically "Split_The_Problem")
|
||||
session_id: Session identifier (format: "call_id_{namespace}")
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
multimodal_processor: Processor for handling image/audio inputs
|
||||
|
||||
Returns:
|
||||
State update with AIMessage containing tool_call
|
||||
|
||||
Examples:
|
||||
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
|
||||
>>> result = await create_input_message(
|
||||
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor
|
||||
... )
|
||||
>>> result["messages"][0].tool_calls[0]["name"]
|
||||
'Split_The_Problem'
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# Extract last message content
|
||||
if messages:
|
||||
last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1])
|
||||
else:
|
||||
logger.warning("[create_input_message] No messages in state, using empty string")
|
||||
last_message = ""
|
||||
|
||||
logger.debug(f"[create_input_message] Original input: {last_message[:100]}...")
|
||||
|
||||
# Process multimodal input (images/audio)
|
||||
try:
|
||||
processed_content = await multimodal_processor.process_input(last_message)
|
||||
if processed_content != last_message:
|
||||
logger.info(
|
||||
f"[create_input_message] Multimodal processing converted input "
|
||||
f"from {len(last_message)} to {len(processed_content)} chars"
|
||||
)
|
||||
last_message = processed_content
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[create_input_message] Multimodal processing failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Continue with original content
|
||||
|
||||
# Generate unique message ID
|
||||
uuid_str = uuid.uuid4()
|
||||
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Extract namespace from session_id
|
||||
# Expected format: "call_id_{namespace}" or similar
|
||||
try:
|
||||
namespace = str(session_id).split('_id_')[1]
|
||||
except (IndexError, AttributeError):
|
||||
logger.warning(
|
||||
f"[create_input_message] Could not extract namespace from session_id: {session_id}"
|
||||
)
|
||||
namespace = "unknown"
|
||||
|
||||
# Handle verified_data extraction (backward compatibility)
|
||||
# This regex-based extraction is kept for compatibility with existing data formats
|
||||
if 'verified_data' in str(last_message):
|
||||
try:
|
||||
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
|
||||
query_match = re.findall(r'"query": "(.*?)",', messages_last)
|
||||
if query_match:
|
||||
last_message = query_match[0]
|
||||
logger.debug(
|
||||
f"[create_input_message] Extracted query from verified_data: {last_message}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[create_input_message] Failed to extract query from verified_data: {e}"
|
||||
)
|
||||
|
||||
# Construct tool call message
|
||||
tool_call_id = f"{session_id}_{uuid_str}"
|
||||
|
||||
logger.info(
|
||||
f"[create_input_message] Creating tool call for '{tool_name}' "
|
||||
f"with ID: {tool_call_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": tool_name,
|
||||
"args": {
|
||||
"sentence": last_message,
|
||||
"sessionid": session_id,
|
||||
"messages_id": str(uuid_str),
|
||||
"search_switch": search_switch,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
},
|
||||
"id": tool_call_id
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
199
api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py
Normal file
199
api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Tool execution node for LangGraph workflow.
|
||||
|
||||
This module provides the ToolExecutionNode class which wraps tool execution
|
||||
with parameter transformation logic using the ParameterBuilder service.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||
extract_tool_call_id,
|
||||
extract_content_payload
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolExecutionNode:
|
||||
"""
|
||||
Custom LangGraph node that wraps tool execution with parameter transformation.
|
||||
|
||||
This node extracts content from previous tool results, transforms parameters
|
||||
based on tool type using ParameterBuilder, and invokes the tool with the
|
||||
correct argument structure.
|
||||
|
||||
Attributes:
|
||||
tool_node: LangGraph ToolNode wrapping the actual tool
|
||||
id: Node identifier for message IDs
|
||||
tool_name: Name of the tool being executed
|
||||
namespace: Namespace for session management
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool: Callable,
|
||||
node_id: str,
|
||||
namespace: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
parameter_builder: ParameterBuilder,
|
||||
storage_type:str,
|
||||
user_rag_memory_id:str
|
||||
):
|
||||
"""
|
||||
Initialize the tool execution node.
|
||||
|
||||
Args:
|
||||
tool: The tool function to execute
|
||||
node_id: Identifier for this node (used in message IDs)
|
||||
namespace: Namespace for session management
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
"""
|
||||
self.tool_node = ToolNode([tool])
|
||||
self.id = node_id
|
||||
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
self.namespace = namespace
|
||||
self.search_switch = search_switch
|
||||
self.apply_id = apply_id
|
||||
self.group_id = group_id
|
||||
self.parameter_builder = parameter_builder
|
||||
self.storage_type=storage_type
|
||||
self.user_rag_memory_id=user_rag_memory_id
|
||||
|
||||
logger.info(
|
||||
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
|
||||
)
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the tool with transformed parameters.
|
||||
|
||||
This method:
|
||||
1. Extracts the last message from state
|
||||
2. Extracts tool call ID using state extractors
|
||||
3. Extracts content payload using state extractors
|
||||
4. Builds tool arguments using parameter builder
|
||||
5. Constructs AIMessage with tool_calls
|
||||
6. Invokes the tool and returns the result
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Updated state with tool result in messages
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
logger.debug( self.tool_name)
|
||||
|
||||
if not messages:
|
||||
logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state")
|
||||
return {"messages": [AIMessage(content="Error: No messages in state")]}
|
||||
|
||||
last_message = messages[-1]
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract tool call ID using state extractors
|
||||
tool_call_id = extract_tool_call_id(last_message)
|
||||
logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}"
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
|
||||
|
||||
try:
|
||||
# Extract content payload using state extractors
|
||||
content = extract_content_payload(last_message)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
content = {}
|
||||
|
||||
try:
|
||||
# Build tool arguments using parameter builder
|
||||
tool_args = self.parameter_builder.build_tool_args(
|
||||
tool_name=self.tool_name,
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
search_switch=self.search_switch,
|
||||
apply_id=self.apply_id,
|
||||
group_id=self.group_id,
|
||||
storage_type=self.storage_type,
|
||||
user_rag_memory_id=self.user_rag_memory_id
|
||||
)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]}
|
||||
|
||||
# Construct tool input message
|
||||
tool_input = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": self.tool_name,
|
||||
"args": tool_args,
|
||||
"id": f"{self.id}_{tool_call_id}",
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
try:
|
||||
# Invoke the tool
|
||||
result = await self.tool_node.ainvoke(tool_input)
|
||||
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution completed"
|
||||
)
|
||||
|
||||
# Return the result directly - it already contains the messages list
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return error as ToolMessage to maintain message chain consistency
|
||||
from langchain_core.messages import ToolMessage
|
||||
return {
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Error executing tool: {str(e)}",
|
||||
tool_call_id=f"{self.id}_{tool_call_id}"
|
||||
)
|
||||
]
|
||||
}
|
||||
508
api/app/core/memory/agent/langgraph_graph/read_graph.py
Normal file
508
api/app/core/memory/agent/langgraph_graph/read_graph.py
Normal file
@@ -0,0 +1,508 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from functools import partial
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
# Import new modular components
|
||||
from app.core.memory.agent.langgraph_graph.nodes import ToolExecutionNode, create_input_message
|
||||
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
Verify_continue,
|
||||
Retrieve_continue,
|
||||
Split_continue
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
load_dotenv()
|
||||
redishost=os.getenv("REDISHOST")
|
||||
redisport=os.getenv('REDISPORT')
|
||||
redisdb=os.getenv('REDISDB')
|
||||
redispassword=os.getenv('REDISPASSWORD')
|
||||
counter = COUNTState(limit=3)
|
||||
|
||||
# 在工作流中添加循环计数更新
|
||||
async def update_loop_count(state):
|
||||
"""更新循环计数器"""
|
||||
current_count = state.get("loop_count", 0)
|
||||
return {"loop_count": current_count + 1}
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
messages = state["messages"]
|
||||
|
||||
# 添加边界检查
|
||||
if not messages:
|
||||
return END
|
||||
counter.add(1) # 累加 1
|
||||
|
||||
loop_count = counter.get_total()
|
||||
logger.debug(f"[should_continue] 当前循环次数: {loop_count}")
|
||||
|
||||
last_message = messages[-1]
|
||||
last_message_str = str(last_message).replace('\\', '')
|
||||
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
|
||||
logger.debug(f"Status tools: {status_tools}")
|
||||
|
||||
if "success" in status_tools:
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
elif "failed" in status_tools:
|
||||
if loop_count < 2: # 最大循环次数 3
|
||||
return "content_input"
|
||||
else:
|
||||
counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# 添加默认返回值,避免返回 None
|
||||
counter.reset()
|
||||
return "Summary" # 或根据业务需求选择合适的默认值
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
# Direct dictionary access instead of regex parsing
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
# Handle case where search_switch might be in messages
|
||||
if search_switch is None and "messages" in state:
|
||||
messages = state.get("messages", [])
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
# Try to extract from tool_calls args
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
break
|
||||
|
||||
# Convert to string for comparison if needed
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '0':
|
||||
return 'Verify'
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
# 添加默认返回值,避免返回 None
|
||||
return 'Retrieve_Summary' # 或根据业务逻辑选择合适的默认值
|
||||
|
||||
|
||||
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
logger.debug(f"Split_continue state: {state}")
|
||||
|
||||
# Direct dictionary access instead of regex parsing
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
# Handle case where search_switch might be in messages
|
||||
if search_switch is None and "messages" in state:
|
||||
messages = state.get("messages", [])
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
# Try to extract from tool_calls args
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
break
|
||||
|
||||
# Convert to string for comparison if needed
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
# 在 input_sentence 函数中修改参数名称
|
||||
async def input_sentence(state, name, id, search_switch,apply_id,group_id):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1].content if messages else ""
|
||||
|
||||
if last_message.endswith('.jpg') or last_message.endswith('.png'):
|
||||
last_message=await picture_model_requests(last_message)
|
||||
if any(last_message.endswith(ext) for ext in audio_extensions):
|
||||
last_message=await Vico_recognition([last_message]).run()
|
||||
logger.debug(f"Audio recognition result: {last_message}")
|
||||
|
||||
|
||||
uuid_str = uuid.uuid4()
|
||||
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
namespace = str(id).split('_id_')[1]
|
||||
if 'verified_data' in str(last_message):
|
||||
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
|
||||
last_message = re.findall(r'"query": "(.*?)",', str(messages_last))[0]
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": name,
|
||||
"args": {
|
||||
"sentence": last_message,
|
||||
'sessionid': id,
|
||||
'messages_id': str(uuid_str),
|
||||
"search_switch": search_switch, # 正确地将 search_switch 放入 args 中
|
||||
"apply_id":apply_id,
|
||||
"group_id":group_id
|
||||
},
|
||||
"id": id + f'_{uuid_str}'
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class ProblemExtensionNode:
|
||||
def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""):
|
||||
self.tool_node = ToolNode([tool])
|
||||
self.id = id
|
||||
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
self.namespace = namespace
|
||||
self.search_switch = search_switch
|
||||
self.apply_id = apply_id
|
||||
self.group_id = group_id
|
||||
self.storage_type = storage_type
|
||||
self.user_rag_memory_id = user_rag_memory_id
|
||||
|
||||
async def __call__(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1] if messages else ""
|
||||
logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}")
|
||||
if self.tool_name=='Input_Summary':
|
||||
tool_call =re.findall(f"'id': '(.*?)'",str(last_message))[0]
|
||||
else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
|
||||
# try:
|
||||
# content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message
|
||||
# except:
|
||||
# content = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||
# 尝试从上一工具的结果中提取实际的内容载荷(而不是整个对象的字符串表示)
|
||||
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||
extracted_payload = None
|
||||
# 捕获 ToolMessage 的 content 字段(支持单/双引号),并避免贪婪匹配
|
||||
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
|
||||
if m:
|
||||
extracted_payload = m.group(1)
|
||||
else:
|
||||
# 回退:直接尝试使用原始字符串
|
||||
extracted_payload = raw_msg
|
||||
|
||||
# 优先尝试将内容解析为 JSON
|
||||
try:
|
||||
content = json.loads(extracted_payload)
|
||||
except Exception:
|
||||
# 尝试从文本中提取 JSON 片段再解析
|
||||
parsed = None
|
||||
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
|
||||
for cand in candidates:
|
||||
try:
|
||||
parsed = json.loads(cand)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
# 如果仍然失败,则以原始字符串作为内容
|
||||
content = parsed if parsed is not None else extracted_payload
|
||||
|
||||
# 根据工具名称构建正确的参数
|
||||
tool_args = {}
|
||||
|
||||
if self.tool_name == "Verify":
|
||||
# Verify工具需要context和usermessages参数
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Retrieve":
|
||||
# Retrieve工具需要context和usermessages参数
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["search_switch"] = str(self.search_switch)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary":
|
||||
# Summary工具需要字符串类型的context参数
|
||||
if isinstance(content, dict):
|
||||
# 将字典转换为JSON字符串
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary_fails":
|
||||
# Summary工具需要字符串类型的context参数
|
||||
if isinstance(content, dict):
|
||||
# 将字典转换为JSON字符串
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name=='Input_Summary':
|
||||
tool_args["context"] =str(last_message)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["search_switch"] = str(self.search_switch)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
tool_args["storage_type"] = getattr(self, 'storage_type', "")
|
||||
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
|
||||
elif self.tool_name=='Retrieve_Summary' :
|
||||
# Retrieve_Summary expects dict directly, not JSON string
|
||||
# content might be a JSON string, try to parse it
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
# Check if it has a "context" key
|
||||
if isinstance(parsed_content, dict) and "context" in parsed_content:
|
||||
tool_args["context"] = parsed_content["context"]
|
||||
else:
|
||||
tool_args["context"] = parsed_content
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, wrap the string
|
||||
tool_args["context"] = {"content": content}
|
||||
elif isinstance(content, dict):
|
||||
# Check if content has a "context" key that needs unwrapping
|
||||
if "context" in content:
|
||||
tool_args["context"] = content["context"]
|
||||
else:
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": str(content)}
|
||||
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
else:
|
||||
# 其他工具使用context参数
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
|
||||
|
||||
tool_input = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": self.tool_name,
|
||||
"args": tool_args,
|
||||
"id": self.id + f"{tool_call}",
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
result = await self.tool_node.ainvoke(tool_input)
|
||||
result_text = str(result)
|
||||
|
||||
return {"messages": [AIMessage(content=result_text)]}
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config_id=None,storage_type=None,user_rag_memory_id=None):
|
||||
memory = InMemorySaver()
|
||||
tool=[i.name for i in tools ]
|
||||
logger.info(f"Initializing read graph with tools: {tool}")
|
||||
if config_id:
|
||||
logger.info(f"使用配置 ID: {config_id}")
|
||||
|
||||
# Extract tool functions
|
||||
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
|
||||
Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None)
|
||||
Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None)
|
||||
Verify_ = next((t for t in tools if t.name == "Verify"), None)
|
||||
Summary_ = next((t for t in tools if t.name == "Summary"), None)
|
||||
Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None)
|
||||
Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None)
|
||||
Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None)
|
||||
|
||||
# Instantiate services
|
||||
parameter_builder = ParameterBuilder()
|
||||
multimodal_processor = MultimodalProcessor()
|
||||
|
||||
# Create nodes using new modular components
|
||||
Split_The_Problem_node = ToolNode([Split_The_Problem_])
|
||||
|
||||
Problem_Extension_node = ToolExecutionNode(
|
||||
tool=Problem_Extension_,
|
||||
node_id="Problem_Extension_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Retrieve_node = ToolExecutionNode(
|
||||
tool=Retrieve_,
|
||||
node_id="Retrieve_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Verify_node = ToolExecutionNode(
|
||||
tool=Verify_,
|
||||
node_id="Verify_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Summary_node = ToolExecutionNode(
|
||||
tool=Summary_,
|
||||
node_id="Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Summary_fails_node = ToolExecutionNode(
|
||||
tool=Summary_fails_,
|
||||
node_id="Summary_fails_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Retrieve_Summary_node = ToolExecutionNode(
|
||||
tool=Retrieve_Summary_,
|
||||
node_id="Retrieve_Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
Input_Summary_node = ToolExecutionNode(
|
||||
tool=Input_Summary_,
|
||||
node_id="Input_Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
|
||||
async def content_input_node(state):
|
||||
state_search_switch = state.get("search_switch", search_switch)
|
||||
|
||||
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
|
||||
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
|
||||
|
||||
return await create_input_message(
|
||||
state=state,
|
||||
tool_name=tool_name,
|
||||
session_id=f"{session_prefix}_{namespace}",
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
multimodal_processor=multimodal_processor
|
||||
)
|
||||
|
||||
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow.add_node("content_input", content_input_node)
|
||||
workflow.add_node("Split_The_Problem", Split_The_Problem_node)
|
||||
workflow.add_node("Problem_Extension", Problem_Extension_node)
|
||||
workflow.add_node("Retrieve", Retrieve_node)
|
||||
workflow.add_node("Verify", Verify_node)
|
||||
workflow.add_node("Summary", Summary_node)
|
||||
workflow.add_node("Summary_fails", Summary_fails_node)
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary_node)
|
||||
workflow.add_node("Input_Summary", Input_Summary_node)
|
||||
|
||||
# Add edges using imported routers
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
graph = workflow.compile(checkpointer=memory)
|
||||
yield graph
|
||||
|
||||
|
||||
# 添加到文件末尾或创建新的执行脚本
|
||||
# 在 memory_agent_service.py 文件中添加以下函数
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""LangGraph routing logic."""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
Verify_continue,
|
||||
Retrieve_continue,
|
||||
Split_continue,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Verify_continue",
|
||||
"Retrieve_continue",
|
||||
"Split_continue",
|
||||
]
|
||||
123
api/app/core/memory/agent/langgraph_graph/routing/routers.py
Normal file
123
api/app/core/memory/agent/langgraph_graph/routing/routers.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Routing functions for LangGraph conditional edges.
|
||||
|
||||
This module provides routing functions that determine the next node to execute
|
||||
based on state values. All functions return Literal types for type safety.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global counter for Verify routing
|
||||
counter = COUNTState(limit=3)
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
"""
|
||||
Determine routing after Verify node based on verification result.
|
||||
|
||||
This function checks the verification result in the last message and routes to:
|
||||
- Summary: if verification succeeded
|
||||
- content_input: if verification failed and retry limit not reached
|
||||
- Summary_fails: if verification failed and retry limit reached
|
||||
|
||||
Args:
|
||||
state: LangGraph state containing messages
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# Boundary check
|
||||
if not messages:
|
||||
logger.warning("[Verify_continue] No messages in state, defaulting to Summary")
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
|
||||
# Increment counter
|
||||
counter.add(1)
|
||||
loop_count = counter.get_total()
|
||||
logger.debug(f"[Verify_continue] Current loop count: {loop_count}")
|
||||
|
||||
# Extract verification result from last message
|
||||
last_message = messages[-1]
|
||||
last_message_str = str(last_message).replace('\\', '')
|
||||
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
|
||||
logger.debug(f"[Verify_continue] Status tools: {status_tools}")
|
||||
|
||||
# Route based on verification result
|
||||
if "success" in status_tools:
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
elif "failed" in status_tools:
|
||||
if loop_count < 2: # Max retry count is 2
|
||||
return "content_input"
|
||||
else:
|
||||
counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Default to Summary if status is unclear
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
|
||||
|
||||
def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing after Retrieve node based on search_switch value.
|
||||
|
||||
This function routes based on the search_switch parameter:
|
||||
- search_switch == '0': Route to Verify (verification needed)
|
||||
- search_switch == '1': Route to Retrieve_Summary (direct summary)
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
search_switch = extract_search_switch(state)
|
||||
|
||||
logger.debug(f"[Retrieve_continue] search_switch: {search_switch}")
|
||||
|
||||
if search_switch == '0':
|
||||
return 'Verify'
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
# Default to Retrieve_Summary
|
||||
logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary")
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
|
||||
def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing after content_input node based on search_switch value.
|
||||
|
||||
This function routes based on the search_switch parameter:
|
||||
- search_switch == '2': Route to Input_Summary (direct input summary)
|
||||
- Otherwise: Route to Split_The_Problem (problem decomposition)
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
logger.debug(f"[Split_continue] state keys: {state.keys()}")
|
||||
|
||||
search_switch = extract_search_switch(state)
|
||||
|
||||
logger.debug(f"[Split_continue] search_switch: {search_switch}")
|
||||
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
|
||||
# Default to Split_The_Problem
|
||||
return 'Split_The_Problem'
|
||||
13
api/app/core/memory/agent/langgraph_graph/state/__init__.py
Normal file
13
api/app/core/memory/agent/langgraph_graph/state/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""LangGraph state management utilities."""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||
extract_search_switch,
|
||||
extract_tool_call_id,
|
||||
extract_content_payload,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"extract_search_switch",
|
||||
"extract_tool_call_id",
|
||||
"extract_content_payload",
|
||||
]
|
||||
164
api/app/core/memory/agent/langgraph_graph/state/extractors.py
Normal file
164
api/app/core/memory/agent/langgraph_graph/state/extractors.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
State extraction utilities for type-safe access to LangGraph state values.
|
||||
|
||||
This module provides utility functions for extracting values from LangGraph state
|
||||
dictionaries with proper error handling and sensible defaults.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_search_switch(state: dict) -> Optional[str]:
|
||||
"""
|
||||
Extract search_switch from state or messages.
|
||||
"""
|
||||
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
|
||||
# Try to extract from messages
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 从最新的消息开始查找
|
||||
for message in reversed(messages):
|
||||
# 尝试从 tool_calls 中提取
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
# 从 tool_call 的 args 中提取
|
||||
if "args" in tool_call and isinstance(tool_call["args"], dict):
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
# 直接从 tool_call 中提取
|
||||
search_switch = tool_call.get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
|
||||
# 尝试从 content 中提取(如果是 JSON 格式)
|
||||
if hasattr(message, "content"):
|
||||
try:
|
||||
import json
|
||||
if isinstance(message.content, str):
|
||||
content_data = json.loads(message.content)
|
||||
if isinstance(content_data, dict):
|
||||
search_switch = content_data.get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_tool_call_id(message: Any) -> str:
|
||||
"""
|
||||
Extract tool call ID from message using structured attributes.
|
||||
|
||||
This function extracts the tool call ID from a message object, handling both
|
||||
direct attribute access and tool_calls list structures.
|
||||
|
||||
Args:
|
||||
message: Message object (typically ToolMessage or AIMessage)
|
||||
|
||||
Returns:
|
||||
Tool call ID as string
|
||||
|
||||
Raises:
|
||||
ValueError: If tool call ID cannot be extracted
|
||||
|
||||
Examples:
|
||||
>>> message = ToolMessage(content="...", tool_call_id="call_123")
|
||||
>>> extract_tool_call_id(message)
|
||||
'call_123'
|
||||
"""
|
||||
# Try direct attribute access for ToolMessage
|
||||
if hasattr(message, "tool_call_id"):
|
||||
tool_call_id = message.tool_call_id
|
||||
if tool_call_id:
|
||||
return str(tool_call_id)
|
||||
|
||||
# Try extracting from tool_calls list for AIMessage
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
tool_call = message.tool_calls[0]
|
||||
if isinstance(tool_call, dict) and "id" in tool_call:
|
||||
return str(tool_call["id"])
|
||||
|
||||
# Try extracting from id attribute
|
||||
if hasattr(message, "id"):
|
||||
message_id = message.id
|
||||
if message_id:
|
||||
return str(message_id)
|
||||
|
||||
# If all else fails, raise an error
|
||||
raise ValueError(f"Could not extract tool call ID from message: {type(message)}")
|
||||
|
||||
|
||||
def extract_content_payload(message: Any) -> Any:
|
||||
"""
|
||||
Extract content payload from ToolMessage, parsing JSON if needed.
|
||||
|
||||
This function extracts the content from a message and attempts to parse it as JSON
|
||||
if it appears to be a JSON string. It handles various message formats and provides
|
||||
sensible fallbacks.
|
||||
|
||||
Args:
|
||||
message: Message object (typically ToolMessage)
|
||||
|
||||
Returns:
|
||||
Parsed content (dict, list, or str)
|
||||
|
||||
Examples:
|
||||
>>> message = ToolMessage(content='{"key": "value"}')
|
||||
>>> extract_content_payload(message)
|
||||
{'key': 'value'}
|
||||
|
||||
>>> message = ToolMessage(content='plain text')
|
||||
>>> extract_content_payload(message)
|
||||
'plain text'
|
||||
"""
|
||||
# Extract raw content
|
||||
# For ToolMessages (responses from tools), extract from content
|
||||
if hasattr(message, "content"):
|
||||
raw_content = message.content
|
||||
|
||||
# If content is empty and this is an AIMessage with tool_calls,
|
||||
# extract from args (this handles the initial tool call from content_input)
|
||||
if not raw_content and hasattr(message, "tool_calls") and message.tool_calls:
|
||||
tool_call = message.tool_calls[0]
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
return tool_call["args"]
|
||||
else:
|
||||
raw_content = str(message)
|
||||
|
||||
# If content is already a dict or list, return it directly
|
||||
if isinstance(raw_content, (dict, list)):
|
||||
return raw_content
|
||||
|
||||
# Try to parse as JSON
|
||||
if isinstance(raw_content, str):
|
||||
# First, try direct JSON parsing
|
||||
try:
|
||||
return json.loads(raw_content)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# If that fails, try to extract JSON from the string
|
||||
# This handles cases where the content is embedded in a larger string
|
||||
import re
|
||||
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
|
||||
for candidate in json_candidates:
|
||||
try:
|
||||
return json.loads(candidate)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
|
||||
# If all parsing attempts fail, return the raw content
|
||||
return raw_content
|
||||
78
api/app/core/memory/agent/langgraph_graph/write_graph.py
Normal file
78
api/app/core/memory/agent/langgraph_graph/write_graph.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import add_messages, StateGraph
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
import warnings
|
||||
import sys
|
||||
from langchain_core.messages import AIMessage
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
import asyncio
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
@asynccontextmanager
|
||||
async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None):
|
||||
logger.info("加载 MCP 工具: %s", [t.name for t in tools])
|
||||
if config_id:
|
||||
logger.info(f"使用配置 ID: {config_id}")
|
||||
|
||||
data_type_tool = next((t for t in tools if t.name == "Data_type_differentiation"), None)
|
||||
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
|
||||
|
||||
if not data_type_tool or not data_write_tool:
|
||||
logger.error('不存在数据存储工具', exc_info=True)
|
||||
raise ValueError('不存在数据存储工具')
|
||||
# ToolNode
|
||||
write_node = ToolNode([data_write_tool])
|
||||
|
||||
|
||||
async def call_model(state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
|
||||
result = await data_type_tool.ainvoke({
|
||||
"context": last_message[1] if isinstance(last_message, tuple) else last_message.content
|
||||
})
|
||||
result=json.loads( result)
|
||||
|
||||
# 调用 Data_write,传递 config_id
|
||||
write_params = {
|
||||
"content": result["context"],
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
# 如果提供了 config_id,添加到参数中
|
||||
if config_id:
|
||||
write_params["config_id"] = config_id
|
||||
logger.debug(f"传递 config_id 到 Data_write: {config_id}")
|
||||
|
||||
write_result = await data_write_tool.ainvoke(write_params)
|
||||
|
||||
if isinstance(write_result, dict):
|
||||
content = write_result.get("data", str(write_result))
|
||||
else:
|
||||
content = str(write_result)
|
||||
logger.info("写入内容: %s", content)
|
||||
return {"messages": [AIMessage(content=content)]}
|
||||
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("content_input", call_model)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_edge("content_input", "save_neo4j")
|
||||
workflow.add_edge("save_neo4j", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
|
||||
yield graph
|
||||
285
api/app/core/memory/agent/logger_file/log_streamer.py
Normal file
285
api/app/core/memory/agent/logger_file/log_streamer.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Log Streamer Module
|
||||
|
||||
Manages streaming of log file content with file watching and real-time transmission.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LogStreamer:
|
||||
"""Manages log file streaming with file watching and content transmission"""
|
||||
|
||||
def __init__(self, log_path: str, keepalive_interval: int = 300):
|
||||
"""
|
||||
Initialize LogStreamer
|
||||
|
||||
Args:
|
||||
log_path: Path to the log file to stream
|
||||
keepalive_interval: Interval in seconds for sending keepalive messages (default: 300)
|
||||
"""
|
||||
self.log_path = log_path
|
||||
self.keepalive_interval = keepalive_interval
|
||||
self.last_position = 0
|
||||
|
||||
# Pattern to match and remove timestamp and log level prefix
|
||||
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
|
||||
# This pattern is comprehensive to handle various log formats
|
||||
self.pattern = re.compile(
|
||||
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
|
||||
)
|
||||
|
||||
logger.info(f"LogStreamer initialized for {log_path}")
|
||||
|
||||
@staticmethod
|
||||
def clean_log_line(line: str) -> str:
|
||||
"""
|
||||
Static method to clean log entry by removing timestamp and log level prefix.
|
||||
This is the canonical log cleaning method used by both file mode and transmission mode.
|
||||
|
||||
Args:
|
||||
line: Raw log line
|
||||
|
||||
Returns:
|
||||
Cleaned log line without timestamp and log level prefix
|
||||
"""
|
||||
# Pattern to match and remove timestamp and log level prefix
|
||||
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
|
||||
pattern = re.compile(
|
||||
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
|
||||
)
|
||||
cleaned = re.sub(pattern, '', line)
|
||||
return cleaned
|
||||
|
||||
def clean_log_entry(self, line: str) -> str:
|
||||
"""
|
||||
Clean log entry by removing timestamp and log level prefix.
|
||||
This instance method delegates to the static method for consistency.
|
||||
|
||||
Args:
|
||||
line: Raw log line
|
||||
|
||||
Returns:
|
||||
Cleaned log line without timestamp and log level prefix
|
||||
"""
|
||||
return LogStreamer.clean_log_line(line)
|
||||
|
||||
async def send_keepalive(self) -> dict:
|
||||
"""
|
||||
Generate keepalive message
|
||||
|
||||
Returns:
|
||||
Keepalive message dict with timestamp
|
||||
"""
|
||||
return {
|
||||
"event": "keepalive",
|
||||
"data": {
|
||||
"timestamp": int(time.time())
|
||||
}
|
||||
}
|
||||
|
||||
async def read_existing_and_stream(self) -> AsyncGenerator[dict, None]:
|
||||
"""
|
||||
Read existing log content first, then watch for new content
|
||||
|
||||
This method reads all existing content in the file first,
|
||||
then continues to watch for new content as it's written.
|
||||
|
||||
Yields:
|
||||
Dict messages with event type and data:
|
||||
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
|
||||
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
|
||||
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
|
||||
- done events: {"event": "done", "data": {"message": "..."}}
|
||||
"""
|
||||
logger.info(f"Starting log stream (read existing) for {self.log_path}")
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(self.log_path):
|
||||
logger.error(f"Log file not found: {self.log_path}")
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 4006,
|
||||
"message": "日志文件不存在",
|
||||
"error": f"File not found: {self.log_path}"
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.log_path, 'r', encoding='utf-8') as f:
|
||||
# First, read all existing content
|
||||
for line in f:
|
||||
if line.strip(): # Skip empty lines
|
||||
cleaned_line = self.clean_log_entry(line)
|
||||
yield {
|
||||
"event": "log",
|
||||
"data": {
|
||||
"content": cleaned_line.rstrip('\n'),
|
||||
"timestamp": int(time.time())
|
||||
}
|
||||
}
|
||||
|
||||
# Now watch for new content
|
||||
self.last_position = f.tell()
|
||||
last_keepalive = time.time()
|
||||
|
||||
while True:
|
||||
line = f.readline()
|
||||
if line:
|
||||
cleaned_line = self.clean_log_entry(line)
|
||||
yield {
|
||||
"event": "log",
|
||||
"data": {
|
||||
"content": cleaned_line.rstrip('\n'),
|
||||
"timestamp": int(time.time())
|
||||
}
|
||||
}
|
||||
last_keepalive = time.time()
|
||||
else:
|
||||
# No new content, check if we need to send keepalive
|
||||
current_time = time.time()
|
||||
if current_time - last_keepalive >= self.keepalive_interval:
|
||||
keepalive_msg = await self.send_keepalive()
|
||||
yield keepalive_msg
|
||||
last_keepalive = current_time
|
||||
|
||||
# Sleep briefly before checking again
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Log file disappeared during streaming: {self.log_path}")
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 4006,
|
||||
"message": "日志文件在流式传输期间变得不可用",
|
||||
"error": "File not found during streaming"
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error during log streaming: {e}", exc_info=True)
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 8001,
|
||||
"message": "流式传输期间发生错误",
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
finally:
|
||||
logger.info(f"Log stream ended for {self.log_path}")
|
||||
yield {
|
||||
"event": "done",
|
||||
"data": {
|
||||
"message": "流式传输完成"
|
||||
}
|
||||
}
|
||||
|
||||
async def watch_and_stream(self) -> AsyncGenerator[dict, None]:
|
||||
"""
|
||||
Watch log file and stream only new content as it's written
|
||||
|
||||
This method starts from the end of the file and only streams
|
||||
new content that is written after the stream starts.
|
||||
|
||||
Yields:
|
||||
Dict messages with event type and data:
|
||||
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
|
||||
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
|
||||
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
|
||||
- done events: {"event": "done", "data": {"message": "..."}}
|
||||
"""
|
||||
logger.info(f"Starting log stream (new content only) for {self.log_path}")
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(self.log_path):
|
||||
logger.error(f"Log file not found: {self.log_path}")
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 4006,
|
||||
"message": "日志文件不存在",
|
||||
"error": f"File not found: {self.log_path}"
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
try:
|
||||
# Open file and seek to end to start streaming new content
|
||||
with open(self.log_path, 'r', encoding='utf-8') as f:
|
||||
# Move to end of file
|
||||
f.seek(0, os.SEEK_END)
|
||||
self.last_position = f.tell()
|
||||
|
||||
last_keepalive = time.time()
|
||||
|
||||
while True:
|
||||
# Check if file has new content
|
||||
current_position = f.tell()
|
||||
|
||||
# Read new lines if available
|
||||
line = f.readline()
|
||||
if line:
|
||||
# Clean the log entry
|
||||
cleaned_line = self.clean_log_entry(line)
|
||||
|
||||
# Yield log event
|
||||
yield {
|
||||
"event": "log",
|
||||
"data": {
|
||||
"content": cleaned_line.rstrip('\n'),
|
||||
"timestamp": int(time.time())
|
||||
}
|
||||
}
|
||||
|
||||
# Update last keepalive time since we sent data
|
||||
last_keepalive = time.time()
|
||||
else:
|
||||
# No new content, check if we need to send keepalive
|
||||
current_time = time.time()
|
||||
if current_time - last_keepalive >= self.keepalive_interval:
|
||||
keepalive_msg = await self.send_keepalive()
|
||||
yield keepalive_msg
|
||||
last_keepalive = current_time
|
||||
|
||||
# Sleep briefly before checking again
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Log file disappeared during streaming: {self.log_path}")
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 4006,
|
||||
"message": "日志文件在流式传输期间变得不可用",
|
||||
"error": "File not found during streaming"
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error during log streaming: {e}", exc_info=True)
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"code": 8001,
|
||||
"message": "流式传输期间发生错误",
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
finally:
|
||||
logger.info(f"Log stream ended for {self.log_path}")
|
||||
yield {
|
||||
"event": "done",
|
||||
"data": {
|
||||
"message": "流式传输完成"
|
||||
}
|
||||
}
|
||||
32
api/app/core/memory/agent/logger_file/logger_data.py
Normal file
32
api/app/core/memory/agent/logger_file/logger_data.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Agent logger module for backward compatibility.
|
||||
|
||||
This module maintains the get_named_logger() function for backward compatibility
|
||||
while delegating to the centralized logging configuration.
|
||||
|
||||
All new code should import directly from app.core.logging_config instead.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__author__ = "RED_BEAR"
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
|
||||
def get_named_logger(name):
|
||||
"""Get a named logger for agent operations.
|
||||
|
||||
This function maintains backward compatibility with existing code.
|
||||
It delegates to the centralized get_agent_logger() function.
|
||||
|
||||
Args:
|
||||
name: Logger name for namespacing
|
||||
|
||||
Returns:
|
||||
Logger configured for agent operations
|
||||
|
||||
Example:
|
||||
>>> logger = get_named_logger("my_agent")
|
||||
>>> logger.info("Agent operation started")
|
||||
"""
|
||||
return get_agent_logger(name)
|
||||
28
api/app/core/memory/agent/mcp_server/__init__.py
Normal file
28
api/app/core/memory/agent/mcp_server/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
MCP Server package for memory agent.
|
||||
|
||||
This package provides the FastMCP server implementation with context-based
|
||||
dependency injection for tool functions.
|
||||
|
||||
Package structure:
|
||||
- server: FastMCP server initialization and context setup
|
||||
- tools: MCP tool implementations
|
||||
- models: Pydantic response models
|
||||
- services: Business logic services
|
||||
"""
|
||||
from app.core.memory.agent.mcp_server.server import (
|
||||
mcp,
|
||||
initialize_context,
|
||||
main,
|
||||
get_context_resource
|
||||
)
|
||||
|
||||
# Import tools to register them (but don't export them)
|
||||
from app.core.memory.agent.mcp_server import tools
|
||||
|
||||
__all__ = [
|
||||
'mcp',
|
||||
'initialize_context',
|
||||
'main',
|
||||
'get_context_resource',
|
||||
]
|
||||
11
api/app/core/memory/agent/mcp_server/mcp_instance.py
Normal file
11
api/app/core/memory/agent/mcp_server/mcp_instance.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
MCP Server Instance
|
||||
|
||||
This module contains the FastMCP server instance that is shared across all modules.
|
||||
It's in a separate file to avoid circular import issues.
|
||||
"""
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Initialize FastMCP server instance
|
||||
# This instance is shared across all tool modules
|
||||
mcp = FastMCP('data_flow')
|
||||
30
api/app/core/memory/agent/mcp_server/models/__init__.py
Normal file
30
api/app/core/memory/agent/mcp_server/models/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Pydantic models for MCP server responses."""
|
||||
|
||||
from .problem_models import (
|
||||
ProblemBreakdownItem,
|
||||
ProblemBreakdownResponse,
|
||||
ExtendedQuestionItem,
|
||||
ProblemExtensionResponse,
|
||||
)
|
||||
from .summary_models import (
|
||||
SummaryData,
|
||||
SummaryResponse,
|
||||
RetrieveSummaryData,
|
||||
RetrieveSummaryResponse,
|
||||
)
|
||||
from .verification_models import VerificationResult
|
||||
from .retrieval_models import RetrievalResult, DistinguishTypeResponse
|
||||
|
||||
__all__ = [
|
||||
"ProblemBreakdownItem",
|
||||
"ProblemBreakdownResponse",
|
||||
"ExtendedQuestionItem",
|
||||
"ProblemExtensionResponse",
|
||||
"SummaryData",
|
||||
"SummaryResponse",
|
||||
"RetrieveSummaryData",
|
||||
"RetrieveSummaryResponse",
|
||||
"VerificationResult",
|
||||
"RetrievalResult",
|
||||
"DistinguishTypeResponse",
|
||||
]
|
||||
@@ -0,0 +1,34 @@
|
||||
"""Pydantic models for problem breakdown and extension operations."""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
|
||||
class ProblemBreakdownItem(BaseModel):
|
||||
"""Individual item in problem breakdown response."""
|
||||
|
||||
id: str
|
||||
question: str
|
||||
type: str
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class ProblemBreakdownResponse(RootModel[List[ProblemBreakdownItem]]):
|
||||
"""Response model for problem breakdown containing list of breakdown items."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExtendedQuestionItem(BaseModel):
|
||||
"""Individual extended question item with reasoning."""
|
||||
|
||||
original_question: str = Field(..., description="原始初步问题")
|
||||
extended_question: str = Field(..., description="扩展后的问题")
|
||||
type: str = Field(..., description="类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)")
|
||||
reason: str = Field(..., description="生成该扩展问题的理由")
|
||||
|
||||
|
||||
class ProblemExtensionResponse(RootModel[List[ExtendedQuestionItem]]):
|
||||
"""Response model for problem extension containing list of extended questions."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Pydantic models for retrieval operations."""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RetrievalResult(BaseModel):
|
||||
"""Result model for retrieval operation."""
|
||||
|
||||
Query: str
|
||||
Expansion_issue: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class DistinguishTypeResponse(BaseModel):
|
||||
"""Response model for data type differentiation."""
|
||||
|
||||
type: str
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Pydantic models for summary operations."""
|
||||
|
||||
from typing import List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SummaryData(BaseModel):
|
||||
"""Data structure for summary input."""
|
||||
|
||||
query: str
|
||||
history: List[str] = Field(default_factory=list)
|
||||
retrieve_info: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SummaryResponse(BaseModel):
|
||||
"""Response model for summary operation."""
|
||||
|
||||
data: SummaryData
|
||||
query_answer: str
|
||||
|
||||
|
||||
class RetrieveSummaryData(BaseModel):
|
||||
"""Data structure for retrieve summary response."""
|
||||
|
||||
query_answer: str = Field(default="")
|
||||
|
||||
|
||||
class RetrieveSummaryResponse(BaseModel):
|
||||
"""Response model for retrieve summary operation."""
|
||||
|
||||
data: RetrieveSummaryData
|
||||
@@ -0,0 +1,14 @@
|
||||
"""Pydantic models for verification operations."""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VerificationResult(BaseModel):
|
||||
"""Result model for verification operation."""
|
||||
|
||||
query: str
|
||||
expansion_issue: List[Dict[str, Any]]
|
||||
split_result: str
|
||||
reason: Optional[str] = None
|
||||
history: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
161
api/app/core/memory/agent/mcp_server/server.py
Normal file
161
api/app/core/memory/agent/mcp_server/server.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
MCP Server initialization with FastMCP context setup.
|
||||
|
||||
This module initializes the FastMCP server and registers shared resources
|
||||
in the context for dependency injection into tool functions.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.redis_tool import RedisSessionStore, store
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID,reload_configuration_from_database
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
|
||||
from app.core.memory.agent.mcp_server.services.search_service import SearchService
|
||||
from app.core.memory.agent.mcp_server.services.session_service import SessionService
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
def get_context_resource(ctx, resource_name: str):
|
||||
"""
|
||||
Helper function to retrieve a resource from the FastMCP context.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP Context object (passed to tool functions)
|
||||
resource_name: Name of the resource to retrieve
|
||||
|
||||
Returns:
|
||||
The requested resource
|
||||
|
||||
Raises:
|
||||
AttributeError: If the resource doesn't exist
|
||||
|
||||
Example:
|
||||
@mcp.tool()
|
||||
async def my_tool(ctx: Context):
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
"""
|
||||
if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None:
|
||||
raise RuntimeError("Context does not have fastmcp attribute")
|
||||
|
||||
if not hasattr(ctx.fastmcp, resource_name):
|
||||
raise AttributeError(
|
||||
f"Resource '{resource_name}' not found in context. "
|
||||
f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}"
|
||||
)
|
||||
|
||||
return getattr(ctx.fastmcp, resource_name)
|
||||
|
||||
|
||||
def initialize_context():
|
||||
"""
|
||||
Initialize and register shared resources in FastMCP context.
|
||||
|
||||
This function sets up all shared resources that will be available
|
||||
to tool functions via dependency injection through the context parameter.
|
||||
|
||||
Resources are stored as attributes on the FastMCP instance and can be
|
||||
accessed via ctx.fastmcp in tool functions.
|
||||
|
||||
Resources registered:
|
||||
- session_store: RedisSessionStore for session management
|
||||
- llm_client: LLM client for structured API calls
|
||||
- app_settings: Application settings (renamed to avoid conflict with FastMCP settings)
|
||||
- template_service: Service for template rendering
|
||||
- search_service: Service for hybrid search
|
||||
- session_service: Service for session operations
|
||||
"""
|
||||
try:
|
||||
# Register Redis session store
|
||||
logger.info("Registering session_store in context")
|
||||
mcp.session_store = store
|
||||
|
||||
# Register LLM client
|
||||
try:
|
||||
logger.info(f"Registering llm_client in context with model ID: {SELECTED_LLM_ID}")
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
mcp.llm_client = llm_client
|
||||
logger.info("llm_client registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register llm_client: {e}", exc_info=True)
|
||||
# 注册一个 None 值,避免工具调用时找不到资源
|
||||
mcp.llm_client = None
|
||||
logger.warning("llm_client set to None due to initialization failure")
|
||||
|
||||
# Register application settings (renamed to avoid conflict with FastMCP's settings)
|
||||
logger.info("Registering app_settings in context")
|
||||
mcp.app_settings = settings
|
||||
|
||||
# Register template service
|
||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
||||
# logger.info(f"Registering template_service in context with root: {template_root}")
|
||||
template_service = TemplateService(template_root)
|
||||
mcp.template_service = template_service
|
||||
|
||||
# Register search service
|
||||
# logger.info("Registering search_service in context")
|
||||
search_service = SearchService()
|
||||
mcp.search_service = search_service
|
||||
|
||||
# Register session service
|
||||
# logger.info("Registering session_service in context")
|
||||
session_service = SessionService(store)
|
||||
mcp.session_service = session_service
|
||||
|
||||
# logger.info("All context resources registered successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize context: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the MCP server.
|
||||
|
||||
Initializes context and starts the server with SSE transport.
|
||||
"""
|
||||
try:
|
||||
# logger.info("Starting MCP server initialization")
|
||||
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
|
||||
# Initialize context resources
|
||||
initialize_context()
|
||||
|
||||
# Import and register tools
|
||||
# logger.info("Importing MCP tools")
|
||||
from app.core.memory.agent.mcp_server.tools import (
|
||||
problem_tools,
|
||||
retrieval_tools,
|
||||
verification_tools,
|
||||
summary_tools,
|
||||
data_tools
|
||||
)
|
||||
# logger.info("All MCP tools imported and registered")
|
||||
|
||||
# Log registered tools for debugging
|
||||
import asyncio
|
||||
tools_list = asyncio.run(mcp.list_tools())
|
||||
# logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}")
|
||||
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:8081 with SSE transport")
|
||||
|
||||
# Run the server with SSE transport for HTTP connections
|
||||
# The server will be available at http://127.0.0.1:8081
|
||||
import uvicorn
|
||||
app = mcp.sse_app()
|
||||
uvicorn.run(app, host=settings.SERVER_IP, port=8081, log_level="info")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
23
api/app/core/memory/agent/mcp_server/services/__init__.py
Normal file
23
api/app/core/memory/agent/mcp_server/services/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
MCP Server Services
|
||||
|
||||
This module provides business logic services for the MCP server:
|
||||
- TemplateService: Template loading and rendering
|
||||
- SearchService: Search result processing
|
||||
- SessionService: Session and history management
|
||||
- ParameterBuilder: Tool parameter construction
|
||||
"""
|
||||
|
||||
from .template_service import TemplateService, TemplateRenderError
|
||||
from .search_service import SearchService
|
||||
from .session_service import SessionService
|
||||
from .parameter_builder import ParameterBuilder
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TemplateService",
|
||||
"TemplateRenderError",
|
||||
"SearchService",
|
||||
"SessionService",
|
||||
"ParameterBuilder",
|
||||
]
|
||||
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Parameter Builder for constructing tool call arguments.
|
||||
|
||||
This service provides tool-specific parameter transformation logic
|
||||
to build correct arguments for each tool type.
|
||||
"""
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ParameterBuilder:
|
||||
"""Service for building tool call arguments based on tool type."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the parameter builder."""
|
||||
logger.info("ParameterBuilder initialized")
|
||||
|
||||
def build_tool_args(
|
||||
self,
|
||||
tool_name: str,
|
||||
content: Any,
|
||||
tool_call_id: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build tool arguments based on tool type.
|
||||
|
||||
Different tools expect different argument formats:
|
||||
- Verify: dict context
|
||||
- Retrieve: dict context + search_switch
|
||||
- Summary/Summary_fails: JSON string context
|
||||
- Retrieve_Summary: unwrap nested context structures
|
||||
- Input_Summary: raw message string
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool being invoked
|
||||
content: Parsed content from previous tool result
|
||||
tool_call_id: Extracted tool call identifier
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary of tool arguments ready for invocation
|
||||
"""
|
||||
# Base arguments common to most tools
|
||||
base_args = {
|
||||
"usermessages": tool_call_id,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
}
|
||||
|
||||
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||
base_args["storage_type"] = storage_type if storage_type is not None else ""
|
||||
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
|
||||
|
||||
# Tool-specific argument construction
|
||||
if tool_name == "Verify":
|
||||
# Verify expects dict context
|
||||
return {
|
||||
"context": content if isinstance(content, dict) else {},
|
||||
**base_args
|
||||
}
|
||||
|
||||
elif tool_name == "Retrieve":
|
||||
# Retrieve expects dict context + search_switch
|
||||
return {
|
||||
"context": content if isinstance(content, dict) else {},
|
||||
"search_switch": search_switch,
|
||||
**base_args
|
||||
}
|
||||
|
||||
elif tool_name in ["Summary", "Summary_fails"]:
|
||||
# Summary tools expect JSON string context
|
||||
if isinstance(content, dict):
|
||||
context_str = json.dumps(content, ensure_ascii=False)
|
||||
elif isinstance(content, str):
|
||||
context_str = content
|
||||
else:
|
||||
context_str = json.dumps({"data": content}, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
"context": context_str,
|
||||
**base_args
|
||||
}
|
||||
|
||||
elif tool_name == "Retrieve_Summary":
|
||||
# Retrieve_Summary needs to unwrap nested context structures
|
||||
# Handle both 'content' and 'context' keys
|
||||
context_dict = content
|
||||
|
||||
if isinstance(content, dict):
|
||||
# Check for nested 'content' wrapper
|
||||
if "content" in content:
|
||||
inner = content["content"]
|
||||
|
||||
# If it's a JSON string, parse it
|
||||
if isinstance(inner, str):
|
||||
try:
|
||||
parsed = json.loads(inner)
|
||||
# Check if parsed has 'context' wrapper
|
||||
if isinstance(parsed, dict) and "context" in parsed:
|
||||
context_dict = parsed["context"]
|
||||
else:
|
||||
context_dict = parsed
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse JSON content for {tool_name}: {inner[:100]}"
|
||||
)
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
elif isinstance(inner, dict):
|
||||
context_dict = inner
|
||||
|
||||
# Check for 'context' wrapper
|
||||
elif "context" in content:
|
||||
context_dict = content["context"] if isinstance(content["context"], dict) else content
|
||||
|
||||
return {
|
||||
"context": context_dict,
|
||||
**base_args
|
||||
}
|
||||
|
||||
elif tool_name == "Input_Summary":
|
||||
# Input_Summary expects raw message string + search_switch
|
||||
# Content should be the raw message string
|
||||
if isinstance(content, dict):
|
||||
# Try to extract message from dict
|
||||
message_str = content.get("sentence", str(content))
|
||||
else:
|
||||
message_str = str(content)
|
||||
|
||||
return {
|
||||
"context": message_str,
|
||||
"search_switch": search_switch,
|
||||
**base_args
|
||||
}
|
||||
|
||||
else:
|
||||
# Default: pass content as context
|
||||
logger.warning(
|
||||
f"Unknown tool name '{tool_name}', using default argument structure"
|
||||
)
|
||||
return {
|
||||
"context": content,
|
||||
**base_args
|
||||
}
|
||||
193
api/app/core/memory/agent/mcp_server/services/search_service.py
Normal file
193
api/app/core/memory/agent/mcp_server/services/search_service.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Search Service for executing hybrid search and processing results.
|
||||
|
||||
This service provides clean search result processing with content extraction
|
||||
and deduplication.
|
||||
"""
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the search service."""
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
def extract_content_from_result(self, result: dict) -> str:
|
||||
"""
|
||||
Extract only meaningful content from search results, dropping all metadata.
|
||||
|
||||
Extraction rules by node type:
|
||||
- Statements: extract 'statement' field
|
||||
- Entities: extract 'name' and 'fact_summary' fields
|
||||
- Summaries: extract 'content' field
|
||||
- Chunks: extract 'content' field
|
||||
|
||||
Args:
|
||||
result: Search result dictionary
|
||||
|
||||
Returns:
|
||||
Clean content string without metadata
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return str(result)
|
||||
|
||||
content_parts = []
|
||||
|
||||
# Statements: extract statement field
|
||||
if 'statement' in result and result['statement']:
|
||||
content_parts.append(result['statement'])
|
||||
|
||||
# Summaries/Chunks: extract content field
|
||||
if 'content' in result and result['content']:
|
||||
content_parts.append(result['content'])
|
||||
|
||||
# Entities: extract name and fact_summary (commented out in original)
|
||||
# if 'name' in result and result['name']:
|
||||
# content_parts.append(result['name'])
|
||||
# if result.get('fact_summary'):
|
||||
# content_parts.append(result['fact_summary'])
|
||||
|
||||
# Return concatenated content or empty string
|
||||
return '\n'.join(content_parts) if content_parts else ""
|
||||
|
||||
def clean_query(self, query: str) -> str:
|
||||
"""
|
||||
Clean and escape query text for Lucene.
|
||||
|
||||
- Removes wrapping quotes
|
||||
- Removes newlines and carriage returns
|
||||
- Applies Lucene escaping
|
||||
|
||||
Args:
|
||||
query: Raw query string
|
||||
|
||||
Returns:
|
||||
Cleaned and escaped query string
|
||||
"""
|
||||
q = str(query).strip()
|
||||
|
||||
# Remove wrapping quotes
|
||||
if (q.startswith("'") and q.endswith("'")) or (
|
||||
q.startswith('"') and q.endswith('"')
|
||||
):
|
||||
q = q[1:-1]
|
||||
|
||||
# Remove newlines and carriage returns
|
||||
q = q.replace('\r', ' ').replace('\n', ' ').strip()
|
||||
|
||||
# Apply Lucene escaping
|
||||
q = escape_lucene_query(q)
|
||||
|
||||
return q
|
||||
|
||||
async def execute_hybrid_search(
|
||||
self,
|
||||
group_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
|
||||
Args:
|
||||
group_id: Group identifier for filtering results
|
||||
question: Search query text
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
|
||||
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
|
||||
Returns:
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
raw_results is None if return_raw_results=False
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
try:
|
||||
# Execute search
|
||||
answer = await run_hybrid_search(
|
||||
query_text=cleaned_query,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
output_path=output_path,
|
||||
rerank_alpha=rerank_alpha
|
||||
)
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
# Prioritize summaries as they contain synthesized contextual information
|
||||
answer_list = []
|
||||
|
||||
# For hybrid search, use reranked_results
|
||||
if search_type == "hybrid":
|
||||
reranked_results = answer.get('reranked_results', {})
|
||||
|
||||
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in reranked_results:
|
||||
category_results = reranked_results[category]
|
||||
if isinstance(category_results, list):
|
||||
answer_list.extend(category_results)
|
||||
else:
|
||||
# For keyword or embedding search, results are directly in answer dict
|
||||
# Apply same priority order
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in answer:
|
||||
category_results = answer[category]
|
||||
if isinstance(category_results, list):
|
||||
answer_list.extend(category_results)
|
||||
|
||||
# Extract clean content from all results
|
||||
content_list = [
|
||||
self.extract_content_from_result(ans)
|
||||
for ans in answer_list
|
||||
]
|
||||
|
||||
|
||||
# Filter out empty strings and join with newlines
|
||||
clean_content = '\n'.join([c for c in content_list if c])
|
||||
|
||||
# Log first 200 chars
|
||||
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
|
||||
|
||||
# Return raw results if requested
|
||||
if return_raw_results:
|
||||
return clean_content, cleaned_query, answer
|
||||
else:
|
||||
return clean_content, cleaned_query, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Search failed for query '{question}' in group '{group_id}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty results on failure
|
||||
if return_raw_results:
|
||||
return "", cleaned_query, {}
|
||||
else:
|
||||
return "", cleaned_query, None
|
||||
169
api/app/core/memory/agent/mcp_server/services/session_service.py
Normal file
169
api/app/core/memory/agent/mcp_server/services/session_service.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Session Service for managing user sessions and conversation history.
|
||||
|
||||
This service provides clean Redis interactions with error handling and
|
||||
session management utilities.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.redis_tool import RedisSessionStore
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""Service for managing user sessions and conversation history."""
|
||||
|
||||
def __init__(self, store: RedisSessionStore):
|
||||
"""
|
||||
Initialize the session service.
|
||||
|
||||
Args:
|
||||
store: Redis session store instance
|
||||
"""
|
||||
self.store = store
|
||||
logger.info("SessionService initialized")
|
||||
|
||||
def resolve_user_id(self, session_string: str) -> str:
|
||||
"""
|
||||
Extract user ID from session string.
|
||||
|
||||
Handles formats like:
|
||||
- 'call_id_user123' -> 'user123'
|
||||
- 'prefix_id_user456_suffix' -> 'user456_suffix'
|
||||
|
||||
Args:
|
||||
session_string: Session identifier string
|
||||
|
||||
Returns:
|
||||
Extracted user ID
|
||||
"""
|
||||
try:
|
||||
# Split by '_id_' and take everything after it
|
||||
parts = session_string.split('_id_')
|
||||
if len(parts) > 1:
|
||||
return parts[1]
|
||||
|
||||
# Fallback: return original string
|
||||
return session_string
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse user ID from session string '{session_string}': {e}"
|
||||
)
|
||||
return session_string
|
||||
|
||||
async def get_history(
|
||||
self,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve conversation history from Redis.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
|
||||
Returns:
|
||||
List of conversation history items with Query and Answer keys
|
||||
Returns empty list if no history found or on error
|
||||
"""
|
||||
try:
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||
|
||||
# Validate history structure
|
||||
if not isinstance(history, list):
|
||||
logger.warning(
|
||||
f"Invalid history format for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||
)
|
||||
return []
|
||||
|
||||
return history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve history for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty list on error to allow execution to continue
|
||||
return []
|
||||
|
||||
async def save_session(
|
||||
self,
|
||||
user_id: str,
|
||||
query: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
ai_response: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Save conversation turn to Redis.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
query: User query/message
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
ai_response: AI response/answer
|
||||
|
||||
Returns:
|
||||
Session ID if successful, None on error
|
||||
"""
|
||||
try:
|
||||
# Validate required fields
|
||||
if not user_id:
|
||||
logger.warning("Cannot save session: user_id is empty")
|
||||
return None
|
||||
|
||||
if not query:
|
||||
logger.warning("Cannot save session: query is empty")
|
||||
return None
|
||||
|
||||
# Save session
|
||||
session_id = self.store.save_session(
|
||||
userid=user_id,
|
||||
messages=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
aimessages=ai_response
|
||||
)
|
||||
|
||||
logger.info(f"Session saved successfully: {session_id}")
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session for user {user_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
async def cleanup_duplicates(self) -> int:
|
||||
"""
|
||||
Remove duplicate session entries.
|
||||
|
||||
Duplicates are identified by matching:
|
||||
- sessionid
|
||||
- user_id (id field)
|
||||
- group_id
|
||||
- messages
|
||||
- aimessages
|
||||
|
||||
Returns:
|
||||
Number of duplicate sessions deleted
|
||||
"""
|
||||
try:
|
||||
deleted_count = self.store.delete_duplicate_sessions()
|
||||
logger.info(f"Cleaned up {deleted_count} duplicate sessions")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True)
|
||||
return 0
|
||||
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Template Service for loading and rendering Jinja2 templates.
|
||||
|
||||
This service provides centralized template management with caching and error handling.
|
||||
"""
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class TemplateRenderError(Exception):
|
||||
"""Exception raised when template rendering fails."""
|
||||
|
||||
def __init__(self, template_name: str, error: Exception, variables: dict):
|
||||
self.template_name = template_name
|
||||
self.error = error
|
||||
self.variables = variables
|
||||
super().__init__(
|
||||
f"Failed to render template '{template_name}': {str(error)}"
|
||||
)
|
||||
|
||||
|
||||
class TemplateService:
|
||||
"""Service for loading and rendering Jinja2 templates with caching."""
|
||||
|
||||
def __init__(self, template_root: str):
|
||||
"""
|
||||
Initialize the template service.
|
||||
|
||||
Args:
|
||||
template_root: Root directory containing template files
|
||||
"""
|
||||
self.template_root = template_root
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(template_root),
|
||||
autoescape=False # Disable autoescape for prompt templates
|
||||
)
|
||||
logger.info(f"TemplateService initialized with root: {template_root}")
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _load_template(self, template_name: str) -> Template:
|
||||
"""
|
||||
Load a template from disk with caching.
|
||||
|
||||
Args:
|
||||
template_name: Relative path to template file
|
||||
|
||||
Returns:
|
||||
Loaded Jinja2 Template object
|
||||
|
||||
Raises:
|
||||
TemplateNotFound: If template file doesn't exist
|
||||
"""
|
||||
try:
|
||||
return self.env.get_template(template_name)
|
||||
except TemplateNotFound as e:
|
||||
expected_path = os.path.join(self.template_root, template_name)
|
||||
logger.error(
|
||||
f"Template not found: {template_name}. "
|
||||
f"Expected path: {expected_path}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def render_template(
|
||||
self,
|
||||
template_name: str,
|
||||
operation_name: str,
|
||||
**variables
|
||||
) -> str:
|
||||
"""
|
||||
Load and render a Jinja2 template.
|
||||
|
||||
Args:
|
||||
template_name: Relative path to template file
|
||||
operation_name: Name for logging (e.g., "split_the_problem")
|
||||
**variables: Template variables to render
|
||||
|
||||
Returns:
|
||||
Rendered template string
|
||||
|
||||
Raises:
|
||||
TemplateRenderError: If template loading or rendering fails
|
||||
"""
|
||||
try:
|
||||
# Load template (cached)
|
||||
template = self._load_template(template_name)
|
||||
|
||||
# Render template
|
||||
rendered = template.render(**variables)
|
||||
|
||||
# Log rendered prompt
|
||||
log_prompt_rendering(operation_name, rendered)
|
||||
|
||||
return rendered
|
||||
|
||||
except TemplateNotFound as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for {operation_name} "
|
||||
f"({template_name}): Template not found",
|
||||
exc_info=True
|
||||
)
|
||||
raise TemplateRenderError(template_name, e, variables)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for {operation_name} "
|
||||
f"({template_name}): {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise TemplateRenderError(template_name, e, variables)
|
||||
27
api/app/core/memory/agent/mcp_server/tools/__init__.py
Normal file
27
api/app/core/memory/agent/mcp_server/tools/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
MCP Tools module.
|
||||
|
||||
This module contains all MCP tool implementations organized by functionality.
|
||||
|
||||
Tools are organized into the following modules:
|
||||
- problem_tools: Question segmentation and extension
|
||||
- retrieval_tools: Database and context retrieval
|
||||
- verification_tools: Data verification
|
||||
- summary_tools: Summarization and summary retrieval
|
||||
- data_tools: Data type differentiation and writing
|
||||
"""
|
||||
|
||||
# Import all tool modules to register them with the MCP server
|
||||
from . import problem_tools
|
||||
from . import retrieval_tools
|
||||
from . import verification_tools
|
||||
from . import summary_tools
|
||||
from . import data_tools
|
||||
|
||||
__all__ = [
|
||||
'problem_tools',
|
||||
'retrieval_tools',
|
||||
'verification_tools',
|
||||
'summary_tools',
|
||||
'data_tools',
|
||||
]
|
||||
149
api/app/core/memory/agent/mcp_server/tools/data_tools.py
Normal file
149
api/app/core/memory/agent/mcp_server/tools/data_tools.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Data Tools for data type differentiation and writing.
|
||||
|
||||
This module contains MCP tools for distinguishing data types and writing data.
|
||||
"""
|
||||
import os
|
||||
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.retrieval_models import DistinguishTypeResponse
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Data_type_differentiation(
|
||||
ctx: Context,
|
||||
context: str
|
||||
) -> dict:
|
||||
"""
|
||||
Distinguish the type of data (read or write).
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Text to analyze for type differentiation
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' with the original text and 'type' field
|
||||
"""
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='distinguish_types_prompt.jinja2',
|
||||
operation_name='status_typle',
|
||||
user_query=context
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Data_type_differentiation: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=DistinguishTypeResponse
|
||||
)
|
||||
|
||||
result = structured.model_dump()
|
||||
|
||||
# Add context to result
|
||||
result["context"] = context
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Data_type_differentiation: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": context,
|
||||
"type": "error",
|
||||
"message": f"LLM call failed: {str(e)}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Data_type_differentiation failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": context,
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Data_write(
|
||||
ctx: Context,
|
||||
content: str,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
config_id: str
|
||||
) -> dict:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
content: Data content to write
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
config_id: Configuration ID for processing (optional, integer)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status', 'saved_to', and 'data' fields
|
||||
"""
|
||||
try:
|
||||
# Ensure output directory exists
|
||||
os.makedirs("data_output", exist_ok=True)
|
||||
file_path = os.path.join("data_output", "user_data.csv")
|
||||
|
||||
# Write data using utility function
|
||||
try:
|
||||
await write(content, user_id, apply_id, group_id, config_id=config_id)
|
||||
logger.info(f"写入成功!Config ID: {config_id if config_id else 'None'}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"saved_to": file_path,
|
||||
"data": content,
|
||||
"config_id": config_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"写入失败: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Data_write failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
293
api/app/core/memory/agent/mcp_server/tools/problem_tools.py
Normal file
293
api/app/core/memory/agent/mcp_server/tools/problem_tools.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
Problem Tools for question segmentation and extension.
|
||||
|
||||
This module contains MCP tools for breaking down and extending user questions.
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.problem_models import (
|
||||
ProblemBreakdownItem,
|
||||
ProblemBreakdownResponse,
|
||||
ExtendedQuestionItem,
|
||||
ProblemExtensionResponse
|
||||
)
|
||||
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Split_The_Problem(
|
||||
ctx: Context,
|
||||
sentence: str,
|
||||
sessionid: str,
|
||||
messages_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
) -> dict:
|
||||
"""
|
||||
Segment the dialogue or sentence into sub-problems.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
sentence: Original sentence to split
|
||||
sessionid: Session identifier
|
||||
messages_id: Message identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' (JSON string of split results) and 'original' sentence
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Extract user ID from session
|
||||
user_id = session_service.resolve_user_id(sessionid)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(user_id, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
history = []
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='problem_breakdown_prompt.jinja2',
|
||||
operation_name='split_the_problem',
|
||||
history=history,
|
||||
sentence=sentence
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Split_The_Problem: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": sentence,
|
||||
"error": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=ProblemBreakdownResponse
|
||||
)
|
||||
|
||||
# Handle RootModel response with .root attribute access
|
||||
if structured is None:
|
||||
# LLM returned None, use empty list as fallback
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
elif hasattr(structured, 'root') and structured.root is not None:
|
||||
split_result = json.dumps(
|
||||
[item.model_dump() for item in structured.root],
|
||||
ensure_ascii=False
|
||||
)
|
||||
elif isinstance(structured, list):
|
||||
# Fallback: treat structured itself as the list
|
||||
split_result = json.dumps(
|
||||
[item.model_dump() for item in structured],
|
||||
ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
# Last resort: use empty list
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Split_The_Problem: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
|
||||
logger.info(f"问题拆分")
|
||||
logger.info(f"问题拆分结果==>>:{split_result}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
"context": split_result,
|
||||
"original": sentence,
|
||||
"_intermediate": {
|
||||
"type": "problem_split",
|
||||
"data": json.loads(split_result) if split_result else [],
|
||||
"original_query": sentence
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Split_The_Problem failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": sentence,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('问题拆分', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Problem_Extension(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Extend the problem with additional sub-questions.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing split problem results
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' (aggregated questions) and 'original' question
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Resolve session ID from usermessages
|
||||
from app.core.memory.agent.utils.messages_tool import Resolve_username
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
history = []
|
||||
|
||||
# Process context to extract questions
|
||||
extent_quest, original = await Problem_Extension_messages_deal(context)
|
||||
|
||||
# Format questions for template rendering
|
||||
questions_formatted = []
|
||||
for msg in extent_quest:
|
||||
if msg.get("role") == "user":
|
||||
questions_formatted.append(msg.get("content", ""))
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Problem_Extension_prompt.jinja2',
|
||||
operation_name='problem_extension',
|
||||
history=history,
|
||||
questions=questions_formatted
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Problem_Extension: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {},
|
||||
"original": original,
|
||||
"error": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
response_content = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=ProblemExtensionResponse
|
||||
)
|
||||
|
||||
# Aggregate results by original question
|
||||
aggregated_dict = {}
|
||||
for item in response_content.root:
|
||||
key = getattr(item, "original_question", None) or (
|
||||
item.get("original_question") if isinstance(item, dict) else None
|
||||
)
|
||||
value = getattr(item, "extended_question", None) or (
|
||||
item.get("extended_question") if isinstance(item, dict) else None
|
||||
)
|
||||
if not key or not value:
|
||||
continue
|
||||
aggregated_dict.setdefault(key, []).append(value)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Problem_Extension: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aggregated_dict = {}
|
||||
|
||||
logger.info(f"问题扩展")
|
||||
logger.info(f"问题扩展==>>:{aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
"context": aggregated_dict,
|
||||
"original": original,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "problem_extension",
|
||||
"data": aggregated_dict,
|
||||
"original_query": original,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Problem_Extension failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {},
|
||||
"original": context.get("original", ""),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('问题扩展', duration)
|
||||
282
api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py
Normal file
282
api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Retrieval Tools for database and context retrieval.
|
||||
|
||||
This module contains MCP tools for retrieving data using hybrid search.
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.llm_tools import deduplicate_entries, merge_to_key_value_pairs
|
||||
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Retrieve(
|
||||
ctx: Context,
|
||||
context,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Retrieve data from the database using hybrid search.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary or string containing query information
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' with Query and Expansion_issue results
|
||||
"""
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id": os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
start = time.time()
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
search_service = get_context_resource(ctx, 'search_service')
|
||||
|
||||
databases_anser = []
|
||||
|
||||
# Handle both dict and string context
|
||||
if isinstance(context, dict):
|
||||
# Process dict context with extended questions
|
||||
all_items = []
|
||||
content, original = await Retriev_messages_deal(context)
|
||||
|
||||
# Extract all query items from content
|
||||
# content is like {original_question: [extended_questions...], ...}
|
||||
for key, values in content.items():
|
||||
if isinstance(values, list):
|
||||
all_items.extend(values)
|
||||
elif isinstance(values, str):
|
||||
all_items.append(values)
|
||||
elif values is not None:
|
||||
# Fallback: convert non-empty non-list values to string
|
||||
all_items.append(str(values))
|
||||
|
||||
# Execute search for each question
|
||||
for idx, question in enumerate(all_items):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": question,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query=question
|
||||
raw_results=clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
clean_content = ''
|
||||
raw_results=''
|
||||
cleaned_query = question
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
|
||||
databases_anser.append({
|
||||
"Query_small": cleaned_query,
|
||||
"Result_small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": idx + 1,
|
||||
"total": len(all_items)
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for question '{question}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Continue with empty result for this question
|
||||
databases_anser.append({
|
||||
"Query_small": question,
|
||||
"Result_small": ""
|
||||
})
|
||||
|
||||
# Build initial database data structure
|
||||
databases_data = {
|
||||
"Query": original,
|
||||
"Expansion_issue": databases_anser
|
||||
}
|
||||
|
||||
# Collect intermediate outputs before deduplication
|
||||
intermediate_outputs = []
|
||||
for item in databases_anser:
|
||||
if '_intermediate' in item:
|
||||
intermediate_outputs.append(item['_intermediate'])
|
||||
|
||||
# Deduplicate and merge results
|
||||
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
|
||||
deduplicated_data_merged = merge_to_key_value_pairs(
|
||||
deduplicated_data,
|
||||
'Query_small',
|
||||
'Result_small'
|
||||
)
|
||||
|
||||
# Restructure for Verify/Retrieve_Summary compatibility
|
||||
keys, val = [], []
|
||||
for item in deduplicated_data_merged:
|
||||
for items_key, items_value in item.items():
|
||||
keys.append(items_key)
|
||||
val.append(items_value)
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val):
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
})
|
||||
|
||||
dup_databases = {
|
||||
"Query": original,
|
||||
"Expansion_issue": send_verify,
|
||||
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
|
||||
}
|
||||
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
|
||||
else:
|
||||
# Handle string context (simple query)
|
||||
query = str(context).strip()
|
||||
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": query,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query = query
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = query
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
# Keep structure for Verify/Retrieve_Summary compatibility
|
||||
dup_databases = {
|
||||
"Query": cleaned_query,
|
||||
"Expansion_issue": [{
|
||||
"Query_small": cleaned_query,
|
||||
"Answer_Small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": 1,
|
||||
"total": 1
|
||||
}
|
||||
}]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for query '{query}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty results on failure
|
||||
dup_databases = {
|
||||
"Query": query,
|
||||
"Expansion_issue": []
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"检索==>>:{storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
|
||||
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
|
||||
)
|
||||
|
||||
# Build result with intermediate outputs
|
||||
result = {
|
||||
"context": dup_databases,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
# Add intermediate outputs list if they exist
|
||||
intermediate_outputs = dup_databases.get('_intermediate_outputs', [])
|
||||
if intermediate_outputs:
|
||||
result['_intermediates'] = intermediate_outputs
|
||||
logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result")
|
||||
else:
|
||||
logger.warning("No intermediate outputs found in dup_databases")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {
|
||||
"Query": "",
|
||||
"Expansion_issue": []
|
||||
},
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
647
api/app/core/memory/agent/mcp_server/tools/summary_tools.py
Normal file
647
api/app/core/memory/agent/mcp_server/tools/summary_tools.py
Normal file
@@ -0,0 +1,647 @@
|
||||
"""
|
||||
Summary Tools for data summarization.
|
||||
|
||||
This module contains MCP tools for summarizing retrieved data and generating responses.
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.summary_models import (
|
||||
SummaryData,
|
||||
SummaryResponse,
|
||||
RetrieveSummaryData,
|
||||
RetrieveSummaryResponse
|
||||
)
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Summary_messages_deal,
|
||||
Resolve_username
|
||||
)
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Summary(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize the verified data.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: JSON string containing verified data
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'summary_result'
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Process context to extract answer and query
|
||||
answer_small, query = await Summary_messages_deal(context)
|
||||
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
# Prepare data for template
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": answer_small
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Summary: initialization failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"summary_result": "信息不足,无法回答"
|
||||
}
|
||||
|
||||
try:
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='summary_prompt.jinja2',
|
||||
operation_name='summary',
|
||||
data=data,
|
||||
query=query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# Call LLM with structured response
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=SummaryResponse
|
||||
)
|
||||
|
||||
aimessages = structured.query_answer or ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = ""
|
||||
|
||||
try:
|
||||
# Save session
|
||||
if aimessages != "":
|
||||
await session_service.save_session(
|
||||
user_id=sessionid,
|
||||
query=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
# Use fallback if empty
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"验证之后的总结==>>:{aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('总结', duration)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Retrieve_Summary(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize data directly from retrieval results.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing Query and Expansion_issue from Retrieve
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'summary_result'
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
|
||||
|
||||
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
|
||||
if isinstance(context, dict):
|
||||
if "content" in context:
|
||||
inner = context["content"]
|
||||
# If it's a JSON string, parse it
|
||||
if isinstance(inner, str):
|
||||
try:
|
||||
parsed = json.loads(inner)
|
||||
logger.info(f"Retrieve_Summary: successfully parsed JSON")
|
||||
except json.JSONDecodeError:
|
||||
# Try unescaping first
|
||||
try:
|
||||
unescaped = inner.encode('utf-8').decode('unicode_escape')
|
||||
parsed = json.loads(unescaped)
|
||||
logger.info(f"Retrieve_Summary: parsed after unescaping")
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: parsing failed even after unescape: {e}"
|
||||
)
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
parsed = None
|
||||
|
||||
if parsed:
|
||||
# Check if parsed has 'context' wrapper
|
||||
if isinstance(parsed, dict) and "context" in parsed:
|
||||
context_dict = parsed["context"]
|
||||
else:
|
||||
context_dict = parsed
|
||||
elif isinstance(inner, dict):
|
||||
context_dict = inner
|
||||
else:
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
elif "context" in context:
|
||||
context_dict = context["context"] if isinstance(context["context"], dict) else context
|
||||
else:
|
||||
context_dict = context
|
||||
else:
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
|
||||
query = context_dict.get("Query", "")
|
||||
expansion_issue = context_dict.get("Expansion_issue", [])
|
||||
|
||||
# Extract retrieve_info from expansion_issue
|
||||
retrieve_info = []
|
||||
for item in expansion_issue:
|
||||
# Check for both Answer_Small and Answer_Samll (typo) for backward compatibility
|
||||
answer = None
|
||||
if isinstance(item, dict):
|
||||
if "Answer_Small" in item:
|
||||
answer = item["Answer_Small"]
|
||||
elif "Answer_Samll" in item:
|
||||
answer = item["Answer_Samll"]
|
||||
|
||||
if answer is not None:
|
||||
# Handle both string and list formats
|
||||
if isinstance(answer, list):
|
||||
# Join list of characters/strings into a single string
|
||||
retrieve_info.append(''.join(str(x) for x in answer))
|
||||
elif isinstance(answer, str):
|
||||
retrieve_info.append(answer)
|
||||
else:
|
||||
retrieve_info.append(str(answer))
|
||||
|
||||
# Join all retrieve_info into a single string
|
||||
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: initialization failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"summary_result": "信息不足,无法回答"
|
||||
}
|
||||
|
||||
try:
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='retrieve_summary',
|
||||
query=query,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info_str
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Retrieve_Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# Call LLM with structured response
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
|
||||
# Handle case where structured response might be None or incomplete
|
||||
if structured and hasattr(structured, 'data') and structured.data:
|
||||
aimessages = structured.data.query_answer or ""
|
||||
else:
|
||||
logger.warning("Structured response is None or incomplete, using default message")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
# Check for insufficient information response
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
|
||||
# Save session
|
||||
await session_service.save_session(
|
||||
user_id=sessionid,
|
||||
query=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: LLM call failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = ""
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
# Use fallback if empty
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"检索之后的总结==>>:{aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索总结', duration)
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "retrieval_summary",
|
||||
"summary": aimessages,
|
||||
"query": query,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Input_Summary(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a quick summary for direct input without verification.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: String containing the input sentence
|
||||
usermessages: User messages identifier
|
||||
search_switch: Search switch value for routing ('2' for summaries only)
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
Returns:
|
||||
dict: Contains 'query_answer' with the summary result
|
||||
"""
|
||||
start = time.time()
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# Initialize variables to avoid UnboundLocalError
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
search_service = get_context_resource(ctx, 'search_service')
|
||||
|
||||
# Check if llm_client is None
|
||||
if llm_client is None:
|
||||
error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable."
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
sessionid = sessionid.replace('call_id_', '')
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(
|
||||
str(sessionid),
|
||||
str(apply_id),
|
||||
str(group_id)
|
||||
)
|
||||
# Override with empty list for now (as in original)
|
||||
|
||||
# Log the raw context for debugging
|
||||
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
|
||||
|
||||
# Extract sentence from context
|
||||
# Context can be a string or might contain the sentence in various formats
|
||||
try:
|
||||
# Try to parse as JSON first
|
||||
if isinstance(context, str) and (context.startswith('{') or context.startswith('[')):
|
||||
try:
|
||||
import json
|
||||
context_dict = json.loads(context)
|
||||
if isinstance(context_dict, dict):
|
||||
query = context_dict.get('sentence', context_dict.get('content', context))
|
||||
else:
|
||||
query = context
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON, try regex
|
||||
match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context)
|
||||
query = match.group(1) if match else context
|
||||
else:
|
||||
query = context
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract query from context: {e}")
|
||||
query = context
|
||||
|
||||
# Clean query
|
||||
query = str(query).strip().strip("\"'")
|
||||
|
||||
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
|
||||
|
||||
# Execute search based on search_switch and storage_type
|
||||
try:
|
||||
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
|
||||
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": query,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
|
||||
'''检索'''
|
||||
if search_switch == '2':
|
||||
search_params["include"] = ["summaries"]
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
raw_results = []
|
||||
retrieve_info = ""
|
||||
kb_config={
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id":os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
|
||||
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
retrieve_info = '\n\n'.join(retrieval_knowledge)
|
||||
raw_results=[retrieve_info]
|
||||
logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
retrieve_info=''
|
||||
raw_results=['']
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
logger.info(f"Input_Summary: 使用 summary 进行检索")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: hybrid_search failed, using empty results: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
retrieve_info, question, raw_results = "", query, []
|
||||
|
||||
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='input_summary',
|
||||
query=query,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
aimessages = structured.data.query_answer or "信息不足,无法回答"
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: response_structured failed, using default answer: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "input_summary",
|
||||
"title": "快速答案",
|
||||
"summary": aimessages,
|
||||
"query": query,
|
||||
"raw_results": raw_results,
|
||||
"search_mode": "quick_search",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "fail",
|
||||
"summary_result": "信息不足,无法回答",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Summary_fails(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Handle workflow failure when summary cannot be generated.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Failure context string
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'query_answer' with failure message
|
||||
"""
|
||||
try:
|
||||
# Extract services from context
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
|
||||
# Parse session ID from usermessages
|
||||
usermessages_parts = usermessages.split('_')[1:]
|
||||
sessionid = '_'.join(usermessages_parts[:-1])
|
||||
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
logger.info(f"没有相关数据")
|
||||
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": "没有相关数据",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Summary_fails failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "fail",
|
||||
"summary_result": "没有相关数据",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
169
api/app/core/memory/agent/mcp_server/tools/verification_tools.py
Normal file
169
api/app/core/memory/agent/mcp_server/tools/verification_tools.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Verification Tools for data verification.
|
||||
|
||||
This module contains MCP tools for verifying retrieved data.
|
||||
"""
|
||||
import time
|
||||
|
||||
from jinja2 import Template
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.verify_tool import VerifyTool
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Verify_messages_deal,
|
||||
Retrieve_verify_tool_messages_deal,
|
||||
Resolve_username
|
||||
)
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Verify(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Verify the retrieved data.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing query and expansion issues
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'verified_data' with verification results
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
|
||||
# Load verification prompt template
|
||||
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2'
|
||||
|
||||
# Read template file directly (VerifyTool expects raw template content)
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
system_prompt = await read_template_file(file_path)
|
||||
|
||||
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
|
||||
template = Template(system_prompt)
|
||||
system_prompt = template.render(history=history, sentence=context)
|
||||
|
||||
# Process context to extract query and results
|
||||
Query_small, Result_small, query = await Verify_messages_deal(context)
|
||||
|
||||
# Build query list for verification
|
||||
query_list = []
|
||||
for query_small, anser in zip(Query_small, Result_small):
|
||||
query_list.append({
|
||||
'Query_small': query_small,
|
||||
'Answer_Small': anser
|
||||
})
|
||||
|
||||
messages = {
|
||||
"Query": query,
|
||||
"Expansion_issue": query_list
|
||||
}
|
||||
|
||||
|
||||
|
||||
# Call verification workflow
|
||||
verify_tool = VerifyTool(system_prompt, messages)
|
||||
verify_result = await verify_tool.verify()
|
||||
|
||||
# Parse LLM verification result with error handling
|
||||
try:
|
||||
messages_deal = await Retrieve_verify_tool_messages_deal(
|
||||
verify_result,
|
||||
history,
|
||||
query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_verify_tool_messages_deal parsing failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Fallback to avoid 500 errors
|
||||
messages_deal = {
|
||||
"data": {
|
||||
"query": query,
|
||||
"expansion_issue": []
|
||||
},
|
||||
"split_result": "failed",
|
||||
"reason": str(e),
|
||||
"history": history,
|
||||
}
|
||||
|
||||
logger.info(f"验证==>>:{messages_deal}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"verified_data": messages_deal,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "数据验证",
|
||||
"result": messages_deal.get("split_result", "unknown"),
|
||||
"reason": messages_deal.get("reason", ""),
|
||||
"query": query,
|
||||
"verified_count": len(query_list),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Verify failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"verified_data": {
|
||||
"data": {
|
||||
"query": "",
|
||||
"expansion_issue": []
|
||||
},
|
||||
"split_result": "failed",
|
||||
"reason": str(e),
|
||||
"history": [],
|
||||
}
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('验证', duration)
|
||||
7
api/app/core/memory/agent/utils/__init__.py
Normal file
7
api/app/core/memory/agent/utils/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Agent utilities."""
|
||||
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
|
||||
__all__ = [
|
||||
"MultimodalProcessor",
|
||||
]
|
||||
70
api/app/core/memory/agent/utils/get_dialogs.py
Normal file
70
api/app/core/memory/agent/utils/get_dialogs.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
import json
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||
|
||||
|
||||
async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "group_1",
|
||||
user_id: str = "user1",
|
||||
apply_id: str = "applyid",
|
||||
content: str = "这是用户的输入",
|
||||
ref_id: str = "wyl_20251027",
|
||||
config_id: str = None
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from all test data entries using the specified chunker strategy.
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
group_id: Group identifier
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
content: Dialog content
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
|
||||
Returns:
|
||||
List of DialogData objects with generated chunks for each test entry
|
||||
"""
|
||||
dialog_data_list = []
|
||||
messages = []
|
||||
|
||||
messages.append(ConversationMessage(role="用户", msg=content))
|
||||
|
||||
# Create DialogData
|
||||
conversation_context = ConversationContext(msgs=messages)
|
||||
# Create DialogData with group_id based on the entry's id for uniqueness
|
||||
dialog_data = DialogData(
|
||||
context=conversation_context,
|
||||
ref_id=ref_id,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
config_id=config_id
|
||||
)
|
||||
# Create DialogueChunker and process the dialogue
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
dialog_data.chunks = extracted_chunks
|
||||
|
||||
dialog_data_list.append(dialog_data)
|
||||
|
||||
# Convert to dict with datetime serialized
|
||||
def serialize_datetime(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
||||
|
||||
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
||||
|
||||
print(dialog_data_list)
|
||||
|
||||
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
|
||||
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
|
||||
|
||||
|
||||
return dialog_data_list
|
||||
204
api/app/core/memory/agent/utils/llm_tools.py
Normal file
204
api/app/core/memory/agent/utils/llm_tools.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import TypedDict, Annotated
|
||||
import os
|
||||
import logging
|
||||
|
||||
from jinja2 import Template
|
||||
from langchain_core.messages import AnyMessage
|
||||
from dotenv import load_dotenv
|
||||
from langgraph.graph import add_messages
|
||||
from openai import OpenAI
|
||||
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
from app.core.memory.utils.config.config_utils import get_picture_config, get_voice_config
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
|
||||
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
#TODO: Refactor entire picture/voice
|
||||
# async def LLM_model_request(context,data,query):
|
||||
# '''
|
||||
# Agent model request
|
||||
# Args:
|
||||
# context:Input request
|
||||
# data: template parameters
|
||||
# query:request content
|
||||
# Returns:
|
||||
|
||||
# '''
|
||||
# template = Template(context)
|
||||
# system_prompt = template.render(**data)
|
||||
# llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
# result = await llm_client.chat(
|
||||
# messages=[{"role": "system", "content": system_prompt}] + [{"role": "user", "content": query}]
|
||||
# )
|
||||
# return result
|
||||
|
||||
async def picture_model_requests(image_url):
|
||||
'''
|
||||
|
||||
Args:
|
||||
image_url:
|
||||
Returns:
|
||||
|
||||
'''
|
||||
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 '
|
||||
system_prompt = await read_template_file(file_path)
|
||||
result = await Picture_recognize(image_url,system_prompt)
|
||||
return (result)
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
user_id:str
|
||||
apply_id:str
|
||||
group_id:str
|
||||
|
||||
class ReadState(TypedDict):
|
||||
'''
|
||||
Langgrapg READING TypedDict
|
||||
name:
|
||||
id:user id
|
||||
loop_count:Traverse times
|
||||
search_switch:type
|
||||
config_id: configuration id for filtering results
|
||||
'''
|
||||
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
|
||||
name: str
|
||||
id: str
|
||||
loop_count:int
|
||||
search_switch: str
|
||||
user_id: str
|
||||
apply_id: str
|
||||
group_id: str
|
||||
config_id: str
|
||||
|
||||
|
||||
class COUNTState:
|
||||
'''
|
||||
The number of times the workflow dialogue retrieval content has no correct message recall traversal
|
||||
'''
|
||||
def __init__(self, limit: int = 5):
|
||||
self.total: int = 0 # 当前累加值
|
||||
self.limit: int = limit # 最大上限
|
||||
|
||||
def add(self, value: int = 1):
|
||||
"""累加数字,如果达到上限就保持最大值"""
|
||||
self.total += value
|
||||
print(f"[COUNTState] 当前值: {self.total}")
|
||||
if self.total >= self.limit:
|
||||
print(f"[COUNTState] 达到上限 {self.limit}")
|
||||
self.total = self.limit # 达到上限不再增加
|
||||
|
||||
def get_total(self) -> int:
|
||||
"""获取当前累加值"""
|
||||
return self.total
|
||||
|
||||
def reset(self):
|
||||
"""手动重置累加值"""
|
||||
self.total = 0
|
||||
print(f"[COUNTState] 已重置为 0")
|
||||
|
||||
|
||||
|
||||
# def embed(texts: list[str]) -> list[list[float]]:
|
||||
# # 这里可以换成 LangChain Embeddings
|
||||
# return [[float(len(t) % 5), float(len(t) % 3)] for t in texts]
|
||||
|
||||
|
||||
# def export_store_to_json(store, namespace):
|
||||
# """Export the entire storage content to a JSON file"""
|
||||
# # 搜索所有存储项
|
||||
# all_items = store.search(namespace)
|
||||
|
||||
# # 整理数据
|
||||
# export_data = {}
|
||||
# for item in all_items:
|
||||
# if hasattr(item, 'key') and hasattr(item, 'value'):
|
||||
# export_data[item.key] = item.value
|
||||
|
||||
# # 保存到文件
|
||||
# os.makedirs("memory_logs", exist_ok=True)
|
||||
# with open("memory_logs/full_memory_export.json", "w", encoding="utf-8") as f:
|
||||
# json.dump(export_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# print(f"{len(export_data)} 条记忆到 JSON 文件")
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
grouped[item[query_key]].append(item[result_key])
|
||||
return [{key: values} for key, values in grouped.items()]
|
||||
|
||||
def deduplicate_entries(entries):
|
||||
seen = set()
|
||||
deduped = []
|
||||
for entry in entries:
|
||||
key = (entry['Query_small'], entry['Result_small'])
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(entry)
|
||||
return deduped
|
||||
|
||||
|
||||
|
||||
async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
|
||||
try:
|
||||
model_config = get_picture_config(SELECTED_LLM_PICTURE_NAME)
|
||||
except Exception as e:
|
||||
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||
logger.error(err)
|
||||
return err
|
||||
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
|
||||
backend_model_name = model_config["llm_name"].split("/")[-1]
|
||||
api_base=model_config['api_base']
|
||||
|
||||
logger.info(f"model_name: {backend_model_name}")
|
||||
logger.info(f"api_key set: {'yes' if api_key else 'no'}")
|
||||
logger.info(f"base_url: {model_config['api_base']}")
|
||||
|
||||
client = OpenAI(
|
||||
api_key=api_key, base_url=api_base,
|
||||
)
|
||||
completion = client.chat.completions.create(
|
||||
model=backend_model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url":image_path,
|
||||
},
|
||||
{"type": "text",
|
||||
"text": PROMPT_TICKET_EXTRACTION}
|
||||
]
|
||||
}
|
||||
])
|
||||
picture_text = completion.choices[0].message.content
|
||||
picture_text = picture_text.replace('```json', '').replace('```', '')
|
||||
picture_text = json.loads(picture_text)
|
||||
return (picture_text['statement'])
|
||||
|
||||
async def Voice_recognize():
|
||||
try:
|
||||
model_config = get_voice_config(SELECTED_LLM_VOICE_NAME)
|
||||
except Exception as e:
|
||||
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||
logger.error(err)
|
||||
return err
|
||||
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
|
||||
backend_model_name = model_config["llm_name"].split("/")[-1]
|
||||
api_base = model_config['api_base']
|
||||
return api_key,backend_model_name,api_base
|
||||
|
||||
|
||||
15
api/app/core/memory/agent/utils/mcp_tools.py
Normal file
15
api/app/core/memory/agent/utils/mcp_tools.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from app.core.config import settings
|
||||
|
||||
def get_mcp_server_config():
|
||||
"""
|
||||
Get the MCP server configuration
|
||||
"""
|
||||
mcp_server_config = {
|
||||
"data_flow": {
|
||||
"url": f"http://{settings.SERVER_IP}:8081/sse", # 你前面的 FastMCP(weather) 服务端口
|
||||
"transport": "sse",
|
||||
"timeout": 15000,
|
||||
"sse_read_timeout": 15000,
|
||||
}
|
||||
}
|
||||
return mcp_server_config
|
||||
239
api/app/core/memory/agent/utils/messages_tool.py
Normal file
239
api/app/core/memory/agent/utils/messages_tool.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Any
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]:
|
||||
out = []
|
||||
for m in msgs:
|
||||
if hasattr(m, "content"):
|
||||
out.append({"role": "user", "content": getattr(m, "content", "")})
|
||||
elif isinstance(m, dict) and "role" in m and "content" in m:
|
||||
out.append(m)
|
||||
else:
|
||||
out.append({"role": "user", "content": str(m)})
|
||||
return out
|
||||
|
||||
|
||||
def _extract_content(resp: Any) -> str:
|
||||
"""Extract LLM content and sanitize to raw JSON/text.
|
||||
|
||||
- Supports both object and dict response shapes.
|
||||
- Removes leading role labels (e.g., "Assistant:").
|
||||
- Strips Markdown code fences like ```json ... ```.
|
||||
- Attempts to isolate the first valid JSON array/object block when extra text is present.
|
||||
"""
|
||||
|
||||
def _to_text(r: Any) -> str:
|
||||
try:
|
||||
# 对象形式: resp.choices[0].message.content
|
||||
if hasattr(r, "choices") and getattr(r, "choices", None):
|
||||
msg = r.choices[0].message
|
||||
if hasattr(msg, "content"):
|
||||
return msg.content
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
return msg["content"]
|
||||
# 字典形式: resp["choices"][0]["message"]["content"]
|
||||
if isinstance(r, dict):
|
||||
return r.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
except Exception:
|
||||
pass
|
||||
return str(r)
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
s = str(text).strip()
|
||||
# 移除可能的角色前缀
|
||||
s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s)
|
||||
# 提取 ```json ... ``` 代码块
|
||||
m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I)
|
||||
if m:
|
||||
s = m.group(1).strip()
|
||||
# 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段
|
||||
if not (s.startswith("{") or s.startswith("[")):
|
||||
left = s.find("[")
|
||||
right = s.rfind("]")
|
||||
if left != -1 and right != -1 and right > left:
|
||||
s = s[left:right + 1].strip()
|
||||
else:
|
||||
left = s.find("{")
|
||||
right = s.rfind("}")
|
||||
if left != -1 and right != -1 and right > left:
|
||||
s = s[left:right + 1].strip()
|
||||
return s
|
||||
|
||||
raw = _to_text(resp)
|
||||
return _clean_text(raw)
|
||||
|
||||
def Resolve_username(usermessages):
|
||||
'''
|
||||
Extract username
|
||||
Args:
|
||||
usermessages: user name
|
||||
|
||||
Returns:
|
||||
|
||||
'''
|
||||
usermessages = usermessages.split('_')[1:]
|
||||
sessionid = '_'.join(usermessages[:-1])
|
||||
return sessionid
|
||||
|
||||
|
||||
# TODO: USE app.core.memory.src.utils.render_template instead
|
||||
async def read_template_file(template_path: str) -> str:
|
||||
"""
|
||||
读取模板文件
|
||||
|
||||
Args:
|
||||
template_path: 模板文件路径
|
||||
|
||||
Returns:
|
||||
模板内容字符串
|
||||
|
||||
Note:
|
||||
建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能
|
||||
"""
|
||||
try:
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"模板文件未找到: {template_path}")
|
||||
raise
|
||||
except IOError as e:
|
||||
logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def Problem_Extension_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
extent_quest = []
|
||||
original = context.get('original', '')
|
||||
messages = context.get('context', '')
|
||||
messages = json.loads(messages)
|
||||
for message in messages:
|
||||
question = message.get('question', '')
|
||||
type = message.get('type', '')
|
||||
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
|
||||
|
||||
return extent_quest, original
|
||||
|
||||
|
||||
async def Retriev_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
if isinstance(context, dict):
|
||||
if 'context' in context or 'original' in context:
|
||||
return context.get('context', {}), context.get('original', '')
|
||||
return content, original_value
|
||||
|
||||
async def Verify_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
|
||||
query = context['context']['Query']
|
||||
Query_small_list = context['context']['Expansion_issue']
|
||||
Result_small = []
|
||||
Query_small = []
|
||||
for i in Query_small_list:
|
||||
Result_small.append(i['Answer_Small'][0])
|
||||
Query_small.append(i['Query_small'])
|
||||
return Query_small, Result_small, query
|
||||
|
||||
|
||||
async def Summary_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
query = re.findall(r'"query": (.*?),', messages)[0]
|
||||
query = query.replace('[', '').replace(']', '').strip()
|
||||
matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages)
|
||||
answer_small_texts = []
|
||||
for m in matches:
|
||||
try:
|
||||
parsed = json.loads(m)
|
||||
for item in parsed:
|
||||
answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', ''))
|
||||
except Exception:
|
||||
answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', ''))
|
||||
|
||||
return answer_small_texts, query
|
||||
|
||||
|
||||
async def VerifyTool_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
content_messages = messages.split('"context":')[1].replace('""', '"')
|
||||
messages = str(content_messages).split("name='Retrieve'")[0]
|
||||
query = re.findall(f'"Query": "(.*?)"', messages)[0]
|
||||
Query_small = re.findall(f'"Query_small": "(.*?)"', messages)
|
||||
Result_small = re.findall(f'"Result_small": "(.*?)"', messages)
|
||||
return Query_small, Result_small, query
|
||||
|
||||
|
||||
async def Retrieve_Summary_messages_deal(context):
|
||||
pass
|
||||
|
||||
|
||||
async def Retrieve_verify_tool_messages_deal(context, history, query):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
results = []
|
||||
# 统一转为字符串,避免 None 或非字符串导致正则报错
|
||||
text = str(context)
|
||||
blocks = re.findall(r'\{(.*?)\}', text, flags=re.S)
|
||||
for block in blocks:
|
||||
query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block)
|
||||
answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block)
|
||||
status = re.search(r'"status"\s*:\s*"([^"]*)"', block)
|
||||
query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block)
|
||||
|
||||
results.append({
|
||||
"query_small": query_small.group(1) if query_small else None,
|
||||
"answer_small": answer_small.group(1) if answer_small else None,
|
||||
# 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误
|
||||
"status": status.group(1) if status else "",
|
||||
"query_answer": query_answer.group(1) if query_answer else None
|
||||
})
|
||||
result = []
|
||||
for r in results:
|
||||
# 统一按字符串判定状态,兼容大小写和缺失情况
|
||||
status_str = str(r.get('status', '')).strip().lower()
|
||||
if status_str == 'false':
|
||||
continue
|
||||
else:
|
||||
result.append(r)
|
||||
split_result = 'failed' if not result else 'success'
|
||||
result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "",
|
||||
"history": history}
|
||||
return result
|
||||
38
api/app/core/memory/agent/utils/model_tool.py
Normal file
38
api/app/core/memory/agent/utils/model_tool.py
Normal file
@@ -0,0 +1,38 @@
|
||||
|
||||
|
||||
# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# sys.path.insert(0, project_root)
|
||||
|
||||
# load_dotenv()
|
||||
|
||||
# async def llm_client_chat(messages: List[dict]) -> str:
|
||||
# """使用 OpenAI 兼容接口进行对话,返回内容字符串。"""
|
||||
# try:
|
||||
# cfg = get_model_config(SELECTED_LLM_ID)
|
||||
# rb_config = RedBearModelConfig(
|
||||
# model_name=cfg["model_name"],
|
||||
# provider=cfg["provider"],
|
||||
# api_key=cfg["api_key"],
|
||||
# base_url=cfg["base_url"],
|
||||
# )
|
||||
# client = OpenAIClient(model_config=rb_config, type_="chat")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取模型配置失败:{e}")
|
||||
# err = f"获取模型配置失败:{str(e)}。请检查!!!"
|
||||
# return err
|
||||
# try:
|
||||
# response = await client.chat(messages)
|
||||
# print(f"model_tool's llm_client_chat response ======>:\n {response}")
|
||||
# return _extract_content(response)
|
||||
# # return _extract_content(result)
|
||||
# except Exception as e:
|
||||
# logger.error(f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。")
|
||||
# return f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。"
|
||||
|
||||
# async def main(image_url):
|
||||
# await llm_client_chat(image_url)
|
||||
#
|
||||
# # 运行主函数
|
||||
# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav']))
|
||||
#
|
||||
131
api/app/core/memory/agent/utils/multimodal.py
Normal file
131
api/app/core/memory/agent/utils/multimodal.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Multimodal input processor for handling image and audio content.
|
||||
|
||||
This module provides utilities for detecting and processing multimodal inputs
|
||||
(images and audio files) by converting them to text using appropriate models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from app.core.memory.agent.multimodal.speech_model import Vico_recognition
|
||||
from app.core.memory.agent.utils.llm_tools import picture_model_requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultimodalProcessor:
|
||||
"""
|
||||
Processor for handling multimodal inputs (images and audio).
|
||||
|
||||
This class detects image and audio file paths in input content and converts
|
||||
them to text using appropriate recognition models.
|
||||
"""
|
||||
|
||||
# Supported file extensions
|
||||
IMAGE_EXTENSIONS = ['.jpg', '.png']
|
||||
AUDIO_EXTENSIONS = [
|
||||
'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov',
|
||||
'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv'
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the multimodal processor."""
|
||||
pass
|
||||
|
||||
def is_image(self, content: str) -> bool:
|
||||
"""
|
||||
Check if content is an image file path.
|
||||
|
||||
Args:
|
||||
content: Input string to check
|
||||
|
||||
Returns:
|
||||
True if content ends with a supported image extension
|
||||
|
||||
Examples:
|
||||
>>> processor = MultimodalProcessor()
|
||||
>>> processor.is_image("photo.jpg")
|
||||
True
|
||||
>>> processor.is_image("document.pdf")
|
||||
False
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
|
||||
content_lower = content.lower()
|
||||
return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS)
|
||||
|
||||
def is_audio(self, content: str) -> bool:
|
||||
"""
|
||||
Check if content is an audio file path.
|
||||
|
||||
Args:
|
||||
content: Input string to check
|
||||
|
||||
Returns:
|
||||
True if content ends with a supported audio extension
|
||||
|
||||
Examples:
|
||||
>>> processor = MultimodalProcessor()
|
||||
>>> processor.is_audio("recording.mp3")
|
||||
True
|
||||
>>> processor.is_audio("video.mp4")
|
||||
True
|
||||
>>> processor.is_audio("document.txt")
|
||||
False
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
|
||||
content_lower = content.lower()
|
||||
return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS)
|
||||
|
||||
async def process_input(self, content: str) -> str:
|
||||
"""
|
||||
Process input content, converting images/audio to text if needed.
|
||||
|
||||
This method detects if the input is an image or audio file and converts
|
||||
it to text using the appropriate recognition model. If processing fails
|
||||
or the content is not multimodal, it returns the original content.
|
||||
|
||||
Args:
|
||||
content: Input string (may be file path or regular text)
|
||||
|
||||
Returns:
|
||||
Text content (original or converted from image/audio)
|
||||
|
||||
Examples:
|
||||
>>> processor = MultimodalProcessor()
|
||||
>>> await processor.process_input("photo.jpg")
|
||||
"Recognized text from image..."
|
||||
|
||||
>>> await processor.process_input("Hello world")
|
||||
"Hello world"
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}")
|
||||
return str(content)
|
||||
|
||||
try:
|
||||
# Check for image input
|
||||
if self.is_image(content):
|
||||
logger.info(f"[MultimodalProcessor] Detected image input: {content}")
|
||||
result = await picture_model_requests(content)
|
||||
logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...")
|
||||
return result
|
||||
|
||||
# Check for audio input
|
||||
if self.is_audio(content):
|
||||
logger.info(f"[MultimodalProcessor] Detected audio input: {content}")
|
||||
result = await Vico_recognition([content]).run()
|
||||
logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True)
|
||||
logger.info(f"[MultimodalProcessor] Falling back to original content")
|
||||
return content
|
||||
|
||||
# Return original content if not multimodal
|
||||
return content
|
||||
@@ -0,0 +1,81 @@
|
||||
|
||||
你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则:
|
||||
|
||||
角色:
|
||||
- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。
|
||||
- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。
|
||||
- 如果历史信息或上下文与当前问题无关,可忽略。
|
||||
|
||||
---
|
||||
|
||||
### 历史信息参考
|
||||
在生成扩展问题时,你可以参考以下历史数据(如果提供):
|
||||
- 历史对话或任务的主题;
|
||||
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
|
||||
- 历史中已解答的问题(避免重复);
|
||||
- 历史推理链(保持逻辑一致性)。
|
||||
|
||||
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
|
||||
输入历史信息内容:{{history}}
|
||||
|
||||
## User Input
|
||||
{% if questions is string %}
|
||||
{{ questions }}
|
||||
{% else %}
|
||||
{% for question in questions %}
|
||||
- {{ question }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
需求:
|
||||
- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。
|
||||
- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。
|
||||
- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。
|
||||
- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。
|
||||
- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。
|
||||
- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。
|
||||
- 子问题数量不超过4个。
|
||||
- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
|
||||
|
||||
输出要求:
|
||||
- 仅输出 JSON 数组,不要包含任何解释或代码块。
|
||||
- 每个元素包含:
|
||||
- `original_question`: 原始问题
|
||||
- `extended_question`: 扩展后的问题
|
||||
- `type`: 类型(事实检索/澄清/定义/比较/行动建议)
|
||||
- `reason`: 生成该扩展问题的简短理由
|
||||
- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。
|
||||
|
||||
示例:
|
||||
输入:
|
||||
[
|
||||
"问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳",
|
||||
]
|
||||
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
|
||||
"extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?",
|
||||
"type": "多跳",
|
||||
"reason": "输出原问题的关键要素"
|
||||
},
|
||||
{
|
||||
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
|
||||
"extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?",
|
||||
"type": "多跳",
|
||||
"reason": "输出原问题的关键要素"
|
||||
}
|
||||
]
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
|
||||
The output language should always be the same as the input language.{{ json_schema }}
|
||||
@@ -0,0 +1,37 @@
|
||||
# 角色
|
||||
你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。
|
||||
|
||||
# 任务
|
||||
根据提供的上下文信息回答用户的问题。
|
||||
|
||||
# 输入信息
|
||||
- 历史对话:{{history}}
|
||||
- 检索信息:{{retrieve_info}}
|
||||
|
||||
## User Query
|
||||
{{query}}
|
||||
|
||||
# 回答指南
|
||||
1. 仔细分析用户的问题
|
||||
2. 优先使用检索信息中的相关内容回答
|
||||
3. 结合历史对话提供连贯的回复
|
||||
4. 如果信息不足:
|
||||
- 对于简单问候或日常对话,给出自然简短的回复
|
||||
- 对于复杂问题,诚实说明信息不足
|
||||
5. 保持回答简洁、相关、自然
|
||||
6. 使用与问题相同的语言回答
|
||||
|
||||
**Output format**
|
||||
- 直接回答问题,像人类对话一样自然流畅
|
||||
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
|
||||
- 不要解释推理过程或评论信息来源
|
||||
- 如果只能部分回答问题,先回答能回答的部分,然后说明哪些方面信息不足
|
||||
- 如果完全无法回答,简洁地说明:"信息不足,无法回答。"
|
||||
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
|
||||
The output language should always be the same as the input language.{{ json_schema }}
|
||||
@@ -0,0 +1,29 @@
|
||||
|
||||
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
|
||||
你是一个智能问答助手,任务如下
|
||||
## 目标:
|
||||
|
||||
1. 接收一个字典,格式为 {'问题': [答案列表]}。
|
||||
2. 接收一个问题(字典中的 key)。
|
||||
3. 找到与问题匹配的答案列表。
|
||||
4. 将答案列表合并成一句自然流畅的话:
|
||||
- 如果答案有两条,使用“是”连接,例如:“A,是B”。
|
||||
- 如果答案有三条或以上,使用“,并且”“另外”等自然连词,保证句子流畅。
|
||||
5. 输出内容时只输出合并后的答案,不输出关键点或其他文字。
|
||||
6. 如果问题未在字典中找到对应答案,请输出:
|
||||
对不起,我没有找到相关信息。
|
||||
|
||||
|
||||
输出要求:
|
||||
- 文本形式
|
||||
---
|
||||
|
||||
字典示例:
|
||||
{
|
||||
'今天的天气怎么样': ['今天天气很好', '今天是晴天']
|
||||
}
|
||||
|
||||
问题示例:
|
||||
今天的天气怎么样
|
||||
输出要求:
|
||||
今天天气很好,是晴天
|
||||
@@ -0,0 +1,10 @@
|
||||
请提图像内的文本
|
||||
返回数据格式以json方式输出,
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
- 关键的JSON格式要求{"statement":识别出的文本内容}
|
||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
@@ -0,0 +1,34 @@
|
||||
你是一个输入分类助手,负责判断用户输入的意图类型。
|
||||
|
||||
## User Input
|
||||
{{ user_query }}
|
||||
|
||||
请你根据以下规则判断:
|
||||
1. 如果输入是在寻求信息、提问、请求解释、或疑问句(包括隐含的问题),则分类为 "question"。
|
||||
2. 如果输入是命令、陈述、描述、感叹、或其他类型,不在寻求答案,则分类为 "other"。
|
||||
只输出:
|
||||
{
|
||||
"type": "question"
|
||||
}
|
||||
或
|
||||
{
|
||||
"type": "other"
|
||||
}
|
||||
示例:
|
||||
输入:"Python怎么读取文件?"
|
||||
输出:{"type": "question"}
|
||||
|
||||
输入:"帮我写个读取文件的函数"
|
||||
输出:{"type": "other"}
|
||||
|
||||
输入:"今天是星期几?"
|
||||
输出:{"type": "question"}
|
||||
返回数据格式以json方式输出,
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
- 关键的JSON格式要求{"statement":识别出的文本内容}
|
||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
@@ -0,0 +1,160 @@
|
||||
|
||||
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
|
||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||
## 目标:
|
||||
你需要根据以下类型对输入数据进行分类,并生成相应的拆分策略和示例。
|
||||
---
|
||||
|
||||
### 历史信息参考
|
||||
在生成扩展问题时,你可以参考以下历史数据(如果提供):
|
||||
- 历史对话或任务的主题;
|
||||
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
|
||||
- 历史中已解答的问题(避免重复);
|
||||
- 历史推理链(保持逻辑一致性)。
|
||||
|
||||
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
|
||||
输入历史信息内容:{{history}}
|
||||
|
||||
## User Input
|
||||
{{ sentence }}
|
||||
|
||||
## 需求:
|
||||
1:首先判断类型(单跳、多跳、开放域、时间)。
|
||||
2:根据类型进行拆分。
|
||||
3:拆分后的内容需保证信息完整且可独立处理。
|
||||
4:对每个拆分条目,可附加示例或说明。
|
||||
5:拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
## 指令:
|
||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||
单跳(Single-hop)
|
||||
描述:问题或数据只需要通过一步即可得到答案或完成拆分,不依赖其他信息。
|
||||
拆分策略:直接识别核心信息或关键字段,生成可独立处理的片段。
|
||||
示例:
|
||||
输入数据:"请列出今年诺贝尔物理学奖的得主"
|
||||
拆分结果:[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": "今年诺贝尔物理学奖得主是谁",
|
||||
"type": "单跳’",
|
||||
}
|
||||
]
|
||||
注意: 当遇到上下文依赖问题时,明确指出缺失的信息类型并且,question可填写输入问题
|
||||
多跳(Multi-hop):
|
||||
描述:问题或数据需要通过多步推理或跨多个信息源才能得到答案。
|
||||
拆分策略:将问题拆解为多个子问题,每个子问题对应一个独立处理步骤,需要具备推理链条与逻辑连接数量。
|
||||
示例:
|
||||
输入数据:"今年诺贝尔物理学奖得主的研究领域及代表性成果"
|
||||
拆分结果:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 今年诺贝尔物理学奖得主是谁?",
|
||||
"type": "多跳’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "该得主的研究领域是什么?",
|
||||
"type": "多跳’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "该得主的代表性成果有哪些?",
|
||||
"type": "多跳’"
|
||||
}
|
||||
]
|
||||
开放域(Open-domain):
|
||||
描述:问题或数据不局限于特定知识库,需要从大范围信息中检索和生成答案,而不是从一个已知的小范围数据源中查找。。
|
||||
拆分策略:根据主题或关键实体拆分,同时保留上下文以便检索外部知识,问题涉及一般性、常识性、跨学科内容,可能是开放式回答(描述性、推理性、综合性)
|
||||
需要外部知识检索或推理才能确定,比如:“为什么人类需要睡眠?”、“量子计算与经典计算的主要区别是什么?”。
|
||||
示例:
|
||||
输入数据:"介绍量子计算的最新研究进展"
|
||||
拆分结果:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 量子计算的基本概念是什么?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "当前量子计算的主要研究方向有哪些?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "近期在量子计算领域有哪些重大进展?",
|
||||
"type": "开放域’",
|
||||
}
|
||||
]
|
||||
|
||||
时间(Temporal):
|
||||
描述:问题或数据涉及时间维度,需要按时间顺序或时间点拆分。
|
||||
拆分策略:根据事件时间或时间段拆分为独立条目或问题。
|
||||
示例:
|
||||
输入数据:"列出苹果公司过去五年的重大事件"
|
||||
拆分结果:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 苹果公司2019年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "苹果公司2020年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "苹果公司2021年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "苹果公司2022年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
}
|
||||
,
|
||||
{
|
||||
"id": "Q4",
|
||||
"question": "苹果公司2023年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
}
|
||||
]
|
||||
|
||||
输出要求:
|
||||
- 每个子问题包括:
|
||||
- `id`: 子问题编号(Q1, Q2...)
|
||||
- `question`: 子问题内容
|
||||
- `type`: 类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)
|
||||
- `reason`: 拆分的理由(为什么要这样拆)
|
||||
- 格式案例:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 量子计算的基本概念是什么?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "当前量子计算的主要研究方向有哪些?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "近期在量子计算领域有哪些重大进展?",
|
||||
"type": "开放域’",
|
||||
}
|
||||
]
|
||||
- 必须通过json.loads()的格式支持的形式输出
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
- 关键的JSON格式要求
|
||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
@@ -0,0 +1,60 @@
|
||||
# 角色
|
||||
你是验证专家
|
||||
你的目标是针对用户的输入Query_Samll字段的提问和Answer_Samll的回答分析,是不是回答Query_Samll这个字段的问题
|
||||
|
||||
{#以下可以采用先总括,再展开详细说明的方式,描述你希望智能体在每一个步骤如何进行工作,具体的工作步骤数量可以根据实际需求增删#}
|
||||
## 工作步骤
|
||||
1. 获取所有的Query_Samll字段和Answer_Samll字段
|
||||
2. 分析Answer_Samll的回复是不是和Query_Samll有关系
|
||||
3. 判断Answer_Samll和Query_Samll之间分析出来的关系状态
|
||||
4. 如果是True保留,否则不要相对应的问题和回答
|
||||
5. 输出,需要严格按照模版
|
||||
输入:{{history}}
|
||||
历史消息:{"history":{{sentence}}}
|
||||
### 第一步 获取用户的输入
|
||||
获取用户的输入提取对应的Query_Samll和Answer_Samll
|
||||
### 第二步 分析验证
|
||||
需要分析Query_Samll和Answer_Samll之间的关系可以参考history字段的内容,如果有关系不是答非所问
|
||||
## 核心验证标准
|
||||
在评估子问题拆分时,必须严格遵循以下标准,且验证过程中完全不依赖于子问题的相关信息(Answer_Samll):
|
||||
1. 合理性标准(必须全部满足):
|
||||
- 完整性:每个不同的子问题必须完整覆盖原问题的所有关键要素(如时间、主体、动作、目标等),无遗漏。
|
||||
- 最小化:每个不同的子问题数量应尽可能少,通常不超过原问题关键要素数量的2倍(建议2-4个),避免冗余和不必要拆分。
|
||||
- 相关性:每个不同的子问题必须直接服务于原问题的解答,不引入无关内容或扩展原问题未提及的主题。
|
||||
- 可操作性:每个不同的子问题应能在有限资源(如标准工具或合理时间)内独立解答,且难度适中。
|
||||
- 逻辑性:每个不同的子问题间应有清晰的逻辑关系(如并列、递进、因果),共同构成原问题的解答路径。
|
||||
|
||||
2. 不合理拆分的特征(出现任一特征即为不合理):
|
||||
- 不同的子问题数量超过5个或明显多于必要数量。
|
||||
- 引入原问题未提及的新主题、人物、细节或个人看法。
|
||||
- 拆分过于细碎,失去实用价值,无法高效合成原问题答案。
|
||||
|
||||
3. 特殊情况说明:
|
||||
- 每个不同的子问题与原问题相同,需进一步判断:
|
||||
- 每个不同的子问题不可进一步拆分 → success(合理,最小化拆分)
|
||||
- 每个不同的子问题能够进一步拆分为更小、更合理的问题 → failed(不合理,拆分没有最小化)
|
||||
- 每个不同的子问题数量=原问题核心要素数量 → success(理想情况)
|
||||
- 每个不同的子问题数量=核心要素数量+1 → success(通常合理)
|
||||
|
||||
### 第三步 添加状态
|
||||
如果有相关性并且比较高给一个状态TRUE,否则给一个FLASE的状态
|
||||
### 第四步 判断
|
||||
如果状态是TRUE保留这条数据,否则需不需要这条数据
|
||||
### 第五步 输出格式
|
||||
按照json的形式输出
|
||||
{"data":"Query":原来Query的字段,"history":原来的history字段,
|
||||
"expansion_issue":以为列表的形式存储验证之后的数据比如[
|
||||
{"query_small": query_small,
|
||||
"answer_small": answer_small,,
|
||||
"status": 回答的结果是否符合query_small,填写状态,
|
||||
"query_answer": answer_small},
|
||||
{
|
||||
"query_small": "张曼婷生日是什么时候?",
|
||||
"answer_small": "张曼婷喜欢绘画。",
|
||||
"status": "True",
|
||||
"query_answer": "张曼 婷喜欢绘画。"
|
||||
},{}......]
|
||||
,
|
||||
"split_result":如果expansion_issue是空的列表返回failed,不是空列表返回success,
|
||||
"reason": 为以上分析完之后的结果给一个说明
|
||||
}
|
||||
57
api/app/core/memory/agent/utils/prompt/summary_prompt.jinja2
Normal file
57
api/app/core/memory/agent/utils/prompt/summary_prompt.jinja2
Normal file
@@ -0,0 +1,57 @@
|
||||
{# 角色定义 #}
|
||||
你是专业的问题解答专家,负责根据上下文信息和检索到的所有信息准确回答用户的问题。
|
||||
|
||||
{# 输入数据展示 #}
|
||||
{% if data %}
|
||||
## 输入数据
|
||||
上下文信息:
|
||||
{% for item in data.history %}
|
||||
- {{ item }}
|
||||
{% endfor %}
|
||||
检索到的所有信息:
|
||||
{% for item in data.retrieve_info %}
|
||||
- {{ item }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## User Query
|
||||
{{ query }}
|
||||
|
||||
{# 问题回答标准 #}
|
||||
## 问题回答核心标准
|
||||
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。注意,若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
|
||||
- 若能根据已有信息回答用户的问题,应根据上下文信息和检索到的所有信息提供简明扼要的答案。
|
||||
- 若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
|
||||
|
||||
{# 重要提醒 #}
|
||||
再次提醒,给出问题的答案时,仅根据已有的信息进行回答,不能自己编造答案。
|
||||
|
||||
{# 输出格式模板 #}
|
||||
## 输出格式
|
||||
严格按照以下JSON格式输出,不添加任何其他内容:
|
||||
{
|
||||
"data": {
|
||||
"query": "{{ query }}",
|
||||
"history": [
|
||||
{% for item in data.history %}
|
||||
"{{ item | replace('"', '\\"') }}"
|
||||
{% if not loop.last %},{% endif %}
|
||||
{% endfor %}
|
||||
],
|
||||
"retrieve_info": [
|
||||
{% for item in data.retrieve_info %}
|
||||
"{{ item | replace('"', '\\"') }}"
|
||||
{% if not loop.last %},{% endif %}
|
||||
{% endfor %}
|
||||
]
|
||||
},
|
||||
"query_answer": "{% if not data.history and not data.retrieve_info %}信息不足,无法回答。{% endif %}"
|
||||
}
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
|
||||
The output language should always be the same as the input language.{{ json_schema }}
|
||||
203
api/app/core/memory/agent/utils/redis_tool.py
Normal file
203
api/app/core/memory/agent/utils/redis_tool.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import redis
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.core.config import settings
|
||||
class RedisSessionStore:
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None,session_id=''):
|
||||
self.r = redis.Redis(host=host, port=port, db=db, password=password)
|
||||
self.uudi=session_id
|
||||
|
||||
|
||||
# 修改后的 save_session 方法
|
||||
def save_session(self, userid, messages, aimessages, apply_id, group_id):
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
||||
|
||||
# 使用 Hash 存储结构化数据
|
||||
result = self.r.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": starttime
|
||||
})
|
||||
print(f"保存结果: {result}, session_id: {session_id}")
|
||||
return session_id # 返回新生成的 session_id
|
||||
except Exception as e:
|
||||
print(f"保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ---------------- 读取 ----------------
|
||||
def get_session(self, session_id):
|
||||
"""
|
||||
读取一条会话数据
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
data = self.r.hgetall(key)
|
||||
if data:
|
||||
return {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||
return None
|
||||
|
||||
def get_session_apply_group(self, sessionid, apply_id, group_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key_bytes in self.r.keys('session:*'):
|
||||
key = key_bytes.decode('utf-8')
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 解码数据
|
||||
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (decoded_data.get('sessionid') == sessionid and
|
||||
decoded_data.get('apply_id') == apply_id and
|
||||
decoded_data.get('group_id') == group_id):
|
||||
result_items.append(decoded_data)
|
||||
|
||||
return result_items
|
||||
|
||||
def get_all_sessions(self):
|
||||
"""
|
||||
获取所有会话数据
|
||||
"""
|
||||
sessions = {}
|
||||
for key in self.r.keys('session:*'):
|
||||
sid = key.decode('utf-8').split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
# ---------------- 更新 ----------------
|
||||
def update_session(self, session_id, field, value):
|
||||
"""
|
||||
更新单个字段
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
if self.r.exists(key):
|
||||
self.r.hset(key, field, value)
|
||||
return True
|
||||
return False
|
||||
|
||||
# ---------------- 删除 ----------------
|
||||
def delete_session(self, session_id):
|
||||
"""
|
||||
删除单条会话
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
return self.r.delete(key)
|
||||
|
||||
def delete_all_sessions(self):
|
||||
"""
|
||||
删除所有会话
|
||||
"""
|
||||
keys = self.r.keys('session:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
def delete_duplicate_sessions(self):
|
||||
"""
|
||||
删除重复会话数据,条件:
|
||||
"sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
"""
|
||||
seen = set() # 用来记录已出现的唯一组合
|
||||
deleted_count = 0
|
||||
|
||||
for key_bytes in self.r.keys('session:*'):
|
||||
key = key_bytes.decode('utf-8')
|
||||
data = self.r.hgetall(key)
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 获取五个字段的值并解码
|
||||
sessionid = data.get(b'sessionid', b'').decode('utf-8')
|
||||
user_id = data.get(b'id', b'').decode('utf-8') # 对应user_id
|
||||
group_id = data.get(b'group_id', b'').decode('utf-8')
|
||||
messages = data.get(b'messages', b'').decode('utf-8')
|
||||
aimessages = data.get(b'aimessages', b'').decode('utf-8')
|
||||
|
||||
# 用五元组作为唯一标识
|
||||
identifier = (sessionid, user_id, group_id, messages, aimessages)
|
||||
|
||||
if identifier in seen:
|
||||
# 重复,删除该 key
|
||||
self.r.delete(key)
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 第一次出现,加入 seen
|
||||
seen.add(identifier)
|
||||
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}")
|
||||
return deleted_count
|
||||
|
||||
def find_user_session(self,sessionid):
|
||||
user_id = sessionid
|
||||
|
||||
result_items = []
|
||||
for key, values in store.get_all_sessions().items():
|
||||
history = {}
|
||||
if user_id == str(values['sessionid']):
|
||||
history["Query"] = values['messages']
|
||||
history["Answer"] = values['aimessages']
|
||||
result_items.append(history)
|
||||
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
return (result_items)
|
||||
|
||||
def find_user_apply_group(self, sessionid, apply_id, group_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key_bytes in self.r.keys('session:*'):
|
||||
key = key_bytes.decode('utf-8')
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 解码数据
|
||||
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (decoded_data.get('sessionid') == sessionid and
|
||||
decoded_data.get('apply_id') == apply_id and
|
||||
decoded_data.get('group_id') == group_id):
|
||||
history = {
|
||||
"Query": decoded_data.get('messages'),
|
||||
"Answer": decoded_data.get('aimessages')
|
||||
}
|
||||
|
||||
|
||||
result_items.append(history)
|
||||
|
||||
# 如果结果少于等于1条,返回空列表
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
|
||||
return result_items
|
||||
|
||||
store = RedisSessionStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
59
api/app/core/memory/agent/utils/type_classifier.py
Normal file
59
api/app/core/memory/agent/utils/type_classifier.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Type classification utility for distinguishing read/write operations.
|
||||
"""
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class DistinguishTypeResponse(BaseModel):
|
||||
"""Response model for type classification"""
|
||||
type: str
|
||||
|
||||
|
||||
async def status_typle(messages: str) -> dict:
|
||||
"""
|
||||
Classify message type as read or write operation.
|
||||
|
||||
Args:
|
||||
messages: User message to classify
|
||||
|
||||
Returns:
|
||||
dict: Contains 'type' field with classification result
|
||||
"""
|
||||
try:
|
||||
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/distinguish_types_prompt.jinja2'
|
||||
template_content = await read_template_file(file_path)
|
||||
template = Template(template_content)
|
||||
system_prompt = template.render(user_query=messages)
|
||||
log_prompt_rendering("status_typle", system_prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"Template rendering failed for status_typle: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=DistinguishTypeResponse
|
||||
)
|
||||
return structured.model_dump()
|
||||
except Exception as e:
|
||||
logger.error(f"LLM call failed for status_typle: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"LLM call failed: {str(e)}"
|
||||
}
|
||||
76
api/app/core/memory/agent/utils/verify_tool.py
Normal file
76
api/app/core/memory/agent/utils/verify_tool.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from typing import TypedDict, Annotated, List, Any
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph, add_messages
|
||||
import asyncio
|
||||
import json
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
import os
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from langchain_core.messages import HumanMessage
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
load_dotenv(find_dotenv())
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
def keep_last(_, right):
|
||||
return right
|
||||
class State(TypedDict):
|
||||
user_input: Annotated[dict, keep_last]
|
||||
messages: Annotated[List[AnyMessage], add_messages]
|
||||
agent1_response: str
|
||||
agent2_response: str
|
||||
agent3_response: str
|
||||
final_response: str
|
||||
status: Annotated[str, keep_last]
|
||||
|
||||
|
||||
class VerifyTool:
|
||||
def __init__(self, system_prompt: str="", verify_data: Any=None):
|
||||
self.system_prompt = system_prompt
|
||||
if isinstance(verify_data, str):
|
||||
self.verify_data = verify_data
|
||||
else:
|
||||
try:
|
||||
self.verify_data = json.dumps(verify_data, ensure_ascii=False)
|
||||
except Exception:
|
||||
self.verify_data = str(verify_data)
|
||||
|
||||
async def model_1(self, state: State) -> State:
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
response_content = await llm_client.chat(
|
||||
messages=[{"role": "system", "content": self.system_prompt}] + _to_openai_messages(state["messages"])
|
||||
)
|
||||
return {
|
||||
"agent1_response": response_content,
|
||||
"status": "processed",
|
||||
}
|
||||
|
||||
|
||||
def get_graph(self):
|
||||
graph = StateGraph(State)
|
||||
graph.add_node("model_1", self.model_1)
|
||||
|
||||
graph.add_edge(START, "model_1")
|
||||
graph.add_edge("model_1", END)
|
||||
|
||||
compiled_graph = graph.compile()
|
||||
return compiled_graph
|
||||
|
||||
async def verify(self):
|
||||
graph = self.get_graph()
|
||||
initial_state = {
|
||||
"user_input": self.verify_data,
|
||||
"messages": [HumanMessage(content=self.verify_data)],
|
||||
"final_response": "",
|
||||
"status": ""
|
||||
}
|
||||
final_state = await graph.ainvoke(initial_state)
|
||||
# return final_state["final_response"]
|
||||
return final_state["agent1_response"]
|
||||
|
||||
49
api/app/core/memory/agent/utils/write_to_database.py
Normal file
49
api/app/core/memory/agent/utils/write_to_database.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from sqlalchemy.orm import Session
|
||||
import logging
|
||||
import json
|
||||
|
||||
from app.db import get_db
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def write_to_database(host_id: uuid.UUID, data: Any) -> str:
|
||||
"""
|
||||
将数据写入数据库
|
||||
:param host_id: 宿主 ID
|
||||
:param data: 要写入的数据
|
||||
:return: 写入数据库的结果
|
||||
"""
|
||||
# 从数据库会话中获取会话
|
||||
db: Session = next(get_db())
|
||||
try:
|
||||
if isinstance(data, (dict, list)):
|
||||
serialized = json.dumps(data, ensure_ascii=False)
|
||||
elif isinstance(data, str):
|
||||
serialized = data
|
||||
else:
|
||||
serialized = str(data)
|
||||
|
||||
new_retrieval_info = RetrievalInfo(
|
||||
# host_id=host_id,
|
||||
host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"),
|
||||
retrieve_info=serialized,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
db.add(new_retrieval_info)
|
||||
db.commit()
|
||||
logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}")
|
||||
return "success to write data to database"
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}")
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
183
api/app/core/memory/agent/utils/write_tools.py
Normal file
183
api/app/core/memory/agent/utils/write_tools.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation_all,
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# 导入配置模块(而不是直接导入变量)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import Memory_summary_generation
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id: str = "wyl20251027", config_id: str = None) -> None:
|
||||
"""
|
||||
执行完整的知识提取流水线(使用新的 ExtractionOrchestrator)
|
||||
|
||||
Args:
|
||||
content: 对话内容
|
||||
user_id: 用户ID
|
||||
apply_id: 应用ID
|
||||
group_id: 组ID
|
||||
ref_id: 参考ID,默认为 "wyl20251027"
|
||||
config_id: 配置ID,用于标记数据处理配置
|
||||
"""
|
||||
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
||||
logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}")
|
||||
logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}")
|
||||
logger.info(f"Using chunker strategy: {config_defs.SELECTED_CHUNKER_STRATEGY}")
|
||||
logger.info(f"Using group ID: {config_defs.SELECTED_GROUP_ID}")
|
||||
logger.info(f"Using embedding ID: {config_defs.SELECTED_EMBEDDING_ID}")
|
||||
logger.info(f"Config ID: {config_id if config_id else 'None'}")
|
||||
logger.info(f"LANGFUSE_ENABLED: {config_defs.LANGFUSE_ENABLED}")
|
||||
logger.info(f"AGENTA_ENABLED: {config_defs.AGENTA_ENABLED}")
|
||||
|
||||
# Initialize timing log
|
||||
log_file = "logs/time.log"
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
|
||||
# 初始化客户端
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
# 获取 embedder 配置
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
# Step 1: 加载和分块数据
|
||||
step_start = time.time()
|
||||
chunked_dialogs = await get_chunked_dialogs(
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
content=content,
|
||||
ref_id=ref_id,
|
||||
config_id=config_id,
|
||||
)
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# Step 2: 初始化并运行 ExtractionOrchestrator
|
||||
step_start = time.time()
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 运行完整的提取流水线
|
||||
# orchestrator.run returns a flat tuple of 7 values after deduplication
|
||||
(
|
||||
all_dialogue_nodes,
|
||||
all_chunk_nodes,
|
||||
all_statement_nodes,
|
||||
all_entity_nodes,
|
||||
all_statement_chunk_edges,
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
# Step 8: Save all data to Neo4j database using graph models
|
||||
step_start = time.time()
|
||||
# 运行索引创建
|
||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||
try:
|
||||
await create_fulltext_indexes()
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
finally:
|
||||
await neo4j_connector.close()
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
# Step 9: Generate Memory summaries and save to local vector DB and Neo4j
|
||||
step_start = time.time()
|
||||
try:
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
|
||||
)
|
||||
|
||||
# Save memory summaries to Neo4j as nodes
|
||||
try:
|
||||
ms_connector = Neo4jConnector()
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
# Link summaries to statements via chunks for summary→entity queries
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
try:
|
||||
await ms_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
finally:
|
||||
log_time("Memory Summary (Local Vector DB & Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
|
||||
|
||||
# Log total pipeline time
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||
|
||||
# Add completion marker to log
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Timing details saved to: {log_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
content = "你好,我是张三,是张曼婷的新朋友。请问张曼婷喜欢什么?"
|
||||
asyncio.run(write(content, ref_id="wyl20251027"))
|
||||
19
api/app/core/memory/llm_tools/__init__.py
Normal file
19
api/app/core/memory/llm_tools/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
LLM 工具模块
|
||||
|
||||
提供 LLM 和 Embedder 客户端的抽象基类和具体实现。
|
||||
"""
|
||||
|
||||
from app.core.memory.llm_tools.llm_client import LLMClient
|
||||
from app.core.memory.llm_tools.embedder_client import EmbedderClient
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.llm_tools.chunker_client import ChunkerClient
|
||||
|
||||
__all__ = [
|
||||
"LLMClient",
|
||||
"EmbedderClient",
|
||||
"OpenAIClient",
|
||||
"OpenAIEmbedderClient",
|
||||
"ChunkerClient",
|
||||
]
|
||||
330
api/app/core/memory/llm_tools/chunker_client.py
Normal file
330
api/app/core/memory/llm_tools/chunker_client.py
Normal file
@@ -0,0 +1,330 @@
|
||||
from typing import Any, List
|
||||
import re
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
# Fix tokenizer parallelism warning
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
from chonkie import (
|
||||
SemanticChunker,
|
||||
RecursiveChunker,
|
||||
RecursiveRules,
|
||||
LateChunker,
|
||||
NeuralChunker,
|
||||
SentenceChunker,
|
||||
TokenChunker,
|
||||
)
|
||||
|
||||
from app.core.memory.models.config_models import ChunkerConfig
|
||||
from app.core.memory.models.message_models import DialogData, Chunk
|
||||
try:
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
except Exception:
|
||||
# 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入
|
||||
OpenAIClient = Any
|
||||
|
||||
|
||||
class LLMChunker:
|
||||
"""基于LLM的智能分块策略"""
|
||||
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
||||
self.llm_client = llm_client
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
async def __call__(self, text: str) -> List[Any]:
|
||||
# 使用LLM分析文本结构并进行智能分块
|
||||
prompt = f"""
|
||||
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
|
||||
请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。
|
||||
|
||||
文本内容:
|
||||
{text[:5000]}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
# 使用异步的 achat 方法
|
||||
if hasattr(self.llm_client, 'achat'):
|
||||
response = await self.llm_client.achat(messages)
|
||||
else:
|
||||
# 如果没有异步方法,使用同步方法并转换为异步
|
||||
response = await asyncio.to_thread(self.llm_client.chat, messages)
|
||||
|
||||
# 检查响应格式并提取内容
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
elif hasattr(response, 'content'):
|
||||
content = response.content
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
# 解析LLM响应
|
||||
if "```json" in content:
|
||||
json_str = content.split("```json")[1].split("```")[0].strip()
|
||||
elif "```" in content:
|
||||
json_str = content.split("```")[1].split("```")[0].strip()
|
||||
else:
|
||||
json_str = content
|
||||
|
||||
result = json.loads(json_str)
|
||||
|
||||
class SimpleChunk:
|
||||
def __init__(self, text, index):
|
||||
self.text = text
|
||||
self.start_index = index * 100 # 近似位置
|
||||
self.end_index = (index + 1) * 100
|
||||
|
||||
return [SimpleChunk(chunk["text"], i) for i, chunk in enumerate(result.get("chunks", []))]
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM分块失败: {e}")
|
||||
# 失败时返回空列表,外层会处理回退方案
|
||||
return []
|
||||
|
||||
|
||||
class HybridChunker:
|
||||
"""混合分块策略:先按结构分块,再按语义合并"""
|
||||
def __init__(self, semantic_threshold: float = 0.8, base_chunk_size: int = 300):
|
||||
self.semantic_threshold = semantic_threshold
|
||||
self.base_chunk_size = base_chunk_size
|
||||
self.base_chunker = TokenChunker(tokenizer="character", chunk_size=base_chunk_size)
|
||||
self.semantic_chunker = SemanticChunker(threshold=semantic_threshold)
|
||||
|
||||
def __call__(self, text: str) -> List[Any]:
|
||||
# 先用基础分块
|
||||
base_chunks = self.base_chunker(text)
|
||||
|
||||
# 如果文本不长,直接返回基础分块
|
||||
if len(base_chunks) <= 3:
|
||||
return base_chunks
|
||||
|
||||
# 对基础分块进行语义合并
|
||||
combined_text = " ".join([chunk.text for chunk in base_chunks])
|
||||
return self.semantic_chunker(combined_text)
|
||||
|
||||
|
||||
class ChunkerClient:
|
||||
def __init__(self, chunker_config: ChunkerConfig, llm_client: OpenAIClient = None):
|
||||
self.chunker_config = chunker_config
|
||||
self.embedding_model = chunker_config.embedding_model
|
||||
self.chunk_size = chunker_config.chunk_size
|
||||
self.threshold = chunker_config.threshold
|
||||
self.language = chunker_config.language
|
||||
self.skip_window = chunker_config.skip_window
|
||||
self.min_sentences = chunker_config.min_sentences
|
||||
self.min_characters_per_chunk = chunker_config.min_characters_per_chunk
|
||||
self.llm_client = llm_client
|
||||
|
||||
# 可选参数(从配置中安全获取,提供默认值)
|
||||
self.chunk_overlap = getattr(chunker_config, 'chunk_overlap', 0)
|
||||
self.min_sentences_per_chunk = getattr(chunker_config, 'min_sentences_per_chunk', 1)
|
||||
self.min_characters_per_sentence = getattr(chunker_config, 'min_characters_per_sentence', 12)
|
||||
self.delim = getattr(chunker_config, 'delim', [".", "!", "?", "\n"])
|
||||
self.include_delim = getattr(chunker_config, 'include_delim', "prev")
|
||||
self.tokenizer_or_token_counter = getattr(chunker_config, 'tokenizer_or_token_counter', "character")
|
||||
|
||||
# 初始化具体分块器策略
|
||||
if chunker_config.chunker_strategy == "TokenChunker":
|
||||
self.chunker = TokenChunker(
|
||||
tokenizer=self.tokenizer_or_token_counter,
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "SemanticChunker":
|
||||
self.chunker = SemanticChunker(
|
||||
embedding_model=self.embedding_model,
|
||||
threshold=self.threshold,
|
||||
chunk_size=self.chunk_size,
|
||||
min_sentences=self.min_sentences,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "RecursiveChunker":
|
||||
self.chunker = RecursiveChunker(
|
||||
rules=RecursiveRules(),
|
||||
min_characters_per_chunk=self.min_characters_per_chunk or 50,
|
||||
chunk_size=self.chunk_size,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "LateChunker":
|
||||
self.chunker = LateChunker(
|
||||
embedding_model=self.embedding_model,
|
||||
chunk_size=self.chunk_size,
|
||||
rules=RecursiveRules(),
|
||||
min_characters_per_chunk=self.min_characters_per_chunk,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "NeuralChunker":
|
||||
self.chunker = NeuralChunker(
|
||||
model=self.embedding_model,
|
||||
min_characters_per_chunk=self.min_characters_per_chunk,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "LLMChunker":
|
||||
if not llm_client:
|
||||
raise ValueError("LLMChunker requires an LLM client")
|
||||
self.chunker = LLMChunker(llm_client, self.chunk_size)
|
||||
elif chunker_config.chunker_strategy == "HybridChunker":
|
||||
self.chunker = HybridChunker(
|
||||
semantic_threshold=self.threshold,
|
||||
base_chunk_size=self.chunk_size,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "SentenceChunker":
|
||||
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
|
||||
# 为了兼容不同版本,这里仅传递广泛支持的参数
|
||||
self.chunker = SentenceChunker(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
min_sentences_per_chunk=self.min_sentences_per_chunk,
|
||||
min_characters_per_sentence=self.min_characters_per_sentence,
|
||||
delim=self.delim,
|
||||
include_delim=self.include_delim,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown chunker strategy: {chunker_config.chunker_strategy}")
|
||||
|
||||
async def generate_chunks(self, dialogue: DialogData):
|
||||
"""
|
||||
生成分块,支持异步操作
|
||||
"""
|
||||
try:
|
||||
# 预处理文本:确保对话标记格式统一
|
||||
content = dialogue.content
|
||||
content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号
|
||||
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
|
||||
|
||||
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
|
||||
# 同步分块器
|
||||
chunks = self.chunker(content)
|
||||
else:
|
||||
# 异步分块器(如LLMChunker)
|
||||
chunks = await self.chunker(content)
|
||||
|
||||
# 过滤空块和过小的块
|
||||
valid_chunks = []
|
||||
for c in chunks:
|
||||
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
|
||||
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
|
||||
valid_chunks.append(c)
|
||||
|
||||
dialogue.chunks = [
|
||||
Chunk(
|
||||
content=c.text if hasattr(c, 'text') else str(c),
|
||||
metadata={
|
||||
"start_index": getattr(c, "start_index", None),
|
||||
"end_index": getattr(c, "end_index", None),
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
)
|
||||
for c in valid_chunks
|
||||
]
|
||||
return dialogue
|
||||
|
||||
except Exception as e:
|
||||
print(f"分块失败: {e}")
|
||||
|
||||
# 改进的后备方案:尝试按对话回合分割
|
||||
try:
|
||||
# 简单的按对话分割
|
||||
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
|
||||
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
|
||||
|
||||
class SimpleChunk:
|
||||
def __init__(self, text, start_index, end_index):
|
||||
self.text = text
|
||||
self.start_index = start_index
|
||||
self.end_index = end_index
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
current_start = 0
|
||||
|
||||
for match in matches:
|
||||
speaker, ct = match[0], match[1].strip()
|
||||
turn_text = f"{speaker} {ct}"
|
||||
|
||||
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
|
||||
if current_chunk:
|
||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||
current_chunk = turn_text
|
||||
current_start = dialogue.content.find(turn_text, current_start)
|
||||
else:
|
||||
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||
|
||||
dialogue.chunks = [
|
||||
Chunk(
|
||||
content=c.text,
|
||||
metadata={
|
||||
"start_index": c.start_index,
|
||||
"end_index": c.end_index,
|
||||
"chunker_strategy": "DialogueTurnFallback",
|
||||
},
|
||||
)
|
||||
for c in chunks
|
||||
]
|
||||
|
||||
except Exception:
|
||||
# 最后的手段:单一大块
|
||||
dialogue.chunks = [Chunk(
|
||||
content=dialogue.content,
|
||||
metadata={"chunker_strategy": "SingleChunkFallback"},
|
||||
)]
|
||||
|
||||
return dialogue
|
||||
|
||||
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||
"""
|
||||
评估分块质量
|
||||
"""
|
||||
if not getattr(dialogue, 'chunks', None):
|
||||
return {}
|
||||
|
||||
chunks = dialogue.chunks
|
||||
total_chars = sum(len(chunk.content) for chunk in chunks)
|
||||
avg_chunk_size = total_chars / len(chunks)
|
||||
|
||||
# 计算各种指标
|
||||
chunk_sizes = [len(chunk.content) for chunk in chunks]
|
||||
|
||||
metrics = {
|
||||
"strategy": self.chunker_config.chunker_strategy,
|
||||
"num_chunks": len(chunks),
|
||||
"total_characters": total_chars,
|
||||
"avg_chunk_size": avg_chunk_size,
|
||||
"min_chunk_size": min(chunk_sizes),
|
||||
"max_chunk_size": max(chunk_sizes),
|
||||
"chunk_size_std": np.std(chunk_sizes) if len(chunk_sizes) > 1 else 0,
|
||||
"coverage_ratio": total_chars / len(dialogue.content) if dialogue.content else 0,
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
def save_chunking_results(self, dialogue: DialogData, output_path: str):
|
||||
"""
|
||||
保存分块结果到文件,文件名包含策略名称
|
||||
"""
|
||||
strategy_name = self.chunker_config.chunker_strategy
|
||||
# 在文件名中添加策略名称
|
||||
base_name, ext = os.path.splitext(output_path)
|
||||
strategy_output_path = f"{base_name}_{strategy_name}{ext}"
|
||||
|
||||
with open(strategy_output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"=== Chunking Strategy: {strategy_name} ===\n")
|
||||
f.write(f"Total chunks: {len(dialogue.chunks)}\n")
|
||||
f.write(f"Total characters: {sum(len(chunk.content) for chunk in dialogue.chunks)}\n")
|
||||
f.write("=" * 60 + "\n\n")
|
||||
|
||||
for i, chunk in enumerate(dialogue.chunks):
|
||||
f.write(f"Chunk {i+1}:\n")
|
||||
f.write(f"Size: {len(chunk.content)} characters\n")
|
||||
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
|
||||
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")
|
||||
f.write(f"Content: {chunk.content}\n")
|
||||
f.write("-" * 40 + "\n\n")
|
||||
|
||||
print(f"Chunking results saved to: {strategy_output_path}")
|
||||
return strategy_output_path
|
||||
176
api/app/core/memory/llm_tools/embedder_client.py
Normal file
176
api/app/core/memory/llm_tools/embedder_client.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Embedder 客户端抽象基类
|
||||
|
||||
提供统一的嵌入向量生成接口,支持重试机制和错误处理。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
import asyncio
|
||||
import logging
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
before_sleep_log,
|
||||
)
|
||||
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbedderClientException(BusinessException):
|
||||
"""Embedder 客户端异常"""
|
||||
def __init__(self, message: str, code: str = BizCode.EMBEDDING_ERROR):
|
||||
super().__init__(message, code=code)
|
||||
|
||||
|
||||
class EmbedderClient(ABC):
|
||||
"""
|
||||
Embedder 客户端抽象基类
|
||||
|
||||
提供统一的嵌入向量生成接口,包括:
|
||||
- 批量文本嵌入(response)
|
||||
- 自动重试机制
|
||||
- 错误处理
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
"""
|
||||
初始化 Embedder 客户端
|
||||
|
||||
Args:
|
||||
model_config: 模型配置,包含模型名称、提供商、API密钥等信息
|
||||
"""
|
||||
self.config = model_config
|
||||
self.model_name = model_config.model_name
|
||||
self.provider = model_config.provider
|
||||
self.api_key = model_config.api_key
|
||||
self.base_url = model_config.base_url
|
||||
self.max_retries = model_config.max_retries
|
||||
self.timeout = model_config.timeout
|
||||
|
||||
logger.info(
|
||||
f"初始化 Embedder 客户端: provider={self.provider}, "
|
||||
f"model={self.model_name}, max_retries={self.max_retries}"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def response(
|
||||
self,
|
||||
messages: List[str],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
生成嵌入向量
|
||||
|
||||
Args:
|
||||
messages: 文本列表
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
嵌入向量列表,每个向量是一个浮点数列表
|
||||
|
||||
Raises:
|
||||
EmbedderClientException: 嵌入向量生成失败
|
||||
"""
|
||||
pass
|
||||
|
||||
def _create_retry_decorator(self):
|
||||
"""
|
||||
创建重试装饰器
|
||||
|
||||
Returns:
|
||||
配置好的 tenacity retry 装饰器
|
||||
"""
|
||||
return retry(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=2, max=10),
|
||||
retry=retry_if_exception_type((
|
||||
asyncio.TimeoutError,
|
||||
ConnectionError,
|
||||
Exception, # 可以根据需要细化异常类型
|
||||
)),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=True,
|
||||
)
|
||||
|
||||
async def response_with_retry(
|
||||
self,
|
||||
messages: List[str],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
带重试机制的嵌入向量生成接口
|
||||
|
||||
Args:
|
||||
messages: 文本列表
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
嵌入向量列表
|
||||
|
||||
Raises:
|
||||
EmbedderClientException: 重试失败后抛出
|
||||
"""
|
||||
retry_decorator = self._create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
async def _response_with_retry():
|
||||
try:
|
||||
return await self.response(messages, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"嵌入向量生成失败: {e}")
|
||||
raise EmbedderClientException(f"嵌入向量生成失败: {e}") from e
|
||||
|
||||
return await _response_with_retry()
|
||||
|
||||
async def embed_single(self, text: str, **kwargs) -> List[float]:
|
||||
"""
|
||||
为单个文本生成嵌入向量
|
||||
|
||||
Args:
|
||||
text: 单个文本
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
嵌入向量(浮点数列表)
|
||||
|
||||
Raises:
|
||||
EmbedderClientException: 嵌入向量生成失败
|
||||
"""
|
||||
embeddings = await self.response_with_retry([text], **kwargs)
|
||||
return embeddings[0] if embeddings else []
|
||||
|
||||
async def embed_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int = 100,
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
批量生成嵌入向量(支持大批量文本)
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
batch_size: 每批处理的文本数量
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
嵌入向量列表
|
||||
|
||||
Raises:
|
||||
EmbedderClientException: 嵌入向量生成失败
|
||||
"""
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
batch_embeddings = await self.response_with_retry(batch, **kwargs)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
return all_embeddings
|
||||
187
api/app/core/memory/llm_tools/llm_client.py
Normal file
187
api/app/core/memory/llm_tools/llm_client.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
LLM 客户端抽象基类
|
||||
|
||||
提供统一的 LLM 调用接口,支持重试机制和错误处理。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
import asyncio
|
||||
import logging
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
before_sleep_log,
|
||||
)
|
||||
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClientException(BusinessException):
|
||||
"""LLM 客户端异常"""
|
||||
def __init__(self, message: str, code: str = BizCode.LLM_ERROR):
|
||||
super().__init__(message, code=code)
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
"""
|
||||
LLM 客户端抽象基类
|
||||
|
||||
提供统一的 LLM 调用接口,包括:
|
||||
- 聊天接口(chat)
|
||||
- 结构化输出接口(response_structured)
|
||||
- 自动重试机制
|
||||
- 错误处理
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
"""
|
||||
初始化 LLM 客户端
|
||||
|
||||
Args:
|
||||
model_config: 模型配置,包含模型名称、提供商、API密钥等信息
|
||||
"""
|
||||
self.config = model_config
|
||||
self.model_name = self.config.model_name
|
||||
self.provider = self.config.provider
|
||||
self.api_key = self.config.api_key
|
||||
self.base_url = self.config.base_url
|
||||
self.max_retries = self.config.max_retries
|
||||
self.timeout = self.config.timeout
|
||||
|
||||
logger.info(
|
||||
f"初始化 LLM 客户端: provider={self.provider}, "
|
||||
f"model={self.model_name}, max_retries={self.max_retries}"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||
"""
|
||||
聊天接口
|
||||
|
||||
Args:
|
||||
messages: 消息列表,每个消息包含 role 和 content
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
LLM 响应内容
|
||||
|
||||
Raises:
|
||||
LLMClientException: LLM 调用失败
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def response_structured(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_model: type[BaseModel],
|
||||
**kwargs
|
||||
) -> BaseModel:
|
||||
"""
|
||||
结构化输出接口
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
response_model: 期望的响应模型类型(Pydantic BaseModel)
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
解析后的 Pydantic 模型实例
|
||||
|
||||
Raises:
|
||||
LLMClientException: LLM 调用或解析失败
|
||||
"""
|
||||
pass
|
||||
|
||||
def _create_retry_decorator(self):
|
||||
"""
|
||||
创建重试装饰器
|
||||
|
||||
Returns:
|
||||
配置好的 tenacity retry 装饰器
|
||||
"""
|
||||
return retry(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=2, max=10),
|
||||
retry=retry_if_exception_type((
|
||||
asyncio.TimeoutError,
|
||||
ConnectionError,
|
||||
Exception, # 可以根据需要细化异常类型
|
||||
)),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=True,
|
||||
)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
带重试机制的聊天接口
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
LLM 响应内容
|
||||
|
||||
Raises:
|
||||
LLMClientException: 重试失败后抛出
|
||||
"""
|
||||
retry_decorator = self._create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
async def _chat_with_retry():
|
||||
try:
|
||||
return await self.chat(messages, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 调用失败: {e}")
|
||||
raise LLMClientException(f"LLM 调用失败: {e}") from e
|
||||
|
||||
return await _chat_with_retry()
|
||||
|
||||
async def response_structured_with_retry(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_model: type[BaseModel],
|
||||
**kwargs
|
||||
) -> BaseModel:
|
||||
"""
|
||||
带重试机制的结构化输出接口
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
response_model: 期望的响应模型类型
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
解析后的 Pydantic 模型实例
|
||||
|
||||
Raises:
|
||||
LLMClientException: 重试失败后抛出
|
||||
"""
|
||||
retry_decorator = self._create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
async def _response_structured_with_retry():
|
||||
try:
|
||||
return await self.response_structured(
|
||||
messages,
|
||||
response_model,
|
||||
**kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 结构化输出失败: {e}")
|
||||
raise LLMClientException(f"LLM 结构化输出失败: {e}") from e
|
||||
|
||||
return await _response_structured_with_retry()
|
||||
198
api/app/core/memory/llm_tools/openai_client.py
Normal file
198
api/app/core/memory/llm_tools/openai_client.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
OpenAI LLM 客户端实现
|
||||
|
||||
基于 LangChain 和 RedBearLLM 的 OpenAI 客户端实现。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.llm import RedBearLLM
|
||||
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
|
||||
from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIClient(LLMClient):
|
||||
"""
|
||||
OpenAI LLM 客户端实现
|
||||
|
||||
基于 LangChain 和 RedBearLLM 的实现,支持:
|
||||
- 聊天接口
|
||||
- 结构化输出
|
||||
- Langfuse 追踪(可选)
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: RedBearModelConfig, type_: str = "chat"):
|
||||
"""
|
||||
初始化 OpenAI 客户端
|
||||
|
||||
Args:
|
||||
model_config: 模型配置
|
||||
type_: 模型类型,"chat" 或 "completion"
|
||||
"""
|
||||
super().__init__(model_config)
|
||||
|
||||
# 初始化 Langfuse 回调处理器(如果启用)
|
||||
self.langfuse_handler = None
|
||||
if LANGFUSE_ENABLED:
|
||||
try:
|
||||
from langfuse.langchain import CallbackHandler
|
||||
self.langfuse_handler = CallbackHandler()
|
||||
logger.info("Langfuse 追踪已启用")
|
||||
except ImportError:
|
||||
logger.warning("Langfuse 未安装,跳过追踪功能")
|
||||
except Exception as e:
|
||||
logger.warning(f"初始化 Langfuse 处理器失败: {e}")
|
||||
|
||||
# 初始化 RedBearLLM 客户端
|
||||
self.client = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=self.model_name,
|
||||
provider=self.provider,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
max_retries=self.max_retries,
|
||||
timeout=self.timeout,
|
||||
),
|
||||
type=type_
|
||||
)
|
||||
|
||||
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||
"""
|
||||
聊天接口实现
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
LLM 响应内容
|
||||
|
||||
Raises:
|
||||
LLMClientException: LLM 调用失败
|
||||
"""
|
||||
try:
|
||||
template = """{messages}"""
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
chain = prompt | self.client
|
||||
|
||||
# 添加 Langfuse 回调(如果可用)
|
||||
config = {}
|
||||
if self.langfuse_handler:
|
||||
config["callbacks"] = [self.langfuse_handler]
|
||||
|
||||
response = await chain.ainvoke({"messages": messages}, config=config)
|
||||
|
||||
logger.debug(f"LLM 响应成功: {len(str(response))} 字符")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 调用失败: {e}")
|
||||
raise LLMClientException(f"LLM 调用失败: {e}") from e
|
||||
|
||||
async def response_structured(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_model: type[BaseModel],
|
||||
**kwargs
|
||||
) -> BaseModel:
|
||||
"""
|
||||
结构化输出接口实现
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
response_model: 期望的响应模型类型
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
解析后的 Pydantic 模型实例
|
||||
|
||||
Raises:
|
||||
LLMClientException: LLM 调用或解析失败
|
||||
"""
|
||||
try:
|
||||
# 构建问题文本
|
||||
question_text = "\n\n".join([
|
||||
str(m.get("content", "")) for m in messages
|
||||
])
|
||||
|
||||
# 准备配置(包含 Langfuse 回调)
|
||||
config = {}
|
||||
if self.langfuse_handler:
|
||||
config["callbacks"] = [self.langfuse_handler]
|
||||
|
||||
# 方法 1: 使用 PydanticOutputParser
|
||||
if PydanticOutputParser is not None:
|
||||
try:
|
||||
parser = PydanticOutputParser(pydantic_object=response_model)
|
||||
format_instructions = parser.get_format_instructions()
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
"{question}\n{format_instructions}"
|
||||
)
|
||||
chain = prompt | self.client | parser
|
||||
|
||||
parsed = await chain.ainvoke(
|
||||
{
|
||||
"question": question_text,
|
||||
"format_instructions": format_instructions,
|
||||
},
|
||||
config=config
|
||||
)
|
||||
|
||||
logger.debug(f"使用 PydanticOutputParser 解析成功")
|
||||
return parsed
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PydanticOutputParser 解析失败,尝试其他方法: {e}"
|
||||
)
|
||||
|
||||
# 方法 2: 使用 LangChain 的 with_structured_output
|
||||
template = """{question}"""
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
try:
|
||||
with_so = getattr(self.client, "with_structured_output", None)
|
||||
|
||||
if callable(with_so):
|
||||
structured_chain = prompt | with_so(response_model, strict=True)
|
||||
parsed = await structured_chain.ainvoke(
|
||||
{"question": question_text},
|
||||
config=config
|
||||
)
|
||||
|
||||
# 验证并返回结果
|
||||
try:
|
||||
return response_model.model_validate(parsed)
|
||||
except Exception:
|
||||
# 如果已经是 Pydantic 实例,直接返回
|
||||
if hasattr(parsed, "model_dump"):
|
||||
return parsed
|
||||
# 尝试从 JSON 解析
|
||||
return response_model.model_validate_json(json.dumps(parsed))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}")
|
||||
raise LLMClientException(f"结构化输出失败: {e}") from e
|
||||
|
||||
# 如果所有方法都失败,抛出异常
|
||||
raise LLMClientException(
|
||||
"无法生成结构化输出,所有解析方法均失败"
|
||||
)
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出处理失败: {e}")
|
||||
raise LLMClientException(f"结构化输出处理失败: {e}") from e
|
||||
87
api/app/core/memory/llm_tools/openai_embedder.py
Normal file
87
api/app/core/memory/llm_tools/openai_embedder.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
OpenAI Embedder 客户端实现
|
||||
|
||||
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
from app.core.memory.llm_tools.embedder_client import (
|
||||
EmbedderClient,
|
||||
EmbedderClientException
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIEmbedderClient(EmbedderClient):
|
||||
"""
|
||||
OpenAI Embedder 客户端实现
|
||||
|
||||
基于 LangChain 和 RedBearEmbeddings 的实现,支持:
|
||||
- 批量文本嵌入
|
||||
- 自动重试机制
|
||||
- 错误处理
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
"""
|
||||
初始化 OpenAI Embedder 客户端
|
||||
|
||||
Args:
|
||||
model_config: 模型配置
|
||||
"""
|
||||
super().__init__(model_config)
|
||||
|
||||
# 初始化 RedBearEmbeddings 模型
|
||||
self.model = RedBearEmbeddings(
|
||||
RedBearModelConfig(
|
||||
model_name=self.model_name,
|
||||
provider=self.provider,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
max_retries=self.max_retries,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("OpenAI Embedder 客户端初始化完成")
|
||||
|
||||
async def response(
|
||||
self,
|
||||
messages: List[str],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
生成嵌入向量实现
|
||||
|
||||
Args:
|
||||
messages: 文本列表
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
嵌入向量列表
|
||||
|
||||
Raises:
|
||||
EmbedderClientException: 嵌入向量生成失败
|
||||
"""
|
||||
try:
|
||||
# 过滤空文本
|
||||
texts: List[str] = [str(m) for m in messages if m is not None]
|
||||
|
||||
if not texts:
|
||||
logger.warning("输入文本列表为空,返回空结果")
|
||||
return []
|
||||
|
||||
# 生成嵌入向量
|
||||
embeddings = await self.model.aembed_documents(texts)
|
||||
|
||||
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"嵌入向量生成失败: {e}")
|
||||
raise EmbedderClientException(f"嵌入向量生成失败: {e}") from e
|
||||
332
api/app/core/memory/main.py
Normal file
332
api/app/core/memory/main.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
MemSci 记忆系统主入口 - 重构版本
|
||||
|
||||
该模块是重构后的记忆系统主入口,使用新的模块化架构。
|
||||
旧版本入口(app/core/memory/src/main.py)已删除。
|
||||
|
||||
主要功能:
|
||||
1. 协调整个知识提取流水线
|
||||
2. 支持试运行模式和正常运行模式
|
||||
3. 使用重构后的 storage_services 模块
|
||||
4. 提供统一的配置管理和日志记录
|
||||
|
||||
作者:Lance77
|
||||
日期:2025-11-22
|
||||
"""
|
||||
|
||||
# 必须在最开始禁用 LangSmith 追踪,避免速率限制错误
|
||||
import os
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "false"
|
||||
os.environ["LANGCHAIN_TRACING"] = "false"
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 导入重构后的模块
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.message_models import ConversationMessage, ConversationContext, DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
|
||||
# 导入数据加载函数
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
get_chunked_dialogs_with_preprocessing,
|
||||
get_chunked_dialogs_from_preprocessed,
|
||||
)
|
||||
# 导入配置模块(而不是直接导入变量)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.logging_config import get_memory_logger, log_time
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
"""
|
||||
记忆系统主流程 - 重构版本
|
||||
|
||||
该函数是重构后的主入口,使用新的模块化架构。
|
||||
|
||||
Args:
|
||||
dialogue_text: 输入的对话文本(可选,用于试运行模式)
|
||||
is_pilot_run: 是否为试运行模式
|
||||
- True: 试运行模式,不保存到 Neo4j
|
||||
- False: 正常运行模式,保存到 Neo4j
|
||||
|
||||
工作流程:
|
||||
1. 初始化客户端和配置
|
||||
2. 加载或准备数据
|
||||
3. 执行知识提取流水线
|
||||
4. 保存结果(正常模式)或输出结果(试运行模式)
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("MemSci 知识提取流水线 - 重构版本")
|
||||
print("=" * 60)
|
||||
print(f"运行模式: {'试运行(不保存到Neo4j)' if is_pilot_run else '正常运行(保存到Neo4j)'}")
|
||||
print("Using chunker strategy:", config_defs.SELECTED_CHUNKER_STRATEGY)
|
||||
print("Using group ID:", config_defs.SELECTED_GROUP_ID)
|
||||
print("Using model ID:", config_defs.SELECTED_LLM_ID)
|
||||
print("Using embedding model ID:", config_defs.SELECTED_EMBEDDING_ID)
|
||||
print("LANGFUSE_ENABLED:", config_defs.LANGFUSE_ENABLED)
|
||||
print("AGENTA_ENABLED:", config_defs.AGENTA_ENABLED)
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化日志
|
||||
log_file = "logs/time.log"
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pipeline Run Started: {timestamp} ({'Pilot Run' if is_pilot_run else 'Normal Run'}) ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
|
||||
try:
|
||||
# 步骤 1: 初始化客户端
|
||||
logger.info("Initializing clients...")
|
||||
step_start = time.time()
|
||||
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
# 获取 embedder 配置并转换为 RedBearModelConfig 对象
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 2: 加载或准备数据
|
||||
logger.info("Loading data...")
|
||||
logger.info(f"[MAIN] dialogue_text type={type(dialogue_text)}, length={len(dialogue_text) if dialogue_text else 0}, is_pilot_run={is_pilot_run}")
|
||||
logger.info(f"[MAIN] dialogue_text preview: {repr(dialogue_text)[:200] if dialogue_text else 'None'}")
|
||||
logger.info(f"[MAIN] Condition check: dialogue_text={bool(dialogue_text)}, isinstance={isinstance(dialogue_text, str) if dialogue_text else False}, strip={bool(dialogue_text.strip()) if dialogue_text and isinstance(dialogue_text, str) else False}")
|
||||
step_start = time.time()
|
||||
|
||||
if dialogue_text and isinstance(dialogue_text, str) and dialogue_text.strip():
|
||||
# 试运行模式:处理前端传入的对话文本
|
||||
logger.info("[MAIN] ✓ Using frontend dialogue text (pilot run mode)")
|
||||
import re
|
||||
|
||||
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
|
||||
pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)"
|
||||
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
|
||||
messages = [
|
||||
ConversationMessage(role=r, msg=c.strip())
|
||||
for r, c in matches if c.strip()
|
||||
]
|
||||
|
||||
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
|
||||
if not messages:
|
||||
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
|
||||
|
||||
# 创建对话上下文和对话数据
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context,
|
||||
ref_id="pilot_dialog_1",
|
||||
group_id=config_defs.SELECTED_GROUP_ID,
|
||||
user_id=config_defs.SELECTED_USER_ID,
|
||||
apply_id=config_defs.SELECTED_APPLY_ID,
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"}
|
||||
)
|
||||
|
||||
# 对前端传入的对话进行分块处理
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
|
||||
else:
|
||||
# 正常运行模式:从 testdata.json 文件加载
|
||||
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
|
||||
logger.info("Loading data from testdata.json...")
|
||||
test_data_path = os.path.join(
|
||||
os.path.dirname(__file__), "data", "testdata.json"
|
||||
)
|
||||
|
||||
if not os.path.exists(test_data_path):
|
||||
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
group_id=config_defs.SELECTED_GROUP_ID,
|
||||
user_id=config_defs.SELECTED_USER_ID,
|
||||
apply_id=config_defs.SELECTED_APPLY_ID,
|
||||
indices=config_defs.SELECTED_TEST_DATA_INDICES,
|
||||
input_data_path=test_data_path,
|
||||
llm_client=llm_client,
|
||||
skip_cleaning=True,
|
||||
)
|
||||
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 3: 初始化流水线编排器
|
||||
logger.info("Initializing extraction orchestrator...")
|
||||
step_start = time.time()
|
||||
|
||||
# 从 runtime.json 加载配置(已经过数据库覆写)
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
logger.info(f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}")
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 4: 执行知识提取流水线
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=is_pilot_run, # 传递试运行模式标志
|
||||
)
|
||||
|
||||
# 解包 extraction_result tuple
|
||||
# extraction_result 是一个包含 7 个元素的 tuple:
|
||||
# (dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes,
|
||||
# statement_chunk_edges, statement_entity_edges, entity_edges)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 5: 保存结果或输出结果
|
||||
if is_pilot_run:
|
||||
logger.info("Pilot run mode: Skipping Neo4j save")
|
||||
print("\n试运行模式:跳过 Neo4j 保存,流水线处理完成。")
|
||||
print("提取结果已生成,可在相关输出中查看。")
|
||||
else:
|
||||
logger.info("Normal mode: Saving to Neo4j...")
|
||||
step_start = time.time()
|
||||
|
||||
# 创建索引和约束
|
||||
try:
|
||||
from app.repositories.neo4j.create_indexes import (
|
||||
create_fulltext_indexes,
|
||||
create_unique_constraints,
|
||||
)
|
||||
await create_fulltext_indexes()
|
||||
await create_unique_constraints()
|
||||
logger.info("Successfully created indexes and constraints")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes/constraints: {e}")
|
||||
|
||||
# 保存数据到 Neo4j
|
||||
try:
|
||||
from app.repositories.neo4j.graph_saver import (
|
||||
save_dialog_and_statements_to_neo4j,
|
||||
)
|
||||
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
chunk_nodes=chunk_nodes,
|
||||
statement_nodes=statement_nodes,
|
||||
entity_nodes=entity_nodes,
|
||||
statement_chunk_edges=statement_chunk_edges,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
entity_edges=entity_edges,
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
print("\n✓ 成功保存所有数据到 Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
print("\n⚠ 部分数据保存到 Neo4j 失败")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to Neo4j: {e}", exc_info=True)
|
||||
print(f"\n✗ 保存到 Neo4j 失败: {e}")
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 6: 生成记忆摘要(可选)
|
||||
try:
|
||||
logger.info("Generating memory summaries...")
|
||||
step_start = time.time()
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
Memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import (
|
||||
add_memory_summary_statement_edges,
|
||||
)
|
||||
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
|
||||
)
|
||||
|
||||
if not is_pilot_run:
|
||||
# 保存记忆摘要到 Neo4j
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
await ms_connector.close()
|
||||
|
||||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline execution failed: {e}", exc_info=True)
|
||||
print(f"\n✗ 流水线执行失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
# 清理资源
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 记录总时间
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||
|
||||
# 添加完成标记
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Timing details saved to: {log_file}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"✓ 流水线执行完成")
|
||||
print(f"✓ 总耗时: {total_time:.2f} 秒")
|
||||
print(f"✓ 详细日志: {log_file}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
115
api/app/core/memory/models/__init__.py
Normal file
115
api/app/core/memory/models/__init__.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Data models for the Memory module.
|
||||
|
||||
This package contains all Pydantic models used in the memory system,
|
||||
including models for messages, dialogues, statements, entities, triplets,
|
||||
graph nodes/edges, configurations, and deduplication decisions.
|
||||
"""
|
||||
|
||||
# Base response models
|
||||
from app.core.memory.models.base_response import RobustLLMResponse
|
||||
|
||||
# Configuration models
|
||||
from app.core.memory.models.config_models import (
|
||||
LLMConfig,
|
||||
ChunkerConfig,
|
||||
PruningConfig,
|
||||
TemporalSearchParams,
|
||||
)
|
||||
|
||||
# Deduplication models
|
||||
from app.core.memory.models.dedup_models import (
|
||||
EntityDedupDecision,
|
||||
EntityDisambDecision,
|
||||
)
|
||||
|
||||
# Graph models (nodes and edges)
|
||||
from app.core.memory.models.graph_models import (
|
||||
# Edges
|
||||
Edge,
|
||||
ChunkEdge,
|
||||
ChunkEntityEdge,
|
||||
ChunkDialogEdge,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
# Nodes
|
||||
Node,
|
||||
DialogueNode,
|
||||
StatementNode,
|
||||
ChunkNode,
|
||||
ExtractedEntityNode,
|
||||
MemorySummaryNode,
|
||||
)
|
||||
|
||||
# Message and dialogue models
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationMessage,
|
||||
TemporalValidityRange,
|
||||
Statement,
|
||||
ConversationContext,
|
||||
Chunk,
|
||||
DialogData,
|
||||
)
|
||||
|
||||
# Triplet and entity models
|
||||
from app.core.memory.models.triplet_models import (
|
||||
Entity,
|
||||
Triplet,
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
|
||||
# Variable configuration models
|
||||
from app.core.memory.models.variate_config import (
|
||||
StatementExtractionConfig,
|
||||
ForgettingEngineConfig,
|
||||
TripletExtractionConfig,
|
||||
TemporalExtractionConfig,
|
||||
DedupConfig,
|
||||
ExtractionPipelineConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base response
|
||||
"RobustLLMResponse",
|
||||
# Configuration
|
||||
"LLMConfig",
|
||||
"ChunkerConfig",
|
||||
"PruningConfig",
|
||||
"TemporalSearchParams",
|
||||
# Deduplication
|
||||
"EntityDedupDecision",
|
||||
"EntityDisambDecision",
|
||||
# Graph edges
|
||||
"Edge",
|
||||
"ChunkEdge",
|
||||
"ChunkEntityEdge",
|
||||
"ChunkDialogEdge",
|
||||
"StatementChunkEdge",
|
||||
"StatementEntityEdge",
|
||||
"EntityEntityEdge",
|
||||
# Graph nodes
|
||||
"Node",
|
||||
"DialogueNode",
|
||||
"StatementNode",
|
||||
"ChunkNode",
|
||||
"ExtractedEntityNode",
|
||||
"MemorySummaryNode",
|
||||
# Messages and dialogues
|
||||
"ConversationMessage",
|
||||
"TemporalValidityRange",
|
||||
"Statement",
|
||||
"ConversationContext",
|
||||
"Chunk",
|
||||
"DialogData",
|
||||
# Triplets and entities
|
||||
"Entity",
|
||||
"Triplet",
|
||||
"TripletExtractionResponse",
|
||||
# Variable configuration
|
||||
"StatementExtractionConfig",
|
||||
"ForgettingEngineConfig",
|
||||
"TripletExtractionConfig",
|
||||
"TemporalExtractionConfig",
|
||||
"DedupConfig",
|
||||
"ExtractionPipelineConfig",
|
||||
]
|
||||
59
api/app/core/memory/models/base_response.py
Normal file
59
api/app/core/memory/models/base_response.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Base classes for LLM response models with common validators.
|
||||
|
||||
This module provides reusable base classes for Pydantic models that handle
|
||||
common LLM response patterns and edge cases.
|
||||
|
||||
Classes:
|
||||
RobustLLMResponse: Base class for LLM response models with robust validation
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
|
||||
class RobustLLMResponse(BaseModel):
|
||||
"""Base class for LLM response models with robust validation.
|
||||
|
||||
This base class provides:
|
||||
- Automatic handling of list-wrapped responses (e.g., [{"field": "value"}])
|
||||
- Ignoring extra fields from LLM output
|
||||
- Validation on assignment
|
||||
|
||||
Usage:
|
||||
class MyResponse(RobustLLMResponse):
|
||||
field1: str
|
||||
field2: int
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="ignore", # Allow extra fields to be ignored (more forgiving)
|
||||
validate_assignment=True # Validate on assignment
|
||||
)
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def handle_list_input(cls, data: Any) -> Any:
|
||||
"""Handle cases where LLM returns a list instead of a dict.
|
||||
|
||||
Some LLMs may wrap the response in a list like [{"field": "value"}].
|
||||
This validator extracts the first item if that happens.
|
||||
|
||||
Args:
|
||||
data: The input data from the LLM
|
||||
|
||||
Returns:
|
||||
The unwrapped data (dict)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is invalid (empty list, wrong type, etc.)
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
if len(data) == 0:
|
||||
raise ValueError("Received empty list from LLM")
|
||||
# Extract first item from list
|
||||
data = data[0]
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Expected dict or list, got {type(data).__name__}")
|
||||
|
||||
return data
|
||||
93
api/app/core/memory/models/config_models.py
Normal file
93
api/app/core/memory/models/config_models.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Configuration models for Memory module components.
|
||||
|
||||
This module contains Pydantic models for configuring various components
|
||||
of the memory system including LLM, chunking, pruning, and search.
|
||||
|
||||
Classes:
|
||||
LLMConfig: Configuration for LLM client
|
||||
ChunkerConfig: Configuration for dialogue chunking
|
||||
PruningConfig: Configuration for semantic pruning
|
||||
TemporalSearchParams: Parameters for temporal search queries
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""Configuration for Large Language Model client.
|
||||
|
||||
Attributes:
|
||||
llm_name: The name of the LLM model to use (e.g., 'gpt-4', 'claude-3')
|
||||
api_base: Optional base URL for the API endpoint
|
||||
max_retries: Maximum number of retries for failed API calls (default: 3)
|
||||
"""
|
||||
llm_name: str = Field(..., description="The name of the LLM model to use.")
|
||||
api_base: Optional[str] = Field(None, description="The base URL for the API endpoint.")
|
||||
max_retries: Optional[int] = Field(3, ge=0, description="The maximum number of retries for API calls.")
|
||||
|
||||
|
||||
class ChunkerConfig(BaseModel):
|
||||
"""Configuration for dialogue chunking strategy.
|
||||
|
||||
Attributes:
|
||||
chunker_strategy: Name of the chunking strategy (e.g., 'RecursiveChunker', 'SemanticChunker')
|
||||
embedding_model: Name of the embedding model to use for semantic chunking
|
||||
chunk_size: Maximum size of each chunk in characters (default: 2048)
|
||||
threshold: Similarity threshold for semantic chunking (0-1, default: 0.8)
|
||||
language: Language of the text (default: 'zh' for Chinese)
|
||||
skip_window: Window size for skip-and-merge strategy (default: 0)
|
||||
min_sentences: Minimum number of sentences per chunk (default: 1)
|
||||
min_characters_per_chunk: Minimum characters per chunk (default: 24)
|
||||
"""
|
||||
chunker_strategy: str = Field(..., description="The name of the chunker strategy to use.")
|
||||
embedding_model: str = Field(..., description="The name of the embedding model to use.")
|
||||
chunk_size: Optional[int] = Field(2048, ge=0, description="The size of each chunk.")
|
||||
threshold: Optional[float] = Field(0.8, ge=0, le=1, description="The threshold for similarity.")
|
||||
language: Optional[str] = Field("zh", description="The language of the text.")
|
||||
skip_window: Optional[int] = Field(0, ge=0, description="The window for skip-and-merge.")
|
||||
min_sentences: Optional[int] = Field(1, ge=0, description="The minimum number of sentences in each chunk.")
|
||||
min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.")
|
||||
|
||||
|
||||
class PruningConfig(BaseModel):
|
||||
"""Configuration for semantic pruning of dialogue content.
|
||||
|
||||
Attributes:
|
||||
pruning_switch: Enable or disable semantic pruning
|
||||
pruning_scene: Scene type for pruning ('education', 'online_service', 'outbound')
|
||||
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
||||
"""
|
||||
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
||||
pruning_scene: str = Field(
|
||||
"education",
|
||||
description="Scene for pruning: one of 'education', 'online_service', 'outbound'.",
|
||||
)
|
||||
pruning_threshold: float = Field(
|
||||
0.5, ge=0.0, le=0.9,
|
||||
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
||||
|
||||
|
||||
class TemporalSearchParams(BaseModel):
|
||||
"""Parameters for temporal search queries in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
group_id: Group ID to filter search results (default: 'test')
|
||||
apply_id: Application ID to filter search results
|
||||
user_id: User ID to filter search results
|
||||
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
|
||||
end_date: End date for temporal filtering (format: 'YYYY-MM-DD')
|
||||
valid_date: Date when memory should be valid (format: 'YYYY-MM-DD')
|
||||
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
|
||||
limit: Maximum number of results to return (default: 3)
|
||||
"""
|
||||
group_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
||||
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
|
||||
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
|
||||
start_date: Optional[str] = Field(None, description="The start date for the search.")
|
||||
end_date: Optional[str] = Field(None, description="The end date for the search.")
|
||||
valid_date: Optional[str] = Field(None, description="The valid date for the search.")
|
||||
invalid_date: Optional[str] = Field(None, description="The invalid date for the search.")
|
||||
limit: int = Field(default=3, description="The maximum number of results to return.")
|
||||
|
||||
|
||||
52
api/app/core/memory/models/dedup_models.py
Normal file
52
api/app/core/memory/models/dedup_models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Models for entity deduplication and disambiguation decisions.
|
||||
|
||||
This module contains Pydantic models for structured LLM responses
|
||||
during entity deduplication and disambiguation processes.
|
||||
|
||||
Classes:
|
||||
EntityDedupDecision: Decision model for entity deduplication
|
||||
EntityDisambDecision: Decision model for entity disambiguation
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EntityDedupDecision(BaseModel):
|
||||
"""Structured decision returned by LLM for entity deduplication.
|
||||
|
||||
This model represents the LLM's decision on whether two entities
|
||||
refer to the same real-world entity and should be merged.
|
||||
|
||||
Attributes:
|
||||
same_entity: Whether the two entities refer to the same real-world entity
|
||||
confidence: Model confidence in the decision (0.0 to 1.0)
|
||||
canonical_idx: Index of the canonical entity to keep when merging (0 or 1, -1 if not applicable)
|
||||
reason: Brief rationale for the decision (1-3 sentences, kept for audit)
|
||||
"""
|
||||
same_entity: bool = Field(..., description="Two entities refer to the same entity")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence of the decision")
|
||||
canonical_idx: int = Field(..., description="Index of canonical entity among the pair: 0 or 1; -1 if not applicable")
|
||||
reason: str = Field(..., description="Short rationale, 1-3 sentences")
|
||||
|
||||
|
||||
class EntityDisambDecision(BaseModel):
|
||||
"""Structured disambiguation decision for same-name but different-type entities.
|
||||
|
||||
This model represents the LLM's decision on whether two entities with
|
||||
the same name but different types should be merged or kept separate.
|
||||
|
||||
Attributes:
|
||||
should_merge: Whether the two entities should be merged despite type difference
|
||||
confidence: Model confidence in the decision (0.0 to 1.0)
|
||||
canonical_idx: Index of the canonical entity to keep when merging (0 or 1, -1 if not applicable)
|
||||
block_pair: If True, this pair should be blocked from fuzzy/auto merges
|
||||
suggested_type: Optional unified type to apply when should_merge is True
|
||||
reason: Brief rationale for audit and analysis (1-3 sentences)
|
||||
"""
|
||||
should_merge: bool = Field(..., description="Merge the pair despite type difference")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence of the decision")
|
||||
canonical_idx: int = Field(..., description="Index of canonical entity among the pair: 0 or 1; -1 if not applicable")
|
||||
block_pair: bool = Field(False, description="Block this pair from fuzzy or heuristic merges")
|
||||
suggested_type: Optional[str] = Field(None, description="Unified entity type when merging")
|
||||
reason: str = Field(..., description="Short rationale, 1-3 sentences")
|
||||
304
api/app/core/memory/models/graph_models.py
Normal file
304
api/app/core/memory/models/graph_models.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""Graph models for Neo4j knowledge graph nodes and edges.
|
||||
|
||||
This module contains Pydantic models representing nodes and edges
|
||||
in the Neo4j knowledge graph, including dialogues, statements,
|
||||
chunks, entities, and their relationships.
|
||||
|
||||
Classes:
|
||||
Edge: Base class for all graph edges
|
||||
ChunkEdge: Edge connecting chunks
|
||||
ChunkEntityEdge: Edge connecting chunks to entities
|
||||
ChunkDialogEdge: Edge connecting chunks to dialogues
|
||||
StatementChunkEdge: Edge connecting statements to chunks
|
||||
StatementEntityEdge: Edge connecting statements to entities
|
||||
EntityEntityEdge: Edge connecting related entities
|
||||
Node: Base class for all graph nodes
|
||||
DialogueNode: Node representing a dialogue
|
||||
StatementNode: Node representing a statement
|
||||
ChunkNode: Node representing a conversation chunk
|
||||
ExtractedEntityNode: Node representing an extracted entity
|
||||
MemorySummaryNode: Node representing a memory summary
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
import re
|
||||
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
|
||||
|
||||
def parse_historical_datetime(v):
|
||||
"""支持任意年份的日期解析,包括历史日期(如公元755年)
|
||||
|
||||
Python datetime 支持公元1年到9999年的日期
|
||||
此函数手动解析 ISO 8601 格式的日期字符串,支持1-4位年份
|
||||
|
||||
Args:
|
||||
v: 日期值(可以是 None、datetime 对象或字符串)
|
||||
|
||||
Returns:
|
||||
datetime 对象或 None
|
||||
"""
|
||||
if v is None or isinstance(v, datetime):
|
||||
return v
|
||||
|
||||
if isinstance(v, str):
|
||||
# 匹配 ISO 8601 格式:YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM]
|
||||
# 支持1-4位年份
|
||||
pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?'
|
||||
match = re.match(pattern, v)
|
||||
|
||||
if match:
|
||||
try:
|
||||
year = int(match.group(1))
|
||||
month = int(match.group(2))
|
||||
day = int(match.group(3))
|
||||
hour = int(match.group(4)) if match.group(4) else 0
|
||||
minute = int(match.group(5)) if match.group(5) else 0
|
||||
second = int(match.group(6)) if match.group(6) else 0
|
||||
microsecond = 0
|
||||
|
||||
# 处理微秒
|
||||
if match.group(7):
|
||||
# 补齐或截断到6位
|
||||
us_str = match.group(7).ljust(6, '0')[:6]
|
||||
microsecond = int(us_str)
|
||||
|
||||
# 处理时区
|
||||
tzinfo = None
|
||||
if 'Z' in v or match.group(8):
|
||||
tzinfo = timezone.utc
|
||||
|
||||
# 创建 datetime 对象
|
||||
return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo)
|
||||
|
||||
except (ValueError, OverflowError):
|
||||
# 日期值无效(如月份13、日期32等)
|
||||
return None
|
||||
|
||||
# 如果不匹配模式,尝试使用 fromisoformat(用于标准格式)
|
||||
try:
|
||||
return datetime.fromisoformat(v.replace('Z', '+00:00'))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class Edge(BaseModel):
|
||||
"""Base class for all graph edges in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the edge
|
||||
source: ID of the source node
|
||||
target: ID of the target node
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
run_id: Unique identifier for the pipeline run that created this edge
|
||||
created_at: Timestamp when the edge was created (system perspective)
|
||||
expired_at: Optional timestamp when the edge expires (system perspective)
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
||||
source: str = Field(..., description="The ID of the source node.")
|
||||
target: str = Field(..., description="The ID of the target node.")
|
||||
group_id: str = Field(..., description="The group ID of the edge.")
|
||||
user_id: str = Field(..., description="The user ID of the edge.")
|
||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||
|
||||
|
||||
class ChunkEdge(Edge):
|
||||
"""Edge connecting two chunks in sequence."""
|
||||
pass
|
||||
|
||||
|
||||
class ChunkEntityEdge(Edge):
|
||||
"""Edge connecting a chunk to an entity mentioned in it."""
|
||||
pass
|
||||
|
||||
|
||||
class ChunkDialogEdge(Edge):
|
||||
"""Edge connecting a chunk to its parent dialog.
|
||||
|
||||
Attributes:
|
||||
sequence_number: Order of this chunk within the dialog
|
||||
"""
|
||||
sequence_number: int = Field(..., description="Order of this chunk within the dialog")
|
||||
|
||||
|
||||
class StatementChunkEdge(Edge):
|
||||
"""Edge connecting a statement to its parent chunk."""
|
||||
pass
|
||||
|
||||
|
||||
class StatementEntityEdge(Edge):
|
||||
"""Edge connecting a statement to entities extracted from it.
|
||||
|
||||
Attributes:
|
||||
connect_strength: Classification of connection strength ('Strong' or 'Weak')
|
||||
"""
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this statement-entity edge")
|
||||
|
||||
|
||||
class EntityEntityEdge(Edge):
|
||||
"""Edge connecting related entities (from triplet relationships).
|
||||
|
||||
Attributes:
|
||||
relation_type: Type of relationship as defined in ontology
|
||||
relation_value: Optional value of the relation
|
||||
statement: The statement text where this relationship was found
|
||||
source_statement_id: ID of the statement where this relationship was extracted
|
||||
valid_at: Optional start date of temporal validity
|
||||
invalid_at: Optional end date of temporal validity
|
||||
"""
|
||||
relation_type: str = Field(..., description="Relation type as defined in ontology")
|
||||
relation_value: Optional[str] = Field(None, description="Value of the relation")
|
||||
statement: str = Field(..., description='The statement of the edge.')
|
||||
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
"""使用通用的历史日期解析函数"""
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
"""Base class for all graph nodes in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the node
|
||||
name: Name of the node
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
run_id: Unique identifier for the pipeline run that created this node
|
||||
created_at: Timestamp when the node was created (system perspective)
|
||||
expired_at: Optional timestamp when the node expires (system perspective)
|
||||
"""
|
||||
id: str = Field(..., description="The unique identifier for the node.")
|
||||
name: str = Field(..., description="The name of the node.")
|
||||
group_id: str = Field(..., description="The group ID of the node.")
|
||||
user_id: str = Field(..., description="The user ID of the edge.")
|
||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
||||
|
||||
|
||||
class DialogueNode(Node):
|
||||
"""Node representing a dialogue in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
ref_id: Reference identifier linking to external dialog system
|
||||
content: Full dialogue content as text
|
||||
dialog_embedding: Optional embedding vector for the entire dialogue
|
||||
config_id: Configuration ID used to process this dialogue
|
||||
"""
|
||||
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
||||
content: str = Field(..., description="Dialogue content")
|
||||
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
|
||||
|
||||
|
||||
class StatementNode(Node):
|
||||
"""Node representing a statement extracted from dialogue.
|
||||
|
||||
Attributes:
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
stmt_type: Type of the statement (from ontology)
|
||||
temporal_info: Temporal information extracted from the statement
|
||||
statement: The actual statement text content
|
||||
connect_strength: Classification of connection strength ('Strong' or 'Weak')
|
||||
valid_at: Optional start date of temporal validity
|
||||
invalid_at: Optional end date of temporal validity
|
||||
statement_embedding: Optional embedding vector for the statement
|
||||
chunk_embedding: Optional embedding vector for the parent chunk
|
||||
config_id: Configuration ID used to process this statement
|
||||
"""
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk")
|
||||
stmt_type: str = Field(..., description="Type of the statement")
|
||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||
statement: str = Field(..., description="The statement text content")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
"""使用通用的历史日期解析函数"""
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
|
||||
class ChunkNode(Node):
|
||||
"""Node representing a chunk of conversation in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
dialog_id: ID of the parent dialog
|
||||
content: The text content of the chunk
|
||||
chunk_embedding: Optional embedding vector for the chunk
|
||||
sequence_number: Order of this chunk within the dialog
|
||||
metadata: Additional chunk metadata as key-value pairs
|
||||
"""
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
content: str = Field(..., description="The text content of the chunk")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
sequence_number: int = Field(..., description="Order of this chunk within the dialog")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional chunk metadata")
|
||||
|
||||
|
||||
class ExtractedEntityNode(Node):
|
||||
"""Node representing an extracted entity in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
entity_idx: Unique numeric identifier for the entity
|
||||
statement_id: ID of the statement this entity was extracted from
|
||||
entity_type: Type/category of the entity
|
||||
description: Textual description of the entity
|
||||
aliases: Optional list of alternative names for the entity
|
||||
name_embedding: Optional embedding vector for the entity name
|
||||
fact_summary: Summary of facts about this entity
|
||||
connect_strength: Classification of connection strength ('Strong' or 'Weak')
|
||||
config_id: Configuration ID used to process this entity
|
||||
"""
|
||||
entity_idx: int = Field(..., description="Unique identifier for the entity")
|
||||
statement_id: str = Field(..., description="Statement this entity was extracted from")
|
||||
entity_type: str = Field(..., description="Type of the entity")
|
||||
description: str = Field(..., description="Entity description")
|
||||
aliases: Optional[List[str]] = Field(default_factory=list, description="Entity aliases")
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
fact_summary: str = Field(..., description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
|
||||
class MemorySummaryNode(Node):
|
||||
"""Node representing a memory summary with vector embedding.
|
||||
|
||||
Attributes:
|
||||
summary_id: Unique identifier for the summary
|
||||
dialog_id: ID of the parent dialog
|
||||
chunk_ids: List of chunk IDs used to generate this summary
|
||||
content: Summary text content
|
||||
summary_embedding: Optional embedding vector for the summary
|
||||
metadata: Additional metadata for the summary
|
||||
config_id: Configuration ID used to process this summary
|
||||
"""
|
||||
summary_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the summary")
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
chunk_ids: List[str] = Field(default_factory=list, description="List of chunk IDs used in the summary")
|
||||
content: str = Field(..., description="Summary text content")
|
||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
||||
247
api/app/core/memory/models/message_models.py
Normal file
247
api/app/core/memory/models/message_models.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Models for dialogue messages, conversations, and statements.
|
||||
|
||||
This module contains Pydantic models for representing dialogue data,
|
||||
including messages, conversation context, chunks, and statements.
|
||||
|
||||
Classes:
|
||||
ConversationMessage: Single message in a conversation
|
||||
TemporalValidityRange: Temporal validity range for statements
|
||||
Statement: Statement extracted from dialogue with metadata
|
||||
ConversationContext: Full conversation history
|
||||
Chunk: Chunk of conversation text
|
||||
DialogData: Complete dialogue data structure
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.utils.data.ontology import StatementType, TemporalInfo, RelevenceInfo
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse, Triplet
|
||||
|
||||
|
||||
class ConversationMessage(BaseModel):
|
||||
"""Represents a single message in a conversation.
|
||||
|
||||
Attributes:
|
||||
role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant)
|
||||
msg: Text content of the message
|
||||
"""
|
||||
role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').")
|
||||
msg: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
class TemporalValidityRange(BaseModel):
|
||||
"""Represents the temporal validity range of a statement.
|
||||
|
||||
Attributes:
|
||||
valid_at: Start date of validity in 'YYYY-MM-DD' format (None if not specified)
|
||||
invalid_at: End date of validity in 'YYYY-MM-DD' format (None if not specified)
|
||||
"""
|
||||
valid_at: Optional[str] = Field(
|
||||
None,
|
||||
description="The start date of the statement's validity, in 'YYYY-MM-DD' format or 'None'.",
|
||||
)
|
||||
invalid_at: Optional[str] = Field(
|
||||
None,
|
||||
description="The end date of the statement's validity, in 'YYYY-MM-DD' format or 'None'.",
|
||||
)
|
||||
|
||||
|
||||
class Statement(BaseModel):
|
||||
"""Represents a statement extracted from dialogue with metadata.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the statement
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
group_id: Optional group ID for multi-tenancy
|
||||
statement: The actual statement text content
|
||||
statement_embedding: Optional embedding vector for the statement
|
||||
stmt_type: Type of the statement (from ontology)
|
||||
temporal_info: Temporal information extracted from the statement
|
||||
relevence_info: Relevance classification (RELEVANT or IRRELEVANT)
|
||||
connect_strength: Optional connection strength ('Strong' or 'Weak')
|
||||
temporal_validity: Optional temporal validity range
|
||||
triplet_extraction_info: Optional triplet extraction results
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
||||
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||
statement: str = Field(..., description="The text content of the statement.")
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
|
||||
stmt_type: StatementType = Field(..., description="The type of the statement.")
|
||||
temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.")
|
||||
relevence_info: RelevenceInfo = Field(RelevenceInfo.RELEVANT, description="The relevence information of the statement.")
|
||||
connect_strength: Optional[str] = Field(None, description="Strong VS Weak about this entity")
|
||||
temporal_validity: Optional[TemporalValidityRange] = Field(
|
||||
None, description="The temporal validity range of the statement."
|
||||
)
|
||||
triplet_extraction_info: Optional[TripletExtractionResponse] = Field(
|
||||
None, description="The triplet extraction information of the statement."
|
||||
)
|
||||
|
||||
|
||||
class ConversationContext(BaseModel):
|
||||
"""Represents the full conversation history.
|
||||
|
||||
Attributes:
|
||||
msgs: List of messages in the conversation
|
||||
|
||||
Properties:
|
||||
content: Formatted string representation of the conversation
|
||||
"""
|
||||
msgs: List[ConversationMessage] = Field(..., description="A list of messages in the conversation.")
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
"""Get the content of the conversation as a formatted string.
|
||||
|
||||
Returns:
|
||||
String with format "role: message" for each message, joined by newlines
|
||||
"""
|
||||
return "\n".join([f"{msg.role}: {msg.msg}" for msg in self.msgs])
|
||||
|
||||
class Chunk(BaseModel):
|
||||
"""A chunk of text from the conversation context.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the chunk
|
||||
text: List of messages in the chunk
|
||||
content: The content of the chunk as a formatted string
|
||||
statements: List of statements extracted from this chunk
|
||||
chunk_embedding: Optional embedding vector for the chunk
|
||||
metadata: Additional metadata as key-value pairs
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.")
|
||||
text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.")
|
||||
content: str = Field(..., description="The content of the chunk as a string.")
|
||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None):
|
||||
"""Create a chunk from a list of messages.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
metadata: Optional metadata dictionary
|
||||
|
||||
Returns:
|
||||
Chunk instance with formatted content
|
||||
"""
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
# Generate content from messages
|
||||
content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages])
|
||||
return cls(text=messages, content=content, metadata=metadata)
|
||||
|
||||
|
||||
class DialogData(BaseModel):
|
||||
"""Represents the complete data structure for a dialog record.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the dialog
|
||||
context: Full conversation context
|
||||
dialog_embedding: Optional embedding vector for the entire dialog
|
||||
ref_id: Reference ID linking to external dialog system
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
created_at: Timestamp when the dialog was created
|
||||
expired_at: Timestamp when the dialog expires (default: far future)
|
||||
metadata: Additional metadata as key-value pairs
|
||||
chunks: List of chunks from the conversation
|
||||
config_id: Configuration ID used to process this dialog
|
||||
|
||||
Properties:
|
||||
content: Formatted string representation of the dialog
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the dialog.")
|
||||
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
|
||||
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
|
||||
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
|
||||
group_id: str = Field(default=..., description="Group ID of dialogue data")
|
||||
user_id: str = Field(..., description="USER ID of dialogue data")
|
||||
apply_id: str = Field(..., description="APPLY ID of dialogue data")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
||||
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the dialog.")
|
||||
chunks: List[Chunk] = Field(default_factory=list, description="A list of chunks from the conversation context.")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialog (integer or string)")
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
"""Get the content of the dialog as a formatted string.
|
||||
|
||||
Returns:
|
||||
String representation of the conversation context
|
||||
"""
|
||||
return self.context.content
|
||||
|
||||
def get_statement_chunk(self, statement_id: str) -> Optional[Chunk]:
|
||||
"""Find the chunk containing a specific statement.
|
||||
|
||||
Args:
|
||||
statement_id: ID of the statement to find
|
||||
|
||||
Returns:
|
||||
Chunk containing the statement, or None if not found
|
||||
"""
|
||||
for chunk in self.chunks:
|
||||
for statement in chunk.statements:
|
||||
if statement.id == statement_id:
|
||||
return chunk
|
||||
return None
|
||||
|
||||
def get_all_statements(self) -> List[Statement]:
|
||||
"""Get all statements from all chunks.
|
||||
|
||||
Returns:
|
||||
List of all statements in the dialog
|
||||
"""
|
||||
all_statements = []
|
||||
for chunk in self.chunks:
|
||||
all_statements.extend(chunk.statements)
|
||||
return all_statements
|
||||
|
||||
def get_statement_by_id(self, statement_id: str) -> Optional[Statement]:
|
||||
"""Find a specific statement by its ID.
|
||||
|
||||
Args:
|
||||
statement_id: ID of the statement to find
|
||||
|
||||
Returns:
|
||||
Statement with the given ID, or None if not found
|
||||
"""
|
||||
for chunk in self.chunks:
|
||||
for statement in chunk.statements:
|
||||
if statement.id == statement_id:
|
||||
return statement
|
||||
return None
|
||||
|
||||
def get_triplets_for_statement(self, statement_id: str) -> List[Triplet]:
|
||||
"""Get all triplets extracted from a specific statement.
|
||||
|
||||
Args:
|
||||
statement_id: ID of the statement
|
||||
|
||||
Returns:
|
||||
List of triplets from the statement, or empty list if none found
|
||||
"""
|
||||
statement = self.get_statement_by_id(statement_id)
|
||||
if statement and statement.triplet_extraction_info:
|
||||
return statement.triplet_extraction_info.triplets
|
||||
return []
|
||||
|
||||
def assign_group_id_to_statements(self) -> None:
|
||||
"""Assign this dialog's group_id to all statements in all chunks.
|
||||
|
||||
This method updates statements that don't have a group_id set.
|
||||
"""
|
||||
for chunk in self.chunks:
|
||||
for statement in chunk.statements:
|
||||
if statement.group_id is None:
|
||||
statement.group_id = self.group_id
|
||||
85
api/app/core/memory/models/triplet_models.py
Normal file
85
api/app/core/memory/models/triplet_models.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Models for knowledge triplets and entities.
|
||||
|
||||
This module contains Pydantic models for representing extracted knowledge
|
||||
in the form of entities and triplets (subject-predicate-object relationships).
|
||||
|
||||
Classes:
|
||||
Entity: Represents an extracted entity
|
||||
Triplet: Represents a knowledge triplet (subject-predicate-object)
|
||||
TripletExtractionResponse: Response model containing extracted triplets and entities
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
"""Represents an extracted entity from dialogue.
|
||||
|
||||
Attributes:
|
||||
id: Unique string identifier for the entity
|
||||
entity_idx: Numeric index for the entity
|
||||
name: Name of the entity
|
||||
name_embedding: Optional embedding vector for the entity name
|
||||
type: Type/category of the entity (e.g., 'Person', 'Organization')
|
||||
description: Textual description of the entity
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the entity.")
|
||||
entity_idx: int = Field(..., description="Unique identifier for the entity")
|
||||
name: str = Field(..., description="Name of the entity")
|
||||
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
|
||||
type: str = Field(..., description="Type/category of the entity")
|
||||
description: str = Field(..., description="Description of the entity")
|
||||
|
||||
|
||||
class Triplet(BaseModel):
|
||||
"""Represents an extracted knowledge triplet (subject-predicate-object).
|
||||
|
||||
A triplet represents a relationship between two entities, forming
|
||||
the basic unit of knowledge in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
id: Unique string identifier for the triplet
|
||||
statement_id: Optional ID of the parent statement (set programmatically)
|
||||
subject_name: Name of the subject entity
|
||||
subject_id: Numeric ID of the subject entity
|
||||
predicate: Relationship/predicate between subject and object
|
||||
object_name: Name of the object entity
|
||||
object_id: Numeric ID of the object entity
|
||||
value: Optional additional value or context for the relationship
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the triplet.")
|
||||
statement_id: Optional[str] = Field(None, description="ID of the parent statement this triplet was extracted from.")
|
||||
subject_name: str = Field(..., description="Name of the subject entity")
|
||||
subject_id: int = Field(..., description="ID of the subject entity")
|
||||
predicate: str = Field(..., description="Relationship/predicate between subject and object")
|
||||
object_name: str = Field(..., description="Name of the object entity")
|
||||
object_id: int = Field(..., description="ID of the object entity")
|
||||
value: Optional[str] = Field(None, description="Additional value or context")
|
||||
|
||||
|
||||
class TripletExtractionResponse(BaseModel):
|
||||
"""Response model for triplet extraction from LLM.
|
||||
|
||||
This model represents the structured output from the LLM when
|
||||
extracting knowledge triplets and entities from statements.
|
||||
|
||||
Attributes:
|
||||
triplets: List of extracted knowledge triplets
|
||||
entities: List of extracted entities
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
triplets: List[Triplet] = Field(default_factory=list, description="List of extracted triplets")
|
||||
entities: List[Entity] = Field(default_factory=list, description="List of extracted entities")
|
||||
151
api/app/core/memory/models/variate_config.py
Normal file
151
api/app/core/memory/models/variate_config.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Variable configuration models for extraction pipeline components.
|
||||
|
||||
This module contains Pydantic models for configuring various aspects
|
||||
of the extraction pipeline, including statement extraction, triplet extraction,
|
||||
temporal extraction, deduplication, and forgetting mechanisms.
|
||||
|
||||
Classes:
|
||||
StatementExtractionConfig: Configuration for statement extraction
|
||||
ForgettingEngineConfig: Configuration for forgetting engine
|
||||
TripletExtractionConfig: Configuration for triplet extraction
|
||||
TemporalExtractionConfig: Configuration for temporal extraction
|
||||
DedupConfig: Configuration for entity deduplication
|
||||
ExtractionPipelineConfig: Combined configuration for entire pipeline
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class StatementExtractionConfig(BaseModel):
|
||||
"""Configuration for statement extraction behavior.
|
||||
|
||||
Attributes:
|
||||
statement_granularity: Granularity level (1-3):
|
||||
- 1: Split sentences into different statements
|
||||
- 2: Sentence-level statements
|
||||
- 3: Combine sentences, shorten long statements
|
||||
temperature: LLM temperature for statement extraction (0-2, default: 0.1)
|
||||
include_dialogue_context: Whether to include full dialogue context
|
||||
max_dialogue_context_chars: Maximum characters from dialogue context (default: 2000)
|
||||
"""
|
||||
statement_granularity: Optional[int] = Field(None, ge=1, le=3, description="Granularity of statements to extract, level 1 to 3")
|
||||
temperature: Optional[float] = Field(0.1, ge=0, le=2, description="LLM temperature for statement extraction")
|
||||
include_dialogue_context: bool = Field(True, description="Whether to include full dialogue context in extraction")
|
||||
max_dialogue_context_chars: Optional[int] = Field(2000, ge=100, description="Maximum number of characters to include from dialogue context")
|
||||
|
||||
|
||||
class ForgettingEngineConfig(BaseModel):
|
||||
"""Configuration for the forgetting engine.
|
||||
|
||||
The forgetting engine implements a memory decay mechanism based on
|
||||
time and memory strength parameters.
|
||||
|
||||
Attributes:
|
||||
offset: Minimum retention level (0-1, prevents complete forgetting, default: 0.1)
|
||||
lambda_time: Lambda parameter controlling time decay effect (default: 0.1)
|
||||
lambda_mem: Lambda parameter controlling memory strength effect (default: 1.0)
|
||||
"""
|
||||
offset: float = Field(0.1, ge=0.0, le=1.0, description="Minimum retention level (prevents complete forgetting).")
|
||||
lambda_time: float = Field(0.1, gt=0.0, description="Lambda parameter controlling time effect.")
|
||||
lambda_mem: float = Field(1.0, gt=0.0, description="Lambda parameter controlling memory strength effect.")
|
||||
|
||||
|
||||
class TripletExtractionConfig(BaseModel):
|
||||
"""Configuration for triplet extraction behavior.
|
||||
|
||||
Attributes:
|
||||
temperature: LLM temperature for triplet extraction (0-2, default: 0.1)
|
||||
enable_entity_normalization: Whether to normalize entity names (default: True)
|
||||
confidence_threshold: Minimum confidence for extracted triplets (0-1, default: 0.7)
|
||||
"""
|
||||
temperature: Optional[float] = Field(0.1, ge=0, le=2, description="LLM temperature for triplet extraction")
|
||||
enable_entity_normalization: bool = Field(True, description="Whether to normalize entity names")
|
||||
confidence_threshold: Optional[float] = Field(0.7, ge=0, le=1, description="Minimum confidence threshold for extracted triplets")
|
||||
|
||||
|
||||
class TemporalExtractionConfig(BaseModel):
|
||||
"""Configuration for temporal extraction behavior.
|
||||
|
||||
Attributes:
|
||||
temperature: LLM temperature for temporal extraction (0-2, default: 0.1)
|
||||
"""
|
||||
temperature: Optional[float] = Field(0.1, ge=0, le=2, description="LLM temperature for temporal extraction")
|
||||
|
||||
|
||||
class DedupConfig(BaseModel):
|
||||
"""Configuration for entity deduplication behavior.
|
||||
|
||||
This configuration controls the multi-stage deduplication process,
|
||||
including fuzzy matching, LLM-based deduplication, and disambiguation.
|
||||
|
||||
Attributes:
|
||||
enable_llm_dedup_blockwise: Enable blockwise LLM-driven deduplication (default: False)
|
||||
enable_llm_disambiguation: Enable LLM disambiguation for same-name different-type entities (default: False)
|
||||
enable_llm_fallback_only_on_borderline: Only trigger LLM when borderline pairs exist (default: True)
|
||||
fuzzy_name_threshold_strict: Strict threshold for name similarity (0-1, default: 0.90)
|
||||
fuzzy_type_threshold_strict: Strict threshold for type similarity (0-1, default: 0.75)
|
||||
fuzzy_overall_threshold: Overall similarity threshold to merge (0-1, default: 0.82)
|
||||
fuzzy_unknown_type_name_threshold: Name threshold when entity type is UNKNOWN (0-1, default: 0.92)
|
||||
fuzzy_unknown_type_type_threshold: Type threshold when entity type is UNKNOWN (0-1, default: 0.50)
|
||||
name_weight: Weight of name similarity in overall score (0-1, default: 0.50)
|
||||
desc_weight: Weight of description similarity in overall score (0-1, default: 0.30)
|
||||
type_weight: Weight of type similarity in overall score (0-1, default: 0.20)
|
||||
context_bonus: Bonus when entities co-occur in same statements (0-0.2, default: 0.03)
|
||||
llm_fallback_floor: Lower bound for borderline score (0-1, default: 0.76)
|
||||
llm_fallback_ceiling: Upper bound for borderline score (0-1, default: 0.82)
|
||||
llm_block_size: Entities per block for LLM dedup (1-500, default: 50)
|
||||
llm_block_concurrency: Concurrent blocks processed by LLM (1-64, default: 4)
|
||||
llm_pair_concurrency: Concurrent pairwise decisions per block (1-64, default: 4)
|
||||
llm_max_rounds: Maximum LLM iterative dedup rounds (1-10, default: 3)
|
||||
"""
|
||||
# LLM deduplication toggles
|
||||
enable_llm_dedup_blockwise: bool = Field(False, description="Toggle blockwise LLM-driven deduplication")
|
||||
enable_llm_disambiguation: bool = Field(False, description="Toggle LLM-driven disambiguation for same-name different-type entities")
|
||||
enable_llm_fallback_only_on_borderline: bool = Field(True, description="Trigger LLM dedup only when borderline pairs are detected in fuzzy stage")
|
||||
|
||||
# Fuzzy match thresholds
|
||||
fuzzy_name_threshold_strict: float = Field(0.90, ge=0, le=1, description="Strict threshold for name similarity")
|
||||
fuzzy_type_threshold_strict: float = Field(0.75, ge=0, le=1, description="Strict threshold for type similarity")
|
||||
fuzzy_overall_threshold: float = Field(0.82, ge=0, le=1, description="Overall similarity threshold to merge")
|
||||
|
||||
# Specialized thresholds when type is UNKNOWN
|
||||
fuzzy_unknown_type_name_threshold: float = Field(0.92, ge=0, le=1, description="Name threshold when any entity type is UNKNOWN")
|
||||
fuzzy_unknown_type_type_threshold: float = Field(0.50, ge=0, le=1, description="Type threshold when any entity type is UNKNOWN")
|
||||
|
||||
# Weighted scoring components for overall similarity
|
||||
name_weight: float = Field(0.50, ge=0, le=1, description="Weight of name similarity in overall score")
|
||||
desc_weight: float = Field(0.30, ge=0, le=1, description="Weight of description similarity in overall score")
|
||||
type_weight: float = Field(0.20, ge=0, le=1, description="Weight of type similarity in overall score")
|
||||
context_bonus: float = Field(0.03, ge=0, le=0.2, description="Bonus added to score when entities co-occur in same statements")
|
||||
|
||||
# Borderline range for LLM fallback triggering
|
||||
llm_fallback_floor: float = Field(0.76, ge=0, le=1, description="Lower bound of overall score to consider as borderline for LLM fallback")
|
||||
llm_fallback_ceiling: float = Field(0.82, ge=0, le=1, description="Upper bound (below merge threshold) of overall score for LLM fallback")
|
||||
|
||||
# LLM iterative dedup parameters
|
||||
llm_block_size: int = Field(50, ge=1, le=500, description="Entities per block for LLM dedup")
|
||||
llm_block_concurrency: int = Field(4, ge=1, le=64, description="Concurrent blocks processed by LLM")
|
||||
llm_pair_concurrency: int = Field(4, ge=1, le=64, description="Concurrent pairwise decisions per block")
|
||||
llm_max_rounds: int = Field(3, ge=1, le=10, description="Maximum LLM iterative dedup rounds")
|
||||
|
||||
|
||||
class ExtractionPipelineConfig(BaseModel):
|
||||
"""Configuration for the entire extraction pipeline.
|
||||
|
||||
This model combines all configuration components for the complete
|
||||
extraction pipeline, including statement extraction, triplet extraction,
|
||||
temporal extraction, deduplication, and forgetting mechanisms.
|
||||
|
||||
Attributes:
|
||||
statement_extraction: Configuration for statement extraction
|
||||
triplet_extraction: Configuration for triplet extraction
|
||||
temporal_extraction: Configuration for temporal extraction
|
||||
deduplication: Configuration for entity deduplication
|
||||
forgetting_engine: Configuration for forgetting engine
|
||||
"""
|
||||
statement_extraction: StatementExtractionConfig = Field(default_factory=StatementExtractionConfig)
|
||||
triplet_extraction: TripletExtractionConfig = Field(default_factory=TripletExtractionConfig)
|
||||
temporal_extraction: TemporalExtractionConfig = Field(default_factory=TemporalExtractionConfig)
|
||||
deduplication: DedupConfig = Field(default_factory=DedupConfig)
|
||||
forgetting_engine: ForgettingEngineConfig = Field(default_factory=ForgettingEngineConfig)
|
||||
0
api/app/core/memory/src/__init__.py
Normal file
0
api/app/core/memory/src/__init__.py
Normal file
0
api/app/core/memory/src/llm_tools/__init__.py
Normal file
0
api/app/core/memory/src/llm_tools/__init__.py
Normal file
330
api/app/core/memory/src/llm_tools/chunker_client.py
Normal file
330
api/app/core/memory/src/llm_tools/chunker_client.py
Normal file
@@ -0,0 +1,330 @@
|
||||
from typing import Any, List
|
||||
import re
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
# Fix tokenizer parallelism warning
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
from chonkie import (
|
||||
SemanticChunker,
|
||||
RecursiveChunker,
|
||||
RecursiveRules,
|
||||
LateChunker,
|
||||
NeuralChunker,
|
||||
SentenceChunker,
|
||||
TokenChunker,
|
||||
)
|
||||
|
||||
from app.core.memory.models.config_models import ChunkerConfig
|
||||
from app.core.memory.models.message_models import DialogData, Chunk
|
||||
try:
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
except Exception:
|
||||
# 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入
|
||||
OpenAIClient = Any
|
||||
|
||||
|
||||
class LLMChunker:
|
||||
"""基于LLM的智能分块策略"""
|
||||
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
||||
self.llm_client = llm_client
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
async def __call__(self, text: str) -> List[Any]:
|
||||
# 使用LLM分析文本结构并进行智能分块
|
||||
prompt = f"""
|
||||
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
|
||||
请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。
|
||||
|
||||
文本内容:
|
||||
{text[:5000]}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
# 使用异步的 achat 方法
|
||||
if hasattr(self.llm_client, 'achat'):
|
||||
response = await self.llm_client.achat(messages)
|
||||
else:
|
||||
# 如果没有异步方法,使用同步方法并转换为异步
|
||||
response = await asyncio.to_thread(self.llm_client.chat, messages)
|
||||
|
||||
# 检查响应格式并提取内容
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
elif hasattr(response, 'content'):
|
||||
content = response.content
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
# 解析LLM响应
|
||||
if "```json" in content:
|
||||
json_str = content.split("```json")[1].split("```")[0].strip()
|
||||
elif "```" in content:
|
||||
json_str = content.split("```")[1].split("```")[0].strip()
|
||||
else:
|
||||
json_str = content
|
||||
|
||||
result = json.loads(json_str)
|
||||
|
||||
class SimpleChunk:
|
||||
def __init__(self, text, index):
|
||||
self.text = text
|
||||
self.start_index = index * 100 # 近似位置
|
||||
self.end_index = (index + 1) * 100
|
||||
|
||||
return [SimpleChunk(chunk["text"], i) for i, chunk in enumerate(result.get("chunks", []))]
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM分块失败: {e}")
|
||||
# 失败时返回空列表,外层会处理回退方案
|
||||
return []
|
||||
|
||||
|
||||
class HybridChunker:
|
||||
"""混合分块策略:先按结构分块,再按语义合并"""
|
||||
def __init__(self, semantic_threshold: float = 0.8, base_chunk_size: int = 300):
|
||||
self.semantic_threshold = semantic_threshold
|
||||
self.base_chunk_size = base_chunk_size
|
||||
self.base_chunker = TokenChunker(tokenizer="character", chunk_size=base_chunk_size)
|
||||
self.semantic_chunker = SemanticChunker(threshold=semantic_threshold)
|
||||
|
||||
def __call__(self, text: str) -> List[Any]:
|
||||
# 先用基础分块
|
||||
base_chunks = self.base_chunker(text)
|
||||
|
||||
# 如果文本不长,直接返回基础分块
|
||||
if len(base_chunks) <= 3:
|
||||
return base_chunks
|
||||
|
||||
# 对基础分块进行语义合并
|
||||
combined_text = " ".join([chunk.text for chunk in base_chunks])
|
||||
return self.semantic_chunker(combined_text)
|
||||
|
||||
|
||||
class ChunkerClient:
|
||||
def __init__(self, chunker_config: ChunkerConfig, llm_client: OpenAIClient = None):
|
||||
self.chunker_config = chunker_config
|
||||
self.embedding_model = chunker_config.embedding_model
|
||||
self.chunk_size = chunker_config.chunk_size
|
||||
self.threshold = chunker_config.threshold
|
||||
self.language = chunker_config.language
|
||||
self.skip_window = chunker_config.skip_window
|
||||
self.min_sentences = chunker_config.min_sentences
|
||||
self.min_characters_per_chunk = chunker_config.min_characters_per_chunk
|
||||
self.llm_client = llm_client
|
||||
|
||||
# 可选参数(从配置中安全获取,提供默认值)
|
||||
self.chunk_overlap = getattr(chunker_config, 'chunk_overlap', 0)
|
||||
self.min_sentences_per_chunk = getattr(chunker_config, 'min_sentences_per_chunk', 1)
|
||||
self.min_characters_per_sentence = getattr(chunker_config, 'min_characters_per_sentence', 12)
|
||||
self.delim = getattr(chunker_config, 'delim', [".", "!", "?", "\n"])
|
||||
self.include_delim = getattr(chunker_config, 'include_delim', "prev")
|
||||
self.tokenizer_or_token_counter = getattr(chunker_config, 'tokenizer_or_token_counter', "character")
|
||||
|
||||
# 初始化具体分块器策略
|
||||
if chunker_config.chunker_strategy == "TokenChunker":
|
||||
self.chunker = TokenChunker(
|
||||
tokenizer=self.tokenizer_or_token_counter,
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "SemanticChunker":
|
||||
self.chunker = SemanticChunker(
|
||||
embedding_model=self.embedding_model,
|
||||
threshold=self.threshold,
|
||||
chunk_size=self.chunk_size,
|
||||
min_sentences=self.min_sentences,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "RecursiveChunker":
|
||||
self.chunker = RecursiveChunker(
|
||||
rules=RecursiveRules(),
|
||||
min_characters_per_chunk=self.min_characters_per_chunk or 50,
|
||||
chunk_size=self.chunk_size,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "LateChunker":
|
||||
self.chunker = LateChunker(
|
||||
embedding_model=self.embedding_model,
|
||||
chunk_size=self.chunk_size,
|
||||
rules=RecursiveRules(),
|
||||
min_characters_per_chunk=self.min_characters_per_chunk,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "NeuralChunker":
|
||||
self.chunker = NeuralChunker(
|
||||
model=self.embedding_model,
|
||||
min_characters_per_chunk=self.min_characters_per_chunk,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "LLMChunker":
|
||||
if not llm_client:
|
||||
raise ValueError("LLMChunker requires an LLM client")
|
||||
self.chunker = LLMChunker(llm_client, self.chunk_size)
|
||||
elif chunker_config.chunker_strategy == "HybridChunker":
|
||||
self.chunker = HybridChunker(
|
||||
semantic_threshold=self.threshold,
|
||||
base_chunk_size=self.chunk_size,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "SentenceChunker":
|
||||
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
|
||||
# 为了兼容不同版本,这里仅传递广泛支持的参数
|
||||
self.chunker = SentenceChunker(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
min_sentences_per_chunk=self.min_sentences_per_chunk,
|
||||
min_characters_per_sentence=self.min_characters_per_sentence,
|
||||
delim=self.delim,
|
||||
include_delim=self.include_delim,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown chunker strategy: {chunker_config.chunker_strategy}")
|
||||
|
||||
async def generate_chunks(self, dialogue: DialogData):
|
||||
"""
|
||||
生成分块,支持异步操作
|
||||
"""
|
||||
try:
|
||||
# 预处理文本:确保对话标记格式统一
|
||||
content = dialogue.content
|
||||
content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号
|
||||
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
|
||||
|
||||
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
|
||||
# 同步分块器
|
||||
chunks = self.chunker(content)
|
||||
else:
|
||||
# 异步分块器(如LLMChunker)
|
||||
chunks = await self.chunker(content)
|
||||
|
||||
# 过滤空块和过小的块
|
||||
valid_chunks = []
|
||||
for c in chunks:
|
||||
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
|
||||
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
|
||||
valid_chunks.append(c)
|
||||
|
||||
dialogue.chunks = [
|
||||
Chunk(
|
||||
content=c.text if hasattr(c, 'text') else str(c),
|
||||
metadata={
|
||||
"start_index": getattr(c, "start_index", None),
|
||||
"end_index": getattr(c, "end_index", None),
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
)
|
||||
for c in valid_chunks
|
||||
]
|
||||
return dialogue
|
||||
|
||||
except Exception as e:
|
||||
print(f"分块失败: {e}")
|
||||
|
||||
# 改进的后备方案:尝试按对话回合分割
|
||||
try:
|
||||
# 简单的按对话分割
|
||||
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
|
||||
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
|
||||
|
||||
class SimpleChunk:
|
||||
def __init__(self, text, start_index, end_index):
|
||||
self.text = text
|
||||
self.start_index = start_index
|
||||
self.end_index = end_index
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
current_start = 0
|
||||
|
||||
for match in matches:
|
||||
speaker, ct = match[0], match[1].strip()
|
||||
turn_text = f"{speaker} {ct}"
|
||||
|
||||
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
|
||||
if current_chunk:
|
||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||
current_chunk = turn_text
|
||||
current_start = dialogue.content.find(turn_text, current_start)
|
||||
else:
|
||||
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||
|
||||
dialogue.chunks = [
|
||||
Chunk(
|
||||
content=c.text,
|
||||
metadata={
|
||||
"start_index": c.start_index,
|
||||
"end_index": c.end_index,
|
||||
"chunker_strategy": "DialogueTurnFallback",
|
||||
},
|
||||
)
|
||||
for c in chunks
|
||||
]
|
||||
|
||||
except Exception:
|
||||
# 最后的手段:单一大块
|
||||
dialogue.chunks = [Chunk(
|
||||
content=dialogue.content,
|
||||
metadata={"chunker_strategy": "SingleChunkFallback"},
|
||||
)]
|
||||
|
||||
return dialogue
|
||||
|
||||
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||
"""
|
||||
评估分块质量
|
||||
"""
|
||||
if not getattr(dialogue, 'chunks', None):
|
||||
return {}
|
||||
|
||||
chunks = dialogue.chunks
|
||||
total_chars = sum(len(chunk.content) for chunk in chunks)
|
||||
avg_chunk_size = total_chars / len(chunks)
|
||||
|
||||
# 计算各种指标
|
||||
chunk_sizes = [len(chunk.content) for chunk in chunks]
|
||||
|
||||
metrics = {
|
||||
"strategy": self.chunker_config.chunker_strategy,
|
||||
"num_chunks": len(chunks),
|
||||
"total_characters": total_chars,
|
||||
"avg_chunk_size": avg_chunk_size,
|
||||
"min_chunk_size": min(chunk_sizes),
|
||||
"max_chunk_size": max(chunk_sizes),
|
||||
"chunk_size_std": np.std(chunk_sizes) if len(chunk_sizes) > 1 else 0,
|
||||
"coverage_ratio": total_chars / len(dialogue.content) if dialogue.content else 0,
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
def save_chunking_results(self, dialogue: DialogData, output_path: str):
|
||||
"""
|
||||
保存分块结果到文件,文件名包含策略名称
|
||||
"""
|
||||
strategy_name = self.chunker_config.chunker_strategy
|
||||
# 在文件名中添加策略名称
|
||||
base_name, ext = os.path.splitext(output_path)
|
||||
strategy_output_path = f"{base_name}_{strategy_name}{ext}"
|
||||
|
||||
with open(strategy_output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"=== Chunking Strategy: {strategy_name} ===\n")
|
||||
f.write(f"Total chunks: {len(dialogue.chunks)}\n")
|
||||
f.write(f"Total characters: {sum(len(chunk.content) for chunk in dialogue.chunks)}\n")
|
||||
f.write("=" * 60 + "\n\n")
|
||||
|
||||
for i, chunk in enumerate(dialogue.chunks):
|
||||
f.write(f"Chunk {i+1}:\n")
|
||||
f.write(f"Size: {len(chunk.content)} characters\n")
|
||||
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
|
||||
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")
|
||||
f.write(f"Content: {chunk.content}\n")
|
||||
f.write("-" * 40 + "\n\n")
|
||||
|
||||
print(f"Chunking results saved to: {strategy_output_path}")
|
||||
return strategy_output_path
|
||||
22
api/app/core/memory/src/llm_tools/embedder_client.py
Normal file
22
api/app/core/memory/src/llm_tools/embedder_client.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
class EmbedderClient(ABC):
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
self.config = model_config
|
||||
|
||||
self.model_name = model_config.model_name
|
||||
self.provider = model_config.provider
|
||||
self.api_key = model_config.api_key
|
||||
self.base_url = model_config.base_url
|
||||
self.max_retries = model_config.max_retries
|
||||
# self.dimension = model_config.dimension
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def response(
|
||||
self,
|
||||
messages: List[str],
|
||||
) -> List[str]:
|
||||
pass
|
||||
37
api/app/core/memory/src/llm_tools/llm_client.py
Normal file
37
api/app/core/memory/src/llm_tools/llm_client.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from app.core.memory.models.config_models import LLMConfig
|
||||
|
||||
"""
|
||||
model_name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
timeout: float = 30.0 # 请求超时时间(秒)
|
||||
max_retries: int = 3 # 最大重试次数
|
||||
concurrency: int = 5 # 并发限流
|
||||
extra_params: Dict[str, Any] = {}
|
||||
"""
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
class LLMClient(ABC):
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
self.config = model_config
|
||||
|
||||
self.model_name = self.config.model_name
|
||||
self.provider = self.config.provider
|
||||
self.api_key = self.config.api_key
|
||||
self.base_url = self.config.base_url
|
||||
self.max_retries = self.config.max_retries
|
||||
|
||||
@abstractmethod
|
||||
def chat(self, messages: List[Dict[str, str]]) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def response_structured(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_model: type[BaseModel],
|
||||
) -> type[BaseModel]:
|
||||
pass
|
||||
224
api/app/core/memory/src/llm_tools/openai_client.py
Normal file
224
api/app/core/memory/src/llm_tools/openai_client.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.llm import RedBearLLM
|
||||
from app.core.memory.src.llm_tools.llm_client import LLMClient
|
||||
# from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED
|
||||
LANGFUSE_ENABLED=False
|
||||
|
||||
class OpenAIClient(LLMClient):
|
||||
def __init__(self, model_config: RedBearModelConfig, type_: str = "chat"):
|
||||
super().__init__(model_config)
|
||||
|
||||
# Initialize Langfuse callback handler if enabled
|
||||
self.langfuse_handler = None
|
||||
if LANGFUSE_ENABLED:
|
||||
try:
|
||||
from langfuse.langchain import CallbackHandler
|
||||
self.langfuse_handler = CallbackHandler()
|
||||
except ImportError:
|
||||
# Langfuse not installed, continue without tracing
|
||||
pass
|
||||
except Exception as e:
|
||||
# Log error but don't fail initialization
|
||||
import logging
|
||||
logging.warning(f"Failed to initialize Langfuse handler: {e}")
|
||||
|
||||
# Initialize RedBearLLM client
|
||||
self.client = RedBearLLM(RedBearModelConfig(
|
||||
model_name=self.model_name,
|
||||
provider=self.provider,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
max_retries=self.max_retries,
|
||||
), type=type_)
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]]) -> Any:
|
||||
template = """{messages}"""
|
||||
# ChatPromptTemplate
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
chain = prompt | self.client
|
||||
|
||||
# Add Langfuse callback if available
|
||||
config = {}
|
||||
if self.langfuse_handler:
|
||||
config["callbacks"] = [self.langfuse_handler]
|
||||
|
||||
response = await chain.ainvoke({"messages": messages}, config=config)
|
||||
# print(f"OpenAIClient response ======>:\n {response}")
|
||||
return response
|
||||
|
||||
async def response_structured(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_model: type[BaseModel],
|
||||
) -> type[BaseModel]:
|
||||
# Build a simple prompt pipeline that sends messages to the underlying LLM
|
||||
question_text = "\n\n".join([str(m.get("content", "")) for m in messages])
|
||||
|
||||
# Prepare config with Langfuse callback if available
|
||||
config = {}
|
||||
if self.langfuse_handler:
|
||||
config["callbacks"] = [self.langfuse_handler]
|
||||
|
||||
# Primary: enforce schema with PydanticOutputParser if available
|
||||
if PydanticOutputParser is not None:
|
||||
try:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
# 使用正确的属性路径:self.config.timeout(从LLMClient基类继承)
|
||||
# logger.info(f"开始LLM结构化输出请求 (模型: {self.model_name}, 超时: {self.config.timeout}秒)")
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=response_model)
|
||||
format_instructions = parser.get_format_instructions()
|
||||
prompt = ChatPromptTemplate.from_template("{question}\n{format_instructions}")
|
||||
chain = prompt | self.client | parser
|
||||
parsed = await chain.ainvoke({
|
||||
"question": question_text,
|
||||
"format_instructions": format_instructions,
|
||||
})
|
||||
|
||||
# logger.info(f"LLM结构化输出请求成功完成")
|
||||
return parsed
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"PydanticOutputParser失败,尝试备用方法: {str(e)}")
|
||||
# Fall through to alternative structured methods
|
||||
pass
|
||||
|
||||
# Fallback path: create plain prompt for other structured methods
|
||||
template = """{question}"""
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
# Try LangChain structured output if available on the underlying client
|
||||
try:
|
||||
with_so = getattr(self.client, "with_structured_output", None)
|
||||
|
||||
if callable(with_so):
|
||||
try:
|
||||
structured_chain = prompt | with_so(response_model, strict=True)
|
||||
parsed = await structured_chain.ainvoke({"question": question_text}, config=config)
|
||||
# parsed may already be a pydantic model or a dict
|
||||
try:
|
||||
return response_model.model_validate(parsed)
|
||||
except Exception:
|
||||
try:
|
||||
# If it's already a pydantic instance (LangChain returns model), return it
|
||||
if hasattr(parsed, "model_dump"):
|
||||
return parsed
|
||||
return response_model.model_validate_json(json.dumps(parsed))
|
||||
except Exception:
|
||||
# Fall through to manual parsing below
|
||||
pass
|
||||
except NotImplementedError:
|
||||
# The underlying model doesn't support structured output, fall through
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"Model {self.model_name} doesn't support with_structured_output, falling back to manual parsing")
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Structured output attempt failed: {e}, falling back to manual parsing")
|
||||
|
||||
# Final fallback: manual parsing with plain LLM response
|
||||
try:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Using manual parsing fallback for model {self.model_name}")
|
||||
|
||||
# Create a prompt that asks for JSON output
|
||||
json_prompt = ChatPromptTemplate.from_template(
|
||||
"{question}\n\n"
|
||||
"Please respond with a valid JSON object that matches this schema:\n"
|
||||
"{schema}\n\n"
|
||||
"Response (JSON only):"
|
||||
)
|
||||
|
||||
# Get the schema from the response model
|
||||
schema = response_model.model_json_schema()
|
||||
|
||||
chain = json_prompt | self.client
|
||||
response = await chain.ainvoke({
|
||||
"question": question_text,
|
||||
"schema": json.dumps(schema, indent=2)
|
||||
}, config=config)
|
||||
|
||||
# Extract JSON from response
|
||||
response_text = str(response.content if hasattr(response, 'content') else response)
|
||||
|
||||
# Try to find JSON in the response
|
||||
import re
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(0)
|
||||
try:
|
||||
parsed_dict = json.loads(json_str)
|
||||
return response_model.model_validate(parsed_dict)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# If JSON parsing fails, try to create a minimal valid response
|
||||
logger.warning(f"Failed to parse JSON from LLM response, creating minimal response")
|
||||
|
||||
# Create a minimal response based on the schema
|
||||
return self._create_minimal_response(response_model)
|
||||
|
||||
except Exception as fallback_error:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Manual parsing fallback also failed: {fallback_error}")
|
||||
# Return minimal response as last resort
|
||||
return self._create_minimal_response(response_model)
|
||||
|
||||
def _create_minimal_response(self, response_model: type[BaseModel]) -> BaseModel:
|
||||
"""Create a minimal valid response based on the model schema."""
|
||||
try:
|
||||
minimal_response = {}
|
||||
|
||||
for field_name, field_info in response_model.model_fields.items():
|
||||
# Check if field has a default value
|
||||
if hasattr(field_info, 'default') and field_info.default is not None:
|
||||
minimal_response[field_name] = field_info.default
|
||||
else:
|
||||
# Create default based on field type
|
||||
field_type = field_info.annotation
|
||||
|
||||
# Handle nested BaseModel
|
||||
if hasattr(field_type, '__bases__') and BaseModel in field_type.__bases__:
|
||||
minimal_response[field_name] = self._create_minimal_response(field_type)
|
||||
elif field_type == str:
|
||||
minimal_response[field_name] = "信息不足,无法回答"
|
||||
elif field_type == int:
|
||||
minimal_response[field_name] = 0
|
||||
elif field_type == float:
|
||||
minimal_response[field_name] = 0.0
|
||||
elif field_type == bool:
|
||||
minimal_response[field_name] = False
|
||||
elif field_type == list:
|
||||
minimal_response[field_name] = []
|
||||
elif field_type == dict:
|
||||
minimal_response[field_name] = {}
|
||||
else:
|
||||
minimal_response[field_name] = None
|
||||
|
||||
return response_model.model_validate(minimal_response)
|
||||
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Failed to create minimal response: {e}")
|
||||
# Last resort: try to create with just required fields
|
||||
try:
|
||||
return response_model()
|
||||
except Exception:
|
||||
# If even that fails, raise the original error
|
||||
raise ValueError(f"Unable to create minimal response for {response_model.__name__}") from e
|
||||
26
api/app/core/memory/src/llm_tools/openai_embedder.py
Normal file
26
api/app/core/memory/src/llm_tools/openai_embedder.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import List
|
||||
|
||||
from app.core.memory.src.llm_tools.embedder_client import EmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
# from app.models.models_model import ModelType
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
|
||||
|
||||
class OpenAIEmbedderClient(EmbedderClient):
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
super().__init__(model_config)
|
||||
|
||||
async def response(
|
||||
self,
|
||||
messages: List[str],
|
||||
) -> List[List[float]]:
|
||||
texts: List[str] = [str(m) for m in messages if m is not None]
|
||||
|
||||
model = RedBearEmbeddings(RedBearModelConfig(
|
||||
model_name=self.model_name,
|
||||
provider=self.provider,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
))
|
||||
embeddings = await model.aembed_documents(texts)
|
||||
return embeddings
|
||||
980
api/app/core/memory/src/search.py
Normal file
980
api/app/core/memory/src/search.py
Normal file
@@ -0,0 +1,980 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dotenv import load_dotenv
|
||||
from datetime import datetime
|
||||
import math
|
||||
from app.core.logging_config import get_memory_logger
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph_by_embedding, search_graph,
|
||||
search_graph_by_temporal, search_graph_by_keyword_temporal,
|
||||
search_graph_by_chunk_id
|
||||
)
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.config_models import TemporalSearchParams
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config, get_pipeline_config
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
from app.core.memory.utils.data.text_utils import extract_plain_query
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
load_dotenv()
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
def _parse_datetime(value: Any) -> Optional[datetime]:
|
||||
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
s = value.strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(s)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") -> List[Dict[str, Any]]:
|
||||
"""Normalize scores using z-score normalization followed by sigmoid transformation."""
|
||||
if not results:
|
||||
return results
|
||||
|
||||
# Extract scores, ensuring they are numeric and not None
|
||||
scores = []
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
score = item.get(score_field)
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
scores.append(float(score))
|
||||
else:
|
||||
scores.append(0.0) # Default for None or non-numeric values
|
||||
|
||||
if not scores:
|
||||
return results
|
||||
|
||||
if len(scores) == 1:
|
||||
# Single score, set to 1.0
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
return results
|
||||
|
||||
# Calculate mean and standard deviation
|
||||
mean_score = sum(scores) / len(scores)
|
||||
variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
|
||||
std_dev = math.sqrt(variance)
|
||||
|
||||
if std_dev == 0:
|
||||
# All scores are the same, set them to 1.0
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
else:
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
score = item[score_field]
|
||||
# Handle None or non-numeric scores
|
||||
if score is None or not isinstance(score, (int, float)):
|
||||
score = 0.0
|
||||
# Calculate z-score
|
||||
z_score = (score - mean_score) / std_dev
|
||||
# Transform to positive range using sigmoid function
|
||||
normalized = 1 / (1 + math.exp(-z_score))
|
||||
item[f"normalized_{score_field}"] = normalized
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rerank_hybrid_results(
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Rerank hybrid search results by combining BM25 and embedding scores.
|
||||
|
||||
Args:
|
||||
keyword_results: Results from keyword/BM25 search
|
||||
embedding_results: Results from embedding search
|
||||
alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
limit: Maximum number of results to return per category
|
||||
|
||||
Returns:
|
||||
Reranked results with combined scores
|
||||
"""
|
||||
reranked = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities","summaries"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
# Normalize scores within each search type
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
# Create a combined pool of unique items
|
||||
combined_items = {}
|
||||
|
||||
# Add keyword results with BM25 scores
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
if item_id:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
combined_items[item_id]["embedding_score"] = 0 # Default
|
||||
|
||||
# Add or update with embedding results
|
||||
for item in embedding_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
if item_id:
|
||||
if item_id in combined_items:
|
||||
# Update existing item with embedding score
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
# New item from embedding search only
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0 # Default
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# Calculate combined scores and rank
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_score = item.get("bm25_score", 0)
|
||||
embedding_score = item.get("embedding_score", 0)
|
||||
|
||||
# Combined score: weighted average of normalized scores
|
||||
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
item["combined_score"] = combined_score
|
||||
|
||||
# Keep original score for reference
|
||||
if "score" not in item and bm25_score > 0:
|
||||
item["score"] = bm25_score
|
||||
elif "score" not in item and embedding_score > 0:
|
||||
item["score"] = embedding_score
|
||||
|
||||
# Sort by combined score and limit results
|
||||
sorted_items = sorted(
|
||||
combined_items.values(),
|
||||
key=lambda x: x.get("combined_score", 0),
|
||||
reverse=True
|
||||
)[:limit]
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
return reranked
|
||||
|
||||
def rerank_with_forgetting_curve(
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10,
|
||||
forgetting_config: ForgettingEngineConfig | None = None,
|
||||
now: datetime | None = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Rerank hybrid results with a forgetting curve applied to combined scores.
|
||||
|
||||
The forgetting curve reduces scores for older memories or weaker connections.
|
||||
|
||||
Args:
|
||||
keyword_results: Results from keyword/BM25 search
|
||||
embedding_results: Results from embedding search
|
||||
alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
limit: Maximum number of results to return per category
|
||||
forgetting_config: Configuration for the forgetting engine
|
||||
now: Optional current time override for testing
|
||||
|
||||
Returns:
|
||||
Reranked results with combined and final scores (after forgetting)
|
||||
"""
|
||||
engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
|
||||
now_dt = now or datetime.now()
|
||||
|
||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities","summaries"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
# Normalize scores within each search type
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Combine two result sets by ID
|
||||
for src_items, is_embedding in (
|
||||
(keyword_items, False), (embedding_items, True)
|
||||
):
|
||||
for item in src_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
if not item_id:
|
||||
continue
|
||||
existing = combined_items.get(item_id)
|
||||
if not existing:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0
|
||||
combined_items[item_id]["embedding_score"] = 0
|
||||
# Update normalized score from the right source
|
||||
if is_embedding:
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# Calculate scores and apply forgetting weights
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
|
||||
# Estimate time elapsed in days
|
||||
dt = _parse_datetime(item.get("created_at"))
|
||||
if dt is None:
|
||||
time_elapsed_days = 0.0
|
||||
else:
|
||||
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
# Memory strength (currently set to default value)
|
||||
memory_strength = 1.0
|
||||
forgetting_weight = engine.calculate_weight(
|
||||
time_elapsed=time_elapsed_days, memory_strength=memory_strength
|
||||
)
|
||||
# print(f"Forgetting weight for {item_id}: {forgetting_weight}")
|
||||
# print(f"Time elapsed days for {item_id}: {time_elapsed_days}")
|
||||
final_score = combined_score * forgetting_weight
|
||||
item["combined_score"] = final_score
|
||||
|
||||
sorted_items = sorted(
|
||||
combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
|
||||
)[:limit]
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
return reranked
|
||||
|
||||
|
||||
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = "search_log.txt"):
|
||||
"""Log search query information to file"""
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
# Ensure the query text is plain and clean before logging
|
||||
cleaned_query = extract_plain_query(query_text)
|
||||
log_entry = {
|
||||
"timestamp": timestamp,
|
||||
# "query": query_text,
|
||||
"query": cleaned_query,
|
||||
"search_type": search_type,
|
||||
"group_id": group_id,
|
||||
"limit": limit,
|
||||
"include": include
|
||||
}
|
||||
|
||||
# Append to log file
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
|
||||
|
||||
logger.info(f"Search logged: {query_text} ({search_type})")
|
||||
|
||||
|
||||
def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
|
||||
"""Remove specified keys recursively from dict/list structures (in place)."""
|
||||
try:
|
||||
if isinstance(obj, dict):
|
||||
for k in keys_to_remove:
|
||||
if k in obj:
|
||||
obj.pop(k, None)
|
||||
for v in list(obj.values()):
|
||||
_remove_keys_recursive(v, keys_to_remove)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
_remove_keys_recursive(item, keys_to_remove)
|
||||
except Exception:
|
||||
# Be defensive: never fail search because of sanitization
|
||||
pass
|
||||
return obj
|
||||
|
||||
|
||||
def apply_reranker_placeholder(
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Placeholder for a cross-encoder reranker.
|
||||
If config enables reranker, annotate items with a final_score equal to combined_score
|
||||
and keep ordering. This is a no-op reranker to be replaced later.
|
||||
"""
|
||||
try:
|
||||
rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load reranker config: {e}")
|
||||
rc = {}
|
||||
if not rc or not rc.get("enabled", False):
|
||||
return results
|
||||
|
||||
top_k = int(rc.get("top_k", 100))
|
||||
model_name = rc.get("model", "placeholder")
|
||||
|
||||
for cat, items in results.items():
|
||||
head = items[:top_k]
|
||||
for it in head:
|
||||
base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
|
||||
it["final_score"] = base
|
||||
it["reranker_model"] = model_name
|
||||
# Keep overall order by final_score if present, otherwise combined/score
|
||||
results[cat] = sorted(
|
||||
items,
|
||||
key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
|
||||
reverse=True,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def apply_llm_reranker(
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
reranker_client: Optional[Any] = None,
|
||||
llm_weight: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Apply LLM-based reranking to search results.
|
||||
|
||||
Args:
|
||||
results: Search results organized by category
|
||||
query_text: Original search query
|
||||
reranker_client: Optional pre-initialized reranker client
|
||||
llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
|
||||
top_k: Maximum number of items to rerank per category
|
||||
batch_size: Number of items to process concurrently
|
||||
|
||||
Returns:
|
||||
Reranked results with final_score and reranker_model fields
|
||||
"""
|
||||
# Load reranker configuration from runtime.json
|
||||
try:
|
||||
rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load reranker config: {e}")
|
||||
rc = {}
|
||||
|
||||
# Check if reranking is enabled
|
||||
enabled = rc.get("enabled", False)
|
||||
if not enabled:
|
||||
logger.debug("LLM reranking is disabled in configuration")
|
||||
return results
|
||||
|
||||
# Load configuration parameters with defaults
|
||||
llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
|
||||
top_k = top_k if top_k is not None else rc.get("top_k", 20)
|
||||
batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
|
||||
|
||||
# Initialize reranker client if not provided
|
||||
if reranker_client is None:
|
||||
try:
|
||||
reranker_client = get_reranker_client()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
|
||||
return results
|
||||
|
||||
# Get model name for metadata
|
||||
model_name = getattr(reranker_client, 'model_name', 'unknown')
|
||||
|
||||
# Process each category
|
||||
reranked_results = {}
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
items = results.get(category, [])
|
||||
if not items:
|
||||
reranked_results[category] = []
|
||||
continue
|
||||
|
||||
# Select top K items by combined_score for reranking
|
||||
sorted_items = sorted(
|
||||
items,
|
||||
key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
top_items = sorted_items[:top_k]
|
||||
remaining_items = sorted_items[top_k:]
|
||||
|
||||
# Extract text content from each item
|
||||
def extract_text(item: Dict[str, Any]) -> str:
|
||||
"""Extract text content from a result item."""
|
||||
# Try different text fields based on category
|
||||
text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
|
||||
return str(text).strip()
|
||||
|
||||
# Batch items for concurrent processing
|
||||
batches = []
|
||||
for i in range(0, len(top_items), batch_size):
|
||||
batch = top_items[i:i + batch_size]
|
||||
batches.append(batch)
|
||||
|
||||
# Process batches concurrently
|
||||
async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Process a batch of items with LLM relevance scoring."""
|
||||
scored_batch = []
|
||||
|
||||
for item in batch:
|
||||
item_text = extract_text(item)
|
||||
|
||||
# Skip items with no text
|
||||
if not item_text:
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["llm_relevance_score"] = 0.0
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
continue
|
||||
|
||||
# Create relevance scoring prompt
|
||||
prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
|
||||
|
||||
Query: {query_text}
|
||||
|
||||
Result: {item_text}
|
||||
|
||||
Respond with only a number between 0.0 and 1.0, where:
|
||||
- 0.0 means completely irrelevant
|
||||
- 1.0 means perfectly relevant
|
||||
|
||||
Relevance score:"""
|
||||
|
||||
# Send request to LLM
|
||||
try:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
response = await reranker_client.chat(messages)
|
||||
|
||||
# Parse LLM response to extract relevance score
|
||||
response_text = str(response.content if hasattr(response, 'content') else response).strip()
|
||||
|
||||
# Try to extract a float from the response
|
||||
try:
|
||||
# Remove any non-numeric characters except decimal point
|
||||
import re
|
||||
score_match = re.search(r'(\d+\.?\d*)', response_text)
|
||||
if score_match:
|
||||
llm_score = float(score_match.group(1))
|
||||
# Clamp to [0.0, 1.0]
|
||||
llm_score = max(0.0, min(1.0, llm_score))
|
||||
else:
|
||||
raise ValueError("No numeric score found in response")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
|
||||
llm_score = None
|
||||
|
||||
# Calculate final score
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
|
||||
if llm_score is not None:
|
||||
final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
|
||||
item_copy["llm_relevance_score"] = llm_score
|
||||
else:
|
||||
# Use combined_score as fallback
|
||||
final_score = combined_score
|
||||
item_copy["llm_relevance_score"] = combined_score
|
||||
|
||||
item_copy["final_score"] = final_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["llm_relevance_score"] = combined_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
|
||||
return scored_batch
|
||||
|
||||
# Process all batches concurrently
|
||||
try:
|
||||
batch_tasks = [process_batch(batch) for batch in batches]
|
||||
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
# Merge batch results
|
||||
scored_items = []
|
||||
for result in batch_results:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"Batch processing failed: {result}")
|
||||
continue
|
||||
scored_items.extend(result)
|
||||
|
||||
# Add remaining items (not in top K) with their combined_score as final_score
|
||||
for item in remaining_items:
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_items.append(item_copy)
|
||||
|
||||
# Sort all items by final_score in descending order
|
||||
scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
|
||||
reranked_results[category] = scored_items
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
|
||||
# Return original items with combined_score as final_score
|
||||
for item in items:
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item["final_score"] = combined_score
|
||||
item["reranker_model"] = model_name
|
||||
reranked_results[category] = items
|
||||
|
||||
return reranked_results
|
||||
|
||||
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
group_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
Run search with specified type: 'keyword', 'embedding', or 'hybrid'
|
||||
"""
|
||||
# Start overall timing
|
||||
search_start_time = time.time()
|
||||
latency_metrics = {}
|
||||
|
||||
# Clean and normalize the incoming query before use/logging
|
||||
query_text = extract_plain_query(query_text)
|
||||
|
||||
# Validate query is not empty after cleaning
|
||||
if not query_text or not query_text.strip():
|
||||
logger.warning(f"Empty query after cleaning, returning empty results")
|
||||
return {
|
||||
"keyword_search": {},
|
||||
"embedding_search": {},
|
||||
"reranked_results": {},
|
||||
"combined_summary": {
|
||||
"total_keyword_results": 0,
|
||||
"total_embedding_results": 0,
|
||||
"total_reranked_results": 0,
|
||||
"search_query": "",
|
||||
"search_timestamp": datetime.now().isoformat(),
|
||||
"error": "Empty query"
|
||||
}
|
||||
}
|
||||
|
||||
# Log the search query
|
||||
log_search_query(query_text, search_type, group_id, limit, include)
|
||||
|
||||
connector = Neo4jConnector()
|
||||
results = {}
|
||||
|
||||
try:
|
||||
keyword_task = None
|
||||
embedding_task = None
|
||||
|
||||
if search_type in ["keyword", "hybrid"]:
|
||||
# Keyword-based search
|
||||
logger.info("Starting keyword search...")
|
||||
keyword_start = time.time()
|
||||
keyword_task = asyncio.create_task(
|
||||
search_graph(
|
||||
connector=connector,
|
||||
q=query_text,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
include=include
|
||||
)
|
||||
)
|
||||
|
||||
if search_type in ["embedding", "hybrid"]:
|
||||
# Embedding-based search
|
||||
logger.info("Starting embedding search...")
|
||||
embedding_start = time.time()
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
api_key=embedder_config_dict["api_key"],
|
||||
base_url=embedder_config_dict["base_url"],
|
||||
type="llm"
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"Config loading took {config_load_time:.4f}s")
|
||||
|
||||
# Init embedder
|
||||
embedder_init_start = time.time()
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
)
|
||||
)
|
||||
|
||||
if keyword_task:
|
||||
keyword_results = await keyword_task
|
||||
keyword_latency = time.time() - keyword_start
|
||||
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
|
||||
logger.info(f"Keyword search completed in {keyword_latency:.4f}s")
|
||||
if search_type == "keyword":
|
||||
results = keyword_results
|
||||
else:
|
||||
results["keyword_search"] = keyword_results
|
||||
|
||||
if embedding_task:
|
||||
embedding_results = await embedding_task
|
||||
embedding_latency = time.time() - embedding_start
|
||||
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
|
||||
logger.info(f"Embedding search completed in {embedding_latency:.4f}s")
|
||||
if search_type == "embedding":
|
||||
results = embedding_results
|
||||
else:
|
||||
results["embedding_search"] = embedding_results
|
||||
|
||||
# Merge and rank results for hybrid search
|
||||
if search_type == "hybrid":
|
||||
results["combined_summary"] = {
|
||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Apply reranking (optionally with forgetting curve)
|
||||
rerank_start = time.time()
|
||||
if use_forgetting_rerank:
|
||||
# Load forgetting parameters from pipeline config
|
||||
try:
|
||||
pc = get_pipeline_config()
|
||||
forgetting_cfg = pc.forgetting_engine
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
|
||||
forgetting_cfg = ForgettingEngineConfig()
|
||||
reranked_results = rerank_with_forgetting_curve(
|
||||
keyword_results=keyword_results,
|
||||
embedding_results=embedding_results,
|
||||
alpha=rerank_alpha,
|
||||
limit=limit,
|
||||
forgetting_config=forgetting_cfg,
|
||||
)
|
||||
else:
|
||||
reranked_results = rerank_hybrid_results(
|
||||
keyword_results=keyword_results,
|
||||
embedding_results=embedding_results,
|
||||
alpha=rerank_alpha, # Configurable weight for BM25 vs embedding
|
||||
limit=limit
|
||||
)
|
||||
rerank_latency = time.time() - rerank_start
|
||||
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
|
||||
logger.info(f"Reranking completed in {rerank_latency:.4f}s")
|
||||
|
||||
# Optional: apply reranker placeholder if enabled via config
|
||||
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
|
||||
|
||||
# Apply LLM reranking if enabled
|
||||
llm_rerank_applied = False
|
||||
if use_llm_rerank:
|
||||
try:
|
||||
reranked_results = await apply_llm_reranker(
|
||||
results=reranked_results,
|
||||
query_text=query_text,
|
||||
)
|
||||
llm_rerank_applied = True
|
||||
logger.info("LLM reranking applied successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM reranking failed: {e}, using previous scores")
|
||||
|
||||
results["reranked_results"] = reranked_results
|
||||
results["combined_summary"] = {
|
||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat(),
|
||||
"reranking_alpha": rerank_alpha,
|
||||
"forgetting_rerank": use_forgetting_rerank,
|
||||
"llm_rerank": llm_rerank_applied,
|
||||
}
|
||||
|
||||
# Calculate total latency
|
||||
total_latency = time.time() - search_start_time
|
||||
latency_metrics["total_latency"] = round(total_latency, 4)
|
||||
|
||||
# Add latency metrics to results
|
||||
if "combined_summary" in results:
|
||||
results["combined_summary"]["latency_metrics"] = latency_metrics
|
||||
else:
|
||||
results["latency_metrics"] = latency_metrics
|
||||
|
||||
logger.info(f"Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"Latency breakdown: {latency_metrics}")
|
||||
|
||||
# Sanitize results: drop large/unused fields
|
||||
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs
|
||||
|
||||
# print(json.dumps(results, ensure_ascii=False, indent=2, default=str))
|
||||
|
||||
# Save to file
|
||||
output_path = output_path or "search_results.json"
|
||||
out_dir = os.path.dirname(output_path)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=2, default=str)
|
||||
logger.info(f"Search results saved to: {output_path}")
|
||||
|
||||
# Log search completion with result count
|
||||
if search_type == "hybrid":
|
||||
result_counts = {
|
||||
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()},
|
||||
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()}
|
||||
}
|
||||
else:
|
||||
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
|
||||
|
||||
completion_log = {
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"query": query_text,
|
||||
"search_type": search_type,
|
||||
"status": "completed",
|
||||
"result_counts": result_counts,
|
||||
"output_file": output_path,
|
||||
"latency_metrics": latency_metrics
|
||||
}
|
||||
|
||||
with open("search_log.txt", "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(completion_log, ensure_ascii=False) + "\n")
|
||||
|
||||
return results
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def search_by_temporal(
|
||||
group_id: Optional[str] = "test",
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created between start_date and end_date
|
||||
- Optionally filters by group_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
if start_date:
|
||||
start_date = normalize_date_safe(start_date)
|
||||
if end_date:
|
||||
end_date = normalize_date_safe(end_date)
|
||||
|
||||
params = TemporalSearchParams.model_validate({
|
||||
"group_id": group_id,
|
||||
"apply_id": apply_id,
|
||||
"user_id": user_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"valid_date": valid_date,
|
||||
"invalid_date": invalid_date,
|
||||
"limit": limit,
|
||||
})
|
||||
statements = await search_graph_by_temporal(
|
||||
connector=connector,
|
||||
group_id=params.group_id,
|
||||
apply_id=params.apply_id,
|
||||
user_id=params.user_id,
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
valid_date=params.valid_date,
|
||||
invalid_date=params.invalid_date,
|
||||
limit=params.limit
|
||||
)
|
||||
return {"statements": statements}
|
||||
|
||||
|
||||
async def search_by_keyword_temporal(
|
||||
query_text: str,
|
||||
group_id: Optional[str] = "test",
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
if start_date:
|
||||
start_date = normalize_date_safe(start_date)
|
||||
if end_date:
|
||||
end_date = normalize_date_safe(end_date)
|
||||
if valid_date:
|
||||
valid_date = normalize_date_safe(valid_date)
|
||||
if invalid_date:
|
||||
invalid_date = normalize_date_safe(invalid_date)
|
||||
|
||||
params = TemporalSearchParams.model_validate({
|
||||
"group_id": group_id,
|
||||
"apply_id": apply_id,
|
||||
"user_id": user_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"valid_date": valid_date,
|
||||
"invalid_date": invalid_date,
|
||||
"limit": limit,
|
||||
})
|
||||
statements = await search_graph_by_keyword_temporal(
|
||||
connector=connector,
|
||||
query_text=query_text,
|
||||
group_id=params.group_id,
|
||||
apply_id=params.apply_id,
|
||||
user_id=params.user_id,
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
valid_date=params.valid_date,
|
||||
invalid_date=params.invalid_date,
|
||||
limit=params.limit
|
||||
)
|
||||
return {"statements": statements}
|
||||
|
||||
|
||||
async def search_chunk_by_chunk_id(
|
||||
chunk_id: str,
|
||||
group_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Search for Chunks by chunk_id.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
chunks = await search_graph_by_chunk_id(
|
||||
connector=connector,
|
||||
chunk_id=chunk_id,
|
||||
group_id=group_id,
|
||||
limit=limit
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the hybrid graph search CLI.
|
||||
|
||||
Parses command line arguments and executes search with specified parameters.
|
||||
Supports keyword, embedding, and hybrid search modes.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
|
||||
parser.add_argument(
|
||||
"--query", "-q", required=True, help="Free-text query to search"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search-type",
|
||||
"-t",
|
||||
choices=["keyword", "embedding", "hybrid"],
|
||||
default="hybrid",
|
||||
help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-name",
|
||||
"-m",
|
||||
default="openai/nomic-embed-text:v1.5",
|
||||
help="Embedding config name from config.json (default: openai/nomic-embed-text:v1.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-id",
|
||||
"-g",
|
||||
default=None,
|
||||
help="Optional group_id to filter results (default: None)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
"-k",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Max number of results per type (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include",
|
||||
"-i",
|
||||
nargs="+",
|
||||
default=["statements", "chunks", "entities", "summaries"],
|
||||
choices=["statements", "chunks", "entities", "summaries"],
|
||||
help="Which targets to search for embedding search (default: statements chunks entities summaries)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
default="search_results.json",
|
||||
help="Path to save the search results JSON (default: search_results.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rerank-alpha",
|
||||
"-a",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--forgetting-rerank",
|
||||
action="store_true",
|
||||
help="Apply forgetting curve during reranking for hybrid search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-rerank",
|
||||
action="store_true",
|
||||
help="Apply LLM-based reranking for hybrid search.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(
|
||||
run_hybrid_search(
|
||||
query_text=args.query,
|
||||
search_type=args.search_type,
|
||||
group_id=args.group_id,
|
||||
limit=args.limit,
|
||||
include=args.include,
|
||||
output_path=args.output,
|
||||
rerank_alpha=args.rerank_alpha,
|
||||
use_forgetting_rerank=args.forgetting_rerank,
|
||||
use_llm_rerank=args.llm_rerank,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
8
api/app/core/memory/storage_services/__init__.py
Normal file
8
api/app/core/memory/storage_services/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
存储服务模块
|
||||
|
||||
包含三大引擎:
|
||||
1. 萃取引擎(Extraction Engine)- 知识提取、预处理、去重消歧
|
||||
2. 遗忘引擎(Forgetting Engine)- 记忆遗忘机制
|
||||
3. 自我反思引擎(Reflection Engine)- 自我反思和优化
|
||||
"""
|
||||
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
萃取引擎(Extraction Engine)
|
||||
|
||||
负责从对话数据中提取结构化知识,包括:
|
||||
- 数据预处理
|
||||
- 知识提取(分块、陈述句、三元组、时间信息、嵌入向量)
|
||||
- 去重消歧
|
||||
"""
|
||||
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
数据预处理模块 - 负责对话数据的清洗、转换和预处理
|
||||
|
||||
包含:
|
||||
- data_preprocessor: 数据预处理器 - 读取、清洗和转换对话数据
|
||||
- data_pruning: 语义剪枝器 - 过滤与场景不相关的内容
|
||||
- data_chunker: 数据分块器 - 将对话分割成可处理的片段
|
||||
"""
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
||||
|
||||
__all__ = ['DataPreprocessor', 'SemanticPruner']
|
||||
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
数据分块器 - 将对话分割成可处理的片段
|
||||
|
||||
功能:
|
||||
- 支持多种分块策略(递归分块、语义分块、LLM分块等)
|
||||
- 根据对话长度和内容特征进行智能分块
|
||||
- 保持对话上下文的连贯性
|
||||
|
||||
注意:此模块当前为占位符,具体实现将在后续任务中完成。
|
||||
分块功能目前在 app/core/memory/llm_tools/chunker_client.py 中实现。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from app.core.memory.models.message_models import DialogData, Chunk
|
||||
|
||||
|
||||
class DataChunker:
|
||||
"""数据分块器 - 将长对话分割成多个可处理的片段"""
|
||||
|
||||
def __init__(self, chunker_strategy: str = "RecursiveChunker"):
|
||||
"""
|
||||
初始化数据分块器
|
||||
|
||||
Args:
|
||||
chunker_strategy: 分块策略名称
|
||||
"""
|
||||
self.chunker_strategy = chunker_strategy
|
||||
|
||||
async def chunk_dialog(self, dialog: DialogData) -> List[Chunk]:
|
||||
"""
|
||||
将对话分割成多个块
|
||||
|
||||
Args:
|
||||
dialog: 对话数据
|
||||
|
||||
Returns:
|
||||
分块列表
|
||||
|
||||
Note:
|
||||
当前此功能在 app/core/memory/llm_tools/chunker_client.py 中实现
|
||||
"""
|
||||
raise NotImplementedError("数据分块功能将在后续任务中实现")
|
||||
|
||||
async def chunk_dialogs(self, dialogs: List[DialogData]) -> List[DialogData]:
|
||||
"""
|
||||
批量处理多个对话的分块
|
||||
|
||||
Args:
|
||||
dialogs: 对话数据列表
|
||||
|
||||
Returns:
|
||||
包含分块信息的对话数据列表
|
||||
"""
|
||||
raise NotImplementedError("数据分块功能将在后续任务中实现")
|
||||
@@ -0,0 +1,785 @@
|
||||
"""
|
||||
数据预处理器 - 支持多种格式的对话数据读取、清洗和预处理
|
||||
|
||||
功能:
|
||||
- 支持多种文件格式:JSON、CSV、Excel、TXT
|
||||
- 自动检测文件编码
|
||||
- 清洗和标准化对话数据
|
||||
- 转换为 DialogData 对象
|
||||
"""
|
||||
|
||||
import json
|
||||
import csv
|
||||
import pandas as pd
|
||||
import re
|
||||
import os
|
||||
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||
|
||||
|
||||
class DataPreprocessor:
|
||||
"""数据预处理器类,支持多种格式的对话数据读取、清洗和预处理。"""
|
||||
|
||||
def __init__(self, input_file_path: str = None, output_file_path: str = None):
|
||||
"""
|
||||
初始化数据预处理器。
|
||||
|
||||
Args:
|
||||
input_file_path: 输入文件路径(可选,可后续通过set_input_path设置)
|
||||
output_file_path: 输出文件路径(可选,可后续通过set_output_path设置)
|
||||
|
||||
注意:您可以通过以下方式指定输入输出路径:
|
||||
1. 初始化时传入参数
|
||||
2. 调用set_input_path()和set_output_path()方法
|
||||
3. 在preprocess()方法中直接传入路径参数
|
||||
"""
|
||||
self.input_file_path = input_file_path or r"src\extracted_statements.txt"
|
||||
self.output_file_path = output_file_path or r"src\data_preprocessing\out-file\extracted_statements-pre.txt"
|
||||
self.supported_formats = ['.json', '.csv', '.txt', '.xlsx', '.tsv']
|
||||
|
||||
def set_input_path(self, input_path: str) -> None:
|
||||
"""
|
||||
设置输入文件路径。
|
||||
|
||||
Args:
|
||||
input_path: 输入文件的完整路径
|
||||
"""
|
||||
self.input_file_path = input_path
|
||||
|
||||
def set_output_path(self, output_path: str) -> None:
|
||||
"""
|
||||
设置输出文件路径。
|
||||
|
||||
Args:
|
||||
output_path: 输出文件的完整路径
|
||||
"""
|
||||
self.output_file_path = output_path
|
||||
|
||||
def get_file_format(self, file_path: str) -> str:
|
||||
"""
|
||||
获取文件格式。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
文件扩展名(小写)
|
||||
"""
|
||||
return Path(file_path).suffix.lower()
|
||||
|
||||
def _detect_encoding(self, file_path: str) -> str:
|
||||
"""
|
||||
检测文件编码,使用多种方法确保准确性。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
检测到的编码格式
|
||||
"""
|
||||
# 常见编码列表,按优先级排序
|
||||
encodings_to_try = ['utf-8', 'gbk', 'gb2312', 'utf-16', 'latin-1']
|
||||
|
||||
# 首先尝试使用chardet检测
|
||||
try:
|
||||
import chardet
|
||||
with open(file_path, 'rb') as f:
|
||||
raw_data = f.read(10000) # 读取前10KB进行检测
|
||||
result = chardet.detect(raw_data)
|
||||
detected_encoding = result.get('encoding')
|
||||
confidence = result.get('confidence', 0)
|
||||
|
||||
# 如果检测置信度较高,使用检测结果
|
||||
if detected_encoding and confidence > 0.7:
|
||||
return detected_encoding
|
||||
except ImportError:
|
||||
print("警告: chardet库未安装,使用备用编码检测方法")
|
||||
except Exception as e:
|
||||
print(f"chardet检测失败: {e},使用备用方法")
|
||||
|
||||
# 备用方法:尝试不同编码读取文件开头
|
||||
for encoding in encodings_to_try:
|
||||
try:
|
||||
with open(file_path, 'r', encoding=encoding) as f:
|
||||
f.read(1000) # 尝试读取前1000个字符
|
||||
return encoding
|
||||
except (UnicodeDecodeError, UnicodeError):
|
||||
continue
|
||||
|
||||
# 如果所有编码都失败,返回utf-8作为最后选择
|
||||
return 'utf-8'
|
||||
|
||||
def _read_json(self, data_path: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
读取JSON格式的对话数据,支持标准JSON和JSONL格式。
|
||||
|
||||
Args:
|
||||
data_path: JSON文件路径
|
||||
|
||||
Returns:
|
||||
解析后的数据列表
|
||||
"""
|
||||
encoding = self._detect_encoding(data_path)
|
||||
content = None
|
||||
|
||||
# 尝试使用检测到的编码读取文件
|
||||
encodings_to_try = [encoding, 'utf-8', 'gbk', 'gb2312', 'latin-1']
|
||||
|
||||
for enc in encodings_to_try:
|
||||
try:
|
||||
with open(data_path, 'r', encoding=enc) as f:
|
||||
content = f.read().strip()
|
||||
print(f"成功使用编码 {enc} 读取文件")
|
||||
break
|
||||
except (UnicodeDecodeError, UnicodeError) as e:
|
||||
print(f"编码 {enc} 读取失败: {e}")
|
||||
continue
|
||||
|
||||
if content is None:
|
||||
raise ValueError(f"无法使用任何编码读取文件: {data_path}")
|
||||
|
||||
try:
|
||||
|
||||
# 尝试解析为标准JSON
|
||||
try:
|
||||
data = json.loads(content)
|
||||
if isinstance(data, dict):
|
||||
return [data]
|
||||
elif isinstance(data, list):
|
||||
return data
|
||||
else:
|
||||
raise ValueError(f"不支持的JSON数据结构: {type(data)}")
|
||||
except json.JSONDecodeError as e:
|
||||
# 如果标准JSON解析失败,尝试JSONL格式(每行一个JSON对象)
|
||||
print(f"标准JSON解析失败: {e},尝试JSONL格式...")
|
||||
data_list = []
|
||||
lines = content.split('\n')
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
line = line.strip()
|
||||
if line: # 跳过空行
|
||||
try:
|
||||
json_obj = json.loads(line)
|
||||
data_list.append(json_obj)
|
||||
except json.JSONDecodeError as line_error:
|
||||
# 如果是单行巨大JSON数组,可能需要特殊处理
|
||||
if line_num == 1 and len(lines) == 1:
|
||||
print(f"检测到单行大型JSON,尝试分块解析...")
|
||||
# 对于超大单行JSON,尝试使用json.JSONDecoder进行流式解析
|
||||
try:
|
||||
decoder = json.JSONDecoder()
|
||||
idx = 0
|
||||
while idx < len(line):
|
||||
line = line[idx:].lstrip()
|
||||
if not line:
|
||||
break
|
||||
try:
|
||||
obj, end_idx = decoder.raw_decode(line)
|
||||
if isinstance(obj, list):
|
||||
data_list.extend(obj)
|
||||
elif isinstance(obj, dict):
|
||||
data_list.append(obj)
|
||||
idx += end_idx
|
||||
except json.JSONDecodeError:
|
||||
break
|
||||
except Exception as decode_error:
|
||||
print(f"分块解析也失败: {decode_error}")
|
||||
else:
|
||||
print(f"警告: 第{line_num}行JSON解析失败: {line_error}")
|
||||
continue
|
||||
|
||||
return data_list
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取JSON文件时发生错误: {e}")
|
||||
|
||||
def _read_csv(self, data_path: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
读取CSV格式的对话数据。
|
||||
|
||||
Args:
|
||||
data_path: CSV文件路径
|
||||
|
||||
Returns:
|
||||
解析后的数据列表
|
||||
"""
|
||||
encoding = self._detect_encoding(data_path)
|
||||
encodings_to_try = [encoding, 'utf-8', 'gbk', 'gb2312', 'latin-1']
|
||||
|
||||
for enc in encodings_to_try:
|
||||
try:
|
||||
# 尝试不同的分隔符
|
||||
separators = [',', '\t', ';', '|']
|
||||
df = None
|
||||
|
||||
for sep in separators:
|
||||
try:
|
||||
df = pd.read_csv(data_path, encoding=enc, sep=sep)
|
||||
if len(df.columns) > 1: # 如果成功分割出多列,则认为找到了正确的分隔符
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if df is None:
|
||||
df = pd.read_csv(data_path, encoding=enc)
|
||||
|
||||
print(f"成功使用编码 {enc} 读取CSV文件")
|
||||
return df.to_dict('records')
|
||||
|
||||
except (UnicodeDecodeError, UnicodeError) as e:
|
||||
print(f"编码 {enc} 读取CSV失败: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
if enc == encodings_to_try[-1]: # 最后一个编码也失败了
|
||||
raise ValueError(f"读取CSV文件失败: {e}")
|
||||
continue
|
||||
|
||||
raise ValueError(f"无法使用任何编码读取CSV文件: {data_path}")
|
||||
|
||||
def _read_excel(self, data_path: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
读取Excel格式的对话数据。
|
||||
|
||||
Args:
|
||||
data_path: Excel文件路径
|
||||
|
||||
Returns:
|
||||
解析后的数据列表
|
||||
"""
|
||||
try:
|
||||
df = pd.read_excel(data_path)
|
||||
return df.to_dict('records')
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取Excel文件失败: {e}")
|
||||
|
||||
def _read_text(self, data_path: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
读取纯文本格式的对话数据。
|
||||
|
||||
Args:
|
||||
data_path: 文本文件路径
|
||||
|
||||
Returns:
|
||||
解析后的数据列表
|
||||
"""
|
||||
encoding = self._detect_encoding(data_path)
|
||||
encodings_to_try = [encoding, 'utf-8', 'gbk', 'gb2312', 'latin-1']
|
||||
content = None
|
||||
|
||||
# 尝试使用不同编码读取文件
|
||||
for enc in encodings_to_try:
|
||||
try:
|
||||
with open(data_path, 'r', encoding=enc) as f:
|
||||
content = f.read()
|
||||
print(f"成功使用编码 {enc} 读取文本文件")
|
||||
break
|
||||
except (UnicodeDecodeError, UnicodeError) as e:
|
||||
print(f"编码 {enc} 读取文本失败: {e}")
|
||||
continue
|
||||
|
||||
if content is None:
|
||||
raise ValueError(f"无法使用任何编码读取文本文件: {data_path}")
|
||||
|
||||
try:
|
||||
|
||||
# 尝试解析不同的文本格式
|
||||
lines = content.strip().split('\n')
|
||||
|
||||
# 格式1: 每行一个对话轮次,格式为 "角色: 内容" 或 "角色:内容"
|
||||
messages = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 尝试匹配 "角色: 内容" 或 "角色:内容" 格式
|
||||
match = re.match(r'^([^::]+)[::]\s*(.+)$', line)
|
||||
if match:
|
||||
role, msg = match.groups()
|
||||
messages.append({'role': role.strip(), 'msg': msg.strip()})
|
||||
else:
|
||||
# 如果不匹配,则作为用户消息处理
|
||||
messages.append({'role': 'User', 'msg': line})
|
||||
|
||||
if messages:
|
||||
return [{'context': {'msgs': messages}}]
|
||||
else:
|
||||
# 如果没有解析出消息,则将整个文本作为一条消息
|
||||
return [{'context': {'msgs': [{'role': 'User', 'msg': content}]}}]
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取文本文件失败: {e}")
|
||||
|
||||
def read_data(self, data_path: str = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
根据文件格式自动选择合适的读取方法。
|
||||
|
||||
Args:
|
||||
data_path: 数据文件路径(如果为None,则使用初始化时设置的路径)
|
||||
|
||||
Returns:
|
||||
解析后的原始数据列表
|
||||
"""
|
||||
if data_path is None:
|
||||
data_path = self.input_file_path
|
||||
|
||||
if not data_path:
|
||||
raise ValueError("请指定输入文件路径")
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"文件不存在: {data_path}")
|
||||
|
||||
file_format = self.get_file_format(data_path)
|
||||
|
||||
if file_format == '.json':
|
||||
return self._read_json(data_path)
|
||||
elif file_format == '.csv':
|
||||
return self._read_csv(data_path)
|
||||
elif file_format in ['.xlsx', '.xls']:
|
||||
return self._read_excel(data_path)
|
||||
elif file_format in ['.txt', '.tsv']:
|
||||
return self._read_text(data_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {file_format}。支持的格式: {self.supported_formats}")
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""
|
||||
增强的文本清洗函数。
|
||||
"""
|
||||
if not text or not isinstance(text, str):
|
||||
return ""
|
||||
|
||||
# 1. 移除消息中的角色标识(支持英文冒号":"与中文冒号":")
|
||||
text = re.sub(r'^(用户|AI|user|ai|assistant|bot|助手|机器人)[::]\s*', '', text, flags=re.IGNORECASE)
|
||||
|
||||
# 2. 移除URL链接
|
||||
text = re.sub(r'https?://[^\s]+', '', text)
|
||||
text = re.sub(r'www\.[^\s]+', '', text)
|
||||
|
||||
# 3. 移除HTML标签
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
|
||||
# 4. 移除乱码和控制字符
|
||||
text = re.sub(r'[<5B>]+', '', text)
|
||||
text = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', text)
|
||||
|
||||
# 5. 标点符号规范化
|
||||
# 将连续的感叹号(中英文)替换为一个句号
|
||||
text = re.sub(r'[!!]+', '。', text)
|
||||
# 将连续的句点/省略号(中英文)替换为一个句号
|
||||
text = re.sub(r'(…{1,}|\.{2,}|。{2,})', '。', text)
|
||||
# 将英文句点统一为中文句号(避免残留英文句点影响断句)
|
||||
text = re.sub(r'\.', '。', text)
|
||||
# 将连续的逗号(中英文)规范为一个中文逗号
|
||||
text = re.sub(r'[,,]{2,}', ',', text)
|
||||
# 将英文逗号统一为中文逗号
|
||||
text = re.sub(r',', ',', text)
|
||||
|
||||
# 6. 规范化空白字符
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
|
||||
return text
|
||||
|
||||
def _parse_message_content(self, content: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
增强的消息内容解析。
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# 先清洗内容
|
||||
cleaned_content = self._clean_text(content)
|
||||
|
||||
if not cleaned_content:
|
||||
return messages
|
||||
|
||||
# 检查是否为有效消息(至少包含中文或英文单词)
|
||||
if not re.search(r'[\u4e00-\u9fff\w]', cleaned_content):
|
||||
return messages
|
||||
|
||||
# 根据内容特征判断角色(更智能的角色识别)
|
||||
if re.search(r'(你好|嗨|早上好|晚上好|请问|谢谢|抱歉)', cleaned_content):
|
||||
role = 'User'
|
||||
elif re.search(r'(很高兴|建议|推荐|可以帮助|请提供)', cleaned_content):
|
||||
role = 'Assistant'
|
||||
else:
|
||||
role = 'User' # 默认
|
||||
|
||||
messages.append({'role': role, 'msg': cleaned_content})
|
||||
|
||||
return messages
|
||||
|
||||
def _filter_empty_messages(self, messages: List[ConversationMessage]) -> List[ConversationMessage]:
|
||||
"""
|
||||
更严格的空消息过滤。
|
||||
"""
|
||||
filtered = []
|
||||
for msg in messages:
|
||||
# 检查消息是否有效
|
||||
if (msg.msg and
|
||||
isinstance(msg.msg, str) and
|
||||
len(msg.msg.strip()) >= 2 and # 至少2个字符
|
||||
re.search(r'[\u4e00-\u9fff\w]', msg.msg)): # 包含有效字符
|
||||
filtered.append(msg)
|
||||
return filtered
|
||||
|
||||
|
||||
def _normalize_role(self, role: str) -> str:
|
||||
"""
|
||||
标准化角色名称。
|
||||
|
||||
Args:
|
||||
role: 原始角色名称
|
||||
|
||||
Returns:
|
||||
标准化后的角色名称
|
||||
"""
|
||||
if not role or not isinstance(role, str):
|
||||
return "User"
|
||||
|
||||
role = role.strip().lower()
|
||||
|
||||
# 用户角色的各种表示
|
||||
user_roles = ['user', 'human', '用户', '人类', 'customer', '客户', 'u']
|
||||
# AI角色的各种表示
|
||||
ai_roles = ['assistant', 'ai', 'bot', 'chatbot', '助手', '机器人', 'system', 'a']
|
||||
|
||||
if role in user_roles:
|
||||
return "User"
|
||||
elif role in ai_roles:
|
||||
return "Assistant"
|
||||
else:
|
||||
return "User" # 默认为用户
|
||||
|
||||
def clean_data(self, raw_data: List[Dict[str, Any]], skip_cleaning: bool = True) -> List[DialogData]:
|
||||
"""
|
||||
清洗原始数据并转换为DialogData对象。
|
||||
|
||||
Args:
|
||||
raw_data: 原始数据列表
|
||||
skip_cleaning: 是否跳过数据清洗,直接转换为DialogData对象(默认False)
|
||||
|
||||
Returns:
|
||||
清洗后的DialogData对象列表
|
||||
"""
|
||||
if skip_cleaning:
|
||||
print("跳过数据清洗步骤,直接转换数据...")
|
||||
return self._convert_to_dialog_data(raw_data)
|
||||
|
||||
cleaned_dialogs = []
|
||||
|
||||
for i, item in enumerate(raw_data):
|
||||
conv_date: Optional[str] = None
|
||||
try:
|
||||
# 提取对话消息
|
||||
messages = []
|
||||
|
||||
# 处理不同的数据结构
|
||||
if 'content' in item and isinstance(item['content'], list):
|
||||
# 新格式:dialog_release_zh.json格式,content是字符串数组
|
||||
content_list = item['content']
|
||||
for j, content_text in enumerate(content_list):
|
||||
# 交替分配角色:偶数索引为用户,奇数索引为AI
|
||||
role = 'User' if j % 2 == 0 else 'Assistant'
|
||||
normalized_role = self._normalize_role(role)
|
||||
|
||||
# 清洗消息内容
|
||||
cleaned_content = self._clean_text(str(content_text))
|
||||
|
||||
# 过滤空消息
|
||||
if cleaned_content:
|
||||
messages.append(ConversationMessage(role=normalized_role, msg=cleaned_content))
|
||||
|
||||
elif 'context' in item and isinstance(item['context'], dict) and 'msgs' in item['context']:
|
||||
# 标准格式:context是字典且包含msgs
|
||||
raw_messages = item['context']['msgs']
|
||||
elif 'context' in item and isinstance(item['context'], str):
|
||||
# testdata.json格式:context是字符串,需要解析对话内容
|
||||
context_text = item['context']
|
||||
# 从context文本中解析绝对日期并存入conv_date(格式:YYYY-MM-DD)
|
||||
m = re.search(r"(\d{4})年(\d{1,2})月(\d{1,2})日", context_text)
|
||||
if m:
|
||||
y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3))
|
||||
conv_date = f"{y:04d}-{mo:02d}-{d:02d}"
|
||||
else:
|
||||
m = re.search(r"(\d{4})[-/](\d{1,2})[-/](\d{1,2})", context_text)
|
||||
if m:
|
||||
y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3))
|
||||
conv_date = f"{y:04d}-{mo:02d}-{d:02d}"
|
||||
messages = self._parse_context_string(context_text)
|
||||
elif 'messages' in item:
|
||||
# 另一种常见格式
|
||||
raw_messages = item['messages']
|
||||
elif 'conversation' in item:
|
||||
# 对话格式
|
||||
raw_messages = item['conversation']
|
||||
else:
|
||||
# 尝试直接解析
|
||||
raw_messages = [item] if 'role' in item and 'msg' in item else []
|
||||
|
||||
# 如果messages还是空的,说明需要处理raw_messages
|
||||
if not messages and 'raw_messages' in locals():
|
||||
# 清洗每条消息
|
||||
for msg_data in raw_messages:
|
||||
if isinstance(msg_data, dict):
|
||||
role = self._normalize_role(msg_data.get('role', 'User'))
|
||||
content = msg_data.get('msg', msg_data.get('content', msg_data.get('message', '')))
|
||||
|
||||
# 清洗消息内容
|
||||
cleaned_content = self._clean_text(str(content))
|
||||
|
||||
# 过滤空消息
|
||||
if cleaned_content:
|
||||
messages.append(ConversationMessage(role=role, msg=cleaned_content))
|
||||
|
||||
# 过滤空对话
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
# 去重相邻的重复消息
|
||||
deduplicated_messages = []
|
||||
for msg in messages:
|
||||
if not deduplicated_messages or (
|
||||
deduplicated_messages[-1].role != msg.role or
|
||||
deduplicated_messages[-1].msg != msg.msg
|
||||
):
|
||||
deduplicated_messages.append(msg)
|
||||
|
||||
# 创建DialogData对象
|
||||
context = ConversationContext(msgs=deduplicated_messages)
|
||||
# 获取对话ID,优先使用dialog_id,然后是ref_id、id,最后生成默认ID
|
||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||
|
||||
|
||||
# 获取group_id,如果不存在则生成默认值
|
||||
group_id = item.get('group_id', f'group_default_{i}')
|
||||
user_id = item.get('user_id', f'user_default_{i}')
|
||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||
|
||||
|
||||
# 构建元数据,附加解析到的会话日期
|
||||
metadata = {
|
||||
**item.get('metadata', {}),
|
||||
'document_id': str(item.get('document_id', 'unknown')) if item.get('document_id') is not None else 'unknown',
|
||||
'original_format': 'dialog_release_zh' if 'content' in item and isinstance(item['content'], list) else 'testdata'
|
||||
}
|
||||
if conv_date:
|
||||
metadata['conversation_date'] = conv_date
|
||||
metadata['publication_date'] = conv_date
|
||||
|
||||
dialog_data = DialogData(
|
||||
context=context,
|
||||
ref_id=dialog_id,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
cleaned_dialogs.append(dialog_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"警告: 处理第{i+1}条数据时出错: {e}")
|
||||
continue
|
||||
|
||||
return cleaned_dialogs
|
||||
|
||||
def _convert_to_dialog_data(self, raw_data: List[Dict[str, Any]]) -> List[DialogData]:
|
||||
"""
|
||||
直接将原始数据转换为DialogData对象,不进行清洗。
|
||||
|
||||
Args:
|
||||
raw_data: 原始数据列表
|
||||
|
||||
Returns:
|
||||
DialogData对象列表
|
||||
"""
|
||||
dialog_list = []
|
||||
|
||||
for i, item in enumerate(raw_data):
|
||||
try:
|
||||
messages = []
|
||||
|
||||
# 处理不同的数据结构
|
||||
if 'content' in item and isinstance(item['content'], list):
|
||||
content_list = item['content']
|
||||
for j, content_text in enumerate(content_list):
|
||||
role = 'User' if j % 2 == 0 else 'Assistant'
|
||||
if content_text:
|
||||
messages.append(ConversationMessage(role=role, msg=str(content_text)))
|
||||
|
||||
elif 'context' in item and isinstance(item['context'], dict) and 'msgs' in item['context']:
|
||||
raw_messages = item['context']['msgs']
|
||||
for msg_data in raw_messages:
|
||||
if isinstance(msg_data, dict):
|
||||
role = msg_data.get('role', 'User')
|
||||
content = msg_data.get('msg', msg_data.get('content', msg_data.get('message', '')))
|
||||
if content:
|
||||
messages.append(ConversationMessage(role=role, msg=str(content)))
|
||||
|
||||
elif 'context' in item and isinstance(item['context'], str):
|
||||
# 尝试解析结构化对话,如果失败则作为单条用户消息处理
|
||||
messages = self._parse_context_string(item['context'])
|
||||
if not messages:
|
||||
# 如果没有解析出结构化消息,将整个context作为用户消息
|
||||
context_text = item['context'].strip()
|
||||
if context_text:
|
||||
messages.append(ConversationMessage(role='User', msg=context_text))
|
||||
|
||||
elif 'messages' in item:
|
||||
raw_messages = item['messages']
|
||||
for msg_data in raw_messages:
|
||||
if isinstance(msg_data, dict):
|
||||
role = msg_data.get('role', 'User')
|
||||
content = msg_data.get('msg', msg_data.get('content', msg_data.get('message', '')))
|
||||
if content:
|
||||
messages.append(ConversationMessage(role=role, msg=str(content)))
|
||||
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||
group_id = item.get('group_id', f'group_default_{i}')
|
||||
user_id = item.get('user_id', f'user_default_{i}')
|
||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||
|
||||
metadata = {
|
||||
**item.get('metadata', {}),
|
||||
'document_id': str(item.get('document_id', 'unknown')) if item.get('document_id') is not None else 'unknown',
|
||||
'original_format': 'raw'
|
||||
}
|
||||
|
||||
dialog_data = DialogData(
|
||||
context=context,
|
||||
ref_id=dialog_id,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
dialog_list.append(dialog_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"警告: 转换第{i+1}条数据时出错: {e}")
|
||||
continue
|
||||
|
||||
return dialog_list
|
||||
|
||||
def _parse_context_string(self, context_text: str) -> List[ConversationMessage]:
|
||||
"""
|
||||
解析context字符串中的对话内容。
|
||||
|
||||
Args:
|
||||
context_text: 包含对话的字符串
|
||||
|
||||
Returns:
|
||||
解析后的ConversationMessage列表
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# 使用正则表达式匹配对话模式
|
||||
# 匹配 "User: 内容" / "用户: 内容" 或 "Assistant: 内容" / "AI: 内容" 格式
|
||||
pattern = r'(User|用户|Assistant|AI|user|assistant)[::]\s*([^\n]+(?:\n(?!(?:User|用户|Assistant|AI|user|assistant)[::])[^\n]*)*?)'
|
||||
matches = re.findall(pattern, context_text, re.MULTILINE | re.DOTALL | re.IGNORECASE)
|
||||
|
||||
for role, content in matches:
|
||||
# 标准化角色名称
|
||||
normalized_role = self._normalize_role(role)
|
||||
|
||||
# 清洗消息内容
|
||||
cleaned_content = self._clean_text(content.strip())
|
||||
|
||||
# 过滤空消息
|
||||
if cleaned_content:
|
||||
messages.append(ConversationMessage(role=normalized_role, msg=cleaned_content))
|
||||
|
||||
return messages
|
||||
|
||||
def save_data(self, dialog_data_list: List[DialogData], output_path: str = None) -> None:
|
||||
"""
|
||||
保存处理后的数据。
|
||||
|
||||
Args:
|
||||
dialog_data_list: DialogData对象列表
|
||||
output_path: 输出文件路径(如果为None,则使用初始化时设置的路径)
|
||||
"""
|
||||
if output_path is None:
|
||||
output_path = self.output_file_path
|
||||
|
||||
if not output_path:
|
||||
raise ValueError("请指定输出文件路径")
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
# 转换为可序列化的格式
|
||||
serializable_data = []
|
||||
for dialog in dialog_data_list:
|
||||
serializable_data.append({
|
||||
'id': dialog.id,
|
||||
'ref_id': dialog.ref_id,
|
||||
'created_at': dialog.created_at.isoformat(),
|
||||
'context': {
|
||||
'msgs': [{'role': msg.role, 'msg': msg.msg} for msg in dialog.context.msgs]
|
||||
},
|
||||
'metadata': dialog.metadata,
|
||||
'chunks': []
|
||||
})
|
||||
|
||||
# 保存为JSON格式
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(serializable_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"数据已保存到: {output_path}")
|
||||
|
||||
def preprocess(self, input_path: str = None, output_path: str = None, skip_cleaning: bool = True, indices: Optional[List[int]] = None) -> List[DialogData]:
|
||||
"""
|
||||
完整的数据预处理流程。
|
||||
|
||||
Args:
|
||||
input_path: 输入文件路径(可选)
|
||||
output_path: 输出文件路径(可选)
|
||||
skip_cleaning: 是否跳过数据清洗步骤(默认False)
|
||||
indices: 要处理的数据索引列表(可选)
|
||||
|
||||
Returns:
|
||||
处理后的DialogData对象列表
|
||||
"""
|
||||
print("开始数据预处理...")
|
||||
|
||||
# 读取原始数据
|
||||
print("正在读取数据...")
|
||||
raw_data = self.read_data(input_path)
|
||||
print(f"成功读取 {len(raw_data)} 条原始数据")
|
||||
|
||||
# 根据索引筛选数据
|
||||
if indices:
|
||||
selected = [raw_data[i] for i in indices if 0 <= i < len(raw_data)]
|
||||
if selected:
|
||||
raw_data = selected
|
||||
print(f"根据索引 {indices} 筛选后,保留 {len(raw_data)} 条数据")
|
||||
else:
|
||||
print(f"警告: 提供的索引 {indices} 筛选为空,处理全部 {len(raw_data)} 条数据")
|
||||
|
||||
# 清洗数据
|
||||
if skip_cleaning:
|
||||
print("跳过数据清洗步骤...")
|
||||
cleaned_data = self.clean_data(raw_data, skip_cleaning=True)
|
||||
else:
|
||||
print("正在清洗数据...")
|
||||
cleaned_data = self.clean_data(raw_data, skip_cleaning=False)
|
||||
print(f"处理完成,得到 {len(cleaned_data)} 条有效对话")
|
||||
|
||||
# 保存数据(如果指定了输出路径)
|
||||
if output_path or self.output_file_path:
|
||||
print("正在保存数据...")
|
||||
self.save_data(cleaned_data, output_path)
|
||||
|
||||
print("数据预处理完成!")
|
||||
return cleaned_data
|
||||
@@ -0,0 +1,573 @@
|
||||
"""
|
||||
语义剪枝器 - 在预处理与分块之间过滤与场景不相关内容
|
||||
|
||||
功能:
|
||||
- 对话级一次性抽取判定相关性
|
||||
- 仅对"不相关对话"的消息按比例删除
|
||||
- 重要信息(时间、编号、金额、联系方式、地址等)优先保留
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.core.memory.utils.config.config_utils import get_pruning_config
|
||||
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
|
||||
|
||||
|
||||
class DialogExtractionResponse(BaseModel):
|
||||
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
||||
|
||||
- is_related:对话与场景的相关性判定。
|
||||
- times / ids / amounts / contacts / addresses / keywords:重要信息片段,用来在不相关对话中保留关键消息。
|
||||
"""
|
||||
is_related: bool = Field(...)
|
||||
times: List[str] = Field(default_factory=list)
|
||||
ids: List[str] = Field(default_factory=list)
|
||||
amounts: List[str] = Field(default_factory=list)
|
||||
contacts: List[str] = Field(default_factory=list)
|
||||
addresses: List[str] = Field(default_factory=list)
|
||||
keywords: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SemanticPruner:
|
||||
"""语义剪枝:在预处理与分块之间过滤与场景不相关内容。
|
||||
|
||||
采用对话级一次性抽取判定相关性;仅对"不相关对话"的消息按比例删除,
|
||||
重要信息(时间、编号、金额、联系方式、地址等)优先保留。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None):
|
||||
cfg_dict = get_pruning_config() if config is None else config.model_dump()
|
||||
self.config = PruningConfig.model_validate(cfg_dict)
|
||||
self.llm_client = llm_client
|
||||
# Load Jinja2 template
|
||||
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
||||
# 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染
|
||||
self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {}
|
||||
# 运行日志:收集关键终端输出,便于写入 JSON
|
||||
self.run_logs: List[str] = []
|
||||
# 采用顺序处理,移除并发配置以简化与稳定执行
|
||||
|
||||
def _is_important_message(self, message: ConversationMessage) -> bool:
|
||||
"""基于启发式规则识别重要信息消息,优先保留。
|
||||
|
||||
- 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。
|
||||
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。
|
||||
- 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址"。
|
||||
"""
|
||||
import re
|
||||
text = message.msg.strip()
|
||||
if not text:
|
||||
return False
|
||||
patterns = [
|
||||
r"\b\d{4}-\d{1,2}-\d{1,2}\b",
|
||||
r"\b\d{1,2}:\d{2}\b",
|
||||
r"\d{4}年\d{1,2}月\d{1,2}日",
|
||||
r"上午|下午|AM|PM",
|
||||
r"订单号|工单|申请号|编号|ID|账号|账户",
|
||||
r"电话|手机号|微信|QQ|邮箱",
|
||||
r"地址|地点",
|
||||
r"金额|费用|价格|¥|¥|\d+元",
|
||||
r"时间|日期|有效期|截止",
|
||||
]
|
||||
for p in patterns:
|
||||
if re.search(p, text, flags=re.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _importance_score(self, message: ConversationMessage) -> int:
|
||||
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
|
||||
|
||||
简单启发:匹配到的类别越多、越关键分值越高。
|
||||
"""
|
||||
import re
|
||||
text = message.msg.strip()
|
||||
score = 0
|
||||
weights = [
|
||||
(r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3),
|
||||
(r"\b\d{1,2}:\d{2}\b", 2),
|
||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 4),
|
||||
(r"电话|手机号|微信|QQ|邮箱", 3),
|
||||
(r"地址|地点", 2),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 4),
|
||||
(r"时间|日期|有效期|截止", 2),
|
||||
]
|
||||
for p, w in weights:
|
||||
if re.search(p, text, flags=re.IGNORECASE):
|
||||
score += w
|
||||
return score
|
||||
|
||||
def _is_filler_message(self, message: ConversationMessage) -> bool:
|
||||
"""检测典型寒暄/口头禅/确认类短消息,用于跳过LLM分类以加速。
|
||||
|
||||
满足以下之一视为填充消息:
|
||||
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体;
|
||||
- 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。
|
||||
"""
|
||||
import re
|
||||
t = message.msg.strip()
|
||||
if not t:
|
||||
return True
|
||||
# 常见填充语
|
||||
fillers = [
|
||||
"你好", "您好", "在吗", "嗯", "嗯嗯", "哦", "好的", "好", "行", "可以", "不可以", "谢谢",
|
||||
"拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", "??"
|
||||
]
|
||||
if t in fillers:
|
||||
return True
|
||||
# 长度与字符类型判断
|
||||
if len(t) <= 8:
|
||||
# 非数字、无关键实体的短文本
|
||||
if not re.search(r"[0-9]", t) and not self._is_important_message(message):
|
||||
# 主要是标点或简单确认词
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse:
|
||||
"""对话级一次性抽取:从整段对话中提取重要信息并判定相关性。
|
||||
|
||||
- 仅使用 LLM 结构化输出;
|
||||
"""
|
||||
# 缓存命中则直接返回(场景+内容作为键)
|
||||
cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest()
|
||||
if cache_key in self._dialog_extract_cache:
|
||||
return self._dialog_extract_cache[cache_key]
|
||||
|
||||
rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene})
|
||||
log_prompt_rendering("pruning-extract", rendered)
|
||||
|
||||
# 强制使用 LLM;移除正则回退
|
||||
if not self.llm_client:
|
||||
raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"},
|
||||
{"role": "user", "content": rendered},
|
||||
]
|
||||
try:
|
||||
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
|
||||
self._dialog_extract_cache[cache_key] = ex
|
||||
return ex
|
||||
except Exception as e:
|
||||
raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e
|
||||
|
||||
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
|
||||
"""判断消息是否包含任意抽取到的重要片段。"""
|
||||
if not tokens:
|
||||
return False
|
||||
t = message.msg
|
||||
return any(tok and (tok in t) for tok in tokens)
|
||||
|
||||
async def prune_dialog(self, dialog: DialogData) -> DialogData:
|
||||
"""单对话剪枝:使用一次性对话抽取,避免逐条消息 LLM 调用。
|
||||
|
||||
流程:
|
||||
- 对整段对话进行抽取与相关性判定;若相关则不剪;
|
||||
- 若不相关:用抽取到的重要片段 + 简单启发识别重要消息,按比例删除不相关消息,优先删除不重要,再删除重要(但重要最多按比例)。
|
||||
- 删除策略:不重要消息按出现顺序删除(确定性、无随机)。
|
||||
"""
|
||||
if not self.config.pruning_switch:
|
||||
return dialog
|
||||
|
||||
proportion = float(self.config.pruning_threshold)
|
||||
extraction = await self._extract_dialog_important(dialog.content)
|
||||
if extraction.is_related:
|
||||
# 相关对话不剪枝
|
||||
return dialog
|
||||
|
||||
# 在不相关对话中,识别重要/不重要消息
|
||||
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
|
||||
msgs = dialog.context.msgs
|
||||
imp_unrel_msgs: List[ConversationMessage] = []
|
||||
unimp_unrel_msgs: List[ConversationMessage] = []
|
||||
for m in msgs:
|
||||
if self._msg_matches_tokens(m, tokens) or self._is_important_message(m):
|
||||
imp_unrel_msgs.append(m)
|
||||
else:
|
||||
unimp_unrel_msgs.append(m)
|
||||
# 计算总删除目标数量
|
||||
total_unrel = len(msgs)
|
||||
delete_target = int(total_unrel * proportion)
|
||||
if proportion > 0 and total_unrel > 0 and delete_target == 0:
|
||||
delete_target = 1
|
||||
imp_del_cap = min(int(len(imp_unrel_msgs) * proportion), len(imp_unrel_msgs))
|
||||
unimp_del_cap = len(unimp_unrel_msgs)
|
||||
max_capacity = max(0, len(msgs) - 1)
|
||||
max_deletable = min(imp_del_cap + unimp_del_cap, max_capacity)
|
||||
delete_target = min(delete_target, max_deletable)
|
||||
# 删除配额分配
|
||||
del_unimp = min(delete_target, unimp_del_cap)
|
||||
rem = delete_target - del_unimp
|
||||
del_imp = min(rem, imp_del_cap)
|
||||
|
||||
# 选取删除集合
|
||||
unimp_delete_ids = []
|
||||
imp_delete_ids = []
|
||||
if del_unimp > 0:
|
||||
# 按出现顺序选取前 del_unimp 条不重要消息进行删除(确定性、可复现)
|
||||
unimp_delete_ids = [id(m) for m in unimp_unrel_msgs[:del_unimp]]
|
||||
if del_imp > 0:
|
||||
imp_sorted = sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))
|
||||
imp_delete_ids = [id(m) for m in imp_sorted[:del_imp]]
|
||||
|
||||
# 统计实际删除数量(重要/不重要)
|
||||
actual_unimp_deleted = 0
|
||||
actual_imp_deleted = 0
|
||||
kept_msgs = []
|
||||
delete_targets = set(unimp_delete_ids) | set(imp_delete_ids)
|
||||
for m in msgs:
|
||||
mid = id(m)
|
||||
if mid in delete_targets:
|
||||
if mid in set(unimp_delete_ids) and actual_unimp_deleted < del_unimp:
|
||||
actual_unimp_deleted += 1
|
||||
continue
|
||||
if mid in set(imp_delete_ids) and actual_imp_deleted < del_imp:
|
||||
actual_imp_deleted += 1
|
||||
continue
|
||||
kept_msgs.append(m)
|
||||
if not kept_msgs and msgs:
|
||||
kept_msgs = [msgs[0]]
|
||||
|
||||
deleted_total = actual_unimp_deleted + actual_imp_deleted
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} 删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
|
||||
)
|
||||
|
||||
dialog.context = ConversationContext(msgs=kept_msgs)
|
||||
return dialog
|
||||
|
||||
async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]:
|
||||
"""数据集层面:全局消息级剪枝,保留所有对话。
|
||||
|
||||
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。
|
||||
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。
|
||||
- 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。
|
||||
- 保证每段对话至少保留1条消息,不会删除整段对话。
|
||||
"""
|
||||
# 如果剪枝功能关闭,直接返回原始数据集。
|
||||
if not self.config.pruning_switch:
|
||||
return dialogs
|
||||
|
||||
# 阈值保护:最高0.9
|
||||
proportion = float(self.config.pruning_threshold)
|
||||
if proportion > 0.9:
|
||||
print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
||||
proportion = 0.9
|
||||
if proportion < 0.0:
|
||||
proportion = 0.0
|
||||
evaluated_dialogs = [] # list of dicts: {dialog, is_related}
|
||||
|
||||
self._log(
|
||||
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}"
|
||||
)
|
||||
# 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存)
|
||||
evaluated_dialogs = []
|
||||
for idx, dd in enumerate(dialogs):
|
||||
try:
|
||||
ex = await self._extract_dialog_important(dd.content)
|
||||
evaluated_dialogs.append({
|
||||
"dialog": dd,
|
||||
"is_related": bool(ex.is_related),
|
||||
"index": idx,
|
||||
"extraction": ex
|
||||
})
|
||||
except Exception:
|
||||
evaluated_dialogs.append({
|
||||
"dialog": dd,
|
||||
"is_related": True,
|
||||
"index": idx,
|
||||
"extraction": None
|
||||
})
|
||||
|
||||
# 统计相关 / 不相关对话
|
||||
not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]]
|
||||
related_dialogs = [d for d in evaluated_dialogs if d["is_related"]]
|
||||
self._log(
|
||||
f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}"
|
||||
)
|
||||
|
||||
# 简洁打印第几段对话相关/不相关(索引基于1)
|
||||
def _fmt_indices(items, cap: int = 10):
|
||||
inds = [i["index"] + 1 for i in items]
|
||||
if len(inds) <= cap:
|
||||
return inds
|
||||
# 超过上限时只打印前cap个,并标注总数
|
||||
return inds[:cap] + ["...", f"共{len(inds)}个"]
|
||||
|
||||
rel_inds = _fmt_indices(related_dialogs)
|
||||
nrel_inds = _fmt_indices(not_related_dialogs)
|
||||
self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}段")
|
||||
|
||||
result: List[DialogData] = []
|
||||
if not_related_dialogs:
|
||||
# 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM)
|
||||
per_dialog_info = {}
|
||||
total_unrelated = 0
|
||||
total_capacity = 0
|
||||
for d in not_related_dialogs:
|
||||
dd = d["dialog"]
|
||||
extraction = d.get("extraction")
|
||||
if extraction is None:
|
||||
extraction = await self._extract_dialog_important(dd.content)
|
||||
# 合并所有重要标记
|
||||
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
|
||||
msgs = dd.context.msgs
|
||||
# 分类消息
|
||||
imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)]
|
||||
unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs]
|
||||
# 重要消息按重要性排序
|
||||
imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))]
|
||||
info = {
|
||||
"dialog": dd,
|
||||
"total_msgs": len(msgs),
|
||||
"unrelated_count": len(msgs),
|
||||
"imp_ids_sorted": imp_sorted_ids,
|
||||
"unimp_ids": [id(m) for m in unimp_unrel_msgs],
|
||||
}
|
||||
per_dialog_info[d["index"]] = info
|
||||
total_unrelated += info["unrelated_count"]
|
||||
# 全局删除配额:比例作用于全部不相关消息(重要+不重要)
|
||||
global_delete = int(total_unrelated * proportion)
|
||||
if proportion > 0 and total_unrelated > 0 and global_delete == 0:
|
||||
global_delete = 1
|
||||
# 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例),且至少保留1条消息
|
||||
capacities = []
|
||||
for d in not_related_dialogs:
|
||||
idx = d["index"]
|
||||
info = per_dialog_info[idx]
|
||||
# 统计重要数量
|
||||
imp_count = len(info["imp_ids_sorted"])
|
||||
unimp_count = len(info["unimp_ids"])
|
||||
imp_cap = int(imp_count * proportion)
|
||||
cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1))
|
||||
capacities.append(cap)
|
||||
total_capacity = sum(capacities)
|
||||
if global_delete > total_capacity:
|
||||
print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。")
|
||||
global_delete = total_capacity
|
||||
|
||||
# 配额分配:按不相关消息占比分配到各对话,但不超过各自容量
|
||||
alloc = []
|
||||
for i, d in enumerate(not_related_dialogs):
|
||||
idx = d["index"]
|
||||
info = per_dialog_info[idx]
|
||||
share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0
|
||||
alloc.append(min(share, capacities[i]))
|
||||
allocated = sum(alloc)
|
||||
rem = global_delete - allocated
|
||||
turn = 0
|
||||
while rem > 0 and turn < 100000:
|
||||
progressed = False
|
||||
for i in range(len(not_related_dialogs)):
|
||||
if rem <= 0:
|
||||
break
|
||||
if alloc[i] < capacities[i]:
|
||||
alloc[i] += 1
|
||||
rem -= 1
|
||||
progressed = True
|
||||
if not progressed:
|
||||
break
|
||||
turn += 1
|
||||
|
||||
# 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先)
|
||||
total_deleted_confirm = 0
|
||||
for d in evaluated_dialogs:
|
||||
dd = d["dialog"]
|
||||
msgs = dd.context.msgs
|
||||
original = len(msgs)
|
||||
if d["is_related"]:
|
||||
result.append(dd)
|
||||
continue
|
||||
idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None)
|
||||
if idx_in_unrel is None:
|
||||
result.append(dd)
|
||||
continue
|
||||
quota = alloc[idx_in_unrel]
|
||||
info = per_dialog_info[d["index"]]
|
||||
# 计算本对话重要最多可删数量
|
||||
imp_count = len(info["imp_ids_sorted"])
|
||||
imp_del_cap = int(imp_count * proportion)
|
||||
# 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条)
|
||||
unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))])
|
||||
del_unimp = min(quota, len(unimp_delete_ids))
|
||||
rem_quota = quota - del_unimp
|
||||
# 再从重要里选低分优先的删除ID(不超过 imp_del_cap)
|
||||
imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)])
|
||||
deleted_here = 0
|
||||
actual_unimp_deleted = 0
|
||||
actual_imp_deleted = 0
|
||||
kept = []
|
||||
for m in msgs:
|
||||
mid = id(m)
|
||||
if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp:
|
||||
actual_unimp_deleted += 1
|
||||
deleted_here += 1
|
||||
continue
|
||||
if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids):
|
||||
actual_imp_deleted += 1
|
||||
deleted_here += 1
|
||||
continue
|
||||
kept.append(m)
|
||||
if not kept and msgs:
|
||||
kept = [msgs[0]]
|
||||
dd.context.msgs = kept
|
||||
total_deleted_confirm += deleted_here
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}"
|
||||
)
|
||||
result.append(dd)
|
||||
self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。")
|
||||
else:
|
||||
# 全部相关:不执行剪枝
|
||||
result = [d["dialog"] for d in evaluated_dialogs]
|
||||
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
|
||||
|
||||
# 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成)
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
|
||||
# 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存
|
||||
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
|
||||
payload = self._parse_logs_to_structured(sanitized_logs)
|
||||
with open(log_output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}")
|
||||
|
||||
# Safety: avoid empty dataset
|
||||
if not result:
|
||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
return dialogs
|
||||
return result
|
||||
|
||||
def _log(self, msg: str) -> None:
|
||||
"""记录日志并打印到终端。"""
|
||||
try:
|
||||
self.run_logs.append(msg)
|
||||
except Exception:
|
||||
# 任何异常都不影响打印
|
||||
pass
|
||||
print(msg)
|
||||
|
||||
def _sanitize_log_line(self, line: str) -> str:
|
||||
"""移除行首的方括号标签前缀,例如 [剪枝-数据集] 或 [剪枝-对话]。"""
|
||||
try:
|
||||
return re.sub(r"^\[[^\]]+\]\s*", "", line)
|
||||
except Exception:
|
||||
return line
|
||||
|
||||
def _parse_logs_to_structured(self, logs: List[str]) -> dict:
|
||||
"""将已去前缀的日志列表解析为结构化 JSON,便于数据对接。"""
|
||||
summary = {
|
||||
"scene": self.config.pruning_scene,
|
||||
"dialog_total": None,
|
||||
"deletion_ratio": None,
|
||||
"enabled": None,
|
||||
"related_count": None,
|
||||
"unrelated_count": None,
|
||||
"related_indices": [],
|
||||
"unrelated_indices": [],
|
||||
"total_deleted_messages": None,
|
||||
"remaining_dialogs": None,
|
||||
}
|
||||
dialogs = []
|
||||
|
||||
# 解析函数
|
||||
def parse_int(value: str) -> Optional[int]:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def parse_float(value: str) -> Optional[float]:
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def parse_indices(s: str) -> List[int]:
|
||||
s = s.strip()
|
||||
if not s:
|
||||
return []
|
||||
parts = [p.strip() for p in s.split(",") if p.strip()]
|
||||
out: List[int] = []
|
||||
for p in parts:
|
||||
try:
|
||||
out.append(int(p))
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
# 正则
|
||||
re_header = re.compile(r"对话总数=(\d+)\s+场景=([^\s]+)\s+删除比例=([0-9.]+)\s+开关=(True|False)")
|
||||
re_counts = re.compile(r"相关对话数=(\d+)\s+不相关对话数=(\d+)")
|
||||
re_indices = re.compile(r"相关对话:第\[(.*?)\]段;不相关对话:第\[(.*?)\]段")
|
||||
re_dialog = re.compile(r"对话\s+(\d+)\s+总消息=(\d+)\s+分配删除=(\d+)\s+实删=(\d+)\s+保留=(\d+)")
|
||||
re_total_del = re.compile(r"总删除\s+(\d+)\s+条")
|
||||
re_remaining = re.compile(r"剩余对话数=(\d+)")
|
||||
|
||||
for line in logs:
|
||||
# 第一行:总览
|
||||
m = re_header.search(line)
|
||||
if m:
|
||||
summary["dialog_total"] = parse_int(m.group(1))
|
||||
# 顶层 scene 依配置,这里不覆盖,但也可校验 m.group(2)
|
||||
summary["deletion_ratio"] = parse_float(m.group(3))
|
||||
summary["enabled"] = True if m.group(4) == "True" else False
|
||||
continue
|
||||
|
||||
# 第二行:相关/不相关数量
|
||||
m = re_counts.search(line)
|
||||
if m:
|
||||
summary["related_count"] = parse_int(m.group(1))
|
||||
summary["unrelated_count"] = parse_int(m.group(2))
|
||||
continue
|
||||
|
||||
# 第三行:相关/不相关索引
|
||||
m = re_indices.search(line)
|
||||
if m:
|
||||
summary["related_indices"] = parse_indices(m.group(1))
|
||||
summary["unrelated_indices"] = parse_indices(m.group(2))
|
||||
continue
|
||||
|
||||
# 对话级统计
|
||||
m = re_dialog.search(line)
|
||||
if m:
|
||||
dialogs.append({
|
||||
"index": parse_int(m.group(1)),
|
||||
"total_messages": parse_int(m.group(2)),
|
||||
"quota_delete": parse_int(m.group(3)),
|
||||
"actual_deleted": parse_int(m.group(4)),
|
||||
"kept": parse_int(m.group(5)),
|
||||
})
|
||||
continue
|
||||
|
||||
# 全局删除总数
|
||||
m = re_total_del.search(line)
|
||||
if m:
|
||||
summary["total_deleted_messages"] = parse_int(m.group(1))
|
||||
continue
|
||||
|
||||
# 剩余对话数
|
||||
m = re_remaining.search(line)
|
||||
if m:
|
||||
summary["remaining_dialogs"] = parse_int(m.group(1))
|
||||
continue
|
||||
|
||||
return {
|
||||
"scene": summary["scene"],
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"summary": {k: v for k, v in summary.items() if k != "scene"},
|
||||
"dialogs": dialogs,
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
去重消歧模块
|
||||
|
||||
提供实体去重和消歧功能,包括:
|
||||
- 基础去重和消歧(精确匹配、模糊匹配)
|
||||
- LLM 实体去重
|
||||
- 第二层去重(与 Neo4j 数据库联合去重)
|
||||
- 两阶段去重(完整的去重流程)
|
||||
"""
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
deduplicate_entities_and_edges,
|
||||
accurate_match,
|
||||
fuzzy_match,
|
||||
LLM_decision,
|
||||
LLM_disamb_decision,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import (
|
||||
llm_dedup_entities,
|
||||
llm_dedup_entities_iterative_blocks,
|
||||
llm_disambiguate_pairs_iterative,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
|
||||
second_layer_dedup_and_merge_with_neo4j,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
dedup_layers_and_merge_and_return,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"deduplicate_entities_and_edges",
|
||||
"accurate_match",
|
||||
"fuzzy_match",
|
||||
"LLM_decision",
|
||||
"LLM_disamb_decision",
|
||||
"llm_dedup_entities",
|
||||
"llm_dedup_entities_iterative_blocks",
|
||||
"llm_disambiguate_pairs_iterative",
|
||||
"second_layer_dedup_and_merge_with_neo4j",
|
||||
"dedup_layers_and_merge_and_return",
|
||||
]
|
||||
@@ -0,0 +1,784 @@
|
||||
"""
|
||||
去重功能函数
|
||||
"""
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from typing import List, Dict, Tuple
|
||||
from app.core.memory.models.graph_models import(
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode
|
||||
)
|
||||
import os
|
||||
from datetime import datetime
|
||||
import difflib # 提供字符串相似度计算工具
|
||||
import asyncio
|
||||
import importlib
|
||||
import re
|
||||
# 模块级属性融合工具函数(统一行为)
|
||||
def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
# 强弱连接合并
|
||||
can_strength = (getattr(canonical, "connect_strength", "") or "").lower()
|
||||
inc_strength = (getattr(ent, "connect_strength", "") or "").lower()
|
||||
pair = {can_strength, inc_strength} - {""}
|
||||
if pair:
|
||||
if "both" in pair or pair == {"strong", "weak"}:
|
||||
canonical.connect_strength = "both"
|
||||
elif pair == {"strong"}:
|
||||
canonical.connect_strength = "strong"
|
||||
elif pair == {"weak"}:
|
||||
canonical.connect_strength = "weak"
|
||||
else:
|
||||
canonical.connect_strength = next(iter(pair))
|
||||
|
||||
# 别名合并(去重保序)
|
||||
try:
|
||||
existing = getattr(canonical, "aliases", []) or []
|
||||
incoming = getattr(ent, "aliases", []) or []
|
||||
seen = set()
|
||||
merged_list: List[str] = []
|
||||
for x in existing + incoming:
|
||||
xn = (x or "").strip()
|
||||
if xn and xn not in seen:
|
||||
seen.add(xn)
|
||||
merged_list.append(x)
|
||||
canonical.aliases = merged_list
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 描述与事实摘要(保留更长者)
|
||||
try:
|
||||
desc_a = getattr(canonical, "description", "") or ""
|
||||
desc_b = getattr(ent, "description", "") or ""
|
||||
if len(desc_b) > len(desc_a):
|
||||
canonical.description = desc_b
|
||||
# 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序
|
||||
fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||
fact_b = getattr(ent, "fact_summary", "") or ""
|
||||
def _extract_sources(txt: str) -> List[str]:
|
||||
sources: List[str] = []
|
||||
if not txt:
|
||||
return sources
|
||||
for line in str(txt).splitlines():
|
||||
ln = line.strip()
|
||||
# 支持“来源:”或“来源:”前缀
|
||||
m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||
if m:
|
||||
content = m.group(1).strip()
|
||||
if content:
|
||||
sources.append(content)
|
||||
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
|
||||
if not sources and txt.strip():
|
||||
sources.append(txt.strip())
|
||||
return sources
|
||||
try:
|
||||
src_a = _extract_sources(fact_a)
|
||||
src_b = _extract_sources(fact_b)
|
||||
seen = set()
|
||||
merged_sources: List[str] = []
|
||||
for s in src_a + src_b:
|
||||
if s and s not in seen:
|
||||
seen.add(s)
|
||||
merged_sources.append(s)
|
||||
if merged_sources:
|
||||
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||
elif fact_b and not fact_a:
|
||||
canonical.fact_summary = fact_b
|
||||
except Exception:
|
||||
# 兜底:若解析失败,保留较长文本
|
||||
if len(fact_b) > len(fact_a):
|
||||
canonical.fact_summary = fact_b
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 名称向量补全
|
||||
try:
|
||||
emb_a = getattr(canonical, "name_embedding", []) or []
|
||||
emb_b = getattr(ent, "name_embedding", []) or []
|
||||
if not emb_a and emb_b:
|
||||
canonical.name_embedding = emb_b
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 时间范围合并
|
||||
try:
|
||||
# 统一使用 created_at / expired_at
|
||||
if getattr(ent, "created_at", None) and getattr(canonical, "created_at", None) and ent.created_at < canonical.created_at:
|
||||
canonical.created_at = ent.created_at
|
||||
if getattr(ent, "expired_at", None) and getattr(canonical, "expired_at", None):
|
||||
if canonical.expired_at is None:
|
||||
canonical.expired_at = ent.expired_at
|
||||
elif ent.expired_at and ent.expired_at > canonical.expired_at:
|
||||
canonical.expired_at = ent.expired_at
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def accurate_match(
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||
"""
|
||||
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||
返回: (deduped_entities, id_redirect, exact_merge_map)
|
||||
"""
|
||||
exact_merge_map: Dict[str, Dict] = {}
|
||||
canonical_map: Dict[str, ExtractedEntityNode] = {}
|
||||
id_redirect: Dict[str, str] = {}
|
||||
|
||||
# 1) 构建规范实体映射(按名称+类型+group 精确匹配)
|
||||
for ent in entity_nodes:
|
||||
name_norm = (getattr(ent, "name", "") or "").strip()
|
||||
type_norm = (getattr(ent, "entity_type", "") or "").strip()
|
||||
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}"
|
||||
# 为避免跨业务组误并,明确以 group_id 为范围边界
|
||||
if key not in canonical_map:
|
||||
canonical_map[key] = ent
|
||||
id_redirect[getattr(ent, "id")] = getattr(ent, "id")
|
||||
continue
|
||||
canonical = canonical_map[key]
|
||||
|
||||
# 执行精确属性与强弱合并,并建立重定向
|
||||
_merge_attribute(canonical, ent)
|
||||
id_redirect[getattr(ent, "id")] = getattr(canonical, "id")
|
||||
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
|
||||
try:
|
||||
k = f"{getattr(canonical, 'group_id')}|{(getattr(canonical, 'name') or '').strip()}|{(getattr(canonical, 'entity_type') or '').strip()}"
|
||||
if k not in exact_merge_map:
|
||||
exact_merge_map[k] = {
|
||||
"canonical_id": getattr(canonical, "id"),
|
||||
"group_id": getattr(canonical, "group_id"),
|
||||
"name": getattr(canonical, "name"),
|
||||
"entity_type": getattr(canonical, "entity_type"),
|
||||
"merged_ids": set(),
|
||||
}
|
||||
exact_merge_map[k]["merged_ids"].add(getattr(ent, "id"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
deduped_entities = list(canonical_map.values())
|
||||
return deduped_entities, id_redirect, exact_merge_map
|
||||
|
||||
def fuzzy_match(
|
||||
deduped_entities: List[ExtractedEntityNode],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
id_redirect: Dict[str, str],
|
||||
config: DedupConfig | None = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]:
|
||||
"""
|
||||
模糊匹配:在精确匹配之后,基于名称/类型相似度与上下文共现,进一步融合高相似实体。
|
||||
返回: (updated_entities, updated_redirect, fuzzy_merge_records)
|
||||
"""
|
||||
fuzzy_merge_records: List[str] = []
|
||||
|
||||
def _normalize_text(s: str) -> str:
|
||||
try:
|
||||
return re.sub(r"\s+", " ", re.sub(r"[^\w\u4e00-\u9fff]+", " ", (s or "").lower())).strip()
|
||||
except Exception:
|
||||
return str(s).lower().strip()
|
||||
|
||||
def _tokenize(s: str) -> List[str]:
|
||||
norm = _normalize_text(s)
|
||||
tokens = re.findall(r"[\u4e00-\u9fff]+|[a-z0-9]+", norm)
|
||||
return tokens
|
||||
|
||||
def _jaccard(a_tokens: List[str], b_tokens: List[str]) -> float:
|
||||
try:
|
||||
set_a, set_b = set(a_tokens), set(b_tokens)
|
||||
if not set_a and not set_b:
|
||||
return 0.0
|
||||
inter = len(set_a & set_b)
|
||||
union = len(set_a | set_b)
|
||||
return inter / union if union > 0 else 0.0
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def _cosine(a: List[float], b: List[float]) -> float:
|
||||
try:
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = sum(x * x for x in a) ** 0.5
|
||||
nb = sum(y * y for y in b) ** 0.5
|
||||
if na == 0 or nb == 0:
|
||||
return 0.0
|
||||
return dot / (na * nb)
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def _name_similarity(e1: ExtractedEntityNode, e2: ExtractedEntityNode):
|
||||
emb_sim = _cosine(getattr(e1, "name_embedding", []) or [], getattr(e2, "name_embedding", []) or [])
|
||||
tokens1 = set(_tokenize(getattr(e1, "name", "") or ""))
|
||||
tokens2 = set(_tokenize(getattr(e2, "name", "") or ""))
|
||||
aliases1 = getattr(e1, "aliases", []) or []
|
||||
aliases2 = getattr(e2, "aliases", []) or []
|
||||
alias_tokens1 = set(tokens1)
|
||||
alias_tokens2 = set(tokens2)
|
||||
for a in aliases1:
|
||||
alias_tokens1 |= set(_tokenize(a))
|
||||
for a in aliases2:
|
||||
alias_tokens2 |= set(_tokenize(a))
|
||||
j_primary = _jaccard(list(tokens1), list(tokens2))
|
||||
j_alias = _jaccard(list(alias_tokens1), list(alias_tokens2))
|
||||
s_name = 0.6 * emb_sim + 0.2 * j_primary + 0.2 * j_alias
|
||||
return s_name, emb_sim, j_primary, j_alias
|
||||
|
||||
def _desc_similarity(e1: ExtractedEntityNode, e2: ExtractedEntityNode):
|
||||
"""
|
||||
计算实体描述的相似度(Jaccard + SequenceMatcher)
|
||||
返回: (相似度得分, Jaccard 相似度(词重合), SequenceMatcher 相似度(序列相似))
|
||||
"""
|
||||
d1 = getattr(e1, "description", "") or ""
|
||||
d2 = getattr(e2, "description", "") or ""
|
||||
if not d1 and not d2:
|
||||
return 0.0, 0.0, 0.0
|
||||
t1 = _tokenize(d1)
|
||||
t2 = _tokenize(d2)
|
||||
j = _jaccard(t1, t2)
|
||||
try:
|
||||
seq = difflib.SequenceMatcher(None, _normalize_text(d1), _normalize_text(d2)).ratio()
|
||||
except Exception:
|
||||
seq = 0.0
|
||||
# 平衡词重合与序列相似(更鲁棒)
|
||||
s_desc = 0.5 * j + 0.5 * seq
|
||||
return s_desc, j, seq
|
||||
|
||||
def _canonicalize_type(t: str) -> str: # 扩展类型同义归一
|
||||
t = (t or "").strip()
|
||||
if not t:
|
||||
return ""
|
||||
t_up = t.upper()
|
||||
TYPE_ALIASES = {
|
||||
"PERSON": {"人物", "人", "个人", "人名", "PERSON", "PEOPLE", "INDIVIDUAL"},
|
||||
"ORG": {"组织", "ORG"},
|
||||
"COMPANY": {"公司", "企业", "COMPANY"},
|
||||
"INSTITUTION": {"机构", "INSTITUTION"},
|
||||
"LOCATION": {"地点", "位置", "LOCATION"},
|
||||
"CITY": {"城市", "CITY"},
|
||||
"COUNTRY": {"国家", "COUNTRY"},
|
||||
"EVENT": {"事件", "EVENT"},
|
||||
# 扩展活动与技能近义,统一到 ACTIVITY,便于本地模糊匹配
|
||||
"ACTIVITY": {"活动", "技术活动", "技能", "ACTIVITY", "SKILL"},
|
||||
"PRODUCT": {"产品", "商品", "物品", "OBJECT", "PRODUCT"},
|
||||
"TOOL": {"工具", "TOOL"},
|
||||
"SOFTWARE": {"软件", "SOFTWARE"},
|
||||
"FOOD": {"食品", "食物", "FOOD"},
|
||||
"INGREDIENT": {"食材", "配料", "原料", "INGREDIENT"},
|
||||
"SWEETMEATS": {"甜点", "甜品", "甜食", "SWEETMEATS"},
|
||||
# 统一本地与 LLM 阶段:将 EQUIPMENT/装备 映射为 APPLIANCE
|
||||
"APPLIANCE": {"设备", "器材", "摄影器材", "摄影设备", "电器", "烤箱", "装备","镜头", "EQUIPMENT", "APPLIANCE"},
|
||||
"ART": {"艺术", "艺术形式", "ART"},
|
||||
"FLOWER": {"花卉", "鲜花", "FLOWER"},
|
||||
"PLANT": {"植物", "PLANT"},
|
||||
"AGENT": {"AI助手", "助手", "人工智能助手", "智能助手", "智能体", "Agent", "AGENTA"},
|
||||
"ROLE": {"角色", "ROLE"},
|
||||
"SCENE_ELEMENT": {"场景元素", "SCENE_ELEMENT"},
|
||||
"UNKNOWN": {"UNKNOWN", "未知", "不明"},
|
||||
}
|
||||
for canon, aliases in TYPE_ALIASES.items():
|
||||
if t_up in {a.upper() for a in aliases}:
|
||||
return canon
|
||||
return t_up
|
||||
|
||||
def _type_similarity(t1: str, t2: str) -> float:
|
||||
import difflib
|
||||
c1 = _canonicalize_type(t1)
|
||||
c2 = _canonicalize_type(t2)
|
||||
if not c1 or not c2:
|
||||
return 0.0
|
||||
if c1 == c2:
|
||||
return 0.5 if c1 == "UNKNOWN" else 1.0
|
||||
if c1 == "UNKNOWN" or c2 == "UNKNOWN":
|
||||
return 0.5
|
||||
sim_table = {
|
||||
("ORG", "COMPANY"): 0.9, ("COMPANY", "ORG"): 0.9,
|
||||
("ORG", "INSTITUTION"): 0.85, ("INSTITUTION", "ORG"): 0.85,
|
||||
("LOCATION", "CITY"): 0.9, ("CITY", "LOCATION"): 0.9,
|
||||
("LOCATION", "COUNTRY"): 0.9, ("COUNTRY", "LOCATION"): 0.9,
|
||||
("EVENT", "ACTIVITY"): 0.8, ("ACTIVITY", "EVENT"): 0.8,
|
||||
("PRODUCT", "TOOL"): 0.8, ("TOOL", "PRODUCT"): 0.8,
|
||||
("PRODUCT", "SOFTWARE"): 0.8, ("SOFTWARE", "PRODUCT"): 0.8,
|
||||
("FOOD", "SWEETMEATS"): 0.8, ("SWEETMEATS", "FOOD"): 0.8,
|
||||
("INGREDIENT", "FOOD"): 0.85, ("FOOD", "INGREDIENT"): 0.85,
|
||||
("APPLIANCE", "TOOL"): 0.8, ("TOOL", "APPLIANCE"): 0.8,
|
||||
("APPLIANCE", "PRODUCT"): 0.7, ("PRODUCT", "APPLIANCE"): 0.7,
|
||||
("FLOWER", "PLANT"): 0.9, ("PLANT", "FLOWER"): 0.9,
|
||||
("AGENT", "SOFTWARE"): 0.85, ("SOFTWARE", "AGENT"): 0.85,
|
||||
("AGENT", "PRODUCT"): 0.7, ("PRODUCT", "AGENT"): 0.7,
|
||||
("AGENT", "ROLE"): 0.9, ("ROLE", "AGENT"): 0.9,
|
||||
("SCENE_ELEMENT", "PRODUCT"): 0.6, ("PRODUCT", "SCENE_ELEMENT"): 0.6,
|
||||
}
|
||||
base = sim_table.get((c1, c2), 0.0)
|
||||
if base:
|
||||
return base
|
||||
t1n = (t1 or "").strip().lower()
|
||||
t2n = (t2 or "").strip().lower()
|
||||
seq_ratio = difflib.SequenceMatcher(None, t1n, t2n).ratio()
|
||||
return seq_ratio * 0.6
|
||||
# 阈值与权重设定(从配置读取;若无配置则使用 DedupConfig 的默认值)
|
||||
_defaults = DedupConfig()
|
||||
T_NAME_STRICT = (config.fuzzy_name_threshold_strict if config is not None else _defaults.fuzzy_name_threshold_strict)
|
||||
T_TYPE_STRICT = (config.fuzzy_type_threshold_strict if config is not None else _defaults.fuzzy_type_threshold_strict)
|
||||
T_OVERALL = (config.fuzzy_overall_threshold if config is not None else _defaults.fuzzy_overall_threshold)
|
||||
UNKNOWN_NAME_T = (config.fuzzy_unknown_type_name_threshold if config is not None else _defaults.fuzzy_unknown_type_name_threshold)
|
||||
UNKNOWN_TYPE_T = (config.fuzzy_unknown_type_type_threshold if config is not None else _defaults.fuzzy_unknown_type_type_threshold)
|
||||
W_NAME = (config.name_weight if config is not None else _defaults.name_weight)
|
||||
W_DESC = (config.desc_weight if config is not None else _defaults.desc_weight)
|
||||
W_TYPE = (config.type_weight if config is not None else _defaults.type_weight)
|
||||
CTX_BONUS = (config.context_bonus if config is not None else _defaults.context_bonus) # 上下文共现加分
|
||||
FALL_FLOOR = (config.llm_fallback_floor if config is not None else _defaults.llm_fallback_floor)
|
||||
FALL_CEIL = (config.llm_fallback_ceiling if config is not None else _defaults.llm_fallback_ceiling)
|
||||
|
||||
|
||||
i = 0
|
||||
while i < len(deduped_entities):
|
||||
a = deduped_entities[i]
|
||||
j = i + 1
|
||||
while j < len(deduped_entities):
|
||||
b = deduped_entities[j]
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
j += 1
|
||||
continue
|
||||
# 上下文共现
|
||||
try:
|
||||
sources_a = {e.source for e in statement_entity_edges if getattr(e, "target", None) == getattr(a, "id", None)}
|
||||
sources_b = {e.source for e in statement_entity_edges if getattr(e, "target", None) == getattr(b, "id", None)}
|
||||
co_ctx = bool(sources_a & sources_b)
|
||||
except Exception:
|
||||
co_ctx = False
|
||||
s_name, emb_sim, j_primary, j_alias = _name_similarity(a, b)
|
||||
s_desc, j_desc, seq_desc = _desc_similarity(a, b)
|
||||
s_type = _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None))
|
||||
unknown_present = (
|
||||
str(getattr(a, "entity_type", "")).upper() == "UNKNOWN"
|
||||
or str(getattr(b, "entity_type", "")).upper() == "UNKNOWN"
|
||||
)
|
||||
tn = UNKNOWN_NAME_T if unknown_present else T_NAME_STRICT
|
||||
tn = min(tn, 0.88) if co_ctx else tn
|
||||
type_threshold = UNKNOWN_TYPE_T if unknown_present else T_TYPE_STRICT
|
||||
tover = T_OVERALL
|
||||
a_cs = (getattr(a, "connect_strength", "") or "").lower()
|
||||
b_cs = (getattr(b, "connect_strength", "") or "").lower()
|
||||
if a_cs in ("strong", "both") or b_cs in ("strong", "both"):
|
||||
tover = 0.80
|
||||
# 综合评分:名称、描述、类型加权 + 上下文加分
|
||||
overall = W_NAME * s_name + W_DESC * s_desc + W_TYPE * s_type + (CTX_BONUS if co_ctx else 0.0)
|
||||
|
||||
if s_name >= tn and s_type >= type_threshold and overall >= tover:
|
||||
_merge_attribute(a, b)
|
||||
try:
|
||||
fuzzy_merge_records.append(
|
||||
f"[模糊] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | s_name={s_name:.3f}, s_desc={s_desc:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, ctx={co_ctx}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# 用于处理合并实体后,Statement节点下方无挂载边的情况 后续考虑将其代码逻辑统一由关系去重消歧管理
|
||||
# 建立 ID 重定向:将合并实体 b 的 ID 指向规范实体 a 的 ID
|
||||
try:
|
||||
canonical_id = id_redirect.get(getattr(a, "id", None), getattr(a, "id", None))
|
||||
losing_id = getattr(b, "id", None)
|
||||
if losing_id and canonical_id:
|
||||
id_redirect[losing_id] = canonical_id
|
||||
# 扁平化可能的重定向链:凡是映射到 b.id 的,统一指向 a.id
|
||||
for k, v in list(id_redirect.items()):
|
||||
if v == losing_id:
|
||||
id_redirect[k] = canonical_id
|
||||
except Exception:
|
||||
pass
|
||||
deduped_entities.pop(j)
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
if s_name >= tn and s_type >= type_threshold and (FALL_FLOOR <= overall < tover) and (overall <= FALL_CEIL):
|
||||
fuzzy_merge_records.append(
|
||||
f"[边界] {a.id}<->{b.id} ({a.group_id}|{a.name}|{a.entity_type} ~ {b.group_id}|{b.name}|{b.entity_type}) | s_name={s_name:.3f}, s_desc={s_desc:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, ctx={co_ctx}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
j += 1
|
||||
i += 1
|
||||
|
||||
return deduped_entities, id_redirect, fuzzy_merge_records
|
||||
|
||||
async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
deduped_entities: List[ExtractedEntityNode],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
id_redirect: Dict[str, str],
|
||||
config: DedupConfig | None = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]:
|
||||
"""
|
||||
基于迭代分块并发的 LLM 判定,生成实体重定向并在本地应用融合。
|
||||
返回 (updated_entities, updated_redirect, llm_records)。
|
||||
- 仅在配置 enable_llm_dedup_blockwise 为 True 时启用;
|
||||
若未提供配置,则使用 DedupConfig 的默认值作为回退。
|
||||
- 内部调用 llm_dedup_entities_iterative_blocks 获取 pairwise 的重定向映射。
|
||||
- 将映射应用到 deduped_entities 与 id_redirect,并记录融合日志。
|
||||
"""
|
||||
llm_records: List[str] = []
|
||||
try:
|
||||
# 优先使用运行时配置;若未提供配置,使用模型默认值,不再回退到环境变量
|
||||
enable_switch = (
|
||||
bool(config.enable_llm_dedup_blockwise) if config is not None else DedupConfig().enable_llm_dedup_blockwise
|
||||
)
|
||||
if not enable_switch:
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
# 从配置读取 LLM 迭代参数;若无配置则使用 DedupConfig 的默认值
|
||||
_defaults = DedupConfig()
|
||||
block_size = (config.llm_block_size if config is not None else _defaults.llm_block_size)
|
||||
block_concurrency = (config.llm_block_concurrency if config is not None else _defaults.llm_block_concurrency)
|
||||
pair_concurrency = (config.llm_pair_concurrency if config is not None else _defaults.llm_pair_concurrency)
|
||||
max_rounds = (config.llm_max_rounds if config is not None else _defaults.llm_max_rounds)
|
||||
|
||||
# 动态导入 llm 客户端(统一从 app.core.memory.utils.llm_utils 获取)
|
||||
try:
|
||||
llm_utils_mod = importlib.import_module("app.core.memory.utils.llm_utils")
|
||||
get_llm_client_fn = getattr(llm_utils_mod, "get_llm_client")
|
||||
except Exception:
|
||||
get_llm_client_fn = lambda: None
|
||||
|
||||
try:
|
||||
llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm")
|
||||
llm_fn = getattr(llm_mod, "llm_dedup_entities_iterative_blocks")
|
||||
except Exception:
|
||||
raise RuntimeError("LLM 模块加载失败:deduplication.entity_dedup_llm 缺少 llm_dedup_entities_iterative_blocks")
|
||||
|
||||
# 获取 LLM 客户端,若环境未配置或抛错则回退为 None
|
||||
try:
|
||||
llm_client = get_llm_client_fn()
|
||||
except Exception:
|
||||
llm_client = None
|
||||
|
||||
llm_redirect, llm_records = await llm_fn(
|
||||
entity_nodes=deduped_entities,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
entity_entity_edges=entity_entity_edges,
|
||||
llm_client=llm_client,
|
||||
block_size=block_size,
|
||||
block_concurrency=block_concurrency,
|
||||
pair_concurrency=pair_concurrency,
|
||||
max_rounds=max_rounds,
|
||||
)
|
||||
except Exception as e:
|
||||
# 记录错误,不中断主流程
|
||||
llm_records.append(f"[LLM错误] 迭代分块执行失败: {e}")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
|
||||
# 若存在 LLM 的重定向,应用到实体与映射
|
||||
# 确保实体集合与 id_redirect 完整反映 LLM 的合并结果;否则后续边重定向不会指向规范 ID,实体仍然重复
|
||||
if llm_redirect:
|
||||
entity_by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in deduped_entities}
|
||||
for losing_id, canonical_id in list(llm_redirect.items()):
|
||||
if losing_id == canonical_id:
|
||||
continue
|
||||
a = entity_by_id.get(canonical_id)
|
||||
b = entity_by_id.get(losing_id)
|
||||
if not a or not b: # 若不存在 a 或 b,可能已在精确或模糊阶段合并,在之前阶段合并之后,不会再处理但是处于审计的目的会记录
|
||||
continue
|
||||
_merge_attribute(a, b)
|
||||
# ID 重定向
|
||||
try:
|
||||
id_redirect[b.id] = a.id
|
||||
for k, v in list(id_redirect.items()):
|
||||
if v == b.id:
|
||||
id_redirect[k] = a.id
|
||||
except Exception:
|
||||
pass
|
||||
# 记录 LLM 融合日志
|
||||
try:
|
||||
llm_records.append(
|
||||
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
||||
)
|
||||
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
|
||||
except Exception:
|
||||
pass
|
||||
# 移除 losing 实体
|
||||
try:
|
||||
if b in deduped_entities:
|
||||
deduped_entities.remove(b)
|
||||
entity_by_id.pop(b.id, None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
|
||||
async def LLM_disamb_decision(
|
||||
deduped_entities: List[ExtractedEntityNode],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
id_redirect: Dict[str, str],
|
||||
config: DedupConfig | None = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], set[tuple[str, str]], List[str]]:
|
||||
"""
|
||||
预消歧阶段:对“同名但类型不同”的实体对调用LLM进行消歧,
|
||||
产出:需阻断的实体对(blocked_pairs)与必要的合并(merge_redirect)。
|
||||
返回 (updated_entities, updated_redirect, blocked_pairs, disamb_records)。
|
||||
- 仅在配置开关 enable_llm_disambiguation 为 True 时启用;否则返回空阻断列表。
|
||||
"""
|
||||
disamb_records: List[str] = []
|
||||
blocked_pairs: set[tuple[str, str]] = set()
|
||||
try:
|
||||
enable_switch = (
|
||||
config.enable_llm_disambiguation
|
||||
if config is not None
|
||||
else DedupConfig().enable_llm_disambiguation
|
||||
)
|
||||
if not bool(enable_switch):
|
||||
return deduped_entities, id_redirect, blocked_pairs, disamb_records
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import llm_disambiguate_pairs_iterative
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative(
|
||||
entity_nodes=deduped_entities,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
entity_entity_edges=entity_entity_edges,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
# 应用LLM消歧的合并建议
|
||||
if merge_redirect:
|
||||
entity_by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in deduped_entities}
|
||||
for losing_id, canonical_id in list(merge_redirect.items()):
|
||||
if losing_id == canonical_id:
|
||||
continue
|
||||
a = entity_by_id.get(canonical_id)
|
||||
b = entity_by_id.get(losing_id)
|
||||
if not a or not b:
|
||||
continue
|
||||
_merge_attribute(a, b)
|
||||
id_redirect[b.id] = a.id
|
||||
for k, v in list(id_redirect.items()):
|
||||
if v == b.id:
|
||||
id_redirect[k] = a.id
|
||||
try:
|
||||
disamb_records.append(
|
||||
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if b in deduped_entities:
|
||||
deduped_entities.remove(b)
|
||||
entity_by_id.pop(b.id, None)
|
||||
except Exception:
|
||||
pass
|
||||
# 保存阻断对
|
||||
try:
|
||||
blocked_pairs = {tuple(sorted(p)) for p in (block_list or [])}
|
||||
except Exception:
|
||||
blocked_pairs = set()
|
||||
except Exception as e:
|
||||
disamb_records.append(f"[DISAMB错误] 消歧执行失败: {e}")
|
||||
return deduped_entities, id_redirect, blocked_pairs, disamb_records
|
||||
|
||||
return deduped_entities, id_redirect, blocked_pairs, disamb_records
|
||||
|
||||
async def deduplicate_entities_and_edges(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
report_stage: str = "第一层去重消歧",
|
||||
report_append: bool = False,
|
||||
report_stage_notes: List[str] | None = None,
|
||||
dedup_config: DedupConfig | None = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||
"""
|
||||
主流程:依次执行精确匹配、模糊匹配与(可选)LLM 决策融合,随后对边做重定向与去重。之后再处理边,是关系去重和消歧
|
||||
返回:去重后的实体、语句→实体边、实体↔实体边。
|
||||
"""
|
||||
local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化,保留为了之后对于LLM决策追溯
|
||||
# 1) 精确匹配
|
||||
deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes)
|
||||
|
||||
# 1.5) LLM 决策消歧:阻断同名不同类型的高相似对,并应用必要的合并
|
||||
deduped_entities, id_redirect, blocked_pairs, disamb_records = await LLM_disamb_decision(
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
|
||||
)
|
||||
|
||||
# 2) 模糊匹配(本地规则)
|
||||
deduped_entities, id_redirect, fuzzy_merge_records = fuzzy_match(
|
||||
deduped_entities, statement_entity_edges, id_redirect, config=dedup_config
|
||||
)
|
||||
|
||||
# 3) LLM 决策(仅按配置开关)
|
||||
try:
|
||||
enable_switch = (
|
||||
dedup_config.enable_llm_dedup_blockwise
|
||||
if dedup_config is not None
|
||||
else DedupConfig().enable_llm_dedup_blockwise
|
||||
)
|
||||
should_trigger_llm = bool(enable_switch)
|
||||
# 将触发信息写入阶段备注,便于输出报告审计
|
||||
if report_stage_notes is None:
|
||||
report_stage_notes = []
|
||||
report_stage_notes.append(f"LLM触发: {'是' if should_trigger_llm else '否'}")
|
||||
except Exception:
|
||||
should_trigger_llm = False
|
||||
|
||||
if should_trigger_llm:
|
||||
deduped_entities, id_redirect, llm_decision_records = await LLM_decision(
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
|
||||
)
|
||||
else:
|
||||
llm_decision_records = []
|
||||
# 累加 LLM 记录 把 LLM_decision 返回的日志 llm_decision_records 追加到 local_llm_records
|
||||
try:
|
||||
local_llm_records.extend(llm_decision_records or [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# 在主流程这里 这里是之后关系去重和消歧的地方,方法可以写在其他地方
|
||||
# 此处统一对边进行处理,使用累积的 id_redirect 把边的 source/target 改成规范ID
|
||||
# 4) 边重定向与去重
|
||||
# 4.1 语句→实体边:重复时优先保留 strong
|
||||
stmt_ent_map: Dict[str, StatementEntityEdge] = {}
|
||||
for edge in statement_entity_edges:
|
||||
new_target = id_redirect.get(edge.target, edge.target)
|
||||
edge.target = new_target
|
||||
key = f"{edge.source}_{edge.target}"
|
||||
if key not in stmt_ent_map:
|
||||
stmt_ent_map[key] = edge
|
||||
else:
|
||||
existing = stmt_ent_map[key]
|
||||
old_strength = getattr(existing, "connect_strength", "")
|
||||
new_strength = getattr(edge, "connect_strength", "")
|
||||
if old_strength != "strong" and new_strength == "strong":
|
||||
stmt_ent_map[key] = edge
|
||||
|
||||
# 4.2 实体↔实体边:按 source_target 去重(无强弱属性)
|
||||
ent_ent_map: Dict[str, EntityEntityEdge] = {}
|
||||
for edge in entity_entity_edges:
|
||||
new_source = id_redirect.get(edge.source, edge.source)
|
||||
new_target = id_redirect.get(edge.target, edge.target)
|
||||
edge.source = new_source
|
||||
edge.target = new_target
|
||||
key = f"{edge.source}_{edge.target}"
|
||||
if key not in ent_ent_map:
|
||||
ent_ent_map[key] = edge
|
||||
|
||||
|
||||
_write_dedup_fusion_report(
|
||||
exact_merge_map=exact_merge_map,
|
||||
fuzzy_merge_records=fuzzy_merge_records,
|
||||
local_llm_records=local_llm_records,
|
||||
disamb_records=disamb_records,
|
||||
stage_label=report_stage,
|
||||
append=report_append,
|
||||
stage_notes=report_stage_notes,
|
||||
)
|
||||
|
||||
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values())
|
||||
|
||||
# 独立模块:去重融合报告写入(与实体/边的计算解耦)
|
||||
def _write_dedup_fusion_report(
|
||||
exact_merge_map: Dict[str, Dict],
|
||||
fuzzy_merge_records: List[str],
|
||||
local_llm_records: List[str],
|
||||
disamb_records: List[str] | None = None,
|
||||
stage_label: str | None = None,
|
||||
append: bool = False,
|
||||
stage_notes: List[str] | None = None,
|
||||
):
|
||||
try:
|
||||
# 使用全局配置的输出路径
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
out_path = settings.get_memory_output_path("dedup_entity_output.txt")
|
||||
report_lines: List[str] = []
|
||||
if not append:
|
||||
report_lines.append(f"去重融合报告 - {datetime.now().isoformat()}")
|
||||
report_lines.append("")
|
||||
if stage_label:
|
||||
# 追加写入时,在阶段标题前增加一个空行以增强分隔
|
||||
if append:
|
||||
report_lines.append("")
|
||||
report_lines.append(f"=== {stage_label} ===")
|
||||
report_lines.append("")
|
||||
# 阶段注释:在标题下追加,如候选数、是否跳过等
|
||||
if stage_notes:
|
||||
for note in stage_notes:
|
||||
try:
|
||||
report_lines.append(str(note))
|
||||
except Exception:
|
||||
pass
|
||||
report_lines.append("")
|
||||
# 精确
|
||||
report_lines.append("精确匹配去重:")
|
||||
aggregated_exact_lines: List[str] = []
|
||||
try:
|
||||
for k, info in (exact_merge_map or {}).items():
|
||||
merged_ids = sorted(list(info.get("merged_ids", set())))
|
||||
if merged_ids:
|
||||
aggregated_exact_lines.append(
|
||||
f"[精确] 键 {k} 规范实体 {info.get('canonical_id')} 名称 '{info.get('name')}' 类型 {info.get('entity_type')} <- 合并实体IDs {', '.join(merged_ids)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
report_lines.extend(aggregated_exact_lines if aggregated_exact_lines else ["无合并项"])
|
||||
report_lines.append("")
|
||||
# 消歧
|
||||
report_lines.append("LLM 决策消歧:")
|
||||
try:
|
||||
# 仅展示阻断项,过滤掉合并与合并应用
|
||||
disamb_block_only = [
|
||||
line for line in (disamb_records or [])
|
||||
if str(line).startswith("[DISAMB阻断]") or str(line).startswith("[DISAMB异常阻断]")
|
||||
]
|
||||
except Exception:
|
||||
disamb_block_only = disamb_records or []
|
||||
report_lines.extend(disamb_block_only if disamb_block_only else ["未执行或无阻断/合并项"])
|
||||
report_lines.append("")
|
||||
# 模糊
|
||||
report_lines.append("模糊匹配去重:")
|
||||
report_lines.extend(fuzzy_merge_records if fuzzy_merge_records else ["未执行或无合并项"])
|
||||
report_lines.append("")
|
||||
# LLM
|
||||
report_lines.append("LLM 决策去重:")
|
||||
try:
|
||||
# 仅保留 LLM 的“去重判定”记录,排除“合并指令/融合落地”
|
||||
def _is_llm_dedup_record(s: str) -> bool:
|
||||
try:
|
||||
text = str(s)
|
||||
return "[LLM去重]" in text
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
llm_dedup_only = [
|
||||
line for line in (local_llm_records or [])
|
||||
if _is_llm_dedup_record(str(line))
|
||||
]
|
||||
# 同名类型相似的 LLM 去重记录可能来源于消歧阶段,将其也纳入展示
|
||||
try:
|
||||
llm_dedup_only.extend([
|
||||
line for line in (disamb_records or [])
|
||||
if _is_llm_dedup_record(str(line))
|
||||
])
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
llm_dedup_only = []
|
||||
# 输出前移除块前缀(如 "[LLM块0] "),并对重复记录去重(保序)
|
||||
try:
|
||||
import re as _re
|
||||
def _strip_block_prefix(s: str) -> str:
|
||||
try:
|
||||
return _re.sub(r"^\[LLM块\d+\]\s*", "", str(s))
|
||||
except Exception:
|
||||
return str(s)
|
||||
stripped = [ _strip_block_prefix(line) for line in (llm_dedup_only or []) ]
|
||||
seen = set()
|
||||
deduped_ordered = []
|
||||
for line in stripped:
|
||||
if line not in seen:
|
||||
seen.add(line)
|
||||
deduped_ordered.append(line)
|
||||
llm_dedup_only = deduped_ordered
|
||||
except Exception:
|
||||
pass
|
||||
report_lines.extend(llm_dedup_only if llm_dedup_only else ["未执行或无合并项"])
|
||||
with open(out_path, ("a" if append else "w"), encoding="utf-8") as f:
|
||||
f.write("\n".join(report_lines) + "\n")
|
||||
except Exception:
|
||||
# 静默失败,避免影响主流程
|
||||
pass
|
||||
@@ -0,0 +1,689 @@
|
||||
"""
|
||||
用于实体去重,基于LLM的决策
|
||||
提供“LLM判定逻辑”的核心实现与并发控制。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import difflib
|
||||
from typing import List, Tuple, Dict
|
||||
import anyio
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge
|
||||
from app.core.memory.models.dedup_models import EntityDedupDecision, EntityDisambDecision
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_entity_dedup_prompt
|
||||
|
||||
|
||||
# --- 类型同义归并与相似度 ---
|
||||
_TYPE_ALIASES_UPPER: Dict[str, set[str]] = {
|
||||
# 设备/器材类近义:统一到 EQUIPMENT
|
||||
"EQUIPMENT": {s.upper() for s in {"设备", "器材", "摄影器材", "装备", "工具", "APPLIANCE", "TOOL"}},
|
||||
# 活动/技能近义:统一到 ACTIVITY,放宽“技术活动/技能”的同类判断
|
||||
"ACTIVITY": {s.upper() for s in {"活动", "技术活动", "技能", "ACTIVITY", "SKILL"}},
|
||||
# 常见类别,按需扩展
|
||||
"PERSON": {s.upper() for s in {"人物", "人", "个人", "人名", "PERSON"}},
|
||||
"LOCATION": {s.upper() for s in {"地点", "位置", "LOCATION", "城市", "CITY", "国家", "COUNTRY"}},
|
||||
"SOFTWARE": {s.upper() for s in {"软件", "SOFTWARE"}},
|
||||
"EVENT": {s.upper() for s in {"事件", "EVENT"}},
|
||||
}
|
||||
|
||||
def _canonicalize_type(t: str | None) -> str:
|
||||
u = (str(t or "").strip().upper())
|
||||
if not u or u == "UNKNOWN":
|
||||
return "UNKNOWN"
|
||||
for canon, aliases in _TYPE_ALIASES_UPPER.items():
|
||||
if u in aliases:
|
||||
return canon
|
||||
return u # 未知类型直接返回自身(保守兼容)
|
||||
|
||||
def _type_similarity(t1: str | None, t2: str | None) -> float:
|
||||
c1, c2 = _canonicalize_type(t1), _canonicalize_type(t2)
|
||||
if c1 == c2:
|
||||
return 1.0
|
||||
if c1 == "UNKNOWN" or c2 == "UNKNOWN":
|
||||
return 0.6 # 任一未知,给中等相似度,允许模型结合描述判断
|
||||
return 0.0
|
||||
|
||||
def _simple_type_ok(t1: str | None, t2: str | None) -> bool:
|
||||
"""类型门控:
|
||||
- 允许同类(含近义归并后同类)或任一 UNKNOWN/空;
|
||||
- 其余不同类不放行(例如 PERSON vs EQUIPMENT)。
|
||||
"""
|
||||
c1, c2 = _canonicalize_type(t1), _canonicalize_type(t2)
|
||||
if c1 == "UNKNOWN" or c2 == "UNKNOWN":
|
||||
return True
|
||||
return c1 == c2
|
||||
|
||||
|
||||
def _name_embed_sim(a: List[float] | None, b: List[float] | None) -> float: # 计算实体名称嵌入向量的余弦相似度
|
||||
a = a or []
|
||||
b = b or []
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
try:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = (sum(x * x for x in a)) ** 0.5
|
||||
nb = (sum(y * y for y in b)) ** 0.5
|
||||
if na > 0 and nb > 0:
|
||||
return dot / (na * nb)
|
||||
except Exception:
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
|
||||
def _name_text_sim(name1: str, name2: str) -> float: # 计算实体名称文本的字符串相似度
|
||||
name1 = (name1 or "").strip().lower()
|
||||
name2 = (name2 or "").strip().lower()
|
||||
if not name1 or not name2:
|
||||
return 0.0
|
||||
return difflib.SequenceMatcher(None, name1, name2).ratio()
|
||||
|
||||
|
||||
def _co_occurrence(statement_edges: List[StatementEntityEdge], a_id: str, b_id: str) -> bool: # 判断两个实体是否在同一陈述中 “同现”
|
||||
try:
|
||||
sources_a = {e.source for e in statement_edges if getattr(e, "target", None) == a_id}
|
||||
sources_b = {e.source for e in statement_edges if getattr(e, "target", None) == b_id}
|
||||
return bool(sources_a & sources_b)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _relation_statements(entity_edges: List[EntityEntityEdge], a_id: str, b_id: str) -> List[str]: # 提取两个实体间的所有关联语句
|
||||
stmts: List[str] = []
|
||||
for e in entity_edges:
|
||||
if (getattr(e, "source", None) == a_id and getattr(e, "target", None) == b_id) or (
|
||||
getattr(e, "source", None) == b_id and getattr(e, "target", None) == a_id
|
||||
):
|
||||
s_text = getattr(e, "statement", None) or ""
|
||||
r_type = getattr(e, "relation_type", None) or ""
|
||||
if s_text or r_type:
|
||||
stmts.append(f"{r_type}: {s_text}".strip(': '))
|
||||
return stmts
|
||||
|
||||
|
||||
def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: # 选择 “规范实体”(合并时保留的实体)
|
||||
# 0 for a, 1 for b
|
||||
# 1. 第一优先级:按“连接强度”排序(连接强度越高,实体越可靠)
|
||||
cs_a = (getattr(a, "connect_strength", "") or "").lower()
|
||||
cs_b = (getattr(b, "connect_strength", "") or "").lower()
|
||||
prio = {"strong": 3, "both": 3, "weak": 1, "": 0}
|
||||
if prio.get(cs_a, 0) != prio.get(cs_b, 0):
|
||||
return 0 if prio.get(cs_a, 0) > prio.get(cs_b, 0) else 1
|
||||
# pick longer description/fact_summary
|
||||
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
|
||||
desc_a = (getattr(a, "description", "") or "")
|
||||
desc_b = (getattr(b, "description", "") or "")
|
||||
fact_a = (getattr(a, "fact_summary", "") or "")
|
||||
fact_b = (getattr(b, "fact_summary", "") or "")
|
||||
score_a = len(desc_a) + len(fact_a)
|
||||
score_b = len(desc_b) + len(fact_b)
|
||||
if score_a != score_b:
|
||||
return 0 if score_a >= score_b else 1
|
||||
return 0
|
||||
|
||||
# _judge_pair(单对实体的 LLM 判断) 已经有分块迭代的函数内容是否还需要单对LLM判断--这是已经创建的工具服务于分块迭代的函数
|
||||
async def _judge_pair(
|
||||
llm_client: OpenAIClient,
|
||||
a: ExtractedEntityNode,
|
||||
b: ExtractedEntityNode,
|
||||
statement_edges: List[StatementEntityEdge],
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
) -> Tuple[EntityDedupDecision, Dict]:
|
||||
# 1. 计算实体名称的核心相似度指标
|
||||
name_text_sim = _name_text_sim(getattr(a, "name", ""), getattr(b, "name", ""))
|
||||
name_embed_sim = _name_embed_sim(getattr(a, "name_embedding", []), getattr(b, "name_embedding", []))
|
||||
# 2. 判断名称是否存在“包含关系”(如“苹果公司”包含“苹果”)
|
||||
name_contains = False
|
||||
try:
|
||||
n1 = (getattr(a, "name", "") or "").strip().lower()
|
||||
n2 = (getattr(b, "name", "") or "").strip().lower()
|
||||
name_contains = bool(n1 and n2 and (n1 in n2 or n2 in n1))
|
||||
except Exception:
|
||||
pass
|
||||
# 3. 构建LLM判断的“上下文信息”(规则层计算的所有特征) 判断上下文特征有助于实体消歧首先判断的类型关系
|
||||
ctx = {
|
||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"name_text_sim": name_text_sim,
|
||||
"name_embed_sim": name_embed_sim,
|
||||
"name_contains": name_contains,
|
||||
"co_occurrence": _co_occurrence(statement_edges, getattr(a, "id", None), getattr(b, "id", None)),
|
||||
"relation_statements": _relation_statements(entity_edges, getattr(a, "id", None), getattr(b, "id", None)),
|
||||
}
|
||||
|
||||
entity_a = {
|
||||
"name": getattr(a, "name", None),
|
||||
"entity_type": getattr(a, "entity_type", None),
|
||||
"description": getattr(a, "description", None),
|
||||
"aliases": getattr(a, "aliases", None) or [],
|
||||
"fact_summary": getattr(a, "fact_summary", None),
|
||||
"connect_strength": getattr(a, "connect_strength", None),
|
||||
}
|
||||
entity_b = {
|
||||
"name": getattr(b, "name", None),
|
||||
"entity_type": getattr(b, "entity_type", None),
|
||||
"description": getattr(b, "description", None),
|
||||
"aliases": getattr(b, "aliases", None) or [],
|
||||
"fact_summary": getattr(b, "fact_summary", None),
|
||||
"connect_strength": getattr(b, "connect_strength", None),
|
||||
}
|
||||
# 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式)
|
||||
prompt = render_entity_dedup_prompt(
|
||||
entity_a=entity_a,
|
||||
entity_b=entity_b,
|
||||
context=ctx,
|
||||
json_schema=EntityDedupDecision.model_json_schema(),
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You judge whether two entities are the same. Return valid JSON only."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
decision = await llm_client.response_structured(messages, EntityDedupDecision)
|
||||
return decision, ctx
|
||||
|
||||
# 消歧场景(同名不同类型)下的LLM判断
|
||||
async def _judge_pair_disamb(
|
||||
llm_client: OpenAIClient,
|
||||
a: ExtractedEntityNode,
|
||||
b: ExtractedEntityNode,
|
||||
statement_edges: List[StatementEntityEdge],
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
) -> Tuple[EntityDisambDecision, Dict]:
|
||||
name_text_sim = _name_text_sim(getattr(a, "name", ""), getattr(b, "name", ""))
|
||||
name_embed_sim = _name_embed_sim(getattr(a, "name_embedding", []), getattr(b, "name_embedding", []))
|
||||
name_contains = False
|
||||
try:
|
||||
n1 = (getattr(a, "name", "") or "").strip().lower()
|
||||
n2 = (getattr(b, "name", "") or "").strip().lower()
|
||||
name_contains = bool(n1 and n2 and (n1 in n2 or n2 in n1))
|
||||
except Exception:
|
||||
pass
|
||||
ctx = {
|
||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"name_text_sim": name_text_sim,
|
||||
"name_embed_sim": name_embed_sim,
|
||||
"name_contains": name_contains,
|
||||
"co_occurrence": _co_occurrence(statement_edges, getattr(a, "id", None), getattr(b, "id", None)),
|
||||
"relation_statements": _relation_statements(entity_edges, getattr(a, "id", None), getattr(b, "id", None)),
|
||||
}
|
||||
entity_a = {
|
||||
"name": getattr(a, "name", None),
|
||||
"entity_type": getattr(a, "entity_type", None),
|
||||
"description": getattr(a, "description", None),
|
||||
"aliases": getattr(a, "aliases", None) or [],
|
||||
"fact_summary": getattr(a, "fact_summary", None),
|
||||
"connect_strength": getattr(a, "connect_strength", None),
|
||||
}
|
||||
entity_b = {
|
||||
"name": getattr(b, "name", None),
|
||||
"entity_type": getattr(b, "entity_type", None),
|
||||
"description": getattr(b, "description", None),
|
||||
"aliases": getattr(b, "aliases", None) or [],
|
||||
"fact_summary": getattr(b, "fact_summary", None),
|
||||
"connect_strength": getattr(b, "connect_strength", None),
|
||||
}
|
||||
prompt = render_entity_dedup_prompt(
|
||||
entity_a=entity_a,
|
||||
entity_b=entity_b,
|
||||
context=ctx,
|
||||
json_schema=EntityDisambDecision.model_json_schema(),
|
||||
disambiguation_mode=True,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": "You disambiguate same-name different-type entities. Return valid JSON only."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
decision = await llm_client.response_structured(messages, EntityDisambDecision)
|
||||
return decision, ctx
|
||||
|
||||
# llm_dedup_entities(单轮实体去重)
|
||||
async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了保证高精度、可审计、可复用和行为一致性
|
||||
# 对偶判断让每次决策只聚焦于一对实体,信息维度清晰,噪声更低,模型更容易给出稳定的“是否同一实体”与“规范方”选择。
|
||||
# 考虑是否将其保留
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
llm_client: OpenAIClient,
|
||||
max_concurrency: int = 4,
|
||||
auto_merge_threshold: float = 0.90,
|
||||
co_ctx_threshold: float = 0.83,
|
||||
) -> Tuple[Dict[str, str], List[str]]:
|
||||
"""
|
||||
Use LLM to assist fuzzy deduplication among candidate entity pairs and
|
||||
produce an `id_redirect` mapping plus audit log records.
|
||||
|
||||
Parameters:
|
||||
- entity_nodes: deduplication input entities
|
||||
- statement_entity_edges: edges from statements to entities (for co-occurrence context)
|
||||
- entity_entity_edges: relational edges between entities (for relation statements)
|
||||
- llm_client: configured async client used to call the model
|
||||
- max_concurrency: semaphore limit for concurrent LLM calls (default 4)
|
||||
- auto_merge_threshold: confidence threshold to auto-merge without co-occurrence (default 0.90)
|
||||
- co_ctx_threshold: slightly lower threshold when co-occurrence is detected (default 0.83)
|
||||
|
||||
Returns:
|
||||
- id_redirect_updates: dict of losing_id -> canonical_id decided by LLM
|
||||
- records: textual logs for decisions, errors, and non-merges
|
||||
|
||||
Notes:
|
||||
- Candidate generation uses simple gates: same group, type compatible, and
|
||||
name similarity or containment, optionally lowered threshold with co-occurrence.
|
||||
- The higher-level pipeline should call this async function upstream, then
|
||||
pass the resulting mapping and records into `deduplicate_entities_and_edges`
|
||||
via `llm_redirect` and `llm_records` to apply merges synchronously before
|
||||
edge redirection.
|
||||
"""
|
||||
# 1. 构建“候选实体对”(用规则层筛选,减少LLM调用量,提高效率)
|
||||
# Build candidate pairs: simple gates
|
||||
candidates: List[Tuple[int, int]] = []
|
||||
for i in range(len(entity_nodes)):
|
||||
a = entity_nodes[i]
|
||||
for j in range(i + 1, len(entity_nodes)):
|
||||
b = entity_nodes[j]
|
||||
# 规则1:必须属于同一组(group_id相同,不同组的实体不重复)
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
continue
|
||||
# 规则2:类型必须兼容(调用_simple_type_ok判断)
|
||||
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
|
||||
continue
|
||||
# 规则3:名称相似度达标(文本/嵌入相似度取最大值)
|
||||
txt_sim = _name_text_sim(getattr(a, "name", ""), getattr(b, "name", ""))
|
||||
emb_sim = _name_embed_sim(getattr(a, "name_embedding", []), getattr(b, "name_embedding", []))
|
||||
# 规则4:名称是否包含(如“苹果公司”和“苹果”)
|
||||
contains = False
|
||||
try:
|
||||
n1 = (getattr(a, "name", "") or "").strip().lower()
|
||||
n2 = (getattr(b, "name", "") or "").strip().lower()
|
||||
contains = bool(n1 and n2 and (n1 in n2 or n2 in n1))
|
||||
except Exception:
|
||||
pass
|
||||
# 规则5:是否同现(同现的实体更可能重复,降低相似度阈值)
|
||||
co_ctx = _co_occurrence(statement_entity_edges, getattr(a, "id", None), getattr(b, "id", None))
|
||||
sim = max(txt_sim, emb_sim)
|
||||
# 候选对筛选条件:满足任一即加入(减少漏判)
|
||||
if (sim >= 0.80) or (co_ctx and sim >= 0.75) or contains:
|
||||
candidates.append((i, j))
|
||||
|
||||
# Use anyio for cross-compatibility with asyncio and trio
|
||||
results = []
|
||||
async with anyio.create_task_group() as tg:
|
||||
result_list = [None] * len(candidates)
|
||||
|
||||
async def _wrapped(idx: int, i: int, j: int):
|
||||
try:
|
||||
result_list[idx] = await _judge_pair(llm_client, entity_nodes[i], entity_nodes[j], statement_entity_edges, entity_entity_edges)
|
||||
except Exception as e:
|
||||
result_list[idx] = e
|
||||
|
||||
# Limit concurrency using semaphore
|
||||
sem = anyio.Semaphore(max_concurrency)
|
||||
|
||||
async def _limited_wrapped(idx: int, i: int, j: int):
|
||||
async with sem:
|
||||
await _wrapped(idx, i, j)
|
||||
|
||||
for idx, (i, j) in enumerate(candidates):
|
||||
tg.start_soon(_limited_wrapped, idx, i, j)
|
||||
|
||||
results = result_list
|
||||
|
||||
id_redirect_updates: Dict[str, str] = {}
|
||||
records: List[str] = []
|
||||
for idx, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
i, j = candidates[idx]
|
||||
a = entity_nodes[i]
|
||||
b = entity_nodes[j]
|
||||
records.append(f"[LLM异常] pair ({a.id},{b.id}) -> {res}")
|
||||
continue
|
||||
decision, ctx = res
|
||||
i, j = candidates[idx]
|
||||
a = entity_nodes[i]
|
||||
b = entity_nodes[j]
|
||||
th = auto_merge_threshold if not ctx.get("co_occurrence") else co_ctx_threshold
|
||||
if decision.same_entity and decision.confidence >= th:
|
||||
canon_idx = decision.canonical_idx if decision.canonical_idx in (0, 1) else _choose_canonical(a, b)
|
||||
canon = a if canon_idx == 0 else b
|
||||
other = b if canon_idx == 0 else a
|
||||
id_redirect_updates[getattr(other, "id")] = getattr(canon, "id")
|
||||
records.append(
|
||||
f"[LLM合并] 规范实体 {canon.id} 名称 '{getattr(canon, 'name', '')}' <- 合并实体 {other.id} 名称 '{getattr(other, 'name', '')}' | conf={decision.confidence:.3f}, th={th:.3f}, co_ctx={ctx.get('co_occurrence')}"
|
||||
)
|
||||
# 若类型相同且名称高度相似/包含关系,补充“同类名称相似”记录,格式与报告要求一致(名称后带类型)
|
||||
try:
|
||||
type_same = (getattr(a, "entity_type", None) == getattr(b, "entity_type", None)) and getattr(a, "entity_type", None) is not None
|
||||
name_sim = max(float(ctx.get("name_text_sim", 0.0)), float(ctx.get("name_embed_sim", 0.0)))
|
||||
name_contains = bool(ctx.get("name_contains", False))
|
||||
if type_same and (name_sim >= 0.80 or name_contains):
|
||||
name_a = (getattr(a, "name", "") or "").strip()
|
||||
name_b = (getattr(b, "name", "") or "").strip()
|
||||
type_a = getattr(a, "entity_type", "")
|
||||
type_b = getattr(b, "entity_type", "")
|
||||
records.append(
|
||||
f"[LLM去重] 同类名称相似 {name_a}({type_a})|{name_b}({type_b}) | conf={decision.confidence:.2f} | reason={decision.reason}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
records.append(
|
||||
f"[LLM不合并] A={a.id} B={b.id} | same={decision.same_entity} conf={decision.confidence:.3f} co_ctx={ctx.get('co_occurrence')}"
|
||||
)
|
||||
|
||||
return id_redirect_updates, records
|
||||
|
||||
# 迭代分块去重,这才是重点
|
||||
async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
entity_nodes: List[ExtractedEntityNode], # 待去重实体列表(需先经过精确去重),LLM决策属于模糊匹配下
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
llm_client: OpenAIClient,
|
||||
block_size: int = 50,
|
||||
block_concurrency: int = 4,
|
||||
pair_concurrency: int = 4,
|
||||
max_rounds: int = 3,
|
||||
auto_merge_threshold: float = 0.90,
|
||||
co_ctx_threshold: float = 0.83,
|
||||
shuffle_each_round: bool = True, # 每轮是否打乱实体顺序(避免同一块内实体重复,提高覆盖度)
|
||||
) -> Tuple[Dict[str, str], List[str]]: # 返回:全局ID映射、全局审计日志
|
||||
"""
|
||||
Iteratively deduplicate entities using LLM in block-wise concurrent rounds.
|
||||
|
||||
Process:
|
||||
- Partition the input entities (post exact + local fuzzy stage) into blocks per round.
|
||||
- Run LLM pairwise decisions concurrently *within each block*, and also run multiple blocks concurrently.
|
||||
- Apply merges from all blocks, collapse to canonical set, re-partition, and repeat until no new merges or max_rounds reached.
|
||||
|
||||
Parameters:
|
||||
- entity_nodes: entities to deduplicate (should already be exact/fuzzy merged candidates)
|
||||
- statement_entity_edges: statement→entity edges for co-occurrence context
|
||||
- entity_entity_edges: entity↔entity relational edges for relation statements context
|
||||
- llm_client: initialized async client
|
||||
- block_size: target number of entities per block (default 50)
|
||||
- block_concurrency: how many blocks to process concurrently (default 4)
|
||||
- pair_concurrency: concurrency for pairwise LLM calls inside each block (default 4)
|
||||
- max_rounds: upper bound for iterative passes (default 3)
|
||||
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
|
||||
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
|
||||
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition
|
||||
|
||||
Returns:
|
||||
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
|
||||
- records: textual logs including per-round/per-block summaries and per-pair decisions
|
||||
"""
|
||||
import asyncio
|
||||
import random
|
||||
# 初始化全局日志和全局ID映射(存储所有轮次的结果)
|
||||
records: List[str] = []
|
||||
global_redirect: Dict[str, str] = {}
|
||||
|
||||
# Helper: resolve final canonical id following redirect chain
|
||||
# 辅助函数1:_resolve:递归解析实体的“最终规范ID”(处理ID映射链,如a→b→c,返回c)
|
||||
def _resolve(id_: str) -> str:
|
||||
while id_ in global_redirect and global_redirect[id_] != id_: # 若ID在映射中且未指向自身
|
||||
id_ = global_redirect[id_] # 递归替换为映射的ID
|
||||
return id_ # 返回最终规范ID
|
||||
## 这里辅助函数没有看懂
|
||||
|
||||
# Helper: collapse nodes to canonical representatives per current global_redirect
|
||||
# 辅助函数2:_collapse_nodes:根据全局ID映射,“折叠”实体列表(保留每个规范ID对应的实体)
|
||||
def _collapse_nodes(nodes: List[ExtractedEntityNode]) -> List[ExtractedEntityNode]:
|
||||
by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in nodes} # 实体ID→实体的映射
|
||||
keep: Dict[str, ExtractedEntityNode] = {} # 存储需保留的规范实体
|
||||
for e in nodes:
|
||||
cid = _resolve(e.id) # 解析e的最终规范ID
|
||||
# 优先保留by_id中已存在的规范实体(若有),否则保留第一个遇到的实体
|
||||
if cid in by_id:
|
||||
keep[cid] = by_id[cid]
|
||||
else:
|
||||
keep[cid] = keep.get(cid, e)
|
||||
return list(keep.values())
|
||||
|
||||
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
|
||||
"""
|
||||
按 group_id 分块,避免跨组实体在同一块,减少无效候选对
|
||||
|
||||
Args:
|
||||
nodes: 实体节点列表
|
||||
|
||||
Returns:
|
||||
分块后的实体列表
|
||||
"""
|
||||
groups: Dict[str, List[ExtractedEntityNode]] = {}
|
||||
for e in nodes:
|
||||
gid = getattr(e, "group_id", None)
|
||||
groups.setdefault(str(gid), []).append(e)
|
||||
blocks: List[List[ExtractedEntityNode]] = []
|
||||
for gid, arr in groups.items():
|
||||
if shuffle_each_round:
|
||||
random.shuffle(arr)
|
||||
# chunk into block_size
|
||||
for i in range(0, len(arr), max(1, block_size)):
|
||||
blocks.append(arr[i:i + max(1, block_size)])
|
||||
return blocks
|
||||
|
||||
# Semaphore for block-level concurrency
|
||||
# 初始化块级并发信号量(控制同时处理的块数量)
|
||||
block_sem = asyncio.Semaphore(max(1, block_concurrency))
|
||||
|
||||
# 辅助函数4:_run_one_block:处理单个块的去重(调用llm_dedup_entities)
|
||||
async def _run_one_block(block_idx: int, block_nodes: List[ExtractedEntityNode]):
|
||||
async with block_sem:
|
||||
# Delegate to existing pairwise function with limited concurrency per block
|
||||
id_map, recs = await llm_dedup_entities(
|
||||
entity_nodes=block_nodes,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
entity_entity_edges=entity_entity_edges,
|
||||
llm_client=llm_client,
|
||||
max_concurrency=pair_concurrency,
|
||||
auto_merge_threshold=auto_merge_threshold,
|
||||
co_ctx_threshold=co_ctx_threshold,
|
||||
)
|
||||
# Prefix block index in records for readability
|
||||
prefixed = [f"[LLM块{block_idx}] {line}" for line in recs]
|
||||
return id_map, prefixed
|
||||
|
||||
# Iterative rounds
|
||||
# 核心:迭代分块去重(多轮处理)
|
||||
current_nodes: List[ExtractedEntityNode] = list(entity_nodes)
|
||||
round_idx = 1
|
||||
while round_idx <= max(1, max_rounds):
|
||||
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
|
||||
# 步骤1:折叠实体(合并已确定的重复实体,减少后续计算量)
|
||||
current_nodes = _collapse_nodes(current_nodes)
|
||||
# 步骤2:分块(按group_id分块,避免跨组处理)
|
||||
blocks = _partition_blocks(current_nodes)
|
||||
if not blocks: # 无块可处理(实体已全部折叠),退出循环
|
||||
break
|
||||
# 步骤3:记录当前轮次的基本信息(轮次、块数、块大小)
|
||||
records.append(f"[LLM批次] 轮次 {round_idx} 预计处理块数 {len(blocks)} 每块大小≈{block_size}")
|
||||
|
||||
# Run all blocks concurrently with block-level semaphore
|
||||
# 步骤4:并发处理所有块(创建块处理任务,批量执行)
|
||||
results = [None] * len(blocks)
|
||||
async with anyio.create_task_group() as tg:
|
||||
async def _run_block_wrapper(idx: int, block: List[ExtractedEntityNode]):
|
||||
try:
|
||||
results[idx] = await _run_one_block(idx, block)
|
||||
except Exception as e:
|
||||
results[idx] = e
|
||||
|
||||
for i in range(len(blocks)):
|
||||
tg.start_soon(_run_block_wrapper, i, blocks[i])
|
||||
|
||||
# Collect and normalize redirects from blocks
|
||||
# 步骤5:合并块结果到全局映射和日志
|
||||
merged_this_round = 0
|
||||
for bi, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
records.append(f"[LLM块异常] 轮次 {round_idx} 块 {bi} -> {res}")
|
||||
continue
|
||||
id_map, recs = res
|
||||
records.extend(recs)
|
||||
# Normalize with current global redirects
|
||||
for losing, canon in id_map.items():
|
||||
losing_final = _resolve(losing)
|
||||
canon_final = _resolve(canon)
|
||||
if losing_final == canon_final:
|
||||
continue
|
||||
# Apply mapping and ensure chain consistency
|
||||
global_redirect[losing_final] = canon_final
|
||||
merged_this_round += 1
|
||||
records.append(f"[LLM批次] 轮次 {round_idx} 块数 {len(blocks)} 新合并 {merged_this_round}")
|
||||
|
||||
if merged_this_round == 0:
|
||||
break
|
||||
|
||||
# Prepare nodes for next round: collapse canonical set
|
||||
current_nodes = _collapse_nodes(current_nodes)
|
||||
round_idx += 1
|
||||
|
||||
return global_redirect, records
|
||||
|
||||
|
||||
# LLM 消歧:同名不同类型的实体对,输出合并建议与阻断对列表
|
||||
async def llm_disambiguate_pairs_iterative(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
llm_client: OpenAIClient,
|
||||
max_concurrency: int = 4,
|
||||
merge_conf_threshold: float = 0.88,
|
||||
block_conf_threshold: float = 0.60,
|
||||
) -> Tuple[Dict[str, str], List[Tuple[str, str]], List[str]]:
|
||||
"""
|
||||
Disambiguate same-name different-type pairs using LLM.
|
||||
|
||||
Returns:
|
||||
- merge_redirect: dict losing_id -> canonical_id for merges decided by LLM
|
||||
- block_pairs: list of sorted (id1, id2) pairs to block from fuzzy/heuristic merges
|
||||
- records: textual logs for audit
|
||||
"""
|
||||
records: List[str] = []
|
||||
merge_redirect: Dict[str, str] = {}
|
||||
block_pairs: List[Tuple[str, str]] = []
|
||||
|
||||
def _is_typed(t: str) -> bool:
|
||||
t = (t or "").strip().upper()
|
||||
return bool(t) and t not in {"UNKNOWN", "UNDEFINED", ""}
|
||||
|
||||
candidates: List[Tuple[int, int]] = []
|
||||
n = len(entity_nodes)
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
a = entity_nodes[i]
|
||||
b = entity_nodes[j]
|
||||
# 必须同组
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
continue
|
||||
ta = getattr(a, "entity_type", None)
|
||||
tb = getattr(b, "entity_type", None)
|
||||
# 必须不同类型且两者均为已定义类型
|
||||
if ta == tb:
|
||||
continue
|
||||
if not (_is_typed(ta) and _is_typed(tb)):
|
||||
continue
|
||||
# 严格“同名不同义”:名称需严格相同(大小写与首尾空格忽略)
|
||||
try:
|
||||
na = (getattr(a, "name", "") or "").strip().lower()
|
||||
nb = (getattr(b, "name", "") or "").strip().lower()
|
||||
except Exception:
|
||||
na, nb = "", ""
|
||||
if not na or not nb:
|
||||
continue
|
||||
if na == nb:
|
||||
candidates.append((i, j))
|
||||
|
||||
if not candidates:
|
||||
return merge_redirect, block_pairs, records
|
||||
|
||||
# Use anyio for cross-compatibility with asyncio and trio
|
||||
judged = [None] * len(candidates)
|
||||
async with anyio.create_task_group() as tg:
|
||||
async def _wrapped(idx: int, i: int, j: int):
|
||||
try:
|
||||
judged[idx] = await _judge_pair_disamb(llm_client, entity_nodes[i], entity_nodes[j], statement_entity_edges, entity_entity_edges)
|
||||
except Exception as e:
|
||||
judged[idx] = e
|
||||
|
||||
# Limit concurrency using semaphore
|
||||
sem = anyio.Semaphore(max_concurrency)
|
||||
|
||||
async def _limited_wrapped(idx: int, i: int, j: int):
|
||||
async with sem:
|
||||
await _wrapped(idx, i, j)
|
||||
|
||||
for idx, (i, j) in enumerate(candidates):
|
||||
tg.start_soon(_limited_wrapped, idx, i, j)
|
||||
for k, res in enumerate(judged):
|
||||
i, j = candidates[k]
|
||||
a = entity_nodes[i]
|
||||
b = entity_nodes[j]
|
||||
a_id = getattr(a, "id", None) or ""
|
||||
b_id = getattr(b, "id", None) or ""
|
||||
if isinstance(res, Exception):
|
||||
records.append(f"[DISAMB错误] 对({a_id},{b_id})调用失败: {res}")
|
||||
block_pairs.append(tuple(sorted((a_id, b_id))))
|
||||
continue
|
||||
decision, ctx = res
|
||||
try:
|
||||
if decision.should_merge and decision.confidence >= merge_conf_threshold:
|
||||
can_idx = 0 if decision.canonical_idx == 0 else 1
|
||||
canonical = a if can_idx == 0 else b
|
||||
losing = b if can_idx == 0 else a
|
||||
merge_redirect[getattr(losing, "id", "")] = getattr(canonical, "id", "")
|
||||
records.append(
|
||||
f"[DISAMB合并] {getattr(losing,'id','')} -> {getattr(canonical,'id','')} | conf={decision.confidence:.2f} | reason={decision.reason} | suggested_type={decision.suggested_type or ''}"
|
||||
)
|
||||
# 追加 LLM 决策去重记录,以便下方报告展示到“LLM 决策去重”区块
|
||||
records.append(
|
||||
f"[LLM去重] 同名类型相似 {getattr(a,'name','')}({getattr(a,'entity_type','')})|{getattr(b,'name','')}({getattr(b,'entity_type','')}) | conf={decision.confidence:.2f} | reason={decision.reason}"
|
||||
)
|
||||
else:
|
||||
# Fallback:同名且类型不同,但语义高度相似且未要求阻断,按“同名类型相似”进行合并
|
||||
name_a = (getattr(a, "name", "") or "").strip().lower()
|
||||
name_b = (getattr(b, "name", "") or "").strip().lower()
|
||||
def _strength_rank(x: str) -> int:
|
||||
s = (x or "").strip().lower()
|
||||
return {"strong": 3, "both": 2, "weak": 1}.get(s, 0)
|
||||
if (
|
||||
name_a and name_b and name_a == name_b
|
||||
and (not decision.block_pair)
|
||||
and decision.confidence >= max(0.80, block_conf_threshold)
|
||||
):
|
||||
# 选择规范实体:优先使用 canonical_idx;否则根据连接强度挑选更强者
|
||||
if decision.canonical_idx in (0, 1):
|
||||
canonical = a if decision.canonical_idx == 0 else b
|
||||
losing = b if decision.canonical_idx == 0 else a
|
||||
else:
|
||||
sa = _strength_rank(getattr(a, "connect_strength", None))
|
||||
sb = _strength_rank(getattr(b, "connect_strength", None))
|
||||
canonical = a if sa >= sb else b
|
||||
losing = b if sa >= sb else a
|
||||
merge_redirect[getattr(losing, "id", "")] = getattr(canonical, "id", "")
|
||||
# 消歧合并审计
|
||||
records.append(
|
||||
f"[DISAMB合并] {getattr(losing,'id','')} -> {getattr(canonical,'id','')} | conf={decision.confidence:.2f} | reason={decision.reason} | suggested_type={decision.suggested_type or ''}"
|
||||
)
|
||||
# 追加 LLM 决策去重记录(同名类型相似)
|
||||
records.append(
|
||||
f"[LLM去重] 同名类型相似 {getattr(a,'name','')}({getattr(a,'entity_type','')})|{getattr(b,'name','')}({getattr(b,'entity_type','')}) | conf={decision.confidence:.2f} | reason={decision.reason}"
|
||||
)
|
||||
else:
|
||||
if decision.block_pair or decision.confidence >= block_conf_threshold:
|
||||
block_pairs.append(tuple(sorted((a_id, b_id))))
|
||||
# 仅保留阻断条目在预筛选报告,包含实体名称与类型,便于人读
|
||||
records.append(
|
||||
f"[DISAMB阻断] {getattr(a,'name','')}({getattr(a,'entity_type','')})|{getattr(b,'name','')}({getattr(b,'entity_type','')}) | conf={decision.confidence:.2f} | reason={decision.reason} || block_pair={decision.block_pair}"
|
||||
)
|
||||
except Exception:
|
||||
block_pairs.append(tuple(sorted((a_id, b_id))))
|
||||
# 异常情况也以阻断形式记录,包含名称便于定位
|
||||
records.append(
|
||||
f"[DISAMB异常阻断] {getattr(a,'name','')}({getattr(a,'entity_type','')})|{getattr(b,'name','')}({getattr(b,'entity_type','')}) | ids=({a_id},{b_id})"
|
||||
)
|
||||
|
||||
return merge_redirect, block_pairs, records
|
||||
@@ -0,0 +1,149 @@
|
||||
# 导入 Python 的annotations特性,允许在类型注解中使用尚未定义的类(支持 “向前引用”),提升代码中类型注解的灵活性。
|
||||
# 这是什么意思? 该类的属性的类型是这个类本身(递归定义)?
|
||||
"""
|
||||
这段代码是 “第二层去重消歧” 的核心实现,逻辑可分为四步:
|
||||
1.从第一层去重消歧后的实体中提取核心信息,作为索引查询 Neo4j 中同组的候选实体;
|
||||
2.对候选实体去重并转换为统一模型;
|
||||
3.构建预重定向关系(第一层实体 ID→数据库实体 ID),确保优先使用数据库 ID;
|
||||
4.合并数据库候选实体与第一层实体,调用去重函数完成最终融合,返回结果。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互
|
||||
from app.repositories.neo4j.graph_search import get_dedup_candidates_for_entities # 导入ge函数,用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
|
||||
from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges, _write_dedup_fusion_report # 导入报告写入以在跳过时追加说明
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
|
||||
|
||||
def _parse_dt(val: Any) -> datetime: # 定义内部辅助函数_parse_dt,用于将任意类型的输入值解析为datetime对象(处理实体节点中的时间字段)
|
||||
if isinstance(val, datetime):
|
||||
return val
|
||||
if isinstance(val, str) and val:
|
||||
try:
|
||||
return datetime.fromisoformat(val) # 使用fromisoformat方法将 ISO 格式的字符串(如 "2023-10-01T12:00:00")解析为datetime对象
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback: now; upstream should provide real times
|
||||
return datetime.now()
|
||||
|
||||
|
||||
def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
||||
"""
|
||||
将 Neo4j 返回的数据库记录转换为 ExtractedEntityNode 模型对象
|
||||
|
||||
Args:
|
||||
row: Neo4j 查询返回的记录字典
|
||||
|
||||
Returns:
|
||||
ExtractedEntityNode: 实体节点对象
|
||||
|
||||
Note:
|
||||
从数据库中查询到的内容是 JSON 格式的字符串,需要先解析为 Python 对象
|
||||
"""
|
||||
return ExtractedEntityNode(
|
||||
id=row.get("id"),
|
||||
name=row.get("name") or "",
|
||||
group_id=row.get("group_id") or "",
|
||||
user_id=row.get("user_id") or "",
|
||||
apply_id=row.get("apply_id") or "",
|
||||
created_at=_parse_dt(row.get("created_at")),
|
||||
expired_at=_parse_dt(row.get("expired_at")) if row.get("expired_at") else None,
|
||||
entity_idx=int(row.get("entity_idx") or 0),
|
||||
statement_id=row.get("statement_id") or "",
|
||||
entity_type=row.get("entity_type") or "",
|
||||
description=row.get("description") or "",
|
||||
aliases=row.get("aliases") or [],
|
||||
name_embedding=row.get("name_embedding") or [],
|
||||
fact_summary=row.get("fact_summary") or "",
|
||||
connect_strength=row.get("connect_strength") or "",
|
||||
)
|
||||
|
||||
|
||||
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
|
||||
connector: Neo4jConnector,
|
||||
group_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重
|
||||
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
|
||||
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
|
||||
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
|
||||
dedup_config: DedupConfig | None = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||
"""
|
||||
第二层去重消歧:
|
||||
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体
|
||||
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
|
||||
- 返回融合后的实体与重定向后的边(边已指向规范 ID,优先 DB ID)
|
||||
"""
|
||||
if not entity_nodes:
|
||||
return entity_nodes, statement_entity_edges, entity_entity_edges
|
||||
|
||||
# 构造批量行并检索候选(精确/别名 + CONTAINS 召回)
|
||||
# 将第一层去重消歧的结果作为索引,批量查询DB候选实体
|
||||
incoming_rows: List[Dict[str, Any]] = [ # 定义 包含第一层实体的核心信息(用于数据库查询)
|
||||
{"id": e.id, "name": e.name, "entity_type": e.entity_type} for e in entity_nodes # 对entity_nodes中的每个实体e,提取id(实体 ID)、name(名称)、entity_type(类型),构造字典作为查询条件。
|
||||
|
||||
]
|
||||
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体,并将结果赋值给candidates_map(等待异步操作完成)。
|
||||
connector=connector, group_id=group_id,
|
||||
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
|
||||
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略(若精确匹配无结果,用包含关系召回候选),与src\database\cypher_queries.py的307产生联动
|
||||
)
|
||||
|
||||
# 拉平候选,转为模型(按 DB 节点优先)
|
||||
db_candidate_rows: List[Dict[str, Any]] = [] # 存储去重后的数据库候选实体记录(行)
|
||||
seen_db_ids: set[str] = set() # 集合,用于记录已处理的数据库实体 ID(避免重复添加同一实体)
|
||||
for _, rows in candidates_map.items():
|
||||
for r in rows:
|
||||
rid = r.get("id")
|
||||
if rid and rid not in seen_db_ids: # 如果rid存在且未被处理
|
||||
seen_db_ids.add(rid) # 将rid加入seen_db_ids,标记为已处理
|
||||
db_candidate_rows.append(r) # 将该记录r添加到db_candidate_rows(确保数据库实体唯一)
|
||||
|
||||
db_candidate_models: List[ExtractedEntityNode] = []
|
||||
for r in db_candidate_rows: # db_candidate_rows:去重后的数据库候选实体记录(行)
|
||||
try:
|
||||
m = _row_to_entity(r) # 调用_row_to_entity函数,将数据库记录r转换为实体模型m
|
||||
db_candidate_models.append(m) # m添加到db_candidate_models
|
||||
except Exception:
|
||||
# 忽略无法解析的记录
|
||||
pass
|
||||
|
||||
# 若 DB 候选为空:跳过二层融合,直接返回第一层结果,并在报告中标注候选数
|
||||
candidate_count = len(db_candidate_models)
|
||||
if candidate_count == 0:
|
||||
try:
|
||||
_write_dedup_fusion_report(
|
||||
exact_merge_map={},
|
||||
fuzzy_merge_records=[],
|
||||
local_llm_records=[],
|
||||
disamb_records=[],
|
||||
stage_label="第二层去重消歧",
|
||||
append=True,
|
||||
stage_notes=[f"候选数:{candidate_count}(DB 为空则标注跳过)"],
|
||||
)
|
||||
except Exception:
|
||||
# 报告写入失败不影响主流程
|
||||
pass
|
||||
return entity_nodes, statement_entity_edges, entity_entity_edges
|
||||
|
||||
# 联合集合(DB 在前,确保规范 ID 优先使用 DB ID)
|
||||
# 将从 DB 检索到的候选实体与第一层去重消歧的实体合并,作为输入继续调用去重方法。
|
||||
# 由于按顺序遍历,规范实体将优先选择位于前面的 DB 节点,因此无需显式预重定向。
|
||||
union_entities: List[ExtractedEntityNode] = db_candidate_models + list(entity_nodes)
|
||||
|
||||
# 融合(内部执行精确/模糊/LLM 决策;随后再做边重定向与去重)
|
||||
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges = await deduplicate_entities_and_edges(
|
||||
union_entities,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
report_stage="第二层去重消歧",
|
||||
report_append=True,
|
||||
dedup_config=dedup_config,
|
||||
)
|
||||
|
||||
return fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges
|
||||
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import second_layer_dedup_and_merge_with_neo4j
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.models.graph_models import (
|
||||
DialogueNode,
|
||||
ChunkNode,
|
||||
StatementNode,
|
||||
ExtractedEntityNode,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
)
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
|
||||
|
||||
async def dedup_layers_and_merge_and_return(
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: Optional[ExtractionPipelineConfig] = None,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
) -> Tuple[
|
||||
List[DialogueNode],
|
||||
List[ChunkNode],
|
||||
List[StatementNode],
|
||||
List[ExtractedEntityNode],
|
||||
List[StatementChunkEdge],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
]:
|
||||
"""
|
||||
执行两层实体去重与融合:
|
||||
- 第一层:精确/模糊/LLM 决策去重
|
||||
- 第二层:与 Neo4j 同组实体联合去重与融合(依赖传入的 connector)
|
||||
返回融合后的实体与边,同时保留原始的对话、片段与语句节点与边。
|
||||
"""
|
||||
|
||||
# 默认从 runtime.json 加载管线配置,避免回退到环境变量
|
||||
if pipeline_config is None:
|
||||
try:
|
||||
pipeline_config = get_pipeline_config()
|
||||
except Exception:
|
||||
pipeline_config = None
|
||||
|
||||
# 先探测 group_id,决定报告写入策略
|
||||
group_id: Optional[str] = None
|
||||
for dd in dialog_data_list:
|
||||
group_id = getattr(dd, "group_id", None)
|
||||
if group_id:
|
||||
break
|
||||
|
||||
# 第一层去重消歧
|
||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
|
||||
entity_nodes,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
report_stage="第一层去重消歧",
|
||||
report_append=False,
|
||||
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
|
||||
)
|
||||
|
||||
# 初始化第二层融合结果为第一层结果
|
||||
fused_entity_nodes = dedup_entity_nodes
|
||||
fused_statement_entity_edges = dedup_statement_entity_edges
|
||||
fused_entity_entity_edges = dedup_entity_entity_edges
|
||||
|
||||
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
|
||||
try:
|
||||
if group_id:
|
||||
if connector:
|
||||
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
|
||||
connector=connector,
|
||||
group_id=group_id,
|
||||
entity_nodes=dedup_entity_nodes,
|
||||
statement_entity_edges=dedup_statement_entity_edges,
|
||||
entity_entity_edges=dedup_entity_entity_edges,
|
||||
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
|
||||
)
|
||||
else:
|
||||
print("Skip second-layer dedup: missing connector")
|
||||
else:
|
||||
print("Skip second-layer dedup: missing group_id")
|
||||
except Exception as e:
|
||||
print(f"Second-layer dedup failed: {e}")
|
||||
|
||||
return (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
fused_entity_nodes,
|
||||
statement_chunk_edges,
|
||||
fused_statement_entity_edges,
|
||||
fused_entity_entity_edges,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
知识提取模块
|
||||
|
||||
包含以下提取器:
|
||||
- DialogueChunker: 对话分块
|
||||
- StatementExtractor: 陈述句提取
|
||||
- TripletExtractor: 三元组提取
|
||||
- TemporalExtractor: 时间信息提取
|
||||
- EmbeddingGenerator: 嵌入向量生成
|
||||
- MemorySummaryGenerator: 记忆摘要生成
|
||||
"""
|
||||
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.models.message_models import DialogData, Chunk
|
||||
from app.core.memory.models.config_models import ChunkerConfig
|
||||
from app.core.memory.llm_tools.chunker_client import ChunkerClient
|
||||
from app.core.memory.utils.config.config_utils import get_chunker_config
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
class DialogueChunker:
|
||||
"""A class that processes dialogues and fills them with chunks based on a specified strategy.
|
||||
|
||||
This class encapsulates the chunking process, allowing for easy configuration and application
|
||||
of different chunking strategies to dialogue data.
|
||||
"""
|
||||
|
||||
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None):
|
||||
"""Initialize the DialogueChunker with a specific chunking strategy.
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
|
||||
"""
|
||||
self.chunker_strategy = chunker_strategy
|
||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||
# 对于 LLMChunker,需要传入 llm_client
|
||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||
else:
|
||||
self.chunker_client = ChunkerClient(self.chunker_config)
|
||||
|
||||
async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]:
|
||||
"""Process a dialogue by generating chunks and adding them to the DialogData object.
|
||||
|
||||
Args:
|
||||
dialogue: The DialogData object to process
|
||||
|
||||
Returns:
|
||||
A list of Chunk objects
|
||||
"""
|
||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||
# Defensive fallback: ensure at least one chunk is returned for non-empty content
|
||||
try:
|
||||
chunks = result_dialogue.chunks
|
||||
except Exception:
|
||||
chunks = []
|
||||
|
||||
if not chunks or len(chunks) == 0:
|
||||
# If the dialogue has content, return a single fallback chunk built from messages
|
||||
content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "")
|
||||
if content_str and len(content_str.strip()) > 0:
|
||||
fallback_chunk = Chunk.from_messages(
|
||||
dialogue.context.msgs,
|
||||
metadata={
|
||||
"fallback": "single_chunk",
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
"source": "DialogueChunkerFallback",
|
||||
},
|
||||
)
|
||||
return [fallback_chunk]
|
||||
# No content: return empty list
|
||||
return []
|
||||
|
||||
return chunks
|
||||
|
||||
def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str:
|
||||
"""Save the chunking results to a file and return the output path.
|
||||
|
||||
Args:
|
||||
dialogue: The processed DialogData object with chunks
|
||||
output_path: Optional path to save the output (default: chunker_output_{strategy}.txt)
|
||||
|
||||
Returns:
|
||||
The path where the output was saved
|
||||
"""
|
||||
if not output_path:
|
||||
output_path = os.path.join(os.path.dirname(__file__), "..", "..",
|
||||
f"chunker_output_{self.chunker_strategy.lower()}.txt")
|
||||
|
||||
output_lines = []
|
||||
output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===")
|
||||
output_lines.append(f"Dialogue ID: {dialogue.ref_id}")
|
||||
output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages")
|
||||
output_lines.append(f"Total characters: {len(dialogue.content)}")
|
||||
|
||||
output_lines.append(f"Generated {len(dialogue.chunks)} chunks:")
|
||||
for i, chunk in enumerate(dialogue.chunks):
|
||||
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
|
||||
output_lines.append(f" Content preview: {chunk.content}...")
|
||||
if chunk.metadata:
|
||||
output_lines.append(f" Metadata: {chunk.metadata}")
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output_lines))
|
||||
|
||||
logger.info(f"Chunking results saved to: {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
嵌入向量生成器
|
||||
|
||||
为陈述句、分块、对话和实体生成嵌入向量,用于语义搜索。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
|
||||
|
||||
class EmbeddingGenerator:
|
||||
"""嵌入向量生成器"""
|
||||
|
||||
def __init__(self, embedding_id: str):
|
||||
"""初始化嵌入向量生成器
|
||||
|
||||
Args:
|
||||
embedding_id: 嵌入模型 ID
|
||||
"""
|
||||
embedder_config = get_embedder_config(embedding_id)
|
||||
self.embedder_client = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(embedder_config),
|
||||
)
|
||||
|
||||
async def _generate_embeddings(self, texts: List[str], batch_size: int = 100) -> List[List[float]]:
|
||||
"""生成一批文本的嵌入向量(支持分批并行)
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
batch_size: 每批处理的文本数量(默认 100)
|
||||
|
||||
Returns:
|
||||
嵌入向量列表
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# 如果文本数量小于批次大小,直接处理
|
||||
if len(texts) <= batch_size:
|
||||
return await self.embedder_client.response(texts)
|
||||
|
||||
# 分批并行处理
|
||||
print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
|
||||
print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||
|
||||
# 并行发送所有批次
|
||||
batch_results = await asyncio.gather(*[
|
||||
self.embedder_client.response(batch) for batch in batches
|
||||
])
|
||||
|
||||
# 合并结果
|
||||
embeddings = []
|
||||
for batch_result in batch_results:
|
||||
embeddings.extend(batch_result)
|
||||
|
||||
print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
|
||||
async def generate_statement_embeddings(
|
||||
self,
|
||||
chunked_dialogs: List[DialogData]
|
||||
) -> List[Dict[str, List[float]]]:
|
||||
"""为所有对话中的陈述句生成嵌入向量
|
||||
|
||||
Args:
|
||||
chunked_dialogs: 包含分块和陈述句的对话列表
|
||||
|
||||
Returns:
|
||||
每个对话的陈述句嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成陈述句嵌入向量 ===")
|
||||
|
||||
# 收集所有陈述句
|
||||
all_statements = []
|
||||
statement_to_dialog_chunk_map = []
|
||||
|
||||
for d_idx, dialog in enumerate(chunked_dialogs):
|
||||
chunks = dialog.chunks
|
||||
if asyncio.iscoroutine(chunks):
|
||||
chunks = await chunks
|
||||
for c_idx, chunk in enumerate(chunks):
|
||||
for s_idx, stmt in enumerate(chunk.statements):
|
||||
all_statements.append(stmt.statement)
|
||||
statement_to_dialog_chunk_map.append((d_idx, c_idx, s_idx))
|
||||
|
||||
# 批量生成嵌入向量
|
||||
stmt_embeddings = await self._generate_embeddings(all_statements)
|
||||
|
||||
# 创建映射
|
||||
stmt_embedding_maps = [{} for _ in chunked_dialogs]
|
||||
for idx, embedding in enumerate(stmt_embeddings):
|
||||
d_idx, c_idx, s_idx = statement_to_dialog_chunk_map[idx]
|
||||
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
|
||||
stmt_embedding_maps[d_idx][stmt_id] = embedding
|
||||
|
||||
print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||
return stmt_embedding_maps
|
||||
|
||||
async def generate_chunk_embeddings(
|
||||
self,
|
||||
chunked_dialogs: List[DialogData]
|
||||
) -> List[Dict[str, List[float]]]:
|
||||
"""为所有对话中的分块生成嵌入向量
|
||||
|
||||
Args:
|
||||
chunked_dialogs: 包含分块的对话列表
|
||||
|
||||
Returns:
|
||||
每个对话的分块嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成分块嵌入向量 ===")
|
||||
|
||||
# 收集所有分块
|
||||
all_chunks = []
|
||||
chunk_to_dialog_map = []
|
||||
|
||||
for d_idx, dialog in enumerate(chunked_dialogs):
|
||||
for c_idx, chunk in enumerate(dialog.chunks):
|
||||
all_chunks.append(chunk.content)
|
||||
chunk_to_dialog_map.append((d_idx, c_idx))
|
||||
|
||||
# 批量生成嵌入向量
|
||||
chunk_embeddings = await self._generate_embeddings(all_chunks)
|
||||
|
||||
# 创建映射
|
||||
chunk_embedding_maps = [{} for _ in chunked_dialogs]
|
||||
for idx, embedding in enumerate(chunk_embeddings):
|
||||
d_idx, c_idx = chunk_to_dialog_map[idx]
|
||||
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
|
||||
chunk_embedding_maps[d_idx][chunk_id] = embedding
|
||||
|
||||
print(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||
return chunk_embedding_maps
|
||||
|
||||
async def generate_dialog_embeddings(
|
||||
self,
|
||||
chunked_dialogs: List[DialogData]
|
||||
) -> List[List[float]]:
|
||||
"""为对话生成嵌入向量(当前跳过,返回空列表)
|
||||
|
||||
Args:
|
||||
chunked_dialogs: 对话列表
|
||||
|
||||
Returns:
|
||||
对话嵌入向量列表(当前为空)
|
||||
"""
|
||||
# 跳过对话嵌入向量生成,但保持正确的长度
|
||||
return [[] for _ in chunked_dialogs]
|
||||
|
||||
async def generate_all_embeddings(
|
||||
self,
|
||||
chunked_dialogs: List[DialogData]
|
||||
) -> Tuple[
|
||||
List[Dict[str, List[float]]],
|
||||
List[Dict[str, List[float]]],
|
||||
List[List[float]]
|
||||
]:
|
||||
"""生成所有类型的嵌入向量
|
||||
|
||||
Args:
|
||||
chunked_dialogs: 包含分块和陈述句的对话列表
|
||||
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
|
||||
"""
|
||||
print("\n=== 生成所有嵌入向量 ===")
|
||||
|
||||
# 并发生成陈述句和分块嵌入向量
|
||||
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
|
||||
self.generate_statement_embeddings(chunked_dialogs),
|
||||
self.generate_chunk_embeddings(chunked_dialogs)
|
||||
)
|
||||
|
||||
# 对话嵌入向量(当前跳过)
|
||||
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
|
||||
|
||||
print(
|
||||
f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
|
||||
)
|
||||
|
||||
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
|
||||
|
||||
async def generate_entity_embeddings(
|
||||
self,
|
||||
triplet_maps: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为三元组中的实体生成嵌入向量
|
||||
|
||||
Args:
|
||||
triplet_maps: 三元组映射列表
|
||||
|
||||
Returns:
|
||||
更新后的三元组映射列表(实体包含嵌入向量)
|
||||
"""
|
||||
print("\n=== 生成实体嵌入向量 ===")
|
||||
|
||||
entity_texts: List[str] = []
|
||||
entity_refs: List[Any] = []
|
||||
|
||||
# 收集所有实体
|
||||
for trip_map in triplet_maps:
|
||||
for _, triplet_info in trip_map.items():
|
||||
entities = getattr(triplet_info, "entities", None)
|
||||
if not entities:
|
||||
continue
|
||||
for ent in entities:
|
||||
text = getattr(ent, "name", None) or getattr(ent, "description", None)
|
||||
if text:
|
||||
entity_texts.append(text)
|
||||
entity_refs.append(ent)
|
||||
|
||||
if not entity_texts:
|
||||
print("没有找到需要生成嵌入向量的实体")
|
||||
return triplet_maps
|
||||
|
||||
# 批量生成嵌入向量
|
||||
embeddings = await self._generate_embeddings(entity_texts)
|
||||
|
||||
# 打印前几个嵌入向量的维度
|
||||
for i in range(min(5, len(embeddings))):
|
||||
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||
|
||||
# 将嵌入向量赋值给实体
|
||||
for ent, emb in zip(entity_refs, embeddings):
|
||||
setattr(ent, "name_embedding", emb)
|
||||
|
||||
print(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||
return triplet_maps
|
||||
|
||||
|
||||
# 保持向后兼容的函数接口
|
||||
async def embedding_generation(
|
||||
chunked_dialogs: List[DialogData],
|
||||
embedding_id: str
|
||||
) -> Tuple[
|
||||
List[Dict[str, List[float]]],
|
||||
List[Dict[str, List[float]]],
|
||||
List[List[float]]
|
||||
]:
|
||||
"""生成陈述句、分块和对话的嵌入向量(向后兼容接口)
|
||||
|
||||
Args:
|
||||
chunked_dialogs: 包含分块和陈述句的对话列表
|
||||
embedding_id: 嵌入模型 ID
|
||||
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
|
||||
"""
|
||||
generator = EmbeddingGenerator(embedding_id)
|
||||
return await generator.generate_all_embeddings(chunked_dialogs)
|
||||
|
||||
|
||||
async def generate_entity_embeddings_from_triplets(
|
||||
triplet_maps: List[Dict[str, Any]],
|
||||
embedding_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为三元组中的实体生成嵌入向量(向后兼容接口)
|
||||
|
||||
Args:
|
||||
triplet_maps: 三元组映射列表
|
||||
embedding_id: 嵌入模型 ID
|
||||
|
||||
Returns:
|
||||
更新后的三元组映射列表(实体包含嵌入向量)
|
||||
"""
|
||||
generator = EmbeddingGenerator(embedding_id)
|
||||
return await generator.generate_entity_embeddings(triplet_maps)
|
||||
|
||||
|
||||
async def embedding_generation_all(
|
||||
chunked_dialogs: List[DialogData],
|
||||
triplet_maps: List[Dict[str, Any]],
|
||||
embedding_id: str
|
||||
) -> Tuple[
|
||||
List[Dict[str, List[float]]],
|
||||
List[Dict[str, List[float]]],
|
||||
List[List[float]],
|
||||
List[Dict[str, Any]]
|
||||
]:
|
||||
"""生成所有类型的嵌入向量(向后兼容接口)
|
||||
|
||||
Args:
|
||||
chunked_dialogs: 包含分块和陈述句的对话列表
|
||||
triplet_maps: 三元组映射列表
|
||||
embedding_id: 嵌入模型 ID
|
||||
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
|
||||
"""
|
||||
print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||
|
||||
generator = EmbeddingGenerator(embedding_id)
|
||||
|
||||
# 生成陈述句、分块和对话的嵌入向量
|
||||
stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings = await generator.generate_all_embeddings(
|
||||
chunked_dialogs
|
||||
)
|
||||
|
||||
# 生成实体嵌入向量
|
||||
updated_triplet_maps = await generator.generate_entity_embeddings(triplet_maps)
|
||||
|
||||
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings, updated_triplet_maps
|
||||
@@ -0,0 +1,117 @@
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
from app.core.memory.models.graph_models import MemorySummaryNode
|
||||
from app.core.memory.models.base_response import RobustLLMResponse
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class MemorySummaryResponse(RobustLLMResponse):
|
||||
"""Structured response for summary generation per chunk.
|
||||
|
||||
This model ensures the LLM returns a valid, non-empty summary.
|
||||
Inherits robust validation from RobustLLMResponse.
|
||||
"""
|
||||
summary: str = Field(
|
||||
...,
|
||||
description="Concise memory summary for a single chunk. Must be a meaningful, non-empty string.",
|
||||
min_length=1,
|
||||
max_length=5000
|
||||
)
|
||||
|
||||
|
||||
async def _process_chunk_summary(
|
||||
dialog: DialogData,
|
||||
chunk,
|
||||
llm_client,
|
||||
embedder: OpenAIEmbedderClient,
|
||||
) -> Optional[MemorySummaryNode]:
|
||||
"""Process a single chunk to generate a memory summary node."""
|
||||
# Skip empty chunks
|
||||
if not chunk.content or not chunk.content.strip():
|
||||
return None
|
||||
|
||||
try:
|
||||
# Render prompt via Jinja2 for a single chunk
|
||||
prompt_content = await render_memory_summary_prompt(
|
||||
chunk_texts=chunk.content,
|
||||
json_schema=MemorySummaryResponse.model_json_schema(),
|
||||
max_words=200,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an expert memory summarizer."},
|
||||
{"role": "user", "content": prompt_content},
|
||||
]
|
||||
|
||||
# Generate structured summary with the existing LLM client
|
||||
structured = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=MemorySummaryResponse,
|
||||
)
|
||||
summary_text = structured.summary.strip()
|
||||
|
||||
# Embed the summary
|
||||
embedding = (await embedder.response([summary_text]))[0]
|
||||
|
||||
# Build node per chunk
|
||||
node = MemorySummaryNode(
|
||||
id=uuid4().hex,
|
||||
name=f"MemorySummaryChunk_{chunk.id}",
|
||||
group_id=dialog.group_id,
|
||||
user_id=dialog.user_id,
|
||||
apply_id=dialog.apply_id,
|
||||
run_id=dialog.run_id, # 使用 dialog 的 run_id
|
||||
created_at=datetime.now(),
|
||||
expired_at=datetime(9999, 12, 31),
|
||||
dialog_id=dialog.id,
|
||||
chunk_ids=[chunk.id],
|
||||
content=summary_text,
|
||||
summary_embedding=embedding,
|
||||
metadata={"ref_id": dialog.ref_id},
|
||||
config_id=dialog.config_id, # 添加 config_id
|
||||
)
|
||||
return node
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but continue processing other chunks
|
||||
logger.warning(f"Failed to generate summary for chunk {chunk.id} in dialog {dialog.id}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def Memory_summary_generation(
|
||||
chunked_dialogs: List[DialogData],
|
||||
llm_client,
|
||||
embedding_id,
|
||||
) -> List[MemorySummaryNode]:
|
||||
"""Generate memory summaries per chunk, embed them, and return nodes."""
|
||||
embedder_cfg_dict = get_embedder_config(embedding_id)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(embedder_cfg_dict),
|
||||
)
|
||||
|
||||
# Collect all tasks for parallel processing
|
||||
tasks = []
|
||||
for dialog in chunked_dialogs:
|
||||
for chunk in dialog.chunks:
|
||||
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder))
|
||||
|
||||
# Process all chunks in parallel
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# Filter out None values (failed or empty chunks)
|
||||
nodes = [node for node in results if node is not None]
|
||||
|
||||
return nodes
|
||||
@@ -0,0 +1,301 @@
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo
|
||||
|
||||
from app.core.memory.models.variate_config import StatementExtractionConfig
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo, RelevenceInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ExtractedStatement(BaseModel):
|
||||
"""Schema for extracted statement from LLM"""
|
||||
statement: str = Field(..., description="The extracted statement text")
|
||||
statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION")
|
||||
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
|
||||
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
|
||||
|
||||
# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句)
|
||||
class StatementExtractionResponse(BaseModel):
|
||||
statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements")
|
||||
|
||||
class StatementExtractor:
|
||||
"""Class for extracting statements from dialog chunks using LLM (relations separated)"""
|
||||
|
||||
def __init__(self, llm_client: Any, config: StatementExtractionConfig = None):
|
||||
# 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
"""Initialize the StatementExtractor with an LLM client and configuration
|
||||
|
||||
Args:
|
||||
llm_client: OpenAIClient instance for processing LLM requests
|
||||
config: StatementExtractionConfig for controlling extraction behavior
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.config = config or StatementExtractionConfig()
|
||||
|
||||
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
"""Process a single chunk and return extracted statements
|
||||
|
||||
Args:
|
||||
chunk: Chunk object to process
|
||||
group_id: Group ID to assign to all statements in this chunk
|
||||
dialogue_content: Full dialogue content to provide as context
|
||||
|
||||
Returns:
|
||||
List of ExtractedStatement objects extracted from the chunk
|
||||
"""
|
||||
# Prepare the chunk content for processing
|
||||
chunk_content = chunk.content
|
||||
|
||||
# Render the prompt using helper function
|
||||
prompt_content = await render_statement_extraction_prompt(
|
||||
chunk_content=chunk_content,
|
||||
definitions=LABEL_DEFINITIONS,
|
||||
json_schema=ExtractedStatement.model_json_schema(),
|
||||
granularity=self.config.statement_granularity,
|
||||
include_dialogue_context=self.config.include_dialogue_context,
|
||||
dialogue_content=dialogue_content,
|
||||
max_dialogue_chars=self.config.max_dialogue_context_chars
|
||||
)
|
||||
|
||||
# Simple system message
|
||||
system_content = "You are an expert at extracting and labeling atomic statements from conversational text. Return valid JSON conforming to the schema."
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": prompt_content}
|
||||
]
|
||||
|
||||
try:
|
||||
# Get structured response from LLM (statements only)
|
||||
response = await self.llm_client.response_structured(messages, StatementExtractionResponse)
|
||||
# Defensive: ensure response has the expected structure
|
||||
if not hasattr(response, "statements") or response.statements is None:
|
||||
logger.warning("Invalid structured response: missing 'statements'. Returning empty list for this chunk.")
|
||||
return []
|
||||
|
||||
# Convert extracted statements to Statement objects
|
||||
chunk_statements = []
|
||||
for extracted_stmt in response.statements:
|
||||
# Normalize and correct enums defensively
|
||||
stmt_type_str = str(extracted_stmt.statement_type).strip().upper()
|
||||
temporal_type_str = str(extracted_stmt.temporal_type).strip().upper()
|
||||
relevence_str = str(extracted_stmt.relevence).strip().upper()
|
||||
|
||||
# Convert strings to enum types with fallback defaults
|
||||
try:
|
||||
stmt_type = StatementType[stmt_type_str] if stmt_type_str in StatementType.__members__ else StatementType.FACT
|
||||
except (KeyError, ValueError):
|
||||
stmt_type = StatementType.FACT
|
||||
|
||||
try:
|
||||
temporal_type = TemporalInfo[temporal_type_str] if temporal_type_str in TemporalInfo.__members__ else TemporalInfo.STATIC
|
||||
except (KeyError, ValueError):
|
||||
temporal_type = TemporalInfo.STATIC
|
||||
|
||||
try:
|
||||
relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT
|
||||
except (KeyError, ValueError):
|
||||
relevence_info = RelevenceInfo.RELEVANT
|
||||
|
||||
chunk_statement = Statement(
|
||||
statement=extracted_stmt.statement,
|
||||
stmt_type=stmt_type,
|
||||
temporal_info=temporal_type,
|
||||
relevence_info=relevence_info,
|
||||
chunk_id=chunk.id,
|
||||
group_id=group_id,
|
||||
)
|
||||
chunk_statements.append(chunk_statement)
|
||||
|
||||
# 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata
|
||||
return chunk_statements
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing chunk: {e}", exc_info=True)
|
||||
# Return empty list to indicate failure for this chunk
|
||||
return []
|
||||
|
||||
async def extract_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> List[List[Statement]]:
|
||||
"""Extract statements from a DialogData object.
|
||||
|
||||
Args:
|
||||
dialog_data: The DialogData object containing chunks.
|
||||
limit_chunks: Optional limit on the number of chunks to process.
|
||||
"""
|
||||
# Determine how many chunks to process
|
||||
chunks_to_process = dialog_data.chunks[:limit_chunks] if limit_chunks else dialog_data.chunks
|
||||
|
||||
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
|
||||
|
||||
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data
|
||||
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
|
||||
results = await asyncio.gather(
|
||||
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# Filter out exceptions and return valid results
|
||||
valid_results = []
|
||||
for result in results:
|
||||
if isinstance(result, list) and result is not None:
|
||||
valid_results.append(result)
|
||||
else:
|
||||
print(f"Error in statement extraction: {result}")
|
||||
valid_results.append([])
|
||||
|
||||
return valid_results
|
||||
|
||||
def save_statements(self, statements: List[Statement], output_path: str = None) -> str:
|
||||
"""Save the extracted statements to a file and return the output path.
|
||||
|
||||
Args:
|
||||
statements: List of Statement objects to save
|
||||
output_path: Optional path to save the output (default: statement_extraction.txt)
|
||||
|
||||
Returns:
|
||||
The path where the output was saved
|
||||
"""
|
||||
# 使用全局配置的输出路径
|
||||
if not output_path:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_path = settings.get_memory_output_path("statement_extraction.txt")
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"Extracted Statements ({len(statements)} total)\n")
|
||||
f.write("=" * 50 + "\n\n")
|
||||
|
||||
for i, statement in enumerate(statements, 1):
|
||||
f.write(f"Statement {i}:\n")
|
||||
f.write(f"Id: {statement.id}\n")
|
||||
f.write(f"Group Id: {statement.group_id}\n")
|
||||
f.write(f"Content: {statement.statement}\n")
|
||||
f.write(f"Type: {statement.stmt_type.value}\n")
|
||||
f.write(f"Temporal Info: {statement.temporal_info.value}\n")
|
||||
f.write(f"Created At: {datetime.now()}\n")
|
||||
f.write(f"Expired At: {None}\n")
|
||||
f.write(f"Valid At: {statement.temporal_validity.valid_at if statement.temporal_validity else None}\n")
|
||||
f.write(f"Invalid At: {statement.temporal_validity.invalid_at if statement.temporal_validity else None}\n")
|
||||
f.write(f"Chunk Id: {statement.chunk_id}\n")
|
||||
# add relevance information to satisfy tests
|
||||
if hasattr(statement, "relevence_info") and statement.relevence_info is not None:
|
||||
f.write(f"Relevence Info: {statement.relevence_info.value}\n")
|
||||
f.write("-" * 30 + "\n\n")
|
||||
|
||||
print(f"Extracted {len(statements)} statements and saved to {output_path}")
|
||||
return output_path
|
||||
|
||||
def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str:
|
||||
"""按对话分组聚合强/弱关系并写入 TXT 文件。
|
||||
- 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content`
|
||||
- 在该对话段内再分为 Strong Relations / Weak Relations 两部分
|
||||
- Strong: 逐条输出 `Chunk ID` 与 `Triple`
|
||||
- Weak: 逐条输出 `Chunk ID` 与 `Entity`
|
||||
"""
|
||||
print("\n=== Relations Classify ===")
|
||||
|
||||
# 使用全局配置的输出路径
|
||||
if not output_path:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_path = settings.get_memory_output_path("relations_output.txt")
|
||||
# output_path = os.path.join(os.path.dirname(__file__), "..", "relations_output.txt")
|
||||
|
||||
dialog_sections: List[Dict[str, Any]] = []
|
||||
total_strong = 0
|
||||
total_weak = 0
|
||||
|
||||
for dialog in dialogs:
|
||||
strong_relations: List[Dict[str, Any]] = []
|
||||
weak_relations: List[Dict[str, Any]] = []
|
||||
|
||||
for chunk in dialog.chunks or []:
|
||||
# 基于三元组/实体推导强弱关系
|
||||
for stmt in chunk.statements or []:
|
||||
te = getattr(stmt, "triplet_extraction_info", None)
|
||||
if not te:
|
||||
continue
|
||||
trips = getattr(te, "triplets", []) or []
|
||||
ents = getattr(te, "entities", []) or []
|
||||
|
||||
# Strong: 逐条输出三元组
|
||||
if trips:
|
||||
for trip in trips:
|
||||
subj = getattr(trip, "subject_name", "")
|
||||
pred = str(getattr(trip, "predicate", ""))
|
||||
obj = getattr(trip, "object_name", "")
|
||||
triple_str = f"({subj}, {pred}, {obj})"
|
||||
strong_relations.append({
|
||||
"chunk_id": chunk.id,
|
||||
"triple": triple_str,
|
||||
})
|
||||
else:
|
||||
# Weak: 无三元组但有实体
|
||||
for ent in ents:
|
||||
name = getattr(ent, "name", "")
|
||||
desc = getattr(ent, "description", "") or ""
|
||||
entity_str = f"{name}: {desc}" if desc else name
|
||||
if name:
|
||||
weak_relations.append({
|
||||
"chunk_id": chunk.id,
|
||||
"entity": entity_str,
|
||||
})
|
||||
|
||||
total_strong += len(strong_relations)
|
||||
total_weak += len(weak_relations)
|
||||
|
||||
dialog_sections.append({
|
||||
"dialog_id": dialog.ref_id,
|
||||
"group_id": dialog.group_id,
|
||||
"content": dialog.content if getattr(dialog, "content", None) else "",
|
||||
"strong": strong_relations,
|
||||
"weak": weak_relations,
|
||||
})
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"Relations Extraction (grouped by dialogs, strong: {total_strong}, weak: {total_weak})\n")
|
||||
f.write("=" * 50 + "\n\n")
|
||||
|
||||
for idx, section in enumerate(dialog_sections, 1):
|
||||
f.write(f"Dialog {idx}:\n")
|
||||
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
|
||||
f.write(f"Group ID: {section.get('group_id', '')}\n")
|
||||
f.write("Content:\n")
|
||||
f.write(f"{section.get('content', '')}\n")
|
||||
f.write("-" * 40 + "\n\n")
|
||||
|
||||
# Strong Relations for this dialog
|
||||
strong_list = section.get("strong", [])
|
||||
f.write(f"Strong Relations ({len(strong_list)} total)\n")
|
||||
f.write("-" * 30 + "\n\n")
|
||||
for i, item in enumerate(strong_list, 1):
|
||||
f.write(f"Item {i}:\n")
|
||||
f.write(f"Chunk ID: {item.get('chunk_id', '')}\n")
|
||||
f.write(f"Triple: {item.get('triple', '')}\n")
|
||||
f.write("-" * 30 + "\n\n")
|
||||
|
||||
# Weak Relations for this dialog
|
||||
weak_list = section.get("weak", [])
|
||||
f.write(f"Weak Relations ({len(weak_list)} total)\n")
|
||||
f.write("-" * 30 + "\n\n")
|
||||
for i, item in enumerate(weak_list, 1):
|
||||
f.write(f"Item {i}:\n")
|
||||
f.write(f"Chunk ID: {item.get('chunk_id', '')}\n")
|
||||
f.write(f"Entity: {item.get('entity', '')}\n")
|
||||
f.write("-" * 30 + "\n\n")
|
||||
|
||||
print(f"Saved relations to {output_path}")
|
||||
return output_path
|
||||
except Exception as e:
|
||||
print(f"Failed to save relations to {output_path}: {e}")
|
||||
return output_path
|
||||
@@ -0,0 +1,222 @@
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.message_models import DialogData, Statement, TemporalValidityRange
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_temporal_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, TemporalInfo
|
||||
from app.core.memory.utils.log.logging_utils import prompt_logger
|
||||
|
||||
|
||||
class RawTemporalRange(BaseModel):
|
||||
"""Schema for the raw temporal range extracted by the LLM."""
|
||||
|
||||
valid_at: Optional[str] = Field(
|
||||
None, description="The start date and time of the validity range in ISO 8601 format."
|
||||
)
|
||||
invalid_at: Optional[str] = Field(
|
||||
None, description="The end date and time of the validity range in ISO 8601 format."
|
||||
)
|
||||
|
||||
|
||||
class TemporalExtractor:
|
||||
"""
|
||||
Extracts temporal validity ranges from statements using an LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: OpenAIClient):
|
||||
"""
|
||||
Initializes the TemporalExtractor.
|
||||
|
||||
Args:
|
||||
llm_client (OpenAIClient): The OpenAI client to use for LLM calls.
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
|
||||
async def _extract_temporal_ranges(
|
||||
self, statement: Statement, ref_dates: dict[str, Any]
|
||||
) -> TemporalValidityRange:
|
||||
"""
|
||||
Extracts the temporal range for a single statement.
|
||||
|
||||
Args:
|
||||
statement (Statement): The statement to process.
|
||||
ref_dates (dict[str, Any]): Reference dates for context.
|
||||
|
||||
Returns:
|
||||
TemporalValidityRange: The extracted temporal validity range.
|
||||
"""
|
||||
if not ref_dates:
|
||||
ref_dates = {"today": datetime.now().strftime("%Y-%m-%d")}
|
||||
|
||||
if statement.temporal_info == TemporalInfo.ATEMPORAL:
|
||||
return TemporalValidityRange(valid_at=None, invalid_at=None)
|
||||
|
||||
temporal_guide = LABEL_DEFINITIONS["temporal_labelling"]
|
||||
statement_guide = LABEL_DEFINITIONS["statement_labelling"]
|
||||
|
||||
# Log start and input context
|
||||
try:
|
||||
prompt_logger.info(f"[Temporal] Started - statement_id={statement.id}")
|
||||
prompt_logger.debug(
|
||||
f"[Temporal] Input statement=\"{statement.statement}\" ref_dates={ref_dates}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
prompt_content = await render_temporal_extraction_prompt(
|
||||
ref_dates=ref_dates,
|
||||
statement=statement.model_dump(),
|
||||
temporal_guide=temporal_guide,
|
||||
statement_guide=statement_guide,
|
||||
json_schema=RawTemporalRange.model_json_schema(),
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert at extracting temporal validity ranges from statements. Follow the provided instructions carefully and return valid JSON.",
|
||||
},
|
||||
{"role": "user", "content": prompt_content},
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm_client.response_structured(
|
||||
messages, RawTemporalRange
|
||||
)
|
||||
if response:
|
||||
# Log raw structured response
|
||||
try:
|
||||
prompt_logger.debug(
|
||||
f"[Temporal] Raw structured response - statement_id={statement.id}: valid_at={response.valid_at}, invalid_at={response.invalid_at}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return TemporalValidityRange(
|
||||
valid_at=response.valid_at, invalid_at=response.invalid_at
|
||||
)
|
||||
except Exception as e:
|
||||
try:
|
||||
prompt_logger.warning(
|
||||
f"[Temporal] Failed to process statement_id={statement.id}. Error: {e}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return TemporalValidityRange(valid_at=None, invalid_at=None)
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
async def extract_temporal_ranges(
|
||||
self, dialog_data: DialogData, ref_dates: Optional[dict[str, Any]] = None
|
||||
) -> Dict[str, TemporalValidityRange]:
|
||||
"""
|
||||
Extracts temporal ranges for statements in the dialog_data.
|
||||
|
||||
Args:
|
||||
dialog_data (DialogData): The dialog data containing chunks with statements to process.
|
||||
ref_dates (Optional[dict[str, Any]]): Reference dates for context.
|
||||
|
||||
Returns:
|
||||
Dict[str, TemporalValidityRange]: A dictionary mapping statement IDs to their temporal ranges.
|
||||
"""
|
||||
if ref_dates is None:
|
||||
ref_dates = {}
|
||||
|
||||
statement_temporal_map = {}
|
||||
|
||||
# Header (match legacy format)
|
||||
try:
|
||||
prompt_logger.info("")
|
||||
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
|
||||
prompt_logger.info(
|
||||
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Collect all statements with their IDs
|
||||
all_tasks = []
|
||||
statement_ids = []
|
||||
|
||||
for chunk in dialog_data.chunks:
|
||||
if not chunk.statements:
|
||||
continue
|
||||
|
||||
for statement in chunk.statements:
|
||||
if statement.temporal_info == TemporalInfo.ATEMPORAL:
|
||||
# Log skipped
|
||||
try:
|
||||
prompt_logger.info(
|
||||
f"[Temporal] Skipped ATEMPORAL - statement_id={statement.id}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
statement_temporal_map[statement.id] = TemporalValidityRange(
|
||||
valid_at=None, invalid_at=None
|
||||
)
|
||||
continue
|
||||
all_tasks.append(self._extract_temporal_ranges(statement, ref_dates))
|
||||
statement_ids.append(statement.id)
|
||||
|
||||
# Process all statements concurrently
|
||||
results = await asyncio.gather(*all_tasks, return_exceptions=True)
|
||||
|
||||
# Map results back to statement IDs
|
||||
for i, result in enumerate(results):
|
||||
statement_id = statement_ids[i]
|
||||
if isinstance(result, TemporalValidityRange):
|
||||
statement_temporal_map[statement_id] = result
|
||||
else:
|
||||
try:
|
||||
prompt_logger.warning(
|
||||
f"[Temporal] Failed to process statement_id={statement_id}. Error: {result}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
statement_temporal_map[statement_id] = TemporalValidityRange(
|
||||
valid_at=None, invalid_at=None
|
||||
)
|
||||
|
||||
# Summary (match legacy completion line)
|
||||
try:
|
||||
extracted_count = sum(
|
||||
1
|
||||
for v in statement_temporal_map.values()
|
||||
if (v.valid_at is not None or v.invalid_at is not None)
|
||||
)
|
||||
prompt_logger.info(
|
||||
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)} completed, extracted_valid_ranges={extracted_count}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return statement_temporal_map
|
||||
|
||||
def save_temporal_extractions_to_file(
|
||||
self, dialog_data: DialogData, output_path: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Saves the extracted temporal data to a text file.
|
||||
|
||||
Args:
|
||||
dialog_data (DialogData): The dialog data containing the statements with temporal data.
|
||||
output_path (str): The path to the output file.
|
||||
"""
|
||||
if not output_path:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_path = settings.get_memory_output_path("extracted_temporal_data.txt")
|
||||
with open(output_path, "w") as f:
|
||||
for chunk in dialog_data.chunks:
|
||||
f.write(f"Chunk: {chunk.content}\n")
|
||||
for statement in chunk.statements:
|
||||
f.write(f" - Statement: {statement.statement}\n")
|
||||
if statement.temporal_validity:
|
||||
f.write(f" - Valid At: {statement.temporal_validity.valid_at}\n")
|
||||
f.write(f" - Invalid At: {statement.temporal_validity.invalid_at}\n")
|
||||
else:
|
||||
f.write(f" - Temporal Validity: Not Extracted\n")
|
||||
f.write("\n")
|
||||
@@ -0,0 +1,223 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
from app.core.memory.utils.log.logging_utils import prompt_logger
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
class TripletExtractor:
|
||||
"""Extracts knowledge triplets and entities from statements using LLM"""
|
||||
|
||||
def __init__(self, llm_client: OpenAIClient):
|
||||
"""Initialize the TripletExtractor with an LLM client
|
||||
|
||||
Args:
|
||||
llm_client: OpenAIClient instance for processing
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
|
||||
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
|
||||
"""Process a single statement and return extracted triplets and entities"""
|
||||
# Render the prompt using helper function
|
||||
# Log start and input context similar to legacy logs
|
||||
try:
|
||||
prompt_logger.info(f"[Triplet] Started - statement_id={statement.id}")
|
||||
prompt_logger.debug(f"[Triplet] Input statement=\"{statement.statement}\"")
|
||||
except Exception:
|
||||
# Avoid breaking flow due to logging issues
|
||||
pass
|
||||
|
||||
prompt_content = await render_triplet_extraction_prompt(
|
||||
statement=statement.statement,
|
||||
chunk_content=chunk_content,
|
||||
json_schema=TripletExtractionResponse.model_json_schema(),
|
||||
predicate_instructions=PREDICATE_DEFINITIONS
|
||||
)
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "user", "content": prompt_content}
|
||||
]
|
||||
|
||||
try:
|
||||
# Get structured response from LLM
|
||||
response = await self.llm_client.response_structured(messages, TripletExtractionResponse)
|
||||
# Filter triplets to only allowed predicates from ontology
|
||||
# 这里过滤掉了不在 Predicate 枚举中的谓语 但是容易造成谓语太严格,有点语句的谓语没有在枚举中,就被判断为弱关系
|
||||
allowed_predicates = {p.value for p in Predicate}
|
||||
filtered_triplets = [t for t in response.triplets if getattr(t, "predicate", "") in allowed_predicates]
|
||||
# 仅保留predicate ∈ Predicate 的三元组,其余全部剔除
|
||||
|
||||
# Create new triplets with statement_id set during creation
|
||||
updated_triplets = []
|
||||
for triplet in filtered_triplets: # 仅保留 predicate ∈ Predicate 的三元组
|
||||
updated_triplet = triplet.model_copy(update={"statement_id": statement.id})
|
||||
updated_triplets.append(updated_triplet)
|
||||
|
||||
# Log completion and per-item details to match legacy format
|
||||
try:
|
||||
prompt_logger.info(
|
||||
f"[Triplet] Completed - statement_id={statement.id}, triplets={len(updated_triplets)}, entities={len(response.entities)}"
|
||||
)
|
||||
for i, t in enumerate(updated_triplets, 1):
|
||||
prompt_logger.debug(
|
||||
f"[Triplet] Triplet #{i}: ({t.subject_name}) - {t.predicate} - ({t.object_name}) value={t.value if t.value is not None else 'None'}"
|
||||
)
|
||||
for i, e in enumerate(response.entities, 1):
|
||||
prompt_logger.debug(
|
||||
f"[Triplet] Entity #{i}: id={getattr(e, 'entity_idx', None)} name={getattr(e, 'name', None)} type={getattr(e, 'type', None)} desc={getattr(e, 'description', None)}"
|
||||
)
|
||||
except Exception:
|
||||
print(f"Error logging triplet details: {e}")
|
||||
pass
|
||||
|
||||
# Return new response with updated triplets
|
||||
return TripletExtractionResponse(
|
||||
triplets=updated_triplets,
|
||||
entities=response.entities
|
||||
)
|
||||
# # Set statement_id for each triplet to establish parent relationship
|
||||
# for triplet in response.triplets:
|
||||
# triplet.statement_id = statement.id
|
||||
|
||||
# return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing statement: {e}", exc_info=True)
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]:
|
||||
"""Extract triplets and entities from statements
|
||||
|
||||
Args:
|
||||
dialog_data: DialogData object to process
|
||||
limit_chunks: Number of chunks to process
|
||||
|
||||
Returns:
|
||||
Dict[str, TripletExtractionResponse]: Dictionary mapping statement IDs to their triplet responses
|
||||
"""
|
||||
# Collect all statements from the specified chunks
|
||||
all_statements = []
|
||||
chunks_to_process = dialog_data.chunks[:limit_chunks] if limit_chunks else dialog_data.chunks
|
||||
|
||||
for chunk in chunks_to_process:
|
||||
all_statements.extend(chunk.statements)
|
||||
|
||||
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
|
||||
try:
|
||||
prompt_logger.info(
|
||||
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Prepare tasks and statement IDs
|
||||
tasks = []
|
||||
statement_ids = []
|
||||
|
||||
for chunk in chunks_to_process:
|
||||
for statement in chunk.statements:
|
||||
tasks.append(self._extract_triplets(statement, chunk.content))
|
||||
statement_ids.append(statement.id)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Map results to statement IDs
|
||||
statement_triplet_map = {}
|
||||
for i, result in enumerate(results):
|
||||
statement_id = statement_ids[i]
|
||||
if isinstance(result, TripletExtractionResponse):
|
||||
statement_triplet_map[statement_id] = result
|
||||
else:
|
||||
logger.error(f"Error in triplet extraction for statement {statement_id}: {result}", exc_info=True)
|
||||
statement_triplet_map[statement_id] = TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
# Dialog-level summary and details (match legacy format)
|
||||
try:
|
||||
# Flatten totals
|
||||
all_triplets = []
|
||||
all_entities_with_stmt = []
|
||||
for sid, resp in statement_triplet_map.items():
|
||||
for t in resp.triplets:
|
||||
all_triplets.append(t)
|
||||
for e in resp.entities:
|
||||
all_entities_with_stmt.append((sid, e))
|
||||
|
||||
prompt_logger.info(
|
||||
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)} completed, total_triplets={len(all_triplets)}, total_entities={len(all_entities_with_stmt)}"
|
||||
)
|
||||
|
||||
# Triplets Detail section
|
||||
prompt_logger.info("\n--- Triplets Detail ---")
|
||||
for i, t in enumerate(all_triplets, 1):
|
||||
prompt_logger.info(
|
||||
f"[Triplet] #{i} statement_id={getattr(t, 'statement_id', None)} subject=({getattr(t, 'subject_name', None)}:{getattr(t, 'subject_id', None)}) predicate={getattr(t, 'predicate', None)} object=({getattr(t, 'object_name', None)}:{getattr(t, 'object_id', None)}) value={getattr(t, 'value', None) if getattr(t, 'value', None) is not None else 'None'}"
|
||||
)
|
||||
|
||||
# Entities Detail section
|
||||
prompt_logger.info("\n--- Entities Detail ---")
|
||||
for i, (sid, e) in enumerate(all_entities_with_stmt, 1):
|
||||
prompt_logger.info(
|
||||
f"[Entity] #{i} statement_id={sid} id={getattr(e, 'entity_idx', None)} name={getattr(e, 'name', None)} type={getattr(e, 'type', None)} desc={getattr(e, 'description', None)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return statement_triplet_map
|
||||
|
||||
def save_triplets(self, triplet_responses: List[TripletExtractionResponse], output_path: str = None) -> str:
|
||||
"""Save extracted triplets and entities to a file
|
||||
|
||||
Args:
|
||||
triplet_responses: List of TripletExtractionResponse objects
|
||||
output_path: Optional path to save the results
|
||||
|
||||
Returns:
|
||||
Path where the triplets were saved
|
||||
"""
|
||||
if output_path is None:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_path = settings.get_memory_output_path("extracted_triplets.txt")
|
||||
|
||||
# Flatten all triplets and entities
|
||||
all_triplets = []
|
||||
all_entities = []
|
||||
|
||||
for response in triplet_responses:
|
||||
all_triplets.extend(response.triplets)
|
||||
all_entities.extend(response.entities)
|
||||
|
||||
# Save to file
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n")
|
||||
for i, triplet in enumerate(all_triplets, 1):
|
||||
f.write(f"Triplet {i}:\n")
|
||||
f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n")
|
||||
f.write(f" Predicate: {triplet.predicate}\n")
|
||||
f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n")
|
||||
if triplet.value:
|
||||
f.write(f" Value: {triplet.value}\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n")
|
||||
for i, entity in enumerate(all_entities, 1):
|
||||
f.write(f"Entity {i}:\n")
|
||||
f.write(f" ID: {entity.entity_idx}\n")
|
||||
f.write(f" Name: {entity.name}\n")
|
||||
f.write(f" Type: {entity.type}\n")
|
||||
f.write(f" Description: {entity.description}\n")
|
||||
f.write("\n")
|
||||
|
||||
logger.info(f"Saved {len(all_triplets)} triplets and {len(all_entities)} entities to: {output_path}")
|
||||
return output_path
|
||||
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
提取流水线工具函数
|
||||
|
||||
该模块提供知识提取流水线的辅助工具函数,包括:
|
||||
1. 解析和格式化提取结果
|
||||
2. 生成提取结果汇总报告
|
||||
3. 导出测试输入文档
|
||||
|
||||
这些函数主要用于:
|
||||
- 解析三元组和实体信息
|
||||
- 统计去重和消歧效果
|
||||
- 生成可读的结果报告
|
||||
|
||||
作者:Memory Refactoring Team
|
||||
原路径:app/core/memory/src/pipeline_help.py(已迁移)
|
||||
迁移日期:2025-11-22
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def _parse_triplets_from_file(filepath):
|
||||
"""解析三元组文件,返回三元组列表"""
|
||||
triplets = []
|
||||
if not os.path.exists(filepath):
|
||||
return triplets
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
lines = content.split('\n')
|
||||
current_triplet = {}
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith('Triplet '):
|
||||
if current_triplet:
|
||||
triplets.append(current_triplet)
|
||||
current_triplet = {}
|
||||
elif line.startswith('Subject:'):
|
||||
subject = line.replace('Subject:', '').strip()
|
||||
subject = subject.split('(ID:')[0].strip()
|
||||
current_triplet['subject'] = subject
|
||||
elif line.startswith('Predicate:'):
|
||||
predicate = line.replace('Predicate:', '').strip()
|
||||
current_triplet['predicate'] = predicate
|
||||
elif line.startswith('Object:'):
|
||||
obj = line.replace('Object:', '').strip()
|
||||
obj = obj.split('(ID:')[0].strip()
|
||||
current_triplet['object'] = obj
|
||||
|
||||
if current_triplet:
|
||||
triplets.append(current_triplet)
|
||||
except Exception as e:
|
||||
print(f"解析三元组文件失败: {e}")
|
||||
|
||||
return triplets
|
||||
|
||||
|
||||
def _parse_entities_from_triplets(filepath):
|
||||
"""从三元组文件中解析实体信息,按类型分组"""
|
||||
entities_by_type = defaultdict(list)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
return entities_by_type
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
if '=== EXTRACTED ENTITIES' in content:
|
||||
entity_section = content.split('=== EXTRACTED ENTITIES')[1]
|
||||
lines = entity_section.split('\n')
|
||||
|
||||
current_entity = {}
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith('Entity '):
|
||||
if current_entity and 'name' in current_entity and 'type' in current_entity:
|
||||
entities_by_type[current_entity['type']].append(current_entity['name'])
|
||||
current_entity = {}
|
||||
elif line.startswith('Name:'):
|
||||
name = line.replace('Name:', '').strip()
|
||||
current_entity['name'] = name
|
||||
elif line.startswith('Type:'):
|
||||
entity_type = line.replace('Type:', '').strip()
|
||||
current_entity['type'] = entity_type
|
||||
|
||||
if current_entity and 'name' in current_entity and 'type' in current_entity:
|
||||
entities_by_type[current_entity['type']].append(current_entity['name'])
|
||||
|
||||
# 去重
|
||||
for entity_type in entities_by_type:
|
||||
entities_by_type[entity_type] = list(set(entities_by_type[entity_type]))
|
||||
except Exception as e:
|
||||
print(f"解析实体信息失败: {e}")
|
||||
|
||||
return entities_by_type
|
||||
|
||||
|
||||
def _format_predicate(predicate):
|
||||
"""格式化谓词为中文"""
|
||||
predicate_map = {
|
||||
'COLLABORATES_WITH': '同事',
|
||||
'MENTIONS': '提到',
|
||||
'DEVELOPED': '开发',
|
||||
'PART_OF': '参与',
|
||||
'LOCATED_IN': '位于',
|
||||
'WORKS_AT': '工作于',
|
||||
'PURCHASED': '购买',
|
||||
'INTERESTED_IN': '感兴趣'
|
||||
}
|
||||
return predicate_map.get(predicate, predicate.lower().replace('_', ' '))
|
||||
|
||||
|
||||
def _write_extracted_result_summary(
|
||||
chunk_nodes,
|
||||
pipeline_output_dir: str,
|
||||
):
|
||||
"""
|
||||
汇总生成 logs/memory-output/extracted_result.json,包含:
|
||||
- 提取实体数(从 extracted_entities_edges.txt 的 ENTITY 行计数)
|
||||
- 去重后合并个数(统计 dedup_entity_output.txt 的精确/模糊/LLM合并记录)
|
||||
- 实体消歧次数(统计阻断与合并应用,并输出同名实体“消歧成功”)
|
||||
- 记忆片段数(chunk_nodes 的数量)
|
||||
- 关系三元组数(从 extracted_triplets.txt 标题获取总数)
|
||||
"""
|
||||
os.makedirs(pipeline_output_dir, exist_ok=True)
|
||||
result_path = os.path.join(pipeline_output_dir, "extracted_result.json")
|
||||
entities_edges_path = os.path.join(pipeline_output_dir, "extracted_entities_edges.txt")
|
||||
dedup_report_path = os.path.join(pipeline_output_dir, "dedup_entity_output.txt")
|
||||
triplets_path = os.path.join(pipeline_output_dir, "extracted_triplets.txt")
|
||||
|
||||
# 1) 提取实体数
|
||||
extracted_entity_count = 0
|
||||
# 初始提取的名称计数(用于“出现X次”的基础计数)
|
||||
initial_name_counts: dict[str, int] = {}
|
||||
try:
|
||||
with open(entities_edges_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip().startswith("ENTITY:"):
|
||||
extracted_entity_count += 1
|
||||
# 解析 name 字段
|
||||
try:
|
||||
m = re.search(r"\{\s*\"id\"\s*:\s*\"[^\"]*\"\s*,\s*\"name\"\s*:\s*\"([^\"]+)\"", line)
|
||||
if m:
|
||||
nm = m.group(1).strip()
|
||||
if nm:
|
||||
initial_name_counts[nm] = initial_name_counts.get(nm, 0) + 1
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2) 去重后合并个数 & 3) 实体消歧次数(含成功名称)
|
||||
exact_merge_total = 0
|
||||
fuzzy_merge_total = 0
|
||||
llm_merge_total = 0
|
||||
disamb_block_total = 0
|
||||
# 记录成功区分的消歧对(阻断的左右实体及类型)
|
||||
disamb_success_pairs: list[tuple[str, str, str, str]] = []
|
||||
# 在外部定义这些字典,确保后续代码可以访问
|
||||
dedup_impact: dict[tuple[str, str], int] = {}
|
||||
# 第二层精准合并新增:包含自合并(自合并视为"比较两个实体后合并为一")
|
||||
second_layer_exact_additions: dict[tuple[str, str], int] = {}
|
||||
# LLM 同名类型相似:按名称计一次出现(代表两个实体合并为一)
|
||||
llm_same_name_additions: dict[str, int] = {}
|
||||
|
||||
try:
|
||||
with open(dedup_report_path, "r", encoding="utf-8") as f:
|
||||
current_layer: str | None = None
|
||||
for raw in f:
|
||||
line = raw.strip()
|
||||
if line.startswith("=== 第一层去重消歧 ==="):
|
||||
current_layer = "第一层去重消歧"
|
||||
continue
|
||||
if line.startswith("=== 第二层去重消歧 ==="):
|
||||
current_layer = "第二层去重消歧"
|
||||
continue
|
||||
# 精确合并:统计“合并实体IDs”数量
|
||||
if line.startswith("[精确] ") and "合并实体IDs" in line:
|
||||
try:
|
||||
# 先提取规范ID(用于第二层去重统计)
|
||||
canonical_id = ""
|
||||
id_match = re.search(r"规范实体\s+([0-9a-f]{40})", line)
|
||||
if id_match:
|
||||
canonical_id = id_match.group(1).strip()
|
||||
|
||||
# 提取名称、类型和合并实体IDs
|
||||
m = re.search(r"名称\s+'([^']+)'\s+类型\s+(\S+)\s+<-\s+合并实体IDs\s+(.+)$", line)
|
||||
if m:
|
||||
name = m.group(1).strip()
|
||||
ent_type = m.group(2).strip()
|
||||
ids_part = m.group(3).strip()
|
||||
else:
|
||||
# 退化解析:如果上式失败,回退到简单切分
|
||||
canonical_id = ""
|
||||
name = ""
|
||||
ent_type = ""
|
||||
ids_part = line.split("合并实体IDs", 1)[1].lstrip("::").strip()
|
||||
id_list = [i.strip() for i in ids_part.split(",") if i.strip()]
|
||||
exact_merge_total += len(id_list)
|
||||
if name and ent_type:
|
||||
key = (name, ent_type)
|
||||
dedup_impact[key] = dedup_impact.get(key, 0) + len(id_list)
|
||||
# 在第二层:统计新增出现次数(包含自合并,视为两实体比较后合并为一,至少+1)
|
||||
if current_layer == "第二层去重消歧":
|
||||
try:
|
||||
non_self = len([i for i in id_list if i != canonical_id]) if canonical_id else len(id_list)
|
||||
except Exception:
|
||||
non_self = len(id_list)
|
||||
add_cnt = non_self if non_self > 0 else 1
|
||||
second_layer_exact_additions[key] = second_layer_exact_additions.get(key, 0) + add_cnt
|
||||
except Exception:
|
||||
pass
|
||||
# 模糊合并:每条记录算一次合并
|
||||
elif line.startswith("[模糊] ") and "<- 合并实体" in line:
|
||||
fuzzy_merge_total += 1
|
||||
# 解析括号中的三元组 (group|name|type)
|
||||
try:
|
||||
m = re.search(r"规范实体[^\(]*\(([^|]+)\|([^|]+)\|([^\)]+)\)", line)
|
||||
if m:
|
||||
name = m.group(2).strip()
|
||||
ent_type = m.group(3).strip()
|
||||
key = (name, ent_type)
|
||||
dedup_impact[key] = dedup_impact.get(key, 0) + 1
|
||||
except Exception:
|
||||
pass
|
||||
# LLM 决策合并:每条记录算一次合并(包含 LLM融合/LLM合并 以及 “同名类型相似”的 LLM 去重)
|
||||
elif (line.startswith("[LLM融合]") or line.startswith("[LLM合并]")) and "<- 合并实体" in line:
|
||||
llm_merge_total += 1
|
||||
try:
|
||||
m = re.search(r"规范实体[^\(]*\(([^|]+)\|([^|]+)\|([^\)]+)\)", line)
|
||||
if m:
|
||||
name = m.group(2).strip()
|
||||
ent_type = m.group(3).strip()
|
||||
key = (name, ent_type)
|
||||
dedup_impact[key] = dedup_impact.get(key, 0) + 1
|
||||
except Exception:
|
||||
pass
|
||||
elif line.startswith("[LLM去重]"):
|
||||
# 例如:[LLM去重] 同名类型相似 A(TypeA)|B(TypeB) | conf=... | reason=...
|
||||
# 这类记录同样属于 LLM 决策的去重合并,计入 LLM 合并总数
|
||||
llm_merge_total += 1
|
||||
# 若同名类型相似(名称相同),按“名称”计一次出现(两实体合并为一)
|
||||
try:
|
||||
m = re.search(r"同名类型相似\s*([^((]+)[((][^))]+[))]\|([^((]+)[((][^))]+[))]", line)
|
||||
if m:
|
||||
left = m.group(1).strip()
|
||||
right = m.group(2).strip()
|
||||
if left and right and left == right:
|
||||
llm_same_name_additions[left] = llm_same_name_additions.get(left, 0) + 1
|
||||
except Exception:
|
||||
pass
|
||||
# 可选:解析名称与类型,当前不用于后续统计输出,保持简单
|
||||
# 若未来需要统计影响,可以解析左右两侧名称/类型并分别+1
|
||||
# 消歧阻断计数:仅统计 [DISAMB阻断],忽略异常阻断与合并应用
|
||||
elif line.startswith("[DISAMB阻断]"):
|
||||
disamb_block_total += 1
|
||||
# 解析形如:
|
||||
# [DISAMB阻断] A(TypeA)|B(TypeB) | conf=... | reason=... || block_pair=True
|
||||
try:
|
||||
m = re.search(r"\[DISAMB阻断\]\s*([^((]+)[((]([^))]+)[))]\|([^((]+)[((]([^))]+)[))]", line)
|
||||
if m:
|
||||
left_name = m.group(1).strip()
|
||||
left_type = m.group(2).strip()
|
||||
right_name = m.group(3).strip()
|
||||
right_type = m.group(4).strip()
|
||||
disamb_success_pairs.append((left_name, left_type, right_name, right_type))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_merged_count = exact_merge_total + fuzzy_merge_total + llm_merge_total
|
||||
disamb_total = disamb_block_total
|
||||
|
||||
# 4) 记忆片段数(分块器生成的 chunk 数量)
|
||||
memory_chunk_count = 0
|
||||
try:
|
||||
memory_chunk_count = len(chunk_nodes) if chunk_nodes is not None else 0
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 5) 关系三元组数(从文件头部“EXTRACTED TRIPLETS (N total)”解析)
|
||||
triplet_count = 0
|
||||
try:
|
||||
with open(triplets_path, "r", encoding="utf-8") as f:
|
||||
head = f.readline()
|
||||
m = re.search(r"EXTRACTED\s+TRIPLETS\s*\((\d+)\s+total\)", head)
|
||||
if m:
|
||||
triplet_count = int(m.group(1))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 写入结果文件
|
||||
# 构建 JSON 结构(字段顺序按用户需求组织:先“实体去重的影响”,后“实体消歧的效果”)
|
||||
readable_path = os.path.join(pipeline_output_dir, "extracted_result_readable.txt")
|
||||
summary_json = {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"entities": {
|
||||
"extracted_count": extracted_entity_count,
|
||||
},
|
||||
"dedup": {
|
||||
"total_merged_count": total_merged_count,
|
||||
"breakdown": {
|
||||
"exact": exact_merge_total,
|
||||
"fuzzy": fuzzy_merge_total,
|
||||
"llm": llm_merge_total,
|
||||
},
|
||||
"impact": [
|
||||
{
|
||||
"name": nm,
|
||||
"type": tp,
|
||||
"appear_count": (initial_name_counts.get(nm, 0)
|
||||
+ second_layer_exact_additions.get((nm, tp), 0)
|
||||
+ llm_same_name_additions.get(nm, 0)) if (initial_name_counts.get(nm, 0)
|
||||
+ second_layer_exact_additions.get((nm, tp), 0)
|
||||
+ llm_same_name_additions.get(nm, 0)) > 0 else merge_cnt,
|
||||
"merge_count": merge_cnt,
|
||||
}
|
||||
for (nm, tp), merge_cnt in (dedup_impact.items() if 'dedup_impact' in locals() else [])
|
||||
],
|
||||
},
|
||||
"disambiguation": {
|
||||
"block_count": disamb_block_total,
|
||||
"effects": [
|
||||
{
|
||||
"left": {"name": ln, "type": lt},
|
||||
"right": {"name": rn, "type": rt},
|
||||
"result": "成功区分"
|
||||
}
|
||||
for (ln, lt, rn, rt) in disamb_success_pairs
|
||||
],
|
||||
},
|
||||
"memory": {"chunks": memory_chunk_count},
|
||||
"triplets": {"count": triplet_count},
|
||||
"core_entities": [], # 将在下面填充
|
||||
"triplet_samples": [], # 将在下面填充
|
||||
}
|
||||
|
||||
# 解析实体和三元组数据(用于JSON和文本输出)
|
||||
entities_by_type = _parse_entities_from_triplets(triplets_path)
|
||||
triplets_list = _parse_triplets_from_file(triplets_path)
|
||||
|
||||
# 类型翻译映射
|
||||
type_translation = {
|
||||
'Person': '人物',
|
||||
'Organization': '组织',
|
||||
'Location': '地点',
|
||||
'Product': '产品',
|
||||
'Event': '事件',
|
||||
'Technology': '技术',
|
||||
'Activity': '活动',
|
||||
'Exercise': '运动'
|
||||
}
|
||||
|
||||
# 构建核心实体数据(按类型分组)
|
||||
core_entities_data = []
|
||||
for entity_type, entities in sorted(entities_by_type.items(), key=lambda x: -len(x[1])):
|
||||
type_name_cn = type_translation.get(entity_type, entity_type)
|
||||
core_entities_data.append({
|
||||
"type": entity_type,
|
||||
"type_cn": type_name_cn,
|
||||
"count": len(entities),
|
||||
"entities": entities[:5] # 最多显示5个
|
||||
})
|
||||
summary_json["core_entities"] = core_entities_data
|
||||
|
||||
# 构建三元组示例数据
|
||||
triplet_samples = []
|
||||
display_count = min(7, len(triplets_list))
|
||||
for i in range(display_count):
|
||||
triplet = triplets_list[i]
|
||||
predicate_cn = _format_predicate(triplet.get('predicate', ''))
|
||||
triplet_samples.append({
|
||||
"subject": triplet.get('subject', ''),
|
||||
"predicate": triplet.get('predicate', ''),
|
||||
"predicate_cn": predicate_cn,
|
||||
"object": triplet.get('object', '')
|
||||
})
|
||||
summary_json["triplet_samples"] = triplet_samples
|
||||
|
||||
# 写 JSON 到 extracted_result.json(满足"以 json 格式输出并为 .json 文件"的要求)
|
||||
with open(result_path, "w", encoding="utf-8") as f:
|
||||
json.dump(summary_json, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 额外生成可读版文本,模块顺序调整
|
||||
lines: list[str] = []
|
||||
lines.append(f"结果汇总 - {datetime.now().isoformat()}")
|
||||
lines.append("")
|
||||
# 提取实体数模块
|
||||
lines.append("提取实体数:")
|
||||
lines.append(f"总计 {extracted_entity_count} 个")
|
||||
lines.append(f"去重后合并个数:{total_merged_count} (精确={exact_merge_total},模糊={fuzzy_merge_total},LLM={llm_merge_total})")
|
||||
lines.append("")
|
||||
# 实体消歧次数模块
|
||||
lines.append("实体消歧次数:")
|
||||
lines.append(f"总计 {disamb_total} 次(阻断={disamb_block_total})")
|
||||
lines.append("")
|
||||
# 记忆片段数模块
|
||||
lines.append("记忆片段数:")
|
||||
lines.append(f"总计 {memory_chunk_count} 条")
|
||||
lines.append("")
|
||||
# 关系三元组数模块
|
||||
lines.append("关系三元组数:")
|
||||
lines.append(f"总计 {triplet_count} 条")
|
||||
lines.append("")
|
||||
|
||||
# 新增模块1:提取的核心实体(去重后)
|
||||
lines.append("提取的核心实体(去重后):")
|
||||
lines.append("")
|
||||
# 从 extracted_triplets.txt 解析去重后的实体并按类型分组
|
||||
entities_by_type = _parse_entities_from_triplets(triplets_path)
|
||||
type_translation = {
|
||||
'Person': '人物',
|
||||
'Organization': '组织',
|
||||
'Location': '地点',
|
||||
'Product': '产品',
|
||||
'Event': '事件',
|
||||
'Technology': '技术',
|
||||
'Activity': '活动',
|
||||
'Exercise': '运动'
|
||||
}
|
||||
for entity_type, entities in sorted(entities_by_type.items(), key=lambda x: -len(x[1])):
|
||||
type_name = type_translation.get(entity_type, entity_type)
|
||||
count = len(entities)
|
||||
lines.append(f"{type_name}({count}):")
|
||||
# 最多显示5个实体
|
||||
display_entities = entities[:5]
|
||||
for entity in display_entities:
|
||||
lines.append(f" • {entity}")
|
||||
lines.append("")
|
||||
|
||||
# 新增模块2:提取的关系三元组(部分)
|
||||
lines.append("提取的关系三元组(部分):")
|
||||
lines.append("")
|
||||
# 从 extracted_triplets.txt 读取三元组
|
||||
triplets = _parse_triplets_from_file(triplets_path)
|
||||
display_count = min(7, len(triplets))
|
||||
for i in range(display_count):
|
||||
triplet = triplets[i]
|
||||
predicate_cn = _format_predicate(triplet['predicate'])
|
||||
lines.append(f" • ({triplet['subject']}, {predicate_cn}, {triplet['object']})")
|
||||
lines.append("")
|
||||
lines.append(f"... 共{triplet_count}条关系三元组")
|
||||
lines.append("")
|
||||
|
||||
# 实体去重的影响模块(先输出)
|
||||
if dedup_impact:
|
||||
lines.append("实体去重的影响:")
|
||||
# 出现次数 = 初始提取次数 + 第二层精准合并新增次数(包含自合并至少+1) + LLM同名类型相似按名称的新增次数
|
||||
# 若某名称初始未出现但发生了合并(少见),退化为使用合并次数
|
||||
for (nm, tp), merge_cnt in dedup_impact.items():
|
||||
init_cnt = initial_name_counts.get(nm, 0)
|
||||
add_cnt = second_layer_exact_additions.get((nm, tp), 0)
|
||||
llm_add = llm_same_name_additions.get(nm, 0)
|
||||
appear_cnt = init_cnt + add_cnt + llm_add
|
||||
if appear_cnt <= 0:
|
||||
appear_cnt = merge_cnt
|
||||
lines.append(f"[{nm}]出现{appear_cnt}次 → 合并为1个类型是[{tp}]的实体")
|
||||
lines.append("")
|
||||
|
||||
# 新增模块:实体消歧的效果(后输出,来源于 dedup_entity_output.txt 的 DISAMB阻断 记录)
|
||||
if disamb_success_pairs:
|
||||
lines.append("实体消歧的效果:")
|
||||
for left_name, left_type, right_name, right_type in disamb_success_pairs:
|
||||
lines.append(f"{left_name}({left_type}) vs {right_name}({right_type}) → 成功区分。")
|
||||
lines.append("")
|
||||
|
||||
with open(readable_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(lines))
|
||||
|
||||
def export_test_input_doc(
|
||||
entity_nodes,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
):
|
||||
"""将提取出的实体与两类边导出到 extracted_entities_edges.txt。
|
||||
|
||||
保持与 extraction_pipeline.py 原本本地函数一致的行为与输出格式。
|
||||
"""
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
out_path = settings.get_memory_output_path("extracted_entities_edges.txt")
|
||||
|
||||
def _to_dict(m):
|
||||
d = m.model_dump()
|
||||
for k, v in list(d.items()):
|
||||
if isinstance(v, datetime):
|
||||
d[k] = v.isoformat()
|
||||
return d
|
||||
|
||||
def _entity_to_dict(e):
|
||||
return {
|
||||
"id": getattr(e, "id"),
|
||||
"name": getattr(e, "name"),
|
||||
"entity_type": getattr(e, "entity_type"),
|
||||
"description": getattr(e, "description"),
|
||||
}
|
||||
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
header_time = entity_nodes[0].created_at.isoformat()
|
||||
f.write(
|
||||
f"=== TEST EXTRACTED ENTITIES === (created_at: {header_time})\n"
|
||||
)
|
||||
for e in entity_nodes:
|
||||
f.write(
|
||||
"ENTITY: " + json.dumps(_entity_to_dict(e), ensure_ascii=False) + "\n"
|
||||
)
|
||||
|
||||
f.write("\n=== TEST STATEMENT-ENTITY EDGES ===\n")
|
||||
for se in statement_entity_edges:
|
||||
f.write("SE_EDGE: " + json.dumps(_to_dict(se), ensure_ascii=False) + "\n")
|
||||
|
||||
f.write("\n=== TEST ENTITY-ENTITY EDGES ===\n")
|
||||
for ee in entity_entity_edges:
|
||||
f.write("EE_EDGE: " + json.dumps(_to_dict(ee), ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Exported extracted entities & edges to: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to export test input doc: {e}")
|
||||
@@ -0,0 +1,8 @@
|
||||
"""遗忘引擎模块
|
||||
|
||||
该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线。
|
||||
"""
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
|
||||
__all__ = ["ForgettingEngine"]
|
||||
@@ -0,0 +1,271 @@
|
||||
"""遗忘引擎实现
|
||||
|
||||
该模块实现基于改进的艾宾浩斯遗忘曲线的记忆遗忘机制。
|
||||
|
||||
遗忘曲线公式:
|
||||
R(t, S) = offset + (1 - offset) * exp(-λ_time * t / (λ_mem * S))
|
||||
|
||||
其中:
|
||||
- R: 记忆保持率 (0 到 1)
|
||||
- t: 自学习以来经过的时间
|
||||
- S: 记忆强度(值越高表示记忆越强)
|
||||
- offset: 最小保持率(防止完全遗忘)
|
||||
- λ_time: 控制时间效应的 Lambda 参数
|
||||
- λ_mem: 控制记忆强度效应的 Lambda 参数
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
|
||||
|
||||
class ForgettingEngine:
|
||||
"""遗忘引擎 - 实现记忆遗忘机制
|
||||
|
||||
该引擎基于改进的艾宾浩斯遗忘曲线计算记忆保持率,
|
||||
结合时间衰减和记忆强度因素,支持可配置的遗忘行为。
|
||||
|
||||
Attributes:
|
||||
config: 遗忘引擎配置
|
||||
offset: 最小保持率(防止完全遗忘)
|
||||
lambda_time: 控制时间衰减效应的参数
|
||||
lambda_mem: 控制记忆强度效应的参数
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ForgettingEngineConfig] = None):
|
||||
"""初始化遗忘引擎
|
||||
|
||||
Args:
|
||||
config: ForgettingEngineConfig 实例,包含遗忘参数配置
|
||||
"""
|
||||
if config is None:
|
||||
config = ForgettingEngineConfig()
|
||||
|
||||
self.config = config
|
||||
self.offset = config.offset
|
||||
self.lambda_time = config.lambda_time
|
||||
self.lambda_mem = config.lambda_mem
|
||||
|
||||
def forgetting_curve(self, t: float, S: float) -> float:
|
||||
"""使用改进的艾宾浩斯遗忘曲线计算记忆保持率
|
||||
|
||||
公式: R = offset + (1-offset) * e^(-λ_time * t / (λ_mem * S))
|
||||
|
||||
Args:
|
||||
t: 自学习以来经过的时间
|
||||
S: 记忆的相对强度
|
||||
|
||||
Returns:
|
||||
记忆保持率,值在 0 到 1 之间
|
||||
"""
|
||||
if S <= 0:
|
||||
return self.offset
|
||||
|
||||
exponent = -self.lambda_time * t / (self.lambda_mem * S)
|
||||
retention = self.offset + (1 - self.offset) * math.exp(exponent)
|
||||
|
||||
# 确保保持率在 0 到 1 之间
|
||||
return max(0.0, min(1.0, retention))
|
||||
|
||||
def calculate_forgetting_score(self, time_elapsed: float, memory_strength: float) -> float:
|
||||
"""计算记忆项的遗忘分数
|
||||
|
||||
遗忘分数 = 1 - 保持率,值越高表示越容易被遗忘
|
||||
|
||||
Args:
|
||||
time_elapsed: 自记忆创建/最后访问以来的时间
|
||||
memory_strength: 记忆强度(值越高表示越难忘记)
|
||||
|
||||
Returns:
|
||||
遗忘分数,值在 0 到 1 之间
|
||||
"""
|
||||
retention = self.forgetting_curve(time_elapsed, memory_strength)
|
||||
return 1.0 - retention
|
||||
|
||||
def calculate_weight(self, time_elapsed: float, memory_strength: float) -> float:
|
||||
"""计算记忆项的权重(即保持率)
|
||||
|
||||
Args:
|
||||
time_elapsed: 自记忆创建/最后访问以来的时间
|
||||
memory_strength: 记忆强度(值越高表示越难忘记)
|
||||
|
||||
Returns:
|
||||
权重值,值在 0 到 1 之间
|
||||
"""
|
||||
return self.forgetting_curve(time_elapsed, memory_strength)
|
||||
|
||||
def apply_forgetting_weights(
|
||||
self,
|
||||
items: List[dict],
|
||||
time_key: str = 'time_elapsed',
|
||||
strength_key: str = 'strength'
|
||||
) -> List[dict]:
|
||||
"""为记忆项列表应用遗忘权重
|
||||
|
||||
Args:
|
||||
items: 包含记忆项的字典列表
|
||||
time_key: 每个项中时间经过的键名
|
||||
strength_key: 每个项中记忆强度的键名
|
||||
|
||||
Returns:
|
||||
添加了 'forgetting_weight' 字段的项列表
|
||||
"""
|
||||
weighted_items = []
|
||||
|
||||
for item in items:
|
||||
item_copy = item.copy()
|
||||
time_elapsed = item.get(time_key, 0)
|
||||
strength = item.get(strength_key, 1.0)
|
||||
|
||||
weight = self.calculate_weight(time_elapsed, strength)
|
||||
item_copy['forgetting_weight'] = weight
|
||||
|
||||
weighted_items.append(item_copy)
|
||||
|
||||
return weighted_items
|
||||
|
||||
def mark_items_for_forgetting(
|
||||
self,
|
||||
items: List[dict],
|
||||
forgetting_threshold: float = 0.5,
|
||||
time_key: str = 'time_elapsed',
|
||||
strength_key: str = 'strength'
|
||||
) -> tuple[List[dict], List[dict]]:
|
||||
"""标记应该被遗忘的记忆项
|
||||
|
||||
Args:
|
||||
items: 包含记忆项的字典列表
|
||||
forgetting_threshold: 遗忘阈值,遗忘分数超过此值的项将被标记
|
||||
time_key: 每个项中时间经过的键名
|
||||
strength_key: 每个项中记忆强度的键名
|
||||
|
||||
Returns:
|
||||
元组 (应保留的项列表, 应遗忘的项列表)
|
||||
"""
|
||||
to_keep = []
|
||||
to_forget = []
|
||||
|
||||
for item in items:
|
||||
time_elapsed = item.get(time_key, 0)
|
||||
strength = item.get(strength_key, 1.0)
|
||||
|
||||
forgetting_score = self.calculate_forgetting_score(time_elapsed, strength)
|
||||
|
||||
item_copy = item.copy()
|
||||
item_copy['forgetting_score'] = forgetting_score
|
||||
|
||||
if forgetting_score > forgetting_threshold:
|
||||
to_forget.append(item_copy)
|
||||
else:
|
||||
to_keep.append(item_copy)
|
||||
|
||||
return to_keep, to_forget
|
||||
|
||||
def get_forgetting_statistics(
|
||||
self,
|
||||
items: List[dict],
|
||||
forgetting_threshold: float = 0.5,
|
||||
time_key: str = 'time_elapsed',
|
||||
strength_key: str = 'strength'
|
||||
) -> Dict[str, Any]:
|
||||
"""获取记忆项的遗忘统计信息
|
||||
|
||||
Args:
|
||||
items: 包含记忆项的字典列表
|
||||
forgetting_threshold: 遗忘阈值
|
||||
time_key: 每个项中时间经过的键名
|
||||
strength_key: 每个项中记忆强度的键名
|
||||
|
||||
Returns:
|
||||
包含统计信息的字典:
|
||||
- total_items: 总项数
|
||||
- items_to_keep: 应保留的项数
|
||||
- items_to_forget: 应遗忘的项数
|
||||
- forgetting_rate: 遗忘率
|
||||
- average_retention: 平均保持率
|
||||
- average_forgetting_score: 平均遗忘分数
|
||||
"""
|
||||
if not items:
|
||||
return {
|
||||
"total_items": 0,
|
||||
"items_to_keep": 0,
|
||||
"items_to_forget": 0,
|
||||
"forgetting_rate": 0.0,
|
||||
"average_retention": 0.0,
|
||||
"average_forgetting_score": 0.0
|
||||
}
|
||||
|
||||
to_keep, to_forget = self.mark_items_for_forgetting(
|
||||
items, forgetting_threshold, time_key, strength_key
|
||||
)
|
||||
|
||||
total = len(items)
|
||||
keep_count = len(to_keep)
|
||||
forget_count = len(to_forget)
|
||||
|
||||
# 计算平均保持率和遗忘分数
|
||||
total_retention = 0.0
|
||||
total_forgetting_score = 0.0
|
||||
|
||||
for item in items:
|
||||
time_elapsed = item.get(time_key, 0)
|
||||
strength = item.get(strength_key, 1.0)
|
||||
|
||||
retention = self.calculate_weight(time_elapsed, strength)
|
||||
forgetting_score = self.calculate_forgetting_score(time_elapsed, strength)
|
||||
|
||||
total_retention += retention
|
||||
total_forgetting_score += forgetting_score
|
||||
|
||||
avg_retention = total_retention / total
|
||||
avg_forgetting_score = total_forgetting_score / total
|
||||
|
||||
return {
|
||||
"total_items": total,
|
||||
"items_to_keep": keep_count,
|
||||
"items_to_forget": forget_count,
|
||||
"forgetting_rate": forget_count / total,
|
||||
"average_retention": avg_retention,
|
||||
"average_forgetting_score": avg_forgetting_score
|
||||
}
|
||||
|
||||
def calculate_time_elapsed_days(
|
||||
self,
|
||||
created_at: datetime,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> float:
|
||||
"""计算经过的天数
|
||||
|
||||
Args:
|
||||
created_at: 创建时间
|
||||
current_time: 当前时间,如果为 None 则使用当前系统时间
|
||||
|
||||
Returns:
|
||||
经过的天数(浮点数)
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
time_diff = current_time - created_at
|
||||
return time_diff.total_seconds() / (24 * 3600)
|
||||
|
||||
def calculate_time_elapsed_hours(
|
||||
self,
|
||||
created_at: datetime,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> float:
|
||||
"""计算经过的小时数
|
||||
|
||||
Args:
|
||||
created_at: 创建时间
|
||||
current_time: 当前时间,如果为 None 则使用当前系统时间
|
||||
|
||||
Returns:
|
||||
经过的小时数(浮点数)
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
time_diff = current_time - created_at
|
||||
return time_diff.total_seconds() / 3600
|
||||
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Memory Strength Calculator based on ACT-R Theory
|
||||
|
||||
This module implements the Base-Level Activation equation from ACT-R
|
||||
(Adaptive Control of Thought-Rational) cognitive architecture.
|
||||
|
||||
Formula: B(i) = ln(Σ(t_k^(-d)))
|
||||
|
||||
Where:
|
||||
- B(i): Base-level activation score
|
||||
- t_k: Time since the k-th access
|
||||
- d: Decay parameter (typically 0.5)
|
||||
- n: Number of accesses
|
||||
|
||||
Reference: Anderson, J. R. (2007). How Can the Human Mind Occur in the Physical Universe?
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class MemoryStrengthCalculator:
|
||||
"""
|
||||
Calculate memory strength using ACT-R base-level activation formula.
|
||||
"""
|
||||
|
||||
def __init__(self, decay_parameter: float = 0.5, time_unit: str = "seconds"):
|
||||
"""
|
||||
Initialize the memory strength calculator.
|
||||
|
||||
Args:
|
||||
decay_parameter: The decay rate (d). Typically 0.5 for human memory.
|
||||
Higher values = faster forgetting.
|
||||
time_unit: Unit for time calculations. Options: 'seconds', 'minutes',
|
||||
'hours', 'days'. Default is 'seconds'.
|
||||
"""
|
||||
self.decay_parameter = decay_parameter
|
||||
self.time_unit = time_unit
|
||||
self._time_multipliers = {
|
||||
"seconds": 1,
|
||||
"minutes": 60,
|
||||
"hours": 3600,
|
||||
"days": 86400,
|
||||
}
|
||||
|
||||
def calculate_activation(
|
||||
self, access_times: List[datetime], current_time: Optional[datetime] = None
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the base-level activation B(i) for a memory item.
|
||||
|
||||
Args:
|
||||
access_times: List of datetime objects representing when the memory
|
||||
was accessed (most recent first or in any order).
|
||||
current_time: The current time for calculation. If None, uses datetime.now().
|
||||
|
||||
Returns:
|
||||
float: The base-level activation score B(i).
|
||||
Higher values indicate stronger, more retrievable memories.
|
||||
|
||||
Raises:
|
||||
ValueError: If access_times is empty or contains invalid data.
|
||||
"""
|
||||
if not access_times:
|
||||
raise ValueError("access_times cannot be empty")
|
||||
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
# Calculate time differences in specified units
|
||||
time_diffs = []
|
||||
for access_time in access_times:
|
||||
diff_seconds = (current_time - access_time).total_seconds()
|
||||
if diff_seconds < 0:
|
||||
raise ValueError(f"Access time {access_time} is in the future")
|
||||
|
||||
# Convert to specified time unit
|
||||
diff = diff_seconds / self._time_multipliers[self.time_unit]
|
||||
|
||||
# Avoid division by zero for very recent accesses
|
||||
# Use a small epsilon (0.01 time units)
|
||||
diff = max(diff, 0.01)
|
||||
time_diffs.append(diff)
|
||||
|
||||
# Calculate B(i) = ln(Σ(t_k^(-d)))
|
||||
sum_power_law = sum(t ** (-self.decay_parameter) for t in time_diffs)
|
||||
activation = math.log(sum_power_law)
|
||||
|
||||
return activation
|
||||
|
||||
def calculate_activation_from_intervals(
|
||||
self, time_intervals: List[float]
|
||||
) -> float:
|
||||
"""
|
||||
Calculate activation directly from time intervals (in the configured time unit).
|
||||
|
||||
Args:
|
||||
time_intervals: List of time intervals since each access.
|
||||
E.g., [1.0, 3.5, 7.2] means accessed 1, 3.5, and 7.2 time units ago.
|
||||
|
||||
Returns:
|
||||
float: The base-level activation score B(i).
|
||||
"""
|
||||
if not time_intervals:
|
||||
raise ValueError("time_intervals cannot be empty")
|
||||
|
||||
# Ensure no zero or negative intervals
|
||||
safe_intervals = [max(t, 0.01) for t in time_intervals]
|
||||
|
||||
sum_power_law = sum(t ** (-self.decay_parameter) for t in safe_intervals)
|
||||
activation = math.log(sum_power_law)
|
||||
|
||||
return activation
|
||||
|
||||
def calculate_memory_strength(self, activation: float) -> float:
|
||||
"""
|
||||
Convert activation score to memory strength S(i) = e^(B(i)).
|
||||
|
||||
This converts the log-space activation to linear space,
|
||||
suitable for use in the Ebbinghaus forgetting curve.
|
||||
|
||||
Args:
|
||||
activation: The base-level activation B(i).
|
||||
|
||||
Returns:
|
||||
float: Memory strength S(i) in linear space.
|
||||
"""
|
||||
return math.exp(activation)
|
||||
|
||||
def calculate_retention_probability(
|
||||
self,
|
||||
activation: float,
|
||||
time_since_last_access: float,
|
||||
decay_rate: float = 0.01,
|
||||
offset: float = 0.1,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate retention probability using the unified Ebbinghaus-ACT-R formula.
|
||||
|
||||
Formula: R(i) = offset + (1-offset) * exp(-λ * t / Σ(t_k^(-d)))
|
||||
|
||||
Args:
|
||||
activation: The base-level activation B(i).
|
||||
time_since_last_access: Time since last access (in configured time units).
|
||||
decay_rate: Lambda (λ) parameter controlling forgetting speed.
|
||||
offset: Baseline retention rate (minimum memory strength).
|
||||
|
||||
Returns:
|
||||
float: Retention probability between 0 and 1.
|
||||
"""
|
||||
memory_strength = self.calculate_memory_strength(activation)
|
||||
|
||||
# Unified formula: R(i) = offset + (1-offset) * exp(-λ * t / S(i))
|
||||
retention = offset + (1 - offset) * math.exp(
|
||||
-decay_rate * time_since_last_access / memory_strength
|
||||
)
|
||||
|
||||
return retention
|
||||
|
||||
def should_retain(
|
||||
self,
|
||||
access_times: List[datetime],
|
||||
threshold: float = 0.5,
|
||||
current_time: Optional[datetime] = None,
|
||||
decay_rate: float = 0.01,
|
||||
offset: float = 0.1,
|
||||
) -> tuple[bool, float, float]:
|
||||
"""
|
||||
Determine if a memory should be retained based on its strength.
|
||||
|
||||
Args:
|
||||
access_times: List of access timestamps.
|
||||
threshold: Retention probability threshold (default 0.5 = 50%).
|
||||
current_time: Current time for calculation.
|
||||
decay_rate: Lambda parameter for forgetting curve.
|
||||
offset: Baseline retention rate.
|
||||
|
||||
Returns:
|
||||
tuple: (should_retain: bool, retention_probability: float, activation: float)
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
activation = self.calculate_activation(access_times, current_time)
|
||||
|
||||
# Time since last access
|
||||
last_access = max(access_times)
|
||||
time_since_last = (current_time - last_access).total_seconds() / self._time_multipliers[self.time_unit]
|
||||
time_since_last = max(time_since_last, 0.01)
|
||||
|
||||
retention_prob = self.calculate_retention_probability(
|
||||
activation, time_since_last, decay_rate, offset
|
||||
)
|
||||
|
||||
return (retention_prob >= threshold, retention_prob, activation)
|
||||
|
||||
|
||||
# Convenience functions for quick calculations
|
||||
def calculate_activation(
|
||||
access_times: List[datetime],
|
||||
decay_parameter: float = 0.5,
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Quick function to calculate activation without creating a calculator instance.
|
||||
|
||||
Args:
|
||||
access_times: List of access timestamps.
|
||||
decay_parameter: Decay rate (default 0.5).
|
||||
current_time: Current time (default now).
|
||||
|
||||
Returns:
|
||||
float: Base-level activation B(i).
|
||||
"""
|
||||
calculator = MemoryStrengthCalculator(decay_parameter=decay_parameter)
|
||||
return calculator.calculate_activation(access_times, current_time)
|
||||
|
||||
|
||||
def calculate_retention(
|
||||
access_times: List[datetime],
|
||||
decay_parameter: float = 0.5,
|
||||
decay_rate: float = 0.01,
|
||||
offset: float = 0.1,
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Quick function to calculate retention probability.
|
||||
|
||||
Args:
|
||||
access_times: List of access timestamps.
|
||||
decay_parameter: ACT-R decay parameter (default 0.5).
|
||||
decay_rate: Ebbinghaus decay rate lambda (default 0.01).
|
||||
offset: Baseline retention (default 0.1).
|
||||
current_time: Current time (default now).
|
||||
|
||||
Returns:
|
||||
float: Retention probability between 0 and 1.
|
||||
"""
|
||||
calculator = MemoryStrengthCalculator(decay_parameter=decay_parameter)
|
||||
activation = calculator.calculate_activation(access_times, current_time)
|
||||
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
last_access = max(access_times)
|
||||
time_since_last = (current_time - last_access).total_seconds()
|
||||
|
||||
return calculator.calculate_retention_probability(
|
||||
activation, time_since_last, decay_rate, offset
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
自我反思引擎模块
|
||||
|
||||
该模块实现了记忆系统的自我反思功能,包括:
|
||||
- 基于时间的反思
|
||||
- 基于事实的反思(冲突检测)
|
||||
- 综合反思
|
||||
- 反思结果应用
|
||||
"""
|
||||
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||
ReflectionEngine,
|
||||
ReflectionConfig,
|
||||
ReflectionResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ReflectionEngine",
|
||||
"ReflectionConfig",
|
||||
"ReflectionResult",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user