feat: Add base project structure with API and web components

This commit is contained in:
Ke Sun
2025-12-02 20:28:01 +08:00
parent f3de6d6cc9
commit c1adc62ec6
817 changed files with 111226 additions and 106 deletions

View File

View File

@@ -0,0 +1,35 @@
from pydantic import BaseModel
from app.core.agent.agent_chat import Agent_chat
from app.core.logging_config import get_business_logger
from fastapi import APIRouter, Depends, HTTPException
from app.dependencies import workspace_access_guard
from app.services.agent_server import config,ChatRequest
router = APIRouter(prefix="/Test", tags=["Apps"])
logger = get_business_logger()
class CombinedRequest(BaseModel):
config_base: config
agent_config: ChatRequest
@router.post("", summary="uuid")
async def agent_chat(
config_base: CombinedRequest
):
chat_config=config_base.agent_config
chat_base=config_base.config_base
request = ChatRequest(
end_user_id=chat_config.end_user_id,
message=chat_config.message,
search_switch=chat_config.search_switch,
kb_ids=chat_config.kb_ids,
similarity_threshold=chat_config.similarity_threshold,
vector_similarity_weight=chat_config.vector_similarity_weight,
top_k=chat_config.top_k,
hybrid=chat_config.hybrid,
token=chat_config.token
)
chat_result=await Agent_chat(chat_base).chat(request)
return chat_result

View File

@@ -0,0 +1,109 @@
import asyncio
import os
import time
from typing import Dict, Any, List
from app.core.logging_config import get_business_logger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.api_resquests_server import messages_type, write_messages
from app.services.agent_server import ChatRequest, tool_memory, create_dynamic_agent, tool_Retrieval
logger = get_business_logger()
class Agent_chat:
def __init__(self,config_data: dict):
self.prompt_message = render_prompt_message(
config_data.template_str,
PromptMessageRole.USER,
config_data.params
)
self.prompt = self.prompt_message.get_text_content()
self.model_configs = config_data.model_configs
self.history_memory = config_data.history_memory
self.knowledge_base = config_data.knowledge_base
logger.info(f"渲染结果:{self.prompt_message.get_text_content()}" )
async def run_agent(self,agent, end_user_id:str, user_prompt:str, model_name:str):
response = agent.invoke(
{
"messages": [
{
"role": "user",
"content": user_prompt
}
]
},
{"configurable": {"thread_id": f'{model_name}_{end_user_id}'}},
)
outputs = []
for msg in response["messages"]:
if hasattr(msg, "tool_calls") and msg.tool_calls:
outputs.append({
"role": "assistant",
"tool_calls": [
{"name": t["name"], "arguments": t["args"]}
for t in msg.tool_calls
]
})
elif hasattr(msg, "content") and msg.content:
outputs.append({
"role": msg.__class__.__name__.lower().replace("message", ""),
"content": msg.content
})
ai_messages=[msg['content'] for msg in outputs if msg["role"] == "ai"]
return {"model_name": model_name, "end_user_id": end_user_id, "response": ai_messages}
async def chat(self,req: ChatRequest) -> Dict[str, Any]:
end_user_id = req.end_user_id # 用 user_id 作为对话线程标识
start=time.time()
user_prompt = req.message
'''判断是都写入redis数据库'''
messags_type = await messages_type(req.message,end_user_id)
messags_type=messags_type['data']
if messags_type=='question':
writer_result=await write_messages(f'{end_user_id}', req.message)
logger.info(f'判断类型写入耗时:{time.time() - start},{writer_result}')
'''history_memory'''
if self.history_memory==True:
tool_result =await tool_memory(req)
if tool_result!='' :tool_result=tool_result['data']
if tool_result!='' :self.prompt=self.prompt+f''',历史消息:{tool_result},结合历史消息'''
logger.info(f"记忆科学消耗时间:{time.time()-start},工具调用结果:{tool_result}")
'''baidu'''
'''knowledge_base'''
if self.knowledge_base == True:
retrieval_result=await tool_Retrieval(req)
retrieval_knowledge = [i['page_content'] for i in retrieval_result['data']]
retrieval_knowledge=','.join(retrieval_knowledge)
logger.info(f"检索消耗时间:{time.time()-start},{retrieval_knowledge}")
if retrieval_knowledge!='' :self.prompt=self.prompt+f",知识库检索内容:{retrieval_knowledge},结合检索结果"
self.prompt=self.prompt+f'给出最合适的答案,确保答案的完整性,只保留用户的问题的回答,不额外输出提示语'
logger.info(f"用户输入:{user_prompt}")
logger.info(f"系统prompt{self.prompt}")
AGENTS = {
cfg["name"]: await create_dynamic_agent(cfg["name"], cfg["moder_id"], self.prompt, req.token)
for cfg in self.model_configs
}
tasks=[
self.run_agent(agent, end_user_id, user_prompt, model_name)
for model_name, agent in AGENTS.items()
]
# 并行运行
results = await asyncio.gather(*tasks)
result=[]
for i in results:
result.append(i)
chat_result=(f"最终耗时:{time.time()-start},{result}")
return chat_result

View File

@@ -0,0 +1,347 @@
"""
LangChain Agent 封装
使用 LangChain 1.x 标准方式
- 使用 create_agent 创建 agent graph
- 支持工具调用循环
- 支持流式输出
- 使用 RedBearLLM 支持多提供商
"""
import os
import time
import asyncio
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.tools import BaseTool
from langchain.agents import create_agent
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.core.logging_config import get_business_logger
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
logger = get_business_logger()
class LangChainAgent:
def __init__(
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False
):
"""初始化 LangChain Agent
Args:
model_name: 模型名称
api_key: API Key
provider: 提供商openai, xinference, gpustack, ollama, dashscope
api_base: API 基础 URL
temperature: 温度参数
max_tokens: 最大 token 数
system_prompt: 系统提示词
tools: 工具列表(可选,框架自动走 ReAct 循环)
streaming: 是否启用流式输出(默认 True
"""
self.model_name = model_name
self.provider = provider
self.system_prompt = system_prompt or "你是一个专业的AI助手"
self.tools = tools or []
self.streaming = streaming
# 创建 RedBearLLM支持多提供商
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
"streaming": streaming # 使用参数控制流式
}
)
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
# 获取底层模型用于真正的流式调用
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
# 确保底层模型也启用流式
if streaming and hasattr(self._underlying_llm, 'streaming'):
self._underlying_llm.streaming = True
# 使用 create_agent 创建 agent graphLangChain 1.x 标准方式)
# 无论是否有工具,都使用 agent 统一处理
self.agent = create_agent(
model=self.llm,
tools=self.tools if self.tools else None,
system_prompt=self.system_prompt
)
logger.info(
f"LangChain Agent 初始化完成",
extra={
"model": model_name,
"provider": provider,
"has_api_base": bool(api_base),
"temperature": temperature,
"streaming": streaming,
"tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
"tool_count": len(self.tools)
}
)
def _prepare_messages(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None
) -> List[BaseMessage]:
"""准备消息列表
Args:
message: 用户消息
history: 历史消息列表
context: 上下文信息
Returns:
List[BaseMessage]: 消息列表
"""
messages = []
# 添加系统提示词
messages.append(SystemMessage(content=self.system_prompt))
# 添加历史消息
if history:
for msg in history:
if msg["role"] == "user":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "assistant":
messages.append(AIMessage(content=msg["content"]))
# 添加当前用户消息
user_content = message
if context:
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
messages.append(HumanMessage(content=user_content))
return messages
async def chat(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None, # 添加这个参数
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""执行对话
Args:
message: 用户消息
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
context: 上下文信息(如知识库检索结果)
Returns:
Dict: 包含 content 和元数据的字典
"""
start_time = time.time()
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
if config_id==None:
actual_config_id = os.getenv("config_id")
else:actual_config_id=config_id
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
logger.debug(
f"准备调用 LangChain Agent",
extra={
"has_context": bool(context),
"has_history": bool(history),
"has_tools": bool(self.tools),
"message_count": len(messages)
}
)
# 统一使用 agent.invoke 调用
result = await self.agent.ainvoke({"messages": messages})
# 获取最后的 AI 消息
output_messages = result.get("messages", [])
content = ""
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
content = msg.content
break
elapsed_time = time.time() - start_time
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type, user_rag_memory_id)
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
response = {
"content": content,
"model": self.model_name,
"elapsed_time": elapsed_time,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
logger.debug(
f"Agent 调用完成",
extra={
"elapsed_time": elapsed_time,
"content_length": len(response["content"])
}
)
return response
except Exception as e:
logger.error(f"Agent 调用失败", extra={"error": str(e)})
raise
async def chat_stream(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id:Optional[str] = None,
config_id: Optional[str] = None,
storage_type:Optional[str] = None,
user_rag_memory_id:Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""执行流式对话
Args:
message: 用户消息
history: 历史消息列表
context: 上下文信息
Yields:
str: 消息内容块
"""
logger.info("=" * 80)
logger.info(f" chat_stream 方法开始执行")
logger.info(f" Message: {message[:100]}")
logger.info(f" Has tools: {bool(self.tools)}")
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80)
start_time = time.time()
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
else:
if config_id==None:
actual_config_id = os.getenv("config_id")
else:actual_config_id=config_id
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
try:
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
except Exception as e:
logger.error(f"Agent 记忆用户输入出错", extra={"error": str(e)})
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
logger.debug(
f"准备流式调用has_tools={bool(self.tools)}, message_count={len(messages)}"
)
chunk_count = 0
yielded_content = False
# 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出")
try:
async for event in self.agent.astream_events(
{"messages": messages},
version="v2"
):
chunk_count += 1
kind = event.get("event")
# 处理所有可能的流式事件
if kind == "on_chat_model_stream":
# LLM 流式输出
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content") and chunk.content:
yield chunk.content
yielded_content = True
elif kind == "on_llm_stream":
# 另一种 LLM 流式事件
chunk = event.get("data", {}).get("chunk")
if chunk:
if hasattr(chunk, "content") and chunk.content:
yield chunk.content
yielded_content = True
elif isinstance(chunk, str):
yield chunk
yielded_content = True
# 记录工具调用(可选)
elif kind == "on_tool_start":
logger.debug(f"工具调用开始: {event.get('name')}")
elif kind == "on_tool_end":
logger.debug(f"工具调用结束: {event.get('name')}")
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise
except Exception as e:
logger.error("=" * 80)
logger.error(f"chat_stream 异常: {str(e)}")
logger.error("=" * 80, exc_info=True)
raise
finally:
logger.info("=" * 80)
logger.info(f"chat_stream 方法执行结束")
logger.info("=" * 80)

View File

@@ -0,0 +1,56 @@
"""API Key 工具函数"""
import secrets
import hashlib
from app.models.api_key_model import ApiKeyType
def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
"""生成 API Key
Args:
key_type: API Key 类型
Returns:
tuple: (api_key, key_hash, key_prefix)
"""
# 前缀映射
prefix_map = {
ApiKeyType.APP: "sk-app-",
ApiKeyType.RAG: "sk-rag-",
ApiKeyType.MEMORY: "sk-mem-",
ApiKeyType.GENERAL: "sk-gen-",
}
prefix = prefix_map[key_type]
random_string = secrets.token_urlsafe(32)[:32] # 32 字符
api_key = f"{prefix}{random_string}"
# 生成哈希值存储
key_hash = hash_api_key(api_key)
return api_key, key_hash, prefix
def hash_api_key(api_key: str) -> str:
"""对 API Key 进行哈希
Args:
api_key: API Key 明文
Returns:
str: 哈希值
"""
return hashlib.sha256(api_key.encode()).hexdigest()
def verify_api_key(api_key: str, key_hash: str) -> bool:
"""验证 API Key
Args:
api_key: API Key 明文
key_hash: 存储的哈希值
Returns:
bool: 是否匹配
"""
return hash_api_key(api_key) == key_hash

View File

@@ -0,0 +1,47 @@
"""
Compensation Transaction Handler
Handles operations that cannot be rolled back (like file system operations).
"""
from typing import List, Callable
from app.core.logging_config import get_logger
logger = get_logger(__name__)
class CompensationHandler:
"""补偿事务处理器,用于处理无法回滚的操作"""
def __init__(self):
self._compensations: List[Callable] = []
def register(self, compensation: Callable):
"""
注册补偿操作
Args:
compensation: 补偿操作的可调用对象
"""
self._compensations.append(compensation)
logger.debug(f"Registered compensation operation: {compensation.__name__ if hasattr(compensation, '__name__') else 'lambda'}")
def execute(self):
"""执行所有补偿操作(按注册的逆序执行)"""
if not self._compensations:
logger.debug("No compensation operations to execute")
return
logger.info(f"Executing {len(self._compensations)} compensation operations")
for compensation in reversed(self._compensations):
try:
compensation()
logger.debug(f"Compensation operation executed successfully")
except Exception as e:
logger.error(f"补偿操作失败: {e}", exc_info=True)
def clear(self):
"""清空补偿操作"""
count = len(self._compensations)
self._compensations.clear()
if count > 0:
logger.debug(f"Cleared {count} compensation operations")

237
api/app/core/config.py Normal file
View File

@@ -0,0 +1,237 @@
import os
import json
from pathlib import Path
from typing import Dict, Any, Optional
from dotenv import load_dotenv
load_dotenv()
class Settings:
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
# Neo4j Configuration (记忆系统数据库)
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
# Database configuration (Postgres)
DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1")
DB_PORT: int = int(os.getenv("DB_PORT", "5432"))
DB_USER: str = os.getenv("DB_USER", "postgres")
DB_PASSWORD: str = os.getenv("DB_PASSWORD", "password")
DB_NAME: str = os.getenv("DB_NAME", "redbear-mem")
DB_AUTO_UPGRADE = os.getenv("DB_AUTO_UPGRADE", "false").lower() == "true"
# Redis configuration
REDIS_HOST: str = os.getenv("REDIS_HOST", "127.0.0.1")
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
# ElasticSearch configuration
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
ELASTICSEARCH_USERNAME: str = os.getenv("ELASTICSEARCH_USERNAME", "elastic")
ELASTICSEARCH_PASSWORD: str = os.getenv("ELASTICSEARCH_PASSWORD", "")
ELASTICSEARCH_VERIFY_CERTS: bool = os.getenv("ELASTICSEARCH_VERIFY_CERTS", "False").lower() == "true"
ELASTICSEARCH_CA_CERTS: str = os.getenv("ELASTICSEARCH_CA_CERTS", "")
ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000"))
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true"
ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10"))
# Xinference configuration
XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1")
# LangSmith configuration
LANGCHAIN_TRACING_V2: bool = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true"
LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true"
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "")
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "")
# LLM Request Configuration
LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0"))
LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2"))
# JWT Token Configuration
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random")
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
# Single Sign-On configuration
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
# File Upload
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
# VOLC ASR settings
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
VOLC_ACCESS_KEY: str = os.getenv("VOLC_ACCESS_KEY", "")
VOLC_SUBMIT_URL: str = os.getenv("VOLC_SUBMIT_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/submit")
VOLC_QUERY_URL: str = os.getenv("VOLC_QUERY_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/query")
# Langfuse configuration
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
# Server Configuration
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
# ========================================================================
# Internal Configuration (not in .env, used by application code)
# ========================================================================
# Superuser settings (internal defaults)
FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com")
FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin")
FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password")
# Generic File Upload (internal)
GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads")
ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true"
ENABLE_VIRUS_SCAN: bool = os.getenv("ENABLE_VIRUS_SCAN", "false").lower() == "true"
FILE_ACCESS_URL_PREFIX: str = os.getenv("FILE_ACCESS_URL_PREFIX", "http://localhost:8000/api/files")
# Frontend URL for workspace invitations (internal)
WEB_URL: str = os.getenv("WEB_URL", "http://localhost:3000")
# CORS configuration (internal)
CORS_ORIGINS: list[str] = [
origin.strip()
for origin in os.getenv("CORS_ORIGINS", "").split(",")
if origin.strip()
]
# Logging settings
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
LOG_FILE_PATH: str = os.getenv("LOG_FILE_PATH", "logs/app.log")
LOG_MAX_SIZE: int = int(os.getenv("LOG_MAX_SIZE", "10485760")) # 10MB
LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5"))
LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true"
LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true"
# Sensitive Data Filtering
ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true"
# Memory Module Logging
PROMPT_LOG_LEVEL: str = os.getenv("PROMPT_LOG_LEVEL", "INFO")
ENABLE_TEMPLATE_LOGGING: bool = os.getenv("ENABLE_TEMPLATE_LOGGING", "false").lower() == "true"
TIMING_LOG_FILE: str = os.getenv("TIMING_LOG_FILE", "logs/time.log")
TIMING_LOG_TO_CONSOLE: bool = os.getenv("TIMING_LOG_TO_CONSOLE", "true").lower() == "true"
AGENT_LOG_FILE: str = os.getenv("AGENT_LOG_FILE", "logs/agent_service.log")
AGENT_LOG_MAX_SIZE: int = int(os.getenv("AGENT_LOG_MAX_SIZE", "5242880")) # 5MB
AGENT_LOG_BACKUP_COUNT: int = int(os.getenv("AGENT_LOG_BACKUP_COUNT", "20"))
# Log Streaming Configuration
LOG_STREAM_KEEPALIVE_INTERVAL: int = int(os.getenv("LOG_STREAM_KEEPALIVE_INTERVAL", "300")) # 5 minutes
LOG_STREAM_MAX_CONNECTIONS: int = int(os.getenv("LOG_STREAM_MAX_CONNECTIONS", "10"))
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
MEMORY_CONFIG_FILE: str = os.getenv("MEMORY_CONFIG_FILE", "config.json")
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
def get_memory_output_path(self, filename: str = "") -> str:
"""
Get the full path for memory module output files.
Args:
filename: Optional filename to append to the output directory
Returns:
Full path to the output file or directory
"""
base_path = Path(self.MEMORY_OUTPUT_DIR)
if filename:
return str(base_path / filename)
return str(base_path)
def get_memory_config_path(self, config_file: str = "") -> str:
"""
Get the full path for memory module configuration files.
Args:
config_file: Optional config filename (defaults to MEMORY_CONFIG_FILE)
Returns:
Full path to the config file
"""
if not config_file:
config_file = self.MEMORY_CONFIG_FILE
return str(Path(self.MEMORY_CONFIG_DIR) / config_file)
def load_memory_config(self) -> Dict[str, Any]:
"""
Load memory module configuration from config.json.
Returns:
Dictionary containing memory configuration
"""
config_path = self.get_memory_config_path(self.MEMORY_CONFIG_FILE)
try:
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory config file not found or malformed at {config_path}. Error: {e}")
return {}
def load_memory_runtime_config(self) -> Dict[str, Any]:
"""
Load memory module runtime configuration from runtime.json.
Returns:
Dictionary containing runtime configuration
"""
runtime_path = self.get_memory_config_path(self.MEMORY_RUNTIME_FILE)
try:
with open(runtime_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory runtime config not found or malformed at {runtime_path}. Error: {e}")
return {"selections": {}}
def load_memory_dbrun_config(self) -> Dict[str, Any]:
"""
Load memory module database run configuration from dbrun.json.
Returns:
Dictionary containing dbrun configuration
"""
dbrun_path = self.get_memory_config_path(self.MEMORY_DBRUN_FILE)
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory dbrun config not found or malformed at {dbrun_path}. Error: {e}")
return {"selections": {}}
def ensure_memory_output_dir(self) -> None:
"""
Ensure the memory output directory exists.
Creates the directory if it doesn't exist.
"""
output_dir = Path(self.MEMORY_OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
settings = Settings()

130
api/app/core/error_codes.py Normal file
View File

@@ -0,0 +1,130 @@
from enum import IntEnum
class BizCode(IntEnum):
# 通用1xxx
OK = 0
BAD_REQUEST = 1000
VALIDATION_FAILED = 1001
MISSING_PARAMETER = 1002
INVALID_PARAMETER = 1003
# 认证/鉴权2xxx/3xxx
UNAUTHORIZED = 2001
TOKEN_INVALID = 2002
TOKEN_EXPIRED = 2003
TOKEN_BLACKLISTED = 2004
PASSWORD_ERROR = 2005
LOGIN_FAILED = 2006
FORBIDDEN = 3001
TENANT_NOT_FOUND = 3002
WORKSPACE_NO_ACCESS = 3003
WORKSPACE_INVITE_NOT_FOUND = 3004
# 资源4xxx
NOT_FOUND = 4000
USER_NOT_FOUND = 4001
WORKSPACE_NOT_FOUND = 4002
MODEL_NOT_FOUND = 4003
KNOWLEDGE_NOT_FOUND = 4004
DOCUMENT_NOT_FOUND = 4005
FILE_NOT_FOUND = 4006
APP_NOT_FOUND = 4007
RELEASE_NOT_FOUND = 4008
# 冲突/状态5xxx
DUPLICATE_NAME = 5001
RESOURCE_ALREADY_EXISTS = 5002
VERSION_ALREADY_EXISTS = 5003
STATE_CONFLICT = 5004
# 应用发布6xxx
PUBLISH_FAILED = 6001
NO_DRAFT_TO_PUBLISH = 6002
ROLLBACK_TARGET_NOT_FOUND = 6003
APP_TYPE_NOT_SUPPORTED = 6004
AGENT_CONFIG_MISSING = 6005
SHARE_DISABLED = 6006
INVALID_PASSWORD = 6007
PASSWORD_REQUIRED = 6008
EMBED_NOT_ALLOWED = 6009
PERMISSION_DENIED = 6010
INVALID_CONVERSATION = 6011
# 模型7xxx
MODEL_CONFIG_INVALID = 7001
API_KEY_MISSING = 7002
PROVIDER_NOT_SUPPORTED = 7003
LLM_ERROR = 7004
EMBEDDING_ERROR = 7005
# 文件/解析8xxx
FILE_READ_ERROR = 8001
PARSER_NOT_SUPPORTED = 8002
CHUNKING_FAILED = 8003
# RAG/知识9xxx
INDEX_BUILD_FAILED = 9001
EMBEDDING_FAILED = 9002
SEARCH_FAILED = 9003
# 系统100xx
INTERNAL_ERROR = 10001
DB_ERROR = 10002
SERVICE_UNAVAILABLE = 10003
RATE_LIMITED = 10004
# 建议的HTTP状态映射如需在异常处理器中使用
HTTP_MAPPING = {
BizCode.OK: 200,
BizCode.LOGIN_FAILED: 200,
BizCode.BAD_REQUEST: 400,
BizCode.VALIDATION_FAILED: 400,
BizCode.MISSING_PARAMETER: 400,
BizCode.INVALID_PARAMETER: 400,
BizCode.UNAUTHORIZED: 401,
BizCode.TOKEN_INVALID: 401,
BizCode.TOKEN_EXPIRED: 401,
BizCode.TOKEN_BLACKLISTED: 401,
BizCode.FORBIDDEN: 403,
BizCode.TENANT_NOT_FOUND: 404,
BizCode.WORKSPACE_NO_ACCESS: 403,
BizCode.NOT_FOUND: 404,
BizCode.USER_NOT_FOUND: 200,
BizCode.WORKSPACE_NOT_FOUND: 404,
BizCode.MODEL_NOT_FOUND: 404,
BizCode.KNOWLEDGE_NOT_FOUND: 404,
BizCode.DOCUMENT_NOT_FOUND: 404,
BizCode.FILE_NOT_FOUND: 404,
BizCode.APP_NOT_FOUND: 404,
BizCode.RELEASE_NOT_FOUND: 404,
BizCode.DUPLICATE_NAME: 409,
BizCode.RESOURCE_ALREADY_EXISTS: 409,
BizCode.VERSION_ALREADY_EXISTS: 409,
BizCode.STATE_CONFLICT: 409,
BizCode.PUBLISH_FAILED: 500,
BizCode.NO_DRAFT_TO_PUBLISH: 400,
BizCode.ROLLBACK_TARGET_NOT_FOUND: 404,
BizCode.APP_TYPE_NOT_SUPPORTED: 400,
BizCode.AGENT_CONFIG_MISSING: 400,
BizCode.SHARE_DISABLED: 403,
BizCode.INVALID_PASSWORD: 401,
BizCode.PASSWORD_REQUIRED: 401,
BizCode.EMBED_NOT_ALLOWED: 403,
BizCode.PERMISSION_DENIED: 403,
BizCode.INVALID_CONVERSATION: 400,
BizCode.MODEL_CONFIG_INVALID: 400,
BizCode.API_KEY_MISSING: 400,
BizCode.PROVIDER_NOT_SUPPORTED: 400,
BizCode.LLM_ERROR: 500,
BizCode.EMBEDDING_ERROR: 500,
BizCode.FILE_READ_ERROR: 500,
BizCode.PARSER_NOT_SUPPORTED: 400,
BizCode.CHUNKING_FAILED: 500,
BizCode.INDEX_BUILD_FAILED: 500,
BizCode.EMBEDDING_FAILED: 500,
BizCode.SEARCH_FAILED: 500,
BizCode.INTERNAL_ERROR: 500,
BizCode.DB_ERROR: 500,
BizCode.SERVICE_UNAVAILABLE: 503,
BizCode.RATE_LIMITED: 429,
}

View File

@@ -0,0 +1,86 @@
"""
业务异常定义
"""
from typing import Any, Dict, Optional
from app.core.error_codes import BizCode
class BusinessException(Exception):
"""业务逻辑异常基类"""
def __init__(
self,
message: str,
code: BizCode | int | None = None,
context: Optional[Dict[str, Any]] = None,
cause: Optional[Exception] = None
):
self.message = message
self.code = code if code is not None else BizCode.BAD_REQUEST
# Make a copy of context to avoid modifying the original dict
self.context = dict(context) if context else {}
self.cause = cause
super().__init__(self.message)
def __str__(self) -> str:
ctx = f", context={self.context}" if self.context else ""
code_name = self.code.name if isinstance(self.code, BizCode) else str(self.code)
return f"{code_name}: {self.message}{ctx}"
class ValidationException(BusinessException):
"""数据验证异常"""
def __init__(self, message: str, field: str = None, **kwargs):
context = {"field": field} if field else {}
if "context" in kwargs:
context.update(kwargs.pop("context"))
super().__init__(message, BizCode.VALIDATION_FAILED, context, **kwargs)
class AuthenticationException(BusinessException):
"""认证异常"""
def __init__(self, message: str = "认证失败", **kwargs):
super().__init__(message, BizCode.UNAUTHORIZED, **kwargs)
class AuthorizationException(BusinessException):
"""授权异常"""
def __init__(self, message: str = "权限不足", **kwargs):
super().__init__(message, BizCode.FORBIDDEN, **kwargs)
class ResourceNotFoundException(BusinessException):
"""资源未找到异常"""
def __init__(self, resource_type: str, resource_id: str = None, **kwargs):
message = f"{resource_type} 不存在"
context = {"resource_type": resource_type}
if resource_id:
context["resource_id"] = resource_id
if "context" in kwargs:
context.update(kwargs.pop("context"))
super().__init__(message, BizCode.FILE_NOT_FOUND, context, **kwargs)
class DuplicateResourceException(BusinessException):
"""资源重复异常"""
def __init__(self, message: str = "资源已存在", **kwargs):
super().__init__(message, BizCode.DUPLICATE_NAME, **kwargs)
class FileUploadException(BusinessException):
"""文件上传异常"""
def __init__(self, message: str, **kwargs):
super().__init__(message, BizCode.FILE_READ_ERROR, **kwargs)
class PermissionDeniedException(BusinessException):
"""权限拒绝异常"""
def __init__(self, message: str = "权限不足", **kwargs):
super().__init__(message, BizCode.FORBIDDEN, **kwargs)

View File

@@ -0,0 +1,633 @@
import logging
import logging.handlers
import os
from pathlib import Path
from typing import Optional
from app.core.config import settings
from app.core.sensitive_filter import SensitiveDataFilter
class SensitiveDataLoggingFilter(logging.Filter):
"""日志过滤器:自动过滤敏感信息"""
def filter(self, record: logging.LogRecord) -> bool:
"""
过滤日志记录中的敏感信息
Args:
record: 日志记录
Returns:
True表示允许记录False表示拒绝
"""
# 过滤消息中的敏感信息
if hasattr(record, 'msg') and isinstance(record.msg, str):
record.msg = SensitiveDataFilter.filter_string(record.msg)
# 过滤参数中的敏感信息
if hasattr(record, 'args') and record.args:
if isinstance(record.args, dict):
record.args = SensitiveDataFilter.filter_dict(record.args)
elif isinstance(record.args, (list, tuple)):
record.args = tuple(
SensitiveDataFilter.filter_string(str(arg)) if isinstance(arg, str) else arg
for arg in record.args
)
return True
class LoggingConfig:
"""全局日志配置类"""
_initialized = False
_memory_loggers_initialized = False
_prompt_logger = None
_template_logger = None
_timing_logger = None
_agent_loggers = {}
@classmethod
def setup_logging(cls) -> None:
"""初始化全局日志配置"""
if cls._initialized:
return
# 创建日志目录
log_dir = Path(settings.LOG_FILE_PATH).parent
log_dir.mkdir(parents=True, exist_ok=True)
# 配置根日志器
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
# 清除现有处理器
root_logger.handlers.clear()
# 创建格式化器
formatter = logging.Formatter(
fmt=settings.LOG_FORMAT,
datefmt='%Y-%m-%d %H:%M:%S'
)
# 创建敏感信息过滤器
sensitive_filter = SensitiveDataLoggingFilter()
# 控制台处理器
if settings.LOG_TO_CONSOLE:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
console_handler.addFilter(sensitive_filter)
root_logger.addHandler(console_handler)
# 文件处理器(带轮转)
if settings.LOG_TO_FILE:
file_handler = logging.handlers.RotatingFileHandler(
filename=settings.LOG_FILE_PATH,
maxBytes=settings.LOG_MAX_SIZE,
backupCount=5,
encoding='utf-8'
)
file_handler.setFormatter(formatter)
file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
file_handler.addFilter(sensitive_filter)
root_logger.addHandler(file_handler)
cls._initialized = True
# Initialize memory module logging
cls.setup_memory_logging()
# 记录初始化完成
logger = logging.getLogger(__name__)
logger.info("全局日志系统初始化完成")
@classmethod
def setup_memory_logging(cls) -> None:
"""Initialize memory module specific loggers.
Called automatically by setup_logging() or can be called independently.
Sets up:
- Prompt logger with timestamped files
- Template logger with conditional file output
- Timing logger with dual output (file + console)
- Agent logger factory with concurrent handlers
"""
if cls._memory_loggers_initialized:
return
# Create logs directory if it doesn't exist
log_dir = Path("logs")
try:
log_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
print(f"Warning: Could not create log directory: {e}")
# Continue with console-only logging
# Initialize memory-specific loggers
# These will be created lazily when first requested via factory functions
# This method just marks the system as ready for memory logging
cls._memory_loggers_initialized = True
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""获取日志器实例
Args:
name: 日志器名称,默认为调用模块名
Returns:
配置好的日志器实例
"""
return logging.getLogger(name)
def get_auth_logger() -> logging.Logger:
"""获取认证专用日志器"""
return logging.getLogger("auth")
def get_security_logger() -> logging.Logger:
"""获取安全专用日志器"""
return logging.getLogger("security")
def get_api_logger() -> logging.Logger:
"""获取API专用日志器"""
return logging.getLogger("api")
def get_db_logger() -> logging.Logger:
"""获取数据库专用日志器"""
return logging.getLogger("database")
def get_business_logger() -> logging.Logger:
"""获取业务逻辑专用日志器"""
return logging.getLogger("business")
def get_prompt_logger() -> logging.Logger:
"""Get the prompt logger for memory module.
Returns a logger configured for prompt rendering output with:
- Logger name: memory.prompts
- Output: logs/prompt_logs-{timestamp}.log
- Level: Configurable via PROMPT_LOG_LEVEL setting (default: INFO)
- Handler: FileHandler (no console output)
The logger is cached after first creation for performance.
Returns:
Logger configured for prompt rendering output
Example:
>>> logger = get_prompt_logger()
>>> logger.info("=== RENDERED EXTRACTION PROMPT ===\\n%s", prompt_content)
"""
# Return cached logger if already initialized
if LoggingConfig._prompt_logger is not None:
return LoggingConfig._prompt_logger
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Create prompt logger
logger = logging.getLogger("memory.prompts")
logger.setLevel(getattr(logging, settings.PROMPT_LOG_LEVEL.upper()))
logger.propagate = False # Don't propagate to root logger (no console output)
# Create timestamped log file
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
log_file = Path("logs/prompts/") / f"prompt_logs-{timestamp}.log"
# Ensure log directory exists
log_file.parent.mkdir(parents=True, exist_ok=True)
# Create file handler
file_handler = logging.FileHandler(
filename=str(log_file),
encoding='utf-8'
)
# Create formatter
formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
# Add handler to logger
logger.addHandler(file_handler)
# Cache the logger
LoggingConfig._prompt_logger = logger
return logger
def get_template_logger() -> logging.Logger:
"""Get the template logger for memory module.
Returns a logger configured for template rendering information with:
- Logger name: memory.templates
- Output: logs/prompt_templates.log (only when ENABLE_TEMPLATE_LOGGING is True)
- Level: INFO
- Handler: FileHandler when enabled, NullHandler when disabled
The logger is cached after first creation for performance.
Returns:
Logger configured for template rendering info
Example:
>>> logger = get_template_logger()
>>> logger.info("Rendering template: %s with context keys: %s",
... template_name, list(context.keys()))
"""
# Return cached logger if already initialized
if LoggingConfig._template_logger is not None:
return LoggingConfig._template_logger
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Create template logger
logger = logging.getLogger("memory.templates")
logger.setLevel(logging.INFO)
logger.propagate = False # Don't propagate to root logger
# Add appropriate handler based on configuration
if settings.ENABLE_TEMPLATE_LOGGING:
# Create log file path
log_file = Path("logs") / "prompt_templates.log"
# Ensure log directory exists
log_file.parent.mkdir(parents=True, exist_ok=True)
# Create file handler
file_handler = logging.FileHandler(
filename=str(log_file),
encoding='utf-8'
)
# Create formatter
formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
# Add handler to logger
logger.addHandler(file_handler)
else:
# Use NullHandler when template logging is disabled
null_handler = logging.NullHandler()
logger.addHandler(null_handler)
# Cache the logger
LoggingConfig._template_logger = logger
return logger
def log_prompt_rendering(prompt_type: str, content: str) -> None:
"""Log rendered prompt content.
Logs the rendered prompt with a formatted header and separator for easy
identification in log files. This is useful for debugging LLM interactions
and understanding what prompts are being sent.
Args:
prompt_type: Type of prompt (e.g., 'statement_extraction', 'triplet_extraction')
content: The rendered prompt text
Example:
>>> log_prompt_rendering("extraction", "Extract entities from: Hello world")
# Logs:
# === RENDERED EXTRACTION PROMPT ===
# Extract entities from: Hello world
# =====================================
"""
logger = get_prompt_logger()
# Format the log entry with header and separator
separator = "=" * 50
header = f"=== RENDERED {prompt_type.upper()} PROMPT ==="
log_message = f"\n{header}\n{content}\n{separator}\n"
logger.info(log_message)
def log_template_rendering(template_name: str, context: dict | None = None) -> None:
"""Log template rendering information.
Logs the template name and context keys for debugging template rendering.
This function is wrapped in try-except to ensure it never breaks application
flow, even if logging fails.
Args:
template_name: Name of the Jinja2 template being rendered
context: Optional context dictionary with template variables
Example:
>>> log_template_rendering("extract_triplet.jinja2", {"text": "...", "ontology": "..."})
# Logs: Rendering template: extract_triplet.jinja2 with context keys: ['text', 'ontology']
>>> log_template_rendering("system.jinja2")
# Logs: Rendering template: system.jinja2 with no context
"""
try:
logger = get_template_logger()
if context is not None:
context_keys = list(context.keys())
logger.info(f"Rendering template: {template_name} with context keys: {context_keys}")
else:
logger.info(f"Rendering template: {template_name} with no context")
except Exception:
# Never break application flow due to logging issues
# Silently ignore any logging errors
pass
def get_timing_logger() -> logging.Logger:
"""Get the timing logger for memory module.
Returns a logger configured for performance timing with:
- Logger name: memory.timing
- Output: Configurable via TIMING_LOG_FILE setting (default: logs/time.log)
- Level: INFO
- Handlers: FileHandler + optional StreamHandler for console output
- Console output: Controlled by TIMING_LOG_TO_CONSOLE setting (default: True)
The logger is cached after first creation for performance.
Returns:
Logger configured for performance timing
Example:
>>> logger = get_timing_logger()
>>> logger.info("[2025-11-18 10:30:45] Extraction: 2.34 seconds")
"""
# Return cached logger if already initialized
if LoggingConfig._timing_logger is not None:
return LoggingConfig._timing_logger
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Create timing logger
logger = logging.getLogger("memory.timing")
logger.setLevel(logging.INFO)
logger.propagate = False # Don't propagate to root logger
# Create formatter
formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Add file handler
log_file = Path(settings.TIMING_LOG_FILE)
# Ensure log directory exists
log_file.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(
filename=str(log_file),
encoding='utf-8'
)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# Add console handler if enabled
if settings.TIMING_LOG_TO_CONSOLE:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Cache the logger
LoggingConfig._timing_logger = logger
return logger
def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -> None:
"""Log timing information for performance tracking.
Logs timing information to both file and console (console output is always shown
for backward compatibility). The file output includes a timestamp and full details,
while console output shows a concise checkmark format.
Args:
step_name: Name of the operation being timed
duration: Duration in seconds
log_file: Optional custom log file path (default: logs/time.log)
Example:
>>> log_time("Knowledge Extraction", 2.34)
# File logs: [2025-11-18 10:30:45] Knowledge Extraction: 2.34 seconds
# Console prints: ✓ Knowledge Extraction: 2.34s
>>> log_time("Database Query", 0.15, "logs/custom_time.log")
# Logs to custom file and console
"""
from datetime import datetime
# Format timestamp
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Format timing entry for file
log_entry = f"[{timestamp}] {step_name}: {duration:.2f} seconds\n"
# Write to file with error handling
try:
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "a", encoding="utf-8") as f:
f.write(log_entry)
except IOError as e:
# Fallback to console only if file write fails
print(f"Warning: Could not write to timing log: {e}")
# Always print to console (backward compatible behavior)
print(f"{step_name}: {duration:.2f}s")
def get_agent_logger(name: str = "agent_service",
console_level: str = "INFO",
file_level: str = "DEBUG") -> logging.Logger:
"""Get an agent logger with concurrent file handling.
Returns a logger configured for agent operations with:
- Logger name: memory.agent.{name}
- Output: Configurable via AGENT_LOG_FILE setting (default: logs/agent_service.log)
- Console level: Configurable (default: INFO)
- File level: Configurable (default: DEBUG)
- Handler: ConcurrentRotatingFileHandler for multi-process support
- Rotation: Configurable via AGENT_LOG_MAX_SIZE (default: 5MB) and
AGENT_LOG_BACKUP_COUNT (default: 20)
The logger is cached by name after first creation for performance.
Supports concurrent writes from multiple processes.
Args:
name: Logger name for namespacing (default: "agent_service")
console_level: Log level for console output (default: "INFO")
file_level: Log level for file output (default: "DEBUG")
Returns:
Logger configured for agent operations
Example:
>>> logger = get_agent_logger("my_agent")
>>> logger.info("Agent operation started")
>>> logger.debug("Detailed agent state information")
>>> logger = get_agent_logger("custom_agent", console_level="WARNING", file_level="INFO")
>>> logger.warning("This appears in console and file")
>>> logger.info("This only appears in file")
"""
# Return cached logger if already initialized
if name in LoggingConfig._agent_loggers:
return LoggingConfig._agent_loggers[name]
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Create agent logger with namespaced name
logger_name = f"memory.agent.{name}"
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG) # Set to DEBUG to allow both handlers to filter
logger.propagate = False # Don't propagate to root logger
# Create formatter
formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Add console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(getattr(logging, console_level.upper()))
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Add concurrent rotating file handler
try:
from concurrent_log_handler import ConcurrentRotatingFileHandler
except ImportError:
# Fall back to standard RotatingFileHandler if concurrent handler not available
from logging.handlers import RotatingFileHandler as ConcurrentRotatingFileHandler
print("Warning: concurrent-log-handler not available, using standard RotatingFileHandler")
# Create log file path
log_file = Path(settings.AGENT_LOG_FILE)
# Ensure log directory exists
log_file.parent.mkdir(parents=True, exist_ok=True)
# Create file handler with rotation
file_handler = ConcurrentRotatingFileHandler(
filename=str(log_file),
maxBytes=settings.AGENT_LOG_MAX_SIZE,
backupCount=settings.AGENT_LOG_BACKUP_COUNT,
encoding='utf-8'
)
file_handler.setLevel(getattr(logging, file_level.upper()))
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# Cache the logger
LoggingConfig._agent_loggers[name] = logger
return logger
def get_named_logger(name: str) -> logging.Logger:
"""Backward compatible alias for get_agent_logger.
This function maintains backward compatibility with existing code that uses
the get_named_logger pattern from the agent logger module.
Args:
name: Logger name for namespacing
Returns:
Logger configured for agent operations
Example:
>>> logger = get_named_logger("my_agent")
>>> logger.info("Agent operation started")
"""
return get_agent_logger(name)
def get_memory_logger(name: Optional[str] = None) -> logging.Logger:
"""Get a standard logger for memory module components.
Returns a logger configured for memory module components that inherits
the root logger's configuration (handlers, formatters, and level). This
provides consistent logging behavior across the memory module while
maintaining the ability to filter and identify memory-specific logs.
The logger uses the 'memory' namespace:
- If name is provided: logger name is 'memory.{module_name}'
- If name is None: logger name is 'memory'
The logger inherits all handlers and formatters from the root logger,
ensuring consistent output format and destinations (console, file, etc.).
Args:
name: Optional logger name, typically __name__ from the calling module.
If provided, creates a namespaced logger under 'memory.{name}'.
If None, returns the base 'memory' logger.
Returns:
Logger configured for memory module operations with root logger inheritance
Example:
>>> # In app/core/memory/src/search.py
>>> logger = get_memory_logger(__name__)
>>> logger.info("Starting search operation")
# Logs: [timestamp] - memory.app.core.memory.src.search - INFO - Starting search operation
>>> # Get base memory logger
>>> logger = get_memory_logger()
>>> logger.debug("Memory module initialized")
# Logs: [timestamp] - memory - DEBUG - Memory module initialized
>>> # In app/core/memory/src/knowledge_extraction/triplet_extraction.py
>>> logger = get_memory_logger(__name__)
>>> logger.error("Extraction failed", exc_info=True)
# Logs error with full traceback
"""
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Construct logger name with memory namespace
if name is not None:
logger_name = f"memory.{name}"
else:
logger_name = "memory"
# Get logger - it will inherit from root logger configuration
logger = logging.getLogger(logger_name)
# The logger automatically inherits handlers, formatters, and level from root logger
# through Python's logging hierarchy, so no additional configuration is needed
return logger

View File

View File

View File

@@ -0,0 +1,16 @@
"""
LangGraph Graph package for memory agent.
This package provides the LangGraph workflow orchestrator with modular
node implementations, routing logic, and state management.
Package structure:
- read_graph: Main graph factory for read operations
- write_graph: Main graph factory for write operations
- nodes: LangGraph node implementations
- routing: State routing logic
- state: State management utilities
"""
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
__all__ = ['make_read_graph']

View File

@@ -0,0 +1,10 @@
"""
LangGraph node implementations.
This module contains custom node implementations for the LangGraph workflow.
"""
from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
__all__ = ["ToolExecutionNode", "create_input_message"]

View File

@@ -0,0 +1,144 @@
"""
Input node for LangGraph workflow entry point.
This module provides the create_input_message function which processes initial
user input with multimodal support and creates the first tool call message.
"""
import logging
import re
import uuid
from datetime import datetime
from typing import Dict, Any
from langchain_core.messages import AIMessage
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
logger = logging.getLogger(__name__)
async def create_input_message(
state: Dict[str, Any],
tool_name: str,
session_id: str,
search_switch: str,
apply_id: str,
group_id: str,
multimodal_processor: MultimodalProcessor
) -> Dict[str, Any]:
"""
Create initial tool call message from user input.
This function:
1. Extracts the last message content from state
2. Processes multimodal inputs (images/audio) using the multimodal processor
3. Generates a unique message ID
4. Extracts namespace from session_id
5. Handles verified_data extraction for backward compatibility
6. Returns AIMessage with complete tool_calls structure
Args:
state: LangGraph state dictionary containing messages
tool_name: Name of the tool to invoke (typically "Split_The_Problem")
session_id: Session identifier (format: "call_id_{namespace}")
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
multimodal_processor: Processor for handling image/audio inputs
Returns:
State update with AIMessage containing tool_call
Examples:
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
>>> result = await create_input_message(
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor
... )
>>> result["messages"][0].tool_calls[0]["name"]
'Split_The_Problem'
"""
messages = state.get("messages", [])
# Extract last message content
if messages:
last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1])
else:
logger.warning("[create_input_message] No messages in state, using empty string")
last_message = ""
logger.debug(f"[create_input_message] Original input: {last_message[:100]}...")
# Process multimodal input (images/audio)
try:
processed_content = await multimodal_processor.process_input(last_message)
if processed_content != last_message:
logger.info(
f"[create_input_message] Multimodal processing converted input "
f"from {len(last_message)} to {len(processed_content)} chars"
)
last_message = processed_content
except Exception as e:
logger.error(
f"[create_input_message] Multimodal processing failed: {e}",
exc_info=True
)
# Continue with original content
# Generate unique message ID
uuid_str = uuid.uuid4()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Extract namespace from session_id
# Expected format: "call_id_{namespace}" or similar
try:
namespace = str(session_id).split('_id_')[1]
except (IndexError, AttributeError):
logger.warning(
f"[create_input_message] Could not extract namespace from session_id: {session_id}"
)
namespace = "unknown"
# Handle verified_data extraction (backward compatibility)
# This regex-based extraction is kept for compatibility with existing data formats
if 'verified_data' in str(last_message):
try:
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
query_match = re.findall(r'"query": "(.*?)",', messages_last)
if query_match:
last_message = query_match[0]
logger.debug(
f"[create_input_message] Extracted query from verified_data: {last_message}"
)
except Exception as e:
logger.warning(
f"[create_input_message] Failed to extract query from verified_data: {e}"
)
# Construct tool call message
tool_call_id = f"{session_id}_{uuid_str}"
logger.info(
f"[create_input_message] Creating tool call for '{tool_name}' "
f"with ID: {tool_call_id}"
)
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": tool_name,
"args": {
"sentence": last_message,
"sessionid": session_id,
"messages_id": str(uuid_str),
"search_switch": search_switch,
"apply_id": apply_id,
"group_id": group_id
},
"id": tool_call_id
}]
)
]
}

View File

@@ -0,0 +1,199 @@
"""
Tool execution node for LangGraph workflow.
This module provides the ToolExecutionNode class which wraps tool execution
with parameter transformation logic using the ParameterBuilder service.
"""
import logging
import time
from typing import Any, Callable, Dict
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_tool_call_id,
extract_content_payload
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
logger = logging.getLogger(__name__)
class ToolExecutionNode:
"""
Custom LangGraph node that wraps tool execution with parameter transformation.
This node extracts content from previous tool results, transforms parameters
based on tool type using ParameterBuilder, and invokes the tool with the
correct argument structure.
Attributes:
tool_node: LangGraph ToolNode wrapping the actual tool
id: Node identifier for message IDs
tool_name: Name of the tool being executed
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
"""
def __init__(
self,
tool: Callable,
node_id: str,
namespace: str,
search_switch: str,
apply_id: str,
group_id: str,
parameter_builder: ParameterBuilder,
storage_type:str,
user_rag_memory_id:str
):
"""
Initialize the tool execution node.
Args:
tool: The tool function to execute
node_id: Identifier for this node (used in message IDs)
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
"""
self.tool_node = ToolNode([tool])
self.id = node_id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.parameter_builder = parameter_builder
self.storage_type=storage_type
self.user_rag_memory_id=user_rag_memory_id
logger.info(
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
)
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute the tool with transformed parameters.
This method:
1. Extracts the last message from state
2. Extracts tool call ID using state extractors
3. Extracts content payload using state extractors
4. Builds tool arguments using parameter builder
5. Constructs AIMessage with tool_calls
6. Invokes the tool and returns the result
Args:
state: LangGraph state dictionary
Returns:
Updated state with tool result in messages
"""
messages = state.get("messages", [])
logger.debug( self.tool_name)
if not messages:
logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state")
return {"messages": [AIMessage(content="Error: No messages in state")]}
last_message = messages[-1]
logger.debug(
f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}"
)
try:
# Extract tool call ID using state extractors
tool_call_id = extract_tool_call_id(last_message)
logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}")
except ValueError as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}"
)
return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
try:
# Extract content payload using state extractors
content = extract_content_payload(last_message)
logger.debug(
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}"
)
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}",
exc_info=True
)
content = {}
try:
# Build tool arguments using parameter builder
tool_args = self.parameter_builder.build_tool_args(
tool_name=self.tool_name,
content=content,
tool_call_id=tool_call_id,
search_switch=self.search_switch,
apply_id=self.apply_id,
group_id=self.group_id,
storage_type=self.storage_type,
user_rag_memory_id=self.user_rag_memory_id
)
logger.debug(
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
)
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}",
exc_info=True
)
return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]}
# Construct tool input message
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": f"{self.id}_{tool_call_id}",
}]
)
]
}
try:
# Invoke the tool
result = await self.tool_node.ainvoke(tool_input)
logger.debug(
f"[ToolExecutionNode] {self.id} - Tool execution completed"
)
# Return the result directly - it already contains the messages list
return result
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
exc_info=True
)
# Return error as ToolMessage to maintain message chain consistency
from langchain_core.messages import ToolMessage
return {
"messages": [
ToolMessage(
content=f"Error executing tool: {str(e)}",
tool_call_id=f"{self.id}_{tool_call_id}"
)
]
}

View File

@@ -0,0 +1,508 @@
import asyncio
import io
import json
import logging
import os
import re
import time
import uuid
import warnings
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Literal
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from functools import partial
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
from langgraph.checkpoint.memory import InMemorySaver
from app.core.memory.agent.utils.redis_tool import store
from app.core.logging_config import get_agent_logger
# Import new modular components
from app.core.memory.agent.langgraph_graph.nodes import ToolExecutionNode, create_input_message
from app.core.memory.agent.langgraph_graph.routing.routers import (
Verify_continue,
Retrieve_continue,
Split_continue
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
logger = get_agent_logger(__name__)
warnings.filterwarnings("ignore", category=RuntimeWarning)
load_dotenv()
redishost=os.getenv("REDISHOST")
redisport=os.getenv('REDISPORT')
redisdb=os.getenv('REDISDB')
redispassword=os.getenv('REDISPASSWORD')
counter = COUNTState(limit=3)
# 在工作流中添加循环计数更新
async def update_loop_count(state):
"""更新循环计数器"""
current_count = state.get("loop_count", 0)
return {"loop_count": current_count + 1}
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
messages = state["messages"]
# 添加边界检查
if not messages:
return END
counter.add(1) # 累加 1
loop_count = counter.get_total()
logger.debug(f"[should_continue] 当前循环次数: {loop_count}")
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"Status tools: {status_tools}")
if "success" in status_tools:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # 最大循环次数 3
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# 添加默认返回值,避免返回 None
counter.reset()
return "Summary" # 或根据业务需求选择合适的默认值
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# 添加默认返回值,避免返回 None
return 'Retrieve_Summary' # 或根据业务逻辑选择合适的默认值
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
logger.debug(f"Split_continue state: {state}")
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '2':
return 'Input_Summary'
return 'Split_The_Problem' # 默认情况
# 在 input_sentence 函数中修改参数名称
async def input_sentence(state, name, id, search_switch,apply_id,group_id):
messages = state["messages"]
last_message = messages[-1].content if messages else ""
if last_message.endswith('.jpg') or last_message.endswith('.png'):
last_message=await picture_model_requests(last_message)
if any(last_message.endswith(ext) for ext in audio_extensions):
last_message=await Vico_recognition([last_message]).run()
logger.debug(f"Audio recognition result: {last_message}")
uuid_str = uuid.uuid4()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
namespace = str(id).split('_id_')[1]
if 'verified_data' in str(last_message):
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
last_message = re.findall(r'"query": "(.*?)",', str(messages_last))[0]
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": name,
"args": {
"sentence": last_message,
'sessionid': id,
'messages_id': str(uuid_str),
"search_switch": search_switch, # 正确地将 search_switch 放入 args 中
"apply_id":apply_id,
"group_id":group_id
},
"id": id + f'_{uuid_str}'
}]
)
]
}
class ProblemExtensionNode:
def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""):
self.tool_node = ToolNode([tool])
self.id = id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.storage_type = storage_type
self.user_rag_memory_id = user_rag_memory_id
async def __call__(self, state):
messages = state["messages"]
last_message = messages[-1] if messages else ""
logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}")
if self.tool_name=='Input_Summary':
tool_call =re.findall(f"'id': '(.*?)'",str(last_message))[0]
else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
# try:
# content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message
# except:
# content = last_message.content if hasattr(last_message, 'content') else str(last_message)
# 尝试从上一工具的结果中提取实际的内容载荷(而不是整个对象的字符串表示)
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
extracted_payload = None
# 捕获 ToolMessage 的 content 字段(支持单/双引号),并避免贪婪匹配
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
if m:
extracted_payload = m.group(1)
else:
# 回退:直接尝试使用原始字符串
extracted_payload = raw_msg
# 优先尝试将内容解析为 JSON
try:
content = json.loads(extracted_payload)
except Exception:
# 尝试从文本中提取 JSON 片段再解析
parsed = None
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
for cand in candidates:
try:
parsed = json.loads(cand)
break
except Exception:
continue
# 如果仍然失败,则以原始字符串作为内容
content = parsed if parsed is not None else extracted_payload
# 根据工具名称构建正确的参数
tool_args = {}
if self.tool_name == "Verify":
# Verify工具需要context和usermessages参数
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Retrieve":
# Retrieve工具需要context和usermessages参数
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary":
# Summary工具需要字符串类型的context参数
if isinstance(content, dict):
# 将字典转换为JSON字符串
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary_fails":
# Summary工具需要字符串类型的context参数
if isinstance(content, dict):
# 将字典转换为JSON字符串
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name=='Input_Summary':
tool_args["context"] =str(last_message)
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_args["storage_type"] = getattr(self, 'storage_type', "")
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
elif self.tool_name=='Retrieve_Summary' :
# Retrieve_Summary expects dict directly, not JSON string
# content might be a JSON string, try to parse it
if isinstance(content, str):
try:
parsed_content = json.loads(content)
# Check if it has a "context" key
if isinstance(parsed_content, dict) and "context" in parsed_content:
tool_args["context"] = parsed_content["context"]
else:
tool_args["context"] = parsed_content
except json.JSONDecodeError:
# If parsing fails, wrap the string
tool_args["context"] = {"content": content}
elif isinstance(content, dict):
# Check if content has a "context" key that needs unwrapping
if "context" in content:
tool_args["context"] = content["context"]
else:
tool_args["context"] = content
else:
tool_args["context"] = {"content": str(content)}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
else:
# 其他工具使用context参数
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": self.id + f"{tool_call}",
}]
)
]
}
result = await self.tool_node.ainvoke(tool_input)
result_text = str(result)
return {"messages": [AIMessage(content=result_text)]}
@asynccontextmanager
async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config_id=None,storage_type=None,user_rag_memory_id=None):
memory = InMemorySaver()
tool=[i.name for i in tools ]
logger.info(f"Initializing read graph with tools: {tool}")
if config_id:
logger.info(f"使用配置 ID: {config_id}")
# Extract tool functions
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None)
Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None)
Verify_ = next((t for t in tools if t.name == "Verify"), None)
Summary_ = next((t for t in tools if t.name == "Summary"), None)
Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None)
Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None)
Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None)
# Instantiate services
parameter_builder = ParameterBuilder()
multimodal_processor = MultimodalProcessor()
# Create nodes using new modular components
Split_The_Problem_node = ToolNode([Split_The_Problem_])
Problem_Extension_node = ToolExecutionNode(
tool=Problem_Extension_,
node_id="Problem_Extension_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Retrieve_node = ToolExecutionNode(
tool=Retrieve_,
node_id="Retrieve_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Verify_node = ToolExecutionNode(
tool=Verify_,
node_id="Verify_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Summary_node = ToolExecutionNode(
tool=Summary_,
node_id="Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Summary_fails_node = ToolExecutionNode(
tool=Summary_fails_,
node_id="Summary_fails_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Retrieve_Summary_node = ToolExecutionNode(
tool=Retrieve_Summary_,
node_id="Retrieve_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Input_Summary_node = ToolExecutionNode(
tool=Input_Summary_,
node_id="Input_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
async def content_input_node(state):
state_search_switch = state.get("search_switch", search_switch)
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
return await create_input_message(
state=state,
tool_name=tool_name,
session_id=f"{session_prefix}_{namespace}",
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
multimodal_processor=multimodal_processor
)
# Build workflow graph
workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem_node)
workflow.add_node("Problem_Extension", Problem_Extension_node)
workflow.add_node("Retrieve", Retrieve_node)
workflow.add_node("Verify", Verify_node)
workflow.add_node("Summary", Summary_node)
workflow.add_node("Summary_fails", Summary_fails_node)
workflow.add_node("Retrieve_Summary", Retrieve_Summary_node)
workflow.add_node("Input_Summary", Input_Summary_node)
# Add edges using imported routers
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
graph = workflow.compile(checkpointer=memory)
yield graph
# 添加到文件末尾或创建新的执行脚本
# 在 memory_agent_service.py 文件中添加以下函数

View File

@@ -0,0 +1,13 @@
"""LangGraph routing logic."""
from app.core.memory.agent.langgraph_graph.routing.routers import (
Verify_continue,
Retrieve_continue,
Split_continue,
)
__all__ = [
"Verify_continue",
"Retrieve_continue",
"Split_continue",
]

View File

@@ -0,0 +1,123 @@
"""
Routing functions for LangGraph conditional edges.
This module provides routing functions that determine the next node to execute
based on state values. All functions return Literal types for type safety.
"""
import logging
import re
from typing import Literal
from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
logger = logging.getLogger(__name__)
# Global counter for Verify routing
counter = COUNTState(limit=3)
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
"""
Determine routing after Verify node based on verification result.
This function checks the verification result in the last message and routes to:
- Summary: if verification succeeded
- content_input: if verification failed and retry limit not reached
- Summary_fails: if verification failed and retry limit reached
Args:
state: LangGraph state containing messages
Returns:
Next node name as Literal type
"""
messages = state.get("messages", [])
# Boundary check
if not messages:
logger.warning("[Verify_continue] No messages in state, defaulting to Summary")
counter.reset()
return "Summary"
# Increment counter
counter.add(1)
loop_count = counter.get_total()
logger.debug(f"[Verify_continue] Current loop count: {loop_count}")
# Extract verification result from last message
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"[Verify_continue] Status tools: {status_tools}")
# Route based on verification result
if "success" in status_tools:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # Max retry count is 2
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# Default to Summary if status is unclear
counter.reset()
return "Summary"
def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing after Retrieve node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '0': Route to Verify (verification needed)
- search_switch == '1': Route to Retrieve_Summary (direct summary)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
search_switch = extract_search_switch(state)
logger.debug(f"[Retrieve_continue] search_switch: {search_switch}")
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# Default to Retrieve_Summary
logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary")
return 'Retrieve_Summary'
def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing after content_input node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '2': Route to Input_Summary (direct input summary)
- Otherwise: Route to Split_The_Problem (problem decomposition)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
logger.debug(f"[Split_continue] state keys: {state.keys()}")
search_switch = extract_search_switch(state)
logger.debug(f"[Split_continue] search_switch: {search_switch}")
if search_switch == '2':
return 'Input_Summary'
# Default to Split_The_Problem
return 'Split_The_Problem'

View File

@@ -0,0 +1,13 @@
"""LangGraph state management utilities."""
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_search_switch,
extract_tool_call_id,
extract_content_payload,
)
__all__ = [
"extract_search_switch",
"extract_tool_call_id",
"extract_content_payload",
]

View File

@@ -0,0 +1,164 @@
"""
State extraction utilities for type-safe access to LangGraph state values.
This module provides utility functions for extracting values from LangGraph state
dictionaries with proper error handling and sensible defaults.
"""
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
def extract_search_switch(state: dict) -> Optional[str]:
"""
Extract search_switch from state or messages.
"""
search_switch = state.get("search_switch")
if search_switch is not None:
return str(search_switch)
# Try to extract from messages
messages = state.get("messages", [])
if not messages:
return None
# 从最新的消息开始查找
for message in reversed(messages):
# 尝试从 tool_calls 中提取
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if isinstance(tool_call, dict):
# 从 tool_call 的 args 中提取
if "args" in tool_call and isinstance(tool_call["args"], dict):
search_switch = tool_call["args"].get("search_switch")
if search_switch is not None:
return str(search_switch)
# 直接从 tool_call 中提取
search_switch = tool_call.get("search_switch")
if search_switch is not None:
return str(search_switch)
# 尝试从 content 中提取(如果是 JSON 格式)
if hasattr(message, "content"):
try:
import json
if isinstance(message.content, str):
content_data = json.loads(message.content)
if isinstance(content_data, dict):
search_switch = content_data.get("search_switch")
if search_switch is not None:
return str(search_switch)
except (json.JSONDecodeError, ValueError):
pass
return None
def extract_tool_call_id(message: Any) -> str:
"""
Extract tool call ID from message using structured attributes.
This function extracts the tool call ID from a message object, handling both
direct attribute access and tool_calls list structures.
Args:
message: Message object (typically ToolMessage or AIMessage)
Returns:
Tool call ID as string
Raises:
ValueError: If tool call ID cannot be extracted
Examples:
>>> message = ToolMessage(content="...", tool_call_id="call_123")
>>> extract_tool_call_id(message)
'call_123'
"""
# Try direct attribute access for ToolMessage
if hasattr(message, "tool_call_id"):
tool_call_id = message.tool_call_id
if tool_call_id:
return str(tool_call_id)
# Try extracting from tool_calls list for AIMessage
if hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "id" in tool_call:
return str(tool_call["id"])
# Try extracting from id attribute
if hasattr(message, "id"):
message_id = message.id
if message_id:
return str(message_id)
# If all else fails, raise an error
raise ValueError(f"Could not extract tool call ID from message: {type(message)}")
def extract_content_payload(message: Any) -> Any:
"""
Extract content payload from ToolMessage, parsing JSON if needed.
This function extracts the content from a message and attempts to parse it as JSON
if it appears to be a JSON string. It handles various message formats and provides
sensible fallbacks.
Args:
message: Message object (typically ToolMessage)
Returns:
Parsed content (dict, list, or str)
Examples:
>>> message = ToolMessage(content='{"key": "value"}')
>>> extract_content_payload(message)
{'key': 'value'}
>>> message = ToolMessage(content='plain text')
>>> extract_content_payload(message)
'plain text'
"""
# Extract raw content
# For ToolMessages (responses from tools), extract from content
if hasattr(message, "content"):
raw_content = message.content
# If content is empty and this is an AIMessage with tool_calls,
# extract from args (this handles the initial tool call from content_input)
if not raw_content and hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "args" in tool_call:
return tool_call["args"]
else:
raw_content = str(message)
# If content is already a dict or list, return it directly
if isinstance(raw_content, (dict, list)):
return raw_content
# Try to parse as JSON
if isinstance(raw_content, str):
# First, try direct JSON parsing
try:
return json.loads(raw_content)
except (json.JSONDecodeError, ValueError):
pass
# If that fails, try to extract JSON from the string
# This handles cases where the content is embedded in a larger string
import re
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
for candidate in json_candidates:
try:
return json.loads(candidate)
except (json.JSONDecodeError, ValueError):
continue
# If all parsing attempts fail, return the raw content
return raw_content

View File

@@ -0,0 +1,78 @@
import asyncio
import json
from contextlib import asynccontextmanager
from langgraph.constants import START, END
from langgraph.graph import add_messages, StateGraph
from langgraph.prebuilt import ToolNode
from app.core.memory.agent.utils.llm_tools import WriteState
import warnings
import sys
from langchain_core.messages import AIMessage
from app.core.logging_config import get_agent_logger
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@asynccontextmanager
async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None):
logger.info("加载 MCP 工具: %s", [t.name for t in tools])
if config_id:
logger.info(f"使用配置 ID: {config_id}")
data_type_tool = next((t for t in tools if t.name == "Data_type_differentiation"), None)
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
if not data_type_tool or not data_write_tool:
logger.error('不存在数据存储工具', exc_info=True)
raise ValueError('不存在数据存储工具')
# ToolNode
write_node = ToolNode([data_write_tool])
async def call_model(state):
messages = state["messages"]
last_message = messages[-1]
result = await data_type_tool.ainvoke({
"context": last_message[1] if isinstance(last_message, tuple) else last_message.content
})
result=json.loads( result)
# 调用 Data_write传递 config_id
write_params = {
"content": result["context"],
"apply_id": apply_id,
"group_id": group_id,
"user_id": user_id
}
# 如果提供了 config_id添加到参数中
if config_id:
write_params["config_id"] = config_id
logger.debug(f"传递 config_id 到 Data_write: {config_id}")
write_result = await data_write_tool.ainvoke(write_params)
if isinstance(write_result, dict):
content = write_result.get("data", str(write_result))
else:
content = str(write_result)
logger.info("写入内容: %s", content)
return {"messages": [AIMessage(content=content)]}
workflow = StateGraph(WriteState)
workflow.add_node("content_input", call_model)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "content_input")
workflow.add_edge("content_input", "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()
yield graph

View File

@@ -0,0 +1,285 @@
"""
Log Streamer Module
Manages streaming of log file content with file watching and real-time transmission.
"""
import os
import re
import time
import asyncio
from typing import AsyncGenerator, Optional
from pathlib import Path
from app.core.logging_config import get_logger
logger = get_logger(__name__)
class LogStreamer:
"""Manages log file streaming with file watching and content transmission"""
def __init__(self, log_path: str, keepalive_interval: int = 300):
"""
Initialize LogStreamer
Args:
log_path: Path to the log file to stream
keepalive_interval: Interval in seconds for sending keepalive messages (default: 300)
"""
self.log_path = log_path
self.keepalive_interval = keepalive_interval
self.last_position = 0
# Pattern to match and remove timestamp and log level prefix
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
# This pattern is comprehensive to handle various log formats
self.pattern = re.compile(
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
)
logger.info(f"LogStreamer initialized for {log_path}")
@staticmethod
def clean_log_line(line: str) -> str:
"""
Static method to clean log entry by removing timestamp and log level prefix.
This is the canonical log cleaning method used by both file mode and transmission mode.
Args:
line: Raw log line
Returns:
Cleaned log line without timestamp and log level prefix
"""
# Pattern to match and remove timestamp and log level prefix
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
pattern = re.compile(
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
)
cleaned = re.sub(pattern, '', line)
return cleaned
def clean_log_entry(self, line: str) -> str:
"""
Clean log entry by removing timestamp and log level prefix.
This instance method delegates to the static method for consistency.
Args:
line: Raw log line
Returns:
Cleaned log line without timestamp and log level prefix
"""
return LogStreamer.clean_log_line(line)
async def send_keepalive(self) -> dict:
"""
Generate keepalive message
Returns:
Keepalive message dict with timestamp
"""
return {
"event": "keepalive",
"data": {
"timestamp": int(time.time())
}
}
async def read_existing_and_stream(self) -> AsyncGenerator[dict, None]:
"""
Read existing log content first, then watch for new content
This method reads all existing content in the file first,
then continues to watch for new content as it's written.
Yields:
Dict messages with event type and data:
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
- done events: {"event": "done", "data": {"message": "..."}}
"""
logger.info(f"Starting log stream (read existing) for {self.log_path}")
# Check if file exists
if not os.path.exists(self.log_path):
logger.error(f"Log file not found: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件不存在",
"error": f"File not found: {self.log_path}"
}
}
return
try:
with open(self.log_path, 'r', encoding='utf-8') as f:
# First, read all existing content
for line in f:
if line.strip(): # Skip empty lines
cleaned_line = self.clean_log_entry(line)
yield {
"event": "log",
"data": {
"content": cleaned_line.rstrip('\n'),
"timestamp": int(time.time())
}
}
# Now watch for new content
self.last_position = f.tell()
last_keepalive = time.time()
while True:
line = f.readline()
if line:
cleaned_line = self.clean_log_entry(line)
yield {
"event": "log",
"data": {
"content": cleaned_line.rstrip('\n'),
"timestamp": int(time.time())
}
}
last_keepalive = time.time()
else:
# No new content, check if we need to send keepalive
current_time = time.time()
if current_time - last_keepalive >= self.keepalive_interval:
keepalive_msg = await self.send_keepalive()
yield keepalive_msg
last_keepalive = current_time
# Sleep briefly before checking again
await asyncio.sleep(0.1)
except FileNotFoundError:
logger.error(f"Log file disappeared during streaming: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件在流式传输期间变得不可用",
"error": "File not found during streaming"
}
}
except Exception as e:
logger.error(f"Error during log streaming: {e}", exc_info=True)
yield {
"event": "error",
"data": {
"code": 8001,
"message": "流式传输期间发生错误",
"error": str(e)
}
}
finally:
logger.info(f"Log stream ended for {self.log_path}")
yield {
"event": "done",
"data": {
"message": "流式传输完成"
}
}
async def watch_and_stream(self) -> AsyncGenerator[dict, None]:
"""
Watch log file and stream only new content as it's written
This method starts from the end of the file and only streams
new content that is written after the stream starts.
Yields:
Dict messages with event type and data:
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
- done events: {"event": "done", "data": {"message": "..."}}
"""
logger.info(f"Starting log stream (new content only) for {self.log_path}")
# Check if file exists
if not os.path.exists(self.log_path):
logger.error(f"Log file not found: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件不存在",
"error": f"File not found: {self.log_path}"
}
}
return
try:
# Open file and seek to end to start streaming new content
with open(self.log_path, 'r', encoding='utf-8') as f:
# Move to end of file
f.seek(0, os.SEEK_END)
self.last_position = f.tell()
last_keepalive = time.time()
while True:
# Check if file has new content
current_position = f.tell()
# Read new lines if available
line = f.readline()
if line:
# Clean the log entry
cleaned_line = self.clean_log_entry(line)
# Yield log event
yield {
"event": "log",
"data": {
"content": cleaned_line.rstrip('\n'),
"timestamp": int(time.time())
}
}
# Update last keepalive time since we sent data
last_keepalive = time.time()
else:
# No new content, check if we need to send keepalive
current_time = time.time()
if current_time - last_keepalive >= self.keepalive_interval:
keepalive_msg = await self.send_keepalive()
yield keepalive_msg
last_keepalive = current_time
# Sleep briefly before checking again
await asyncio.sleep(0.1)
except FileNotFoundError:
logger.error(f"Log file disappeared during streaming: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件在流式传输期间变得不可用",
"error": "File not found during streaming"
}
}
except Exception as e:
logger.error(f"Error during log streaming: {e}", exc_info=True)
yield {
"event": "error",
"data": {
"code": 8001,
"message": "流式传输期间发生错误",
"error": str(e)
}
}
finally:
logger.info(f"Log stream ended for {self.log_path}")
yield {
"event": "done",
"data": {
"message": "流式传输完成"
}
}

View File

@@ -0,0 +1,32 @@
"""
Agent logger module for backward compatibility.
This module maintains the get_named_logger() function for backward compatibility
while delegating to the centralized logging configuration.
All new code should import directly from app.core.logging_config instead.
"""
__version__ = "0.1.0"
__author__ = "RED_BEAR"
from app.core.logging_config import get_agent_logger
def get_named_logger(name):
"""Get a named logger for agent operations.
This function maintains backward compatibility with existing code.
It delegates to the centralized get_agent_logger() function.
Args:
name: Logger name for namespacing
Returns:
Logger configured for agent operations
Example:
>>> logger = get_named_logger("my_agent")
>>> logger.info("Agent operation started")
"""
return get_agent_logger(name)

View File

@@ -0,0 +1,28 @@
"""
MCP Server package for memory agent.
This package provides the FastMCP server implementation with context-based
dependency injection for tool functions.
Package structure:
- server: FastMCP server initialization and context setup
- tools: MCP tool implementations
- models: Pydantic response models
- services: Business logic services
"""
from app.core.memory.agent.mcp_server.server import (
mcp,
initialize_context,
main,
get_context_resource
)
# Import tools to register them (but don't export them)
from app.core.memory.agent.mcp_server import tools
__all__ = [
'mcp',
'initialize_context',
'main',
'get_context_resource',
]

View File

@@ -0,0 +1,11 @@
"""
MCP Server Instance
This module contains the FastMCP server instance that is shared across all modules.
It's in a separate file to avoid circular import issues.
"""
from mcp.server.fastmcp import FastMCP
# Initialize FastMCP server instance
# This instance is shared across all tool modules
mcp = FastMCP('data_flow')

View File

@@ -0,0 +1,30 @@
"""Pydantic models for MCP server responses."""
from .problem_models import (
ProblemBreakdownItem,
ProblemBreakdownResponse,
ExtendedQuestionItem,
ProblemExtensionResponse,
)
from .summary_models import (
SummaryData,
SummaryResponse,
RetrieveSummaryData,
RetrieveSummaryResponse,
)
from .verification_models import VerificationResult
from .retrieval_models import RetrievalResult, DistinguishTypeResponse
__all__ = [
"ProblemBreakdownItem",
"ProblemBreakdownResponse",
"ExtendedQuestionItem",
"ProblemExtensionResponse",
"SummaryData",
"SummaryResponse",
"RetrieveSummaryData",
"RetrieveSummaryResponse",
"VerificationResult",
"RetrievalResult",
"DistinguishTypeResponse",
]

View File

@@ -0,0 +1,34 @@
"""Pydantic models for problem breakdown and extension operations."""
from typing import List, Optional
from pydantic import BaseModel, Field, RootModel
class ProblemBreakdownItem(BaseModel):
"""Individual item in problem breakdown response."""
id: str
question: str
type: str
reason: Optional[str] = None
class ProblemBreakdownResponse(RootModel[List[ProblemBreakdownItem]]):
"""Response model for problem breakdown containing list of breakdown items."""
pass
class ExtendedQuestionItem(BaseModel):
"""Individual extended question item with reasoning."""
original_question: str = Field(..., description="原始初步问题")
extended_question: str = Field(..., description="扩展后的问题")
type: str = Field(..., description="类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)")
reason: str = Field(..., description="生成该扩展问题的理由")
class ProblemExtensionResponse(RootModel[List[ExtendedQuestionItem]]):
"""Response model for problem extension containing list of extended questions."""
pass

View File

@@ -0,0 +1,17 @@
"""Pydantic models for retrieval operations."""
from typing import List, Dict, Any
from pydantic import BaseModel
class RetrievalResult(BaseModel):
"""Result model for retrieval operation."""
Query: str
Expansion_issue: List[Dict[str, Any]]
class DistinguishTypeResponse(BaseModel):
"""Response model for data type differentiation."""
type: str

View File

@@ -0,0 +1,31 @@
"""Pydantic models for summary operations."""
from typing import List
from pydantic import BaseModel, Field
class SummaryData(BaseModel):
"""Data structure for summary input."""
query: str
history: List[str] = Field(default_factory=list)
retrieve_info: List[str] = Field(default_factory=list)
class SummaryResponse(BaseModel):
"""Response model for summary operation."""
data: SummaryData
query_answer: str
class RetrieveSummaryData(BaseModel):
"""Data structure for retrieve summary response."""
query_answer: str = Field(default="")
class RetrieveSummaryResponse(BaseModel):
"""Response model for retrieve summary operation."""
data: RetrieveSummaryData

View File

@@ -0,0 +1,14 @@
"""Pydantic models for verification operations."""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class VerificationResult(BaseModel):
"""Result model for verification operation."""
query: str
expansion_issue: List[Dict[str, Any]]
split_result: str
reason: Optional[str] = None
history: List[Dict[str, Any]] = Field(default_factory=list)

View File

@@ -0,0 +1,161 @@
"""
MCP Server initialization with FastMCP context setup.
This module initializes the FastMCP server and registers shared resources
in the context for dependency injection into tool functions.
"""
import os
import sys
from mcp.server.fastmcp import FastMCP
from app.core.config import settings
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.redis_tool import RedisSessionStore, store
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID,reload_configuration_from_database
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
from app.core.memory.agent.mcp_server.services.search_service import SearchService
from app.core.memory.agent.mcp_server.services.session_service import SessionService
from app.core.memory.agent.mcp_server.mcp_instance import mcp
logger = get_agent_logger(__name__)
def get_context_resource(ctx, resource_name: str):
"""
Helper function to retrieve a resource from the FastMCP context.
Args:
ctx: FastMCP Context object (passed to tool functions)
resource_name: Name of the resource to retrieve
Returns:
The requested resource
Raises:
AttributeError: If the resource doesn't exist
Example:
@mcp.tool()
async def my_tool(ctx: Context):
template_service = get_context_resource(ctx, 'template_service')
llm_client = get_context_resource(ctx, 'llm_client')
"""
if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None:
raise RuntimeError("Context does not have fastmcp attribute")
if not hasattr(ctx.fastmcp, resource_name):
raise AttributeError(
f"Resource '{resource_name}' not found in context. "
f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}"
)
return getattr(ctx.fastmcp, resource_name)
def initialize_context():
"""
Initialize and register shared resources in FastMCP context.
This function sets up all shared resources that will be available
to tool functions via dependency injection through the context parameter.
Resources are stored as attributes on the FastMCP instance and can be
accessed via ctx.fastmcp in tool functions.
Resources registered:
- session_store: RedisSessionStore for session management
- llm_client: LLM client for structured API calls
- app_settings: Application settings (renamed to avoid conflict with FastMCP settings)
- template_service: Service for template rendering
- search_service: Service for hybrid search
- session_service: Service for session operations
"""
try:
# Register Redis session store
logger.info("Registering session_store in context")
mcp.session_store = store
# Register LLM client
try:
logger.info(f"Registering llm_client in context with model ID: {SELECTED_LLM_ID}")
llm_client = get_llm_client(SELECTED_LLM_ID)
mcp.llm_client = llm_client
logger.info("llm_client registered successfully")
except Exception as e:
logger.error(f"Failed to register llm_client: {e}", exc_info=True)
# 注册一个 None 值,避免工具调用时找不到资源
mcp.llm_client = None
logger.warning("llm_client set to None due to initialization failure")
# Register application settings (renamed to avoid conflict with FastMCP's settings)
logger.info("Registering app_settings in context")
mcp.app_settings = settings
# Register template service
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
# logger.info(f"Registering template_service in context with root: {template_root}")
template_service = TemplateService(template_root)
mcp.template_service = template_service
# Register search service
# logger.info("Registering search_service in context")
search_service = SearchService()
mcp.search_service = search_service
# Register session service
# logger.info("Registering session_service in context")
session_service = SessionService(store)
mcp.session_service = session_service
# logger.info("All context resources registered successfully")
except Exception as e:
logger.error(f"Failed to initialize context: {e}", exc_info=True)
raise
def main():
"""
Main entry point for the MCP server.
Initializes context and starts the server with SSE transport.
"""
try:
# logger.info("Starting MCP server initialization")
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
# Initialize context resources
initialize_context()
# Import and register tools
# logger.info("Importing MCP tools")
from app.core.memory.agent.mcp_server.tools import (
problem_tools,
retrieval_tools,
verification_tools,
summary_tools,
data_tools
)
# logger.info("All MCP tools imported and registered")
# Log registered tools for debugging
import asyncio
tools_list = asyncio.run(mcp.list_tools())
# logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}")
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:8081 with SSE transport")
# Run the server with SSE transport for HTTP connections
# The server will be available at http://127.0.0.1:8081
import uvicorn
app = mcp.sse_app()
uvicorn.run(app, host=settings.SERVER_IP, port=8081, log_level="info")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,23 @@
"""
MCP Server Services
This module provides business logic services for the MCP server:
- TemplateService: Template loading and rendering
- SearchService: Search result processing
- SessionService: Session and history management
- ParameterBuilder: Tool parameter construction
"""
from .template_service import TemplateService, TemplateRenderError
from .search_service import SearchService
from .session_service import SessionService
from .parameter_builder import ParameterBuilder
__all__ = [
"TemplateService",
"TemplateRenderError",
"SearchService",
"SessionService",
"ParameterBuilder",
]

View File

@@ -0,0 +1,157 @@
"""
Parameter Builder for constructing tool call arguments.
This service provides tool-specific parameter transformation logic
to build correct arguments for each tool type.
"""
import json
from typing import Any, Dict, Optional
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
class ParameterBuilder:
"""Service for building tool call arguments based on tool type."""
def __init__(self):
"""Initialize the parameter builder."""
logger.info("ParameterBuilder initialized")
def build_tool_args(
self,
tool_name: str,
content: Any,
tool_call_id: str,
search_switch: str,
apply_id: str,
group_id: str,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Build tool arguments based on tool type.
Different tools expect different argument formats:
- Verify: dict context
- Retrieve: dict context + search_switch
- Summary/Summary_fails: JSON string context
- Retrieve_Summary: unwrap nested context structures
- Input_Summary: raw message string
Args:
tool_name: Name of the tool being invoked
content: Parsed content from previous tool result
tool_call_id: Extracted tool call identifier
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
Returns:
Dictionary of tool arguments ready for invocation
"""
# Base arguments common to most tools
base_args = {
"usermessages": tool_call_id,
"apply_id": apply_id,
"group_id": group_id
}
# Always add storage_type and user_rag_memory_id (with defaults if None)
base_args["storage_type"] = storage_type if storage_type is not None else ""
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
# Tool-specific argument construction
if tool_name == "Verify":
# Verify expects dict context
return {
"context": content if isinstance(content, dict) else {},
**base_args
}
elif tool_name == "Retrieve":
# Retrieve expects dict context + search_switch
return {
"context": content if isinstance(content, dict) else {},
"search_switch": search_switch,
**base_args
}
elif tool_name in ["Summary", "Summary_fails"]:
# Summary tools expect JSON string context
if isinstance(content, dict):
context_str = json.dumps(content, ensure_ascii=False)
elif isinstance(content, str):
context_str = content
else:
context_str = json.dumps({"data": content}, ensure_ascii=False)
return {
"context": context_str,
**base_args
}
elif tool_name == "Retrieve_Summary":
# Retrieve_Summary needs to unwrap nested context structures
# Handle both 'content' and 'context' keys
context_dict = content
if isinstance(content, dict):
# Check for nested 'content' wrapper
if "content" in content:
inner = content["content"]
# If it's a JSON string, parse it
if isinstance(inner, str):
try:
parsed = json.loads(inner)
# Check if parsed has 'context' wrapper
if isinstance(parsed, dict) and "context" in parsed:
context_dict = parsed["context"]
else:
context_dict = parsed
except json.JSONDecodeError:
logger.warning(
f"Failed to parse JSON content for {tool_name}: {inner[:100]}"
)
context_dict = {"Query": "", "Expansion_issue": []}
elif isinstance(inner, dict):
context_dict = inner
# Check for 'context' wrapper
elif "context" in content:
context_dict = content["context"] if isinstance(content["context"], dict) else content
return {
"context": context_dict,
**base_args
}
elif tool_name == "Input_Summary":
# Input_Summary expects raw message string + search_switch
# Content should be the raw message string
if isinstance(content, dict):
# Try to extract message from dict
message_str = content.get("sentence", str(content))
else:
message_str = str(content)
return {
"context": message_str,
"search_switch": search_switch,
**base_args
}
else:
# Default: pass content as context
logger.warning(
f"Unknown tool name '{tool_name}', using default argument structure"
)
return {
"context": content,
**base_args
}

View File

@@ -0,0 +1,193 @@
"""
Search Service for executing hybrid search and processing results.
This service provides clean search result processing with content extraction
and deduplication.
"""
from typing import List, Tuple, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__)
class SearchService:
"""Service for executing hybrid search and processing results."""
def __init__(self):
"""Initialize the search service."""
logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict) -> str:
"""
Extract only meaningful content from search results, dropping all metadata.
Extraction rules by node type:
- Statements: extract 'statement' field
- Entities: extract 'name' and 'fact_summary' fields
- Summaries: extract 'content' field
- Chunks: extract 'content' field
Args:
result: Search result dictionary
Returns:
Clean content string without metadata
"""
if not isinstance(result, dict):
return str(result)
content_parts = []
# Statements: extract statement field
if 'statement' in result and result['statement']:
content_parts.append(result['statement'])
# Summaries/Chunks: extract content field
if 'content' in result and result['content']:
content_parts.append(result['content'])
# Entities: extract name and fact_summary (commented out in original)
# if 'name' in result and result['name']:
# content_parts.append(result['name'])
# if result.get('fact_summary'):
# content_parts.append(result['fact_summary'])
# Return concatenated content or empty string
return '\n'.join(content_parts) if content_parts else ""
def clean_query(self, query: str) -> str:
"""
Clean and escape query text for Lucene.
- Removes wrapping quotes
- Removes newlines and carriage returns
- Applies Lucene escaping
Args:
query: Raw query string
Returns:
Cleaned and escaped query string
"""
q = str(query).strip()
# Remove wrapping quotes
if (q.startswith("'") and q.endswith("'")) or (
q.startswith('"') and q.endswith('"')
):
q = q[1:-1]
# Remove newlines and carriage returns
q = q.replace('\r', ' ').replace('\n', ' ').strip()
# Apply Lucene escaping
q = escape_lucene_query(q)
return q
async def execute_hybrid_search(
self,
group_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search and return clean content.
Args:
group_id: Group identifier for filtering results
question: Search query text
limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False)
Returns:
Tuple of (clean_content, cleaned_query, raw_results)
raw_results is None if return_raw_results=False
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Clean query
cleaned_query = self.clean_query(question)
try:
# Execute search
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
group_id=group_id,
limit=limit,
include=include,
output_path=output_path,
rerank_alpha=rerank_alpha
)
# Extract results based on search type and include parameter
# Prioritize summaries as they contain synthesized contextual information
answer_list = []
# For hybrid search, use reranked_results
if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then statements, chunks, entities
priority_order = ['summaries', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in reranked_results:
category_results = reranked_results[category]
if isinstance(category_results, list):
answer_list.extend(category_results)
else:
# For keyword or embedding search, results are directly in answer dict
# Apply same priority order
priority_order = ['summaries', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in answer:
category_results = answer[category]
if isinstance(category_results, list):
answer_list.extend(category_results)
# Extract clean content from all results
content_list = [
self.extract_content_from_result(ans)
for ans in answer_list
]
# Filter out empty strings and join with newlines
clean_content = '\n'.join([c for c in content_list if c])
# Log first 200 chars
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
# Return raw results if requested
if return_raw_results:
return clean_content, cleaned_query, answer
else:
return clean_content, cleaned_query, None
except Exception as e:
logger.error(
f"Search failed for query '{question}' in group '{group_id}': {e}",
exc_info=True
)
# Return empty results on failure
if return_raw_results:
return "", cleaned_query, {}
else:
return "", cleaned_query, None

View File

@@ -0,0 +1,169 @@
"""
Session Service for managing user sessions and conversation history.
This service provides clean Redis interactions with error handling and
session management utilities.
"""
from typing import List, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.redis_tool import RedisSessionStore
logger = get_agent_logger(__name__)
class SessionService:
"""Service for managing user sessions and conversation history."""
def __init__(self, store: RedisSessionStore):
"""
Initialize the session service.
Args:
store: Redis session store instance
"""
self.store = store
logger.info("SessionService initialized")
def resolve_user_id(self, session_string: str) -> str:
"""
Extract user ID from session string.
Handles formats like:
- 'call_id_user123' -> 'user123'
- 'prefix_id_user456_suffix' -> 'user456_suffix'
Args:
session_string: Session identifier string
Returns:
Extracted user ID
"""
try:
# Split by '_id_' and take everything after it
parts = session_string.split('_id_')
if len(parts) > 1:
return parts[1]
# Fallback: return original string
return session_string
except Exception as e:
logger.warning(
f"Failed to parse user ID from session string '{session_string}': {e}"
)
return session_string
async def get_history(
self,
user_id: str,
apply_id: str,
group_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
)
return []
return history
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
return []
async def save_session(
self,
user_id: str,
query: str,
apply_id: str,
group_id: str,
ai_response: str
) -> Optional[str]:
"""
Save conversation turn to Redis.
Args:
user_id: User identifier
query: User query/message
apply_id: Application identifier
group_id: Group identifier
ai_response: AI response/answer
Returns:
Session ID if successful, None on error
"""
try:
# Validate required fields
if not user_id:
logger.warning("Cannot save session: user_id is empty")
return None
if not query:
logger.warning("Cannot save session: query is empty")
return None
# Save session
session_id = self.store.save_session(
userid=user_id,
messages=query,
apply_id=apply_id,
group_id=group_id,
aimessages=ai_response
)
logger.info(f"Session saved successfully: {session_id}")
return session_id
except Exception as e:
logger.error(
f"Failed to save session for user {user_id}: {e}",
exc_info=True
)
return None
async def cleanup_duplicates(self) -> int:
"""
Remove duplicate session entries.
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- group_id
- messages
- aimessages
Returns:
Number of duplicate sessions deleted
"""
try:
deleted_count = self.store.delete_duplicate_sessions()
logger.info(f"Cleaned up {deleted_count} duplicate sessions")
return deleted_count
except Exception as e:
logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True)
return 0

View File

@@ -0,0 +1,116 @@
"""
Template Service for loading and rendering Jinja2 templates.
This service provides centralized template management with caching and error handling.
"""
import os
from functools import lru_cache
from typing import Optional
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
from app.core.logging_config import get_agent_logger, log_prompt_rendering
logger = get_agent_logger(__name__)
class TemplateRenderError(Exception):
"""Exception raised when template rendering fails."""
def __init__(self, template_name: str, error: Exception, variables: dict):
self.template_name = template_name
self.error = error
self.variables = variables
super().__init__(
f"Failed to render template '{template_name}': {str(error)}"
)
class TemplateService:
"""Service for loading and rendering Jinja2 templates with caching."""
def __init__(self, template_root: str):
"""
Initialize the template service.
Args:
template_root: Root directory containing template files
"""
self.template_root = template_root
self.env = Environment(
loader=FileSystemLoader(template_root),
autoescape=False # Disable autoescape for prompt templates
)
logger.info(f"TemplateService initialized with root: {template_root}")
@lru_cache(maxsize=128)
def _load_template(self, template_name: str) -> Template:
"""
Load a template from disk with caching.
Args:
template_name: Relative path to template file
Returns:
Loaded Jinja2 Template object
Raises:
TemplateNotFound: If template file doesn't exist
"""
try:
return self.env.get_template(template_name)
except TemplateNotFound as e:
expected_path = os.path.join(self.template_root, template_name)
logger.error(
f"Template not found: {template_name}. "
f"Expected path: {expected_path}"
)
raise
async def render_template(
self,
template_name: str,
operation_name: str,
**variables
) -> str:
"""
Load and render a Jinja2 template.
Args:
template_name: Relative path to template file
operation_name: Name for logging (e.g., "split_the_problem")
**variables: Template variables to render
Returns:
Rendered template string
Raises:
TemplateRenderError: If template loading or rendering fails
"""
try:
# Load template (cached)
template = self._load_template(template_name)
# Render template
rendered = template.render(**variables)
# Log rendered prompt
log_prompt_rendering(operation_name, rendered)
return rendered
except TemplateNotFound as e:
logger.error(
f"Template rendering failed for {operation_name} "
f"({template_name}): Template not found",
exc_info=True
)
raise TemplateRenderError(template_name, e, variables)
except Exception as e:
logger.error(
f"Template rendering failed for {operation_name} "
f"({template_name}): {e}",
exc_info=True
)
raise TemplateRenderError(template_name, e, variables)

View File

@@ -0,0 +1,27 @@
"""
MCP Tools module.
This module contains all MCP tool implementations organized by functionality.
Tools are organized into the following modules:
- problem_tools: Question segmentation and extension
- retrieval_tools: Database and context retrieval
- verification_tools: Data verification
- summary_tools: Summarization and summary retrieval
- data_tools: Data type differentiation and writing
"""
# Import all tool modules to register them with the MCP server
from . import problem_tools
from . import retrieval_tools
from . import verification_tools
from . import summary_tools
from . import data_tools
__all__ = [
'problem_tools',
'retrieval_tools',
'verification_tools',
'summary_tools',
'data_tools',
]

View File

@@ -0,0 +1,149 @@
"""
Data Tools for data type differentiation and writing.
This module contains MCP tools for distinguishing data types and writing data.
"""
import os
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.retrieval_models import DistinguishTypeResponse
from app.core.memory.agent.utils.write_tools import write
logger = get_agent_logger(__name__)
@mcp.tool()
async def Data_type_differentiation(
ctx: Context,
context: str
) -> dict:
"""
Distinguish the type of data (read or write).
Args:
ctx: FastMCP context for dependency injection
context: Text to analyze for type differentiation
Returns:
dict: Contains 'context' with the original text and 'type' field
"""
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Render template
try:
system_prompt = await template_service.render_template(
template_name='distinguish_types_prompt.jinja2',
operation_name='status_typle',
user_query=context
)
except Exception as e:
logger.error(
f"Template rendering failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"type": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=DistinguishTypeResponse
)
result = structured.model_dump()
# Add context to result
result["context"] = context
return result
except Exception as e:
logger.error(
f"LLM call failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": f"LLM call failed: {str(e)}"
}
except Exception as e:
logger.error(
f"Data_type_differentiation failed: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": str(e)
}
@mcp.tool()
async def Data_write(
ctx: Context,
content: str,
user_id: str,
apply_id: str,
group_id: str,
config_id: str
) -> dict:
"""
Write data to the database/file system.
Args:
ctx: FastMCP context for dependency injection
content: Data content to write
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
config_id: Configuration ID for processing (optional, integer)
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
"""
try:
# Ensure output directory exists
os.makedirs("data_output", exist_ok=True)
file_path = os.path.join("data_output", "user_data.csv")
# Write data using utility function
try:
await write(content, user_id, apply_id, group_id, config_id=config_id)
logger.info(f"写入成功Config ID: {config_id if config_id else 'None'}")
return {
"status": "success",
"saved_to": file_path,
"data": content,
"config_id": config_id
}
except Exception as e:
logger.error(f"写入失败: {e}", exc_info=True)
return {
"status": "error",
"message": str(e)
}
except Exception as e:
logger.error(
f"Data_write failed: {e}",
exc_info=True
)
return {
"status": "error",
"message": str(e)
}

View File

@@ -0,0 +1,293 @@
"""
Problem Tools for question segmentation and extension.
This module contains MCP tools for breaking down and extending user questions.
"""
import json
import time
from typing import List
from pydantic import BaseModel, Field, RootModel
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.problem_models import (
ProblemBreakdownItem,
ProblemBreakdownResponse,
ExtendedQuestionItem,
ProblemExtensionResponse
)
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
logger = get_agent_logger(__name__)
@mcp.tool()
async def Split_The_Problem(
ctx: Context,
sentence: str,
sessionid: str,
messages_id: str,
apply_id: str,
group_id: str
) -> dict:
"""
Segment the dialogue or sentence into sub-problems.
Args:
ctx: FastMCP context for dependency injection
sentence: Original sentence to split
sessionid: Session identifier
messages_id: Message identifier
apply_id: Application identifier
group_id: Group identifier
Returns:
dict: Contains 'context' (JSON string of split results) and 'original' sentence
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Extract user ID from session
user_id = session_service.resolve_user_id(sessionid)
# Get conversation history
history = await session_service.get_history(user_id, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Render template
try:
system_prompt = await template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
history=history,
sentence=sentence
)
except Exception as e:
logger.error(
f"Template rendering failed for Split_The_Problem: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemBreakdownResponse
)
# Handle RootModel response with .root attribute access
if structured is None:
# LLM returned None, use empty list as fallback
split_result = json.dumps([], ensure_ascii=False)
elif hasattr(structured, 'root') and structured.root is not None:
split_result = json.dumps(
[item.model_dump() for item in structured.root],
ensure_ascii=False
)
elif isinstance(structured, list):
# Fallback: treat structured itself as the list
split_result = json.dumps(
[item.model_dump() for item in structured],
ensure_ascii=False
)
else:
# Last resort: use empty list
split_result = json.dumps([], ensure_ascii=False)
except Exception as e:
logger.error(
f"LLM call failed for Split_The_Problem: {e}",
exc_info=True
)
split_result = json.dumps([], ensure_ascii=False)
logger.info(f"问题拆分")
logger.info(f"问题拆分结果==>>:{split_result}")
# Emit intermediate output for frontend
result = {
"context": split_result,
"original": sentence,
"_intermediate": {
"type": "problem_split",
"data": json.loads(split_result) if split_result else [],
"original_query": sentence
}
}
return result
except Exception as e:
logger.error(
f"Split_The_Problem failed: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('问题拆分', duration)
@mcp.tool()
async def Problem_Extension(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Extend the problem with additional sub-questions.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing split problem results
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'context' (aggregated questions) and 'original' question
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Resolve session ID from usermessages
from app.core.memory.agent.utils.messages_tool import Resolve_username
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Process context to extract questions
extent_quest, original = await Problem_Extension_messages_deal(context)
# Format questions for template rendering
questions_formatted = []
for msg in extent_quest:
if msg.get("role") == "user":
questions_formatted.append(msg.get("content", ""))
# Render template
try:
system_prompt = await template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
history=history,
questions=questions_formatted
)
except Exception as e:
logger.error(
f"Template rendering failed for Problem_Extension: {e}",
exc_info=True
)
return {
"context": {},
"original": original,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
response_content = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemExtensionResponse
)
# Aggregate results by original question
aggregated_dict = {}
for item in response_content.root:
key = getattr(item, "original_question", None) or (
item.get("original_question") if isinstance(item, dict) else None
)
value = getattr(item, "extended_question", None) or (
item.get("extended_question") if isinstance(item, dict) else None
)
if not key or not value:
continue
aggregated_dict.setdefault(key, []).append(value)
except Exception as e:
logger.error(
f"LLM call failed for Problem_Extension: {e}",
exc_info=True
)
aggregated_dict = {}
logger.info(f"问题扩展")
logger.info(f"问题扩展==>>:{aggregated_dict}")
# Emit intermediate output for frontend
result = {
"context": aggregated_dict,
"original": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "problem_extension",
"data": aggregated_dict,
"original_query": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return result
except Exception as e:
logger.error(
f"Problem_Extension failed: {e}",
exc_info=True
)
return {
"context": {},
"original": context.get("original", ""),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('问题扩展', duration)

View File

@@ -0,0 +1,282 @@
"""
Retrieval Tools for database and context retrieval.
This module contains MCP tools for retrieving data using hybrid search.
"""
from dotenv import load_dotenv
import os
from app.core.rag.nlp.search import knowledge_retrieval
# 加载.env文件
load_dotenv()
import time
from typing import List
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.llm_tools import deduplicate_entries, merge_to_key_value_pairs
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
logger = get_agent_logger(__name__)
@mcp.tool()
async def Retrieve(
ctx: Context,
context,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Retrieve data from the database using hybrid search.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary or string containing query information
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'context' with Query and Expansion_issue results
"""
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
start = time.time()
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
# Extract services from context
search_service = get_context_resource(ctx, 'search_service')
databases_anser = []
# Handle both dict and string context
if isinstance(context, dict):
# Process dict context with extended questions
all_items = []
content, original = await Retriev_messages_deal(context)
# Extract all query items from content
# content is like {original_question: [extended_questions...], ...}
for key, values in content.items():
if isinstance(values, list):
all_items.extend(values)
elif isinstance(values, str):
all_items.append(values)
elif values is not None:
# Fallback: convert non-empty non-list values to string
all_items.append(str(values))
# Execute search for each question
for idx, question in enumerate(all_items):
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": question,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query=question
raw_results=clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results=''
cleaned_query = question
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
databases_anser.append({
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(all_items)
}
})
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Continue with empty result for this question
databases_anser.append({
"Query_small": question,
"Result_small": ""
})
# Build initial database data structure
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val):
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
else:
# Handle string context (simple query)
query = str(context).strip()
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = query
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results = ''
cleaned_query = query
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
# Keep structure for Verify/Retrieve_Summary compatibility
dup_databases = {
"Query": cleaned_query,
"Expansion_issue": [{
"Query_small": cleaned_query,
"Answer_Small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": 1,
"total": 1
}
}]
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for query '{query}': {e}",
exc_info=True
)
# Return empty results on failure
dup_databases = {
"Query": query,
"Expansion_issue": []
}
logger.info(
f"检索==>>:{storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
)
# Build result with intermediate outputs
result = {
"context": dup_databases,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
# Add intermediate outputs list if they exist
intermediate_outputs = dup_databases.get('_intermediate_outputs', [])
if intermediate_outputs:
result['_intermediates'] = intermediate_outputs
logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result")
else:
logger.warning("No intermediate outputs found in dup_databases")
return result
except Exception as e:
logger.error(
f"Retrieve failed: {e}",
exc_info=True
)
return {
"context": {
"Query": "",
"Expansion_issue": []
},
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)

View File

@@ -0,0 +1,647 @@
"""
Summary Tools for data summarization.
This module contains MCP tools for summarizing retrieved data and generating responses.
"""
import json
import re
import time
from typing import List
from pydantic import BaseModel, Field
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.summary_models import (
SummaryData,
SummaryResponse,
RetrieveSummaryData,
RetrieveSummaryResponse
)
from app.core.memory.agent.utils.messages_tool import (
Summary_messages_deal,
Resolve_username
)
from app.core.rag.nlp.search import knowledge_retrieval
from dotenv import load_dotenv
import os
# 加载.env文件
load_dotenv()
logger = get_agent_logger(__name__)
@mcp.tool()
async def Summary(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Summarize the verified data.
Args:
ctx: FastMCP context for dependency injection
context: JSON string containing verified data
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Process context to extract answer and query
answer_small, query = await Summary_messages_deal(context)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
# Prepare data for template
data = {
"query": query,
"history": history,
"retrieve_info": answer_small
}
except Exception as e:
logger.error(
f"Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='summary_prompt.jinja2',
operation_name='summary',
data=data,
query=query
)
except Exception as e:
logger.error(
f"Template rendering failed for Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=SummaryResponse
)
aimessages = structured.query_answer or ""
except Exception as e:
logger.error(
f"LLM call failed for Summary: {e}",
exc_info=True
)
aimessages = ""
try:
# Save session
if aimessages != "":
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}",
exc_info=True
)
return {
"status": "error",
"message": str(e)
}
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"验证之后的总结==>>:{aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('总结', duration)
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
@mcp.tool()
async def Retrieve_Summary(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Summarize data directly from retrieval results.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing Query and Expansion_issue from Retrieve
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
if isinstance(context, dict):
if "content" in context:
inner = context["content"]
# If it's a JSON string, parse it
if isinstance(inner, str):
try:
parsed = json.loads(inner)
logger.info(f"Retrieve_Summary: successfully parsed JSON")
except json.JSONDecodeError:
# Try unescaping first
try:
unescaped = inner.encode('utf-8').decode('unicode_escape')
parsed = json.loads(unescaped)
logger.info(f"Retrieve_Summary: parsed after unescaping")
except (json.JSONDecodeError, UnicodeDecodeError) as e:
logger.error(
f"Retrieve_Summary: parsing failed even after unescape: {e}"
)
context_dict = {"Query": "", "Expansion_issue": []}
parsed = None
if parsed:
# Check if parsed has 'context' wrapper
if isinstance(parsed, dict) and "context" in parsed:
context_dict = parsed["context"]
else:
context_dict = parsed
elif isinstance(inner, dict):
context_dict = inner
else:
context_dict = {"Query": "", "Expansion_issue": []}
elif "context" in context:
context_dict = context["context"] if isinstance(context["context"], dict) else context
else:
context_dict = context
else:
context_dict = {"Query": "", "Expansion_issue": []}
query = context_dict.get("Query", "")
expansion_issue = context_dict.get("Expansion_issue", [])
# Extract retrieve_info from expansion_issue
retrieve_info = []
for item in expansion_issue:
# Check for both Answer_Small and Answer_Samll (typo) for backward compatibility
answer = None
if isinstance(item, dict):
if "Answer_Small" in item:
answer = item["Answer_Small"]
elif "Answer_Samll" in item:
answer = item["Answer_Samll"]
if answer is not None:
# Handle both string and list formats
if isinstance(answer, list):
# Join list of characters/strings into a single string
retrieve_info.append(''.join(str(x) for x in answer))
elif isinstance(answer, str):
retrieve_info.append(answer)
else:
retrieve_info.append(str(answer))
# Join all retrieve_info into a single string
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
except Exception as e:
logger.error(
f"Retrieve_Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='retrieve_summary',
query=query,
history=history,
retrieve_info=retrieve_info_str
)
except Exception as e:
logger.error(
f"Template rendering failed for Retrieve_Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
# Handle case where structured response might be None or incomplete
if structured and hasattr(structured, 'data') and structured.data:
aimessages = structured.data.query_answer or ""
else:
logger.warning("Structured response is None or incomplete, using default message")
aimessages = "信息不足,无法回答"
# Check for insufficient information response
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
# Save session
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"Retrieve_Summary: LLM call failed: {e}",
exc_info=True
)
aimessages = ""
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"检索之后的总结==>>:{aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索总结', duration)
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "retrieval_summary",
"summary": aimessages,
"query": query,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
@mcp.tool()
async def Input_Summary(
ctx: Context,
context: str,
usermessages: str,
search_switch: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Generate a quick summary for direct input without verification.
Args:
ctx: FastMCP context for dependency injection
context: String containing the input sentence
usermessages: User messages identifier
search_switch: Search switch value for routing ('2' for summaries only)
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'query_answer' with the summary result
"""
start = time.time()
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# Initialize variables to avoid UnboundLocalError
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
search_service = get_context_resource(ctx, 'search_service')
# Check if llm_client is None
if llm_client is None:
error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable."
logger.error(error_msg)
return error_msg
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
sessionid = sessionid.replace('call_id_', '')
# Get conversation history
history = await session_service.get_history(
str(sessionid),
str(apply_id),
str(group_id)
)
# Override with empty list for now (as in original)
# Log the raw context for debugging
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
# Extract sentence from context
# Context can be a string or might contain the sentence in various formats
try:
# Try to parse as JSON first
if isinstance(context, str) and (context.startswith('{') or context.startswith('[')):
try:
import json
context_dict = json.loads(context)
if isinstance(context_dict, dict):
query = context_dict.get('sentence', context_dict.get('content', context))
else:
query = context
except json.JSONDecodeError:
# Not valid JSON, try regex
match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context)
query = match.group(1) if match else context
else:
query = context
except Exception as e:
logger.warning(f"Failed to extract query from context: {e}")
query = context
# Clean query
query = str(query).strip().strip("\"'")
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
# Execute search based on search_switch and storage_type
try:
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
'''检索'''
if search_switch == '2':
search_params["include"] = ["summaries"]
if storage_type == "rag" and user_rag_memory_id:
raw_results = []
retrieve_info = ""
kb_config={
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id":os.getenv('reranker_id'),
"reranker_top_k": 10
}
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
retrieve_info = '\n\n'.join(retrieval_knowledge)
raw_results=[retrieve_info]
logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}")
except:
retrieve_info=''
raw_results=['']
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
logger.info(f"Input_Summary: 使用 summary 进行检索")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
except Exception as e:
logger.error(
f"Input_Summary: hybrid_search failed, using empty results: {e}",
exc_info=True
)
retrieve_info, question, raw_results = "", query, []
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='input_summary',
query=query,
history=history,
retrieve_info=retrieve_info
)
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
aimessages = structured.data.query_answer or "信息不足,无法回答"
except Exception as e:
logger.error(
f"Input_Summary: response_structured failed, using default answer: {e}",
exc_info=True
)
aimessages = "信息不足,无法回答"
logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "input_summary",
"title": "快速答案",
"summary": aimessages,
"query": query,
"raw_results": raw_results,
"search_mode": "quick_search",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Input_Summary failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "信息不足,无法回答",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)
@mcp.tool()
async def Summary_fails(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Handle workflow failure when summary cannot be generated.
Args:
ctx: FastMCP context for dependency injection
context: Failure context string
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'query_answer' with failure message
"""
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Parse session ID from usermessages
usermessages_parts = usermessages.split('_')[1:]
sessionid = '_'.join(usermessages_parts[:-1])
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
logger.info(f"没有相关数据")
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
return {
"status": "success",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
except Exception as e:
logger.error(
f"Summary_fails failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}

View File

@@ -0,0 +1,169 @@
"""
Verification Tools for data verification.
This module contains MCP tools for verifying retrieved data.
"""
import time
from jinja2 import Template
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.verify_tool import VerifyTool
from app.core.memory.agent.utils.messages_tool import (
Verify_messages_deal,
Retrieve_verify_tool_messages_deal,
Resolve_username
)
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
logger = get_agent_logger(__name__)
@mcp.tool()
async def Verify(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Verify the retrieved data.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing query and expansion issues
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'verified_data' with verification results
"""
start = time.time()
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Load verification prompt template
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2'
# Read template file directly (VerifyTool expects raw template content)
from app.core.memory.agent.utils.messages_tool import read_template_file
system_prompt = await read_template_file(file_path)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
template = Template(system_prompt)
system_prompt = template.render(history=history, sentence=context)
# Process context to extract query and results
Query_small, Result_small, query = await Verify_messages_deal(context)
# Build query list for verification
query_list = []
for query_small, anser in zip(Query_small, Result_small):
query_list.append({
'Query_small': query_small,
'Answer_Small': anser
})
messages = {
"Query": query,
"Expansion_issue": query_list
}
# Call verification workflow
verify_tool = VerifyTool(system_prompt, messages)
verify_result = await verify_tool.verify()
# Parse LLM verification result with error handling
try:
messages_deal = await Retrieve_verify_tool_messages_deal(
verify_result,
history,
query
)
except Exception as e:
logger.error(
f"Retrieve_verify_tool_messages_deal parsing failed: {e}",
exc_info=True
)
# Fallback to avoid 500 errors
messages_deal = {
"data": {
"query": query,
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": history,
}
logger.info(f"验证==>>:{messages_deal}")
# Emit intermediate output for frontend
return {
"status": "success",
"verified_data": messages_deal,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "数据验证",
"result": messages_deal.get("split_result", "unknown"),
"reason": messages_deal.get("reason", ""),
"query": query,
"verified_count": len(query_list),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Verify failed: {e}",
exc_info=True
)
return {
"status": "error",
"message": str(e),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"verified_data": {
"data": {
"query": "",
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": [],
}
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('验证', duration)

View File

@@ -0,0 +1,7 @@
"""Agent utilities."""
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
__all__ = [
"MultimodalProcessor",
]

View File

@@ -0,0 +1,70 @@
import os
import json
from typing import List
from datetime import datetime
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1",
user_id: str = "user1",
apply_id: str = "applyid",
content: str = "这是用户的输入",
ref_id: str = "wyl_20251027",
config_id: str = None
) -> List[DialogData]:
"""Generate chunks from all test data entries using the specified chunker strategy.
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier
user_id: User identifier
apply_id: Application identifier
content: Dialog content
ref_id: Reference identifier
config_id: Configuration ID for processing
Returns:
List of DialogData objects with generated chunks for each test entry
"""
dialog_data_list = []
messages = []
messages.append(ConversationMessage(role="用户", msg=content))
# Create DialogData
conversation_context = ConversationContext(msgs=messages)
# Create DialogData with group_id based on the entry's id for uniqueness
dialog_data = DialogData(
context=conversation_context,
ref_id=ref_id,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
config_id=config_id
)
# Create DialogueChunker and process the dialogue
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks
dialog_data_list.append(dialog_data)
# Convert to dict with datetime serialized
def serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
combined_output = [dd.model_dump() for dd in dialog_data_list]
print(dialog_data_list)
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
return dialog_data_list

View File

@@ -0,0 +1,204 @@
import asyncio
import json
from collections import defaultdict
from typing import TypedDict, Annotated
import os
import logging
from jinja2 import Template
from langchain_core.messages import AnyMessage
from dotenv import load_dotenv
from langgraph.graph import add_messages
from openai import OpenAI
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.config.config_utils import get_picture_config, get_voice_config
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
from app.core.models.base import RedBearModelConfig
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logger = logging.getLogger(__name__)
load_dotenv()
#TODO: Refactor entire picture/voice
# async def LLM_model_request(context,data,query):
# '''
# Agent model request
# Args:
# context:Input request
# data: template parameters
# query:request content
# Returns:
# '''
# template = Template(context)
# system_prompt = template.render(**data)
# llm_client = get_llm_client(SELECTED_LLM_ID)
# result = await llm_client.chat(
# messages=[{"role": "system", "content": system_prompt}] + [{"role": "user", "content": query}]
# )
# return result
async def picture_model_requests(image_url):
'''
Args:
image_url:
Returns:
'''
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 '
system_prompt = await read_template_file(file_path)
result = await Picture_recognize(image_url,system_prompt)
return (result)
class WriteState(TypedDict):
'''
Langgrapg Writing TypedDict
'''
messages: Annotated[list[AnyMessage], add_messages]
user_id:str
apply_id:str
group_id:str
class ReadState(TypedDict):
'''
Langgrapg READING TypedDict
name:
id:user id
loop_count:Traverse times
search_switchtype
config_id: configuration id for filtering results
'''
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
name: str
id: str
loop_count:int
search_switch: str
user_id: str
apply_id: str
group_id: str
config_id: str
class COUNTState:
'''
The number of times the workflow dialogue retrieval content has no correct message recall traversal
'''
def __init__(self, limit: int = 5):
self.total: int = 0 # 当前累加值
self.limit: int = limit # 最大上限
def add(self, value: int = 1):
"""累加数字,如果达到上限就保持最大值"""
self.total += value
print(f"[COUNTState] 当前值: {self.total}")
if self.total >= self.limit:
print(f"[COUNTState] 达到上限 {self.limit}")
self.total = self.limit # 达到上限不再增加
def get_total(self) -> int:
"""获取当前累加值"""
return self.total
def reset(self):
"""手动重置累加值"""
self.total = 0
print(f"[COUNTState] 已重置为 0")
# def embed(texts: list[str]) -> list[list[float]]:
# # 这里可以换成 LangChain Embeddings
# return [[float(len(t) % 5), float(len(t) % 3)] for t in texts]
# def export_store_to_json(store, namespace):
# """Export the entire storage content to a JSON file"""
# # 搜索所有存储项
# all_items = store.search(namespace)
# # 整理数据
# export_data = {}
# for item in all_items:
# if hasattr(item, 'key') and hasattr(item, 'value'):
# export_data[item.key] = item.value
# # 保存到文件
# os.makedirs("memory_logs", exist_ok=True)
# with open("memory_logs/full_memory_export.json", "w", encoding="utf-8") as f:
# json.dump(export_data, f, ensure_ascii=False, indent=2)
# print(f"{len(export_data)} 条记忆到 JSON 文件")
def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list)
for item in data:
grouped[item[query_key]].append(item[result_key])
return [{key: values} for key, values in grouped.items()]
def deduplicate_entries(entries):
seen = set()
deduped = []
for entry in entries:
key = (entry['Query_small'], entry['Result_small'])
if key not in seen:
seen.add(key)
deduped.append(entry)
return deduped
async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
try:
model_config = get_picture_config(SELECTED_LLM_PICTURE_NAME)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
return err
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
backend_model_name = model_config["llm_name"].split("/")[-1]
api_base=model_config['api_base']
logger.info(f"model_name: {backend_model_name}")
logger.info(f"api_key set: {'yes' if api_key else 'no'}")
logger.info(f"base_url: {model_config['api_base']}")
client = OpenAI(
api_key=api_key, base_url=api_base,
)
completion = client.chat.completions.create(
model=backend_model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url":image_path,
},
{"type": "text",
"text": PROMPT_TICKET_EXTRACTION}
]
}
])
picture_text = completion.choices[0].message.content
picture_text = picture_text.replace('```json', '').replace('```', '')
picture_text = json.loads(picture_text)
return (picture_text['statement'])
async def Voice_recognize():
try:
model_config = get_voice_config(SELECTED_LLM_VOICE_NAME)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
return err
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
backend_model_name = model_config["llm_name"].split("/")[-1]
api_base = model_config['api_base']
return api_key,backend_model_name,api_base

View File

@@ -0,0 +1,15 @@
from app.core.config import settings
def get_mcp_server_config():
"""
Get the MCP server configuration
"""
mcp_server_config = {
"data_flow": {
"url": f"http://{settings.SERVER_IP}:8081/sse", # 你前面的 FastMCP(weather) 服务端口
"transport": "sse",
"timeout": 15000,
"sse_read_timeout": 15000,
}
}
return mcp_server_config

View File

@@ -0,0 +1,239 @@
import json
import logging
import re
from typing import List, Any
from langchain_core.messages import AnyMessage
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]:
out = []
for m in msgs:
if hasattr(m, "content"):
out.append({"role": "user", "content": getattr(m, "content", "")})
elif isinstance(m, dict) and "role" in m and "content" in m:
out.append(m)
else:
out.append({"role": "user", "content": str(m)})
return out
def _extract_content(resp: Any) -> str:
"""Extract LLM content and sanitize to raw JSON/text.
- Supports both object and dict response shapes.
- Removes leading role labels (e.g., "Assistant:").
- Strips Markdown code fences like ```json ... ```.
- Attempts to isolate the first valid JSON array/object block when extra text is present.
"""
def _to_text(r: Any) -> str:
try:
# 对象形式: resp.choices[0].message.content
if hasattr(r, "choices") and getattr(r, "choices", None):
msg = r.choices[0].message
if hasattr(msg, "content"):
return msg.content
if isinstance(msg, dict) and "content" in msg:
return msg["content"]
# 字典形式: resp["choices"][0]["message"]["content"]
if isinstance(r, dict):
return r.get("choices", [{}])[0].get("message", {}).get("content", "")
except Exception:
pass
return str(r)
def _clean_text(text: str) -> str:
s = str(text).strip()
# 移除可能的角色前缀
s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s)
# 提取 ```json ... ``` 代码块
m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I)
if m:
s = m.group(1).strip()
# 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段
if not (s.startswith("{") or s.startswith("[")):
left = s.find("[")
right = s.rfind("]")
if left != -1 and right != -1 and right > left:
s = s[left:right + 1].strip()
else:
left = s.find("{")
right = s.rfind("}")
if left != -1 and right != -1 and right > left:
s = s[left:right + 1].strip()
return s
raw = _to_text(resp)
return _clean_text(raw)
def Resolve_username(usermessages):
'''
Extract username
Args:
usermessages: user name
Returns:
'''
usermessages = usermessages.split('_')[1:]
sessionid = '_'.join(usermessages[:-1])
return sessionid
# TODO: USE app.core.memory.src.utils.render_template instead
async def read_template_file(template_path: str) -> str:
"""
读取模板文件
Args:
template_path: 模板文件路径
Returns:
模板内容字符串
Note:
建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能
"""
try:
with open(template_path, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
logger.error(f"模板文件未找到: {template_path}")
raise
except IOError as e:
logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True)
raise
async def Problem_Extension_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
extent_quest = []
original = context.get('original', '')
messages = context.get('context', '')
messages = json.loads(messages)
for message in messages:
question = message.get('question', '')
type = message.get('type', '')
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
return extent_quest, original
async def Retriev_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
if isinstance(context, dict):
if 'context' in context or 'original' in context:
return context.get('context', {}), context.get('original', '')
return content, original_value
async def Verify_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
query = context['context']['Query']
Query_small_list = context['context']['Expansion_issue']
Result_small = []
Query_small = []
for i in Query_small_list:
Result_small.append(i['Answer_Small'][0])
Query_small.append(i['Query_small'])
return Query_small, Result_small, query
async def Summary_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
query = re.findall(r'"query": (.*?),', messages)[0]
query = query.replace('[', '').replace(']', '').strip()
matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages)
answer_small_texts = []
for m in matches:
try:
parsed = json.loads(m)
for item in parsed:
answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', ''))
except Exception:
answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', ''))
return answer_small_texts, query
async def VerifyTool_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
content_messages = messages.split('"context":')[1].replace('""', '"')
messages = str(content_messages).split("name='Retrieve'")[0]
query = re.findall(f'"Query": "(.*?)"', messages)[0]
Query_small = re.findall(f'"Query_small": "(.*?)"', messages)
Result_small = re.findall(f'"Result_small": "(.*?)"', messages)
return Query_small, Result_small, query
async def Retrieve_Summary_messages_deal(context):
pass
async def Retrieve_verify_tool_messages_deal(context, history, query):
'''
Extract data
Args:
context:
Returns:
'''
results = []
# 统一转为字符串,避免 None 或非字符串导致正则报错
text = str(context)
blocks = re.findall(r'\{(.*?)\}', text, flags=re.S)
for block in blocks:
query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block)
answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block)
status = re.search(r'"status"\s*:\s*"([^"]*)"', block)
query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block)
results.append({
"query_small": query_small.group(1) if query_small else None,
"answer_small": answer_small.group(1) if answer_small else None,
# 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误
"status": status.group(1) if status else "",
"query_answer": query_answer.group(1) if query_answer else None
})
result = []
for r in results:
# 统一按字符串判定状态,兼容大小写和缺失情况
status_str = str(r.get('status', '')).strip().lower()
if status_str == 'false':
continue
else:
result.append(r)
split_result = 'failed' if not result else 'success'
result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "",
"history": history}
return result

View File

@@ -0,0 +1,38 @@
# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# sys.path.insert(0, project_root)
# load_dotenv()
# async def llm_client_chat(messages: List[dict]) -> str:
# """使用 OpenAI 兼容接口进行对话,返回内容字符串。"""
# try:
# cfg = get_model_config(SELECTED_LLM_ID)
# rb_config = RedBearModelConfig(
# model_name=cfg["model_name"],
# provider=cfg["provider"],
# api_key=cfg["api_key"],
# base_url=cfg["base_url"],
# )
# client = OpenAIClient(model_config=rb_config, type_="chat")
# except Exception as e:
# logger.error(f"获取模型配置失败:{e}")
# err = f"获取模型配置失败:{str(e)}。请检查!!!"
# return err
# try:
# response = await client.chat(messages)
# print(f"model_tool's llm_client_chat response ======>:\n {response}")
# return _extract_content(response)
# # return _extract_content(result)
# except Exception as e:
# logger.error(f"LLM调用失败{str(e)}。请检查 model_name、api_key、api_base 是否正确。")
# return f"LLM调用失败{str(e)}。请检查 model_name、api_key、api_base 是否正确。"
# async def main(image_url):
# await llm_client_chat(image_url)
#
# # 运行主函数
# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav']))
#

View File

@@ -0,0 +1,131 @@
"""
Multimodal input processor for handling image and audio content.
This module provides utilities for detecting and processing multimodal inputs
(images and audio files) by converting them to text using appropriate models.
"""
import logging
from typing import List
from app.core.memory.agent.multimodal.speech_model import Vico_recognition
from app.core.memory.agent.utils.llm_tools import picture_model_requests
logger = logging.getLogger(__name__)
class MultimodalProcessor:
"""
Processor for handling multimodal inputs (images and audio).
This class detects image and audio file paths in input content and converts
them to text using appropriate recognition models.
"""
# Supported file extensions
IMAGE_EXTENSIONS = ['.jpg', '.png']
AUDIO_EXTENSIONS = [
'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov',
'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv'
]
def __init__(self):
"""Initialize the multimodal processor."""
pass
def is_image(self, content: str) -> bool:
"""
Check if content is an image file path.
Args:
content: Input string to check
Returns:
True if content ends with a supported image extension
Examples:
>>> processor = MultimodalProcessor()
>>> processor.is_image("photo.jpg")
True
>>> processor.is_image("document.pdf")
False
"""
if not isinstance(content, str):
return False
content_lower = content.lower()
return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS)
def is_audio(self, content: str) -> bool:
"""
Check if content is an audio file path.
Args:
content: Input string to check
Returns:
True if content ends with a supported audio extension
Examples:
>>> processor = MultimodalProcessor()
>>> processor.is_audio("recording.mp3")
True
>>> processor.is_audio("video.mp4")
True
>>> processor.is_audio("document.txt")
False
"""
if not isinstance(content, str):
return False
content_lower = content.lower()
return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS)
async def process_input(self, content: str) -> str:
"""
Process input content, converting images/audio to text if needed.
This method detects if the input is an image or audio file and converts
it to text using the appropriate recognition model. If processing fails
or the content is not multimodal, it returns the original content.
Args:
content: Input string (may be file path or regular text)
Returns:
Text content (original or converted from image/audio)
Examples:
>>> processor = MultimodalProcessor()
>>> await processor.process_input("photo.jpg")
"Recognized text from image..."
>>> await processor.process_input("Hello world")
"Hello world"
"""
if not isinstance(content, str):
logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}")
return str(content)
try:
# Check for image input
if self.is_image(content):
logger.info(f"[MultimodalProcessor] Detected image input: {content}")
result = await picture_model_requests(content)
logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...")
return result
# Check for audio input
if self.is_audio(content):
logger.info(f"[MultimodalProcessor] Detected audio input: {content}")
result = await Vico_recognition([content]).run()
logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...")
return result
except Exception as e:
logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True)
logger.info(f"[MultimodalProcessor] Falling back to original content")
return content
# Return original content if not multimodal
return content

View File

@@ -0,0 +1,81 @@
你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则:
角色:
- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。
- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。
- 如果历史信息或上下文与当前问题无关,可忽略。
---
### 历史信息参考
在生成扩展问题时,你可以参考以下历史数据(如果提供):
- 历史对话或任务的主题;
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
- 历史中已解答的问题(避免重复);
- 历史推理链(保持逻辑一致性)。
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
输入历史信息内容:{{history}}
## User Input
{% if questions is string %}
{{ questions }}
{% else %}
{% for question in questions %}
- {{ question }}
{% endfor %}
{% endif %}
需求:
- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。
- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。
- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。
- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。
- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。
- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。
- 子问题数量不超过4个。
- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}]
拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么
输出要求:
- 仅输出 JSON 数组,不要包含任何解释或代码块。
- 每个元素包含:
- `original_question`: 原始问题
- `extended_question`: 扩展后的问题
- `type`: 类型(事实检索/澄清/定义/比较/行动建议)
- `reason`: 生成该扩展问题的简短理由
- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。
示例:
输入:
[
"问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳",
]
输出:
[
{
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
"extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?",
"type": "多跳",
"reason": "输出原问题的关键要素"
},
{
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
"extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?",
"type": "多跳",
"reason": "输出原问题的关键要素"
}
]
**Output format**
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
The output language should always be the same as the input language.{{ json_schema }}

View File

@@ -0,0 +1,37 @@
# 角色
你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。
# 任务
根据提供的上下文信息回答用户的问题。
# 输入信息
- 历史对话:{{history}}
- 检索信息:{{retrieve_info}}
## User Query
{{query}}
# 回答指南
1. 仔细分析用户的问题
2. 优先使用检索信息中的相关内容回答
3. 结合历史对话提供连贯的回复
4. 如果信息不足:
- 对于简单问候或日常对话,给出自然简短的回复
- 对于复杂问题,诚实说明信息不足
5. 保持回答简洁、相关、自然
6. 使用与问题相同的语言回答
**Output format**
- 直接回答问题,像人类对话一样自然流畅
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
- 不要解释推理过程或评论信息来源
- 如果只能部分回答问题,先回答能回答的部分,然后说明哪些方面信息不足
- 如果完全无法回答,简洁地说明:"信息不足,无法回答。"
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
The output language should always be the same as the input language.{{ json_schema }}

View File

@@ -0,0 +1,29 @@
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
你是一个智能问答助手,任务如下
## 目标:
1. 接收一个字典,格式为 {'问题': [答案列表]}。
2. 接收一个问题(字典中的 key
3. 找到与问题匹配的答案列表。
4. 将答案列表合并成一句自然流畅的话:
- 如果答案有两条使用“是”连接例如“A是B”。
- 如果答案有三条或以上,使用“,并且”“另外”等自然连词,保证句子流畅。
5. 输出内容时只输出合并后的答案,不输出关键点或其他文字。
6. 如果问题未在字典中找到对应答案,请输出:
对不起,我没有找到相关信息。
输出要求:
- 文本形式
---
字典示例:
{
'今天的天气怎么样': ['今天天气很好', '今天是晴天']
}
问题示例:
今天的天气怎么样
输出要求:
今天天气很好,是晴天

View File

@@ -0,0 +1,10 @@
请提图像内的文本
返回数据格式以json方式输出,
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
- 关键的JSON格式要求{"statement":识别出的文本内容}
1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
3.确保所有JSON字符串都正确关闭并以逗号分隔
4.JSON字符串值中不包括换行符
5.正确转义的例子“statement”“Zhang Xinhua said\”我非常喜欢这本书\""
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```

View File

@@ -0,0 +1,34 @@
你是一个输入分类助手,负责判断用户输入的意图类型。
## User Input
{{ user_query }}
请你根据以下规则判断:
1. 如果输入是在寻求信息、提问、请求解释、或疑问句(包括隐含的问题),则分类为 "question"。
2. 如果输入是命令、陈述、描述、感叹、或其他类型,不在寻求答案,则分类为 "other"。
只输出:
{
"type": "question"
}
{
"type": "other"
}
示例:
输入:"Python怎么读取文件"
输出:{"type": "question"}
输入:"帮我写个读取文件的函数"
输出:{"type": "other"}
输入:"今天是星期几?"
输出:{"type": "question"}
返回数据格式以json方式输出,
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
- 关键的JSON格式要求{"statement":识别出的文本内容}
1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
3.确保所有JSON字符串都正确关闭并以逗号分隔
4.JSON字符串值中不包括换行符
5.正确转义的例子“statement”“Zhang Xinhua said\”我非常喜欢这本书\""
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```

View File

@@ -0,0 +1,160 @@
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
## 目标:
你需要根据以下类型对输入数据进行分类,并生成相应的拆分策略和示例。
---
### 历史信息参考
在生成扩展问题时,你可以参考以下历史数据(如果提供):
- 历史对话或任务的主题;
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
- 历史中已解答的问题(避免重复);
- 历史推理链(保持逻辑一致性)。
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
输入历史信息内容:{{history}}
## User Input
{{ sentence }}
## 需求:
1:首先判断类型(单跳、多跳、开放域、时间)。
2:根据类型进行拆分。
3:拆分后的内容需保证信息完整且可独立处理。
4:对每个拆分条目,可附加示例或说明。
5:拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}]
拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么
## 指令:
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
单跳Single-hop
描述:问题或数据只需要通过一步即可得到答案或完成拆分,不依赖其他信息。
拆分策略:直接识别核心信息或关键字段,生成可独立处理的片段。
示例:
输入数据:"请列出今年诺贝尔物理学奖的得主"
拆分结果:[
{
"id": "Q1",
"question": "今年诺贝尔物理学奖得主是谁",
"type": "单跳’",
}
]
注意: 当遇到上下文依赖问题时明确指出缺失的信息类型并且question可填写输入问题
多跳Multi-hop:
描述:问题或数据需要通过多步推理或跨多个信息源才能得到答案。
拆分策略:将问题拆解为多个子问题,每个子问题对应一个独立处理步骤,需要具备推理链条与逻辑连接数量。
示例:
输入数据:"今年诺贝尔物理学奖得主的研究领域及代表性成果"
拆分结果:
[
{
"id": "Q1",
"question": 今年诺贝尔物理学奖得主是谁?",
"type": "多跳’",
},
{
"id": "Q2",
"question": "该得主的研究领域是什么?",
"type": "多跳’",
},
{
"id": "Q3",
"question": "该得主的代表性成果有哪些?",
"type": "多跳’"
}
]
开放域Open-domain:
描述:问题或数据不局限于特定知识库,需要从大范围信息中检索和生成答案,而不是从一个已知的小范围数据源中查找。。
拆分策略:根据主题或关键实体拆分,同时保留上下文以便检索外部知识,问题涉及一般性、常识性、跨学科内容,可能是开放式回答(描述性、推理性、综合性)
需要外部知识检索或推理才能确定,比如:“为什么人类需要睡眠?”、“量子计算与经典计算的主要区别是什么?”。
示例:
输入数据:"介绍量子计算的最新研究进展"
拆分结果:
[
{
"id": "Q1",
"question": 量子计算的基本概念是什么?",
"type": "开放域’",
},
{
"id": "Q2",
"question": "当前量子计算的主要研究方向有哪些?",
"type": "开放域’",
},
{
"id": "Q3",
"question": "近期在量子计算领域有哪些重大进展?",
"type": "开放域’",
}
]
时间Temporal:
描述:问题或数据涉及时间维度,需要按时间顺序或时间点拆分。
拆分策略:根据事件时间或时间段拆分为独立条目或问题。
示例:
输入数据:"列出苹果公司过去五年的重大事件"
拆分结果:
[
{
"id": "Q1",
"question": 苹果公司2019年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q2",
"question": "苹果公司2020年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q3",
"question": "苹果公司2021年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q3",
"question": "苹果公司2022年的重大事件有哪些",
"type": "时间’",
}
,
{
"id": "Q4",
"question": "苹果公司2023年的重大事件有哪些",
"type": "时间’",
}
]
输出要求:
- 每个子问题包括:
- `id`: 子问题编号Q1, Q2...
- `question`: 子问题内容
- `type`: 类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)
- `reason`: 拆分的理由(为什么要这样拆)
- 格式案例:
[
{
"id": "Q1",
"question": 量子计算的基本概念是什么?",
"type": "开放域’",
},
{
"id": "Q2",
"question": "当前量子计算的主要研究方向有哪些?",
"type": "开放域’",
},
{
"id": "Q3",
"question": "近期在量子计算领域有哪些重大进展?",
"type": "开放域’",
}
]
- 必须通过json.loads()的格式支持的形式输出
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
- 关键的JSON格式要求
1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
3.确保所有JSON字符串都正确关闭并以逗号分隔
4.JSON字符串值中不包括换行符
5.正确转义的例子“statement”“Zhang Xinhua said\”我非常喜欢这本书\""
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```

View File

@@ -0,0 +1,60 @@
# 角色
你是验证专家
你的目标是针对用户的输入Query_Samll字段的提问和Answer_Samll的回答分析是不是回答Query_Samll这个字段的问题
{#以下可以采用先总括,再展开详细说明的方式,描述你希望智能体在每一个步骤如何进行工作,具体的工作步骤数量可以根据实际需求增删#}
## 工作步骤
1. 获取所有的Query_Samll字段和Answer_Samll字段
2. 分析Answer_Samll的回复是不是和Query_Samll有关系
3. 判断Answer_Samll和Query_Samll之间分析出来的关系状态
4. 如果是True保留否则不要相对应的问题和回答
5. 输出,需要严格按照模版
输入:{{history}}
历史消息:{"history":{{sentence}}}
### 第一步 获取用户的输入
获取用户的输入提取对应的Query_Samll和Answer_Samll
### 第二步 分析验证
需要分析Query_Samll和Answer_Samll之间的关系可以参考history字段的内容如果有关系不是答非所问
## 核心验证标准
在评估子问题拆分时必须严格遵循以下标准且验证过程中完全不依赖于子问题的相关信息Answer_Samll
1. 合理性标准(必须全部满足)
- 完整性:每个不同的子问题必须完整覆盖原问题的所有关键要素(如时间、主体、动作、目标等),无遗漏。
- 最小化每个不同的子问题数量应尽可能少通常不超过原问题关键要素数量的2倍建议2-4个避免冗余和不必要拆分。
- 相关性:每个不同的子问题必须直接服务于原问题的解答,不引入无关内容或扩展原问题未提及的主题。
- 可操作性:每个不同的子问题应能在有限资源(如标准工具或合理时间)内独立解答,且难度适中。
- 逻辑性:每个不同的子问题间应有清晰的逻辑关系(如并列、递进、因果),共同构成原问题的解答路径。
2. 不合理拆分的特征(出现任一特征即为不合理):
- 不同的子问题数量超过5个或明显多于必要数量。
- 引入原问题未提及的新主题、人物、细节或个人看法。
- 拆分过于细碎,失去实用价值,无法高效合成原问题答案。
3. 特殊情况说明:
- 每个不同的子问题与原问题相同,需进一步判断:
- 每个不同的子问题不可进一步拆分 → success合理最小化拆分
- 每个不同的子问题能够进一步拆分为更小、更合理的问题 → failed不合理拆分没有最小化
- 每个不同的子问题数量=原问题核心要素数量 → success理想情况
- 每个不同的子问题数量=核心要素数量+1 → success通常合理
### 第三步 添加状态
如果有相关性并且比较高给一个状态TRUE否则给一个FLASE的状态
### 第四步 判断
如果状态是TRUE保留这条数据否则需不需要这条数据
### 第五步 输出格式
按照json的形式输出
{"data":"Query":原来Query的字段"history":原来的history字段
"expansion_issue":以为列表的形式存储验证之后的数据比如[
{"query_small": query_small,
"answer_small": answer_small,,
"status": 回答的结果是否符合query_small填写状态,
"query_answer": answer_small},
{
"query_small": "张曼婷生日是什么时候?",
"answer_small": "张曼婷喜欢绘画。",
"status": "True",
"query_answer": "张曼 婷喜欢绘画。"
},{}......]
,
"split_result":如果expansion_issue是空的列表返回failed不是空列表返回success,
"reason": 为以上分析完之后的结果给一个说明
}

View File

@@ -0,0 +1,57 @@
{# 角色定义 #}
你是专业的问题解答专家,负责根据上下文信息和检索到的所有信息准确回答用户的问题。
{# 输入数据展示 #}
{% if data %}
## 输入数据
上下文信息:
{% for item in data.history %}
- {{ item }}
{% endfor %}
检索到的所有信息:
{% for item in data.retrieve_info %}
- {{ item }}
{% endfor %}
{% endif %}
## User Query
{{ query }}
{# 问题回答标准 #}
## 问题回答核心标准
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。注意,若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
- 若能根据已有信息回答用户的问题,应根据上下文信息和检索到的所有信息提供简明扼要的答案。
- 若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
{# 重要提醒 #}
再次提醒,给出问题的答案时,仅根据已有的信息进行回答,不能自己编造答案。
{# 输出格式模板 #}
## 输出格式
严格按照以下JSON格式输出不添加任何其他内容
{
"data": {
"query": "{{ query }}",
"history": [
{% for item in data.history %}
"{{ item | replace('"', '\\"') }}"
{% if not loop.last %},{% endif %}
{% endfor %}
],
"retrieve_info": [
{% for item in data.retrieve_info %}
"{{ item | replace('"', '\\"') }}"
{% if not loop.last %},{% endif %}
{% endfor %}
]
},
"query_answer": "{% if not data.history and not data.retrieve_info %}信息不足,无法回答。{% endif %}"
}
**Output format**
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
The output language should always be the same as the input language.{{ json_schema }}

View File

@@ -0,0 +1,203 @@
import redis
import uuid
from datetime import datetime
from app.core.config import settings
class RedisSessionStore:
def __init__(self, host='localhost', port=6379, db=0, password=None,session_id=''):
self.r = redis.Redis(host=host, port=port, db=db, password=password)
self.uudi=session_id
# 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, group_id):
"""
写入一条会话数据,返回 session_id
"""
try:
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
# 使用 Hash 存储结构化数据
result = self.r.hset(key, mapping={
"id": self.uudi,
"sessionid": userid,
"apply_id": apply_id,
"group_id": group_id,
"messages": messages,
"aimessages": aimessages,
"starttime": starttime
})
print(f"保存结果: {result}, session_id: {session_id}")
return session_id # 返回新生成的 session_id
except Exception as e:
print(f"保存会话失败: {e}")
raise e
# ---------------- 读取 ----------------
def get_session(self, session_id):
"""
读取一条会话数据
"""
key = f"session:{session_id}"
data = self.r.hgetall(key)
if data:
return {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
return None
def get_session_apply_group(self, sessionid, apply_id, group_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
"""
result_items = []
# 遍历所有会话数据
for key_bytes in self.r.keys('session:*'):
key = key_bytes.decode('utf-8')
data = self.r.hgetall(key)
if not data:
continue
# 解码数据
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
# 检查三个条件是否都匹配
if (decoded_data.get('sessionid') == sessionid and
decoded_data.get('apply_id') == apply_id and
decoded_data.get('group_id') == group_id):
result_items.append(decoded_data)
return result_items
def get_all_sessions(self):
"""
获取所有会话数据
"""
sessions = {}
for key in self.r.keys('session:*'):
sid = key.decode('utf-8').split(':')[1]
sessions[sid] = self.get_session(sid)
return sessions
# ---------------- 更新 ----------------
def update_session(self, session_id, field, value):
"""
更新单个字段
"""
key = f"session:{session_id}"
if self.r.exists(key):
self.r.hset(key, field, value)
return True
return False
# ---------------- 删除 ----------------
def delete_session(self, session_id):
"""
删除单条会话
"""
key = f"session:{session_id}"
return self.r.delete(key)
def delete_all_sessions(self):
"""
删除所有会话
"""
keys = self.r.keys('session:*')
if keys:
return self.r.delete(*keys)
return 0
def delete_duplicate_sessions(self):
"""
删除重复会话数据,条件:
"sessionid""user_id""group_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
"""
seen = set() # 用来记录已出现的唯一组合
deleted_count = 0
for key_bytes in self.r.keys('session:*'):
key = key_bytes.decode('utf-8')
data = self.r.hgetall(key)
if not data:
continue
# 获取五个字段的值并解码
sessionid = data.get(b'sessionid', b'').decode('utf-8')
user_id = data.get(b'id', b'').decode('utf-8') # 对应user_id
group_id = data.get(b'group_id', b'').decode('utf-8')
messages = data.get(b'messages', b'').decode('utf-8')
aimessages = data.get(b'aimessages', b'').decode('utf-8')
# 用五元组作为唯一标识
identifier = (sessionid, user_id, group_id, messages, aimessages)
if identifier in seen:
# 重复,删除该 key
self.r.delete(key)
deleted_count += 1
else:
# 第一次出现,加入 seen
seen.add(identifier)
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}")
return deleted_count
def find_user_session(self,sessionid):
user_id = sessionid
result_items = []
for key, values in store.get_all_sessions().items():
history = {}
if user_id == str(values['sessionid']):
history["Query"] = values['messages']
history["Answer"] = values['aimessages']
result_items.append(history)
if len(result_items) <= 1:
result_items = []
return (result_items)
def find_user_apply_group(self, sessionid, apply_id, group_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
"""
result_items = []
# 遍历所有会话数据
for key_bytes in self.r.keys('session:*'):
key = key_bytes.decode('utf-8')
data = self.r.hgetall(key)
if not data:
continue
# 解码数据
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
# 检查三个条件是否都匹配
if (decoded_data.get('sessionid') == sessionid and
decoded_data.get('apply_id') == apply_id and
decoded_data.get('group_id') == group_id):
history = {
"Query": decoded_data.get('messages'),
"Answer": decoded_data.get('aimessages')
}
result_items.append(history)
# 如果结果少于等于1条返回空列表
if len(result_items) <= 1:
result_items = []
return result_items
store = RedisSessionStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)

View File

@@ -0,0 +1,59 @@
"""
Type classification utility for distinguishing read/write operations.
"""
from jinja2 import Template
from pydantic import BaseModel
from app.core.logging_config import get_agent_logger, log_prompt_rendering
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.config import settings
logger = get_agent_logger(__name__)
class DistinguishTypeResponse(BaseModel):
"""Response model for type classification"""
type: str
async def status_typle(messages: str) -> dict:
"""
Classify message type as read or write operation.
Args:
messages: User message to classify
Returns:
dict: Contains 'type' field with classification result
"""
try:
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/distinguish_types_prompt.jinja2'
template_content = await read_template_file(file_path)
template = Template(template_content)
system_prompt = template.render(user_query=messages)
log_prompt_rendering("status_typle", system_prompt)
except Exception as e:
logger.error(f"Template rendering failed for status_typle: {e}", exc_info=True)
return {
"type": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
from app.core.memory.utils.config import definitions as config_defs
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=DistinguishTypeResponse
)
return structured.model_dump()
except Exception as e:
logger.error(f"LLM call failed for status_typle: {e}", exc_info=True)
return {
"type": "error",
"message": f"LLM call failed: {str(e)}"
}

View File

@@ -0,0 +1,76 @@
from typing import TypedDict, Annotated, List, Any
from langchain_core.messages import AnyMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph, add_messages
import asyncio
import json
from dotenv import load_dotenv, find_dotenv
import os
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from langchain_core.messages import HumanMessage
from jinja2 import Environment, FileSystemLoader
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
from app.core.logging_config import get_agent_logger
load_dotenv(find_dotenv())
logger = get_agent_logger(__name__)
def keep_last(_, right):
return right
class State(TypedDict):
user_input: Annotated[dict, keep_last]
messages: Annotated[List[AnyMessage], add_messages]
agent1_response: str
agent2_response: str
agent3_response: str
final_response: str
status: Annotated[str, keep_last]
class VerifyTool:
def __init__(self, system_prompt: str="", verify_data: Any=None):
self.system_prompt = system_prompt
if isinstance(verify_data, str):
self.verify_data = verify_data
else:
try:
self.verify_data = json.dumps(verify_data, ensure_ascii=False)
except Exception:
self.verify_data = str(verify_data)
async def model_1(self, state: State) -> State:
llm_client = get_llm_client(SELECTED_LLM_ID)
response_content = await llm_client.chat(
messages=[{"role": "system", "content": self.system_prompt}] + _to_openai_messages(state["messages"])
)
return {
"agent1_response": response_content,
"status": "processed",
}
def get_graph(self):
graph = StateGraph(State)
graph.add_node("model_1", self.model_1)
graph.add_edge(START, "model_1")
graph.add_edge("model_1", END)
compiled_graph = graph.compile()
return compiled_graph
async def verify(self):
graph = self.get_graph()
initial_state = {
"user_input": self.verify_data,
"messages": [HumanMessage(content=self.verify_data)],
"final_response": "",
"status": ""
}
final_state = await graph.ainvoke(initial_state)
# return final_state["final_response"]
return final_state["agent1_response"]

View File

@@ -0,0 +1,49 @@
import os
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy.orm import Session
import logging
import json
from app.db import get_db
from app.models.retrieval_info import RetrievalInfo
logger = logging.getLogger(__name__)
async def write_to_database(host_id: uuid.UUID, data: Any) -> str:
"""
将数据写入数据库
:param host_id: 宿主 ID
:param data: 要写入的数据
:return: 写入数据库的结果
"""
# 从数据库会话中获取会话
db: Session = next(get_db())
try:
if isinstance(data, (dict, list)):
serialized = json.dumps(data, ensure_ascii=False)
elif isinstance(data, str):
serialized = data
else:
serialized = str(data)
new_retrieval_info = RetrievalInfo(
# host_id=host_id,
host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"),
retrieve_info=serialized,
created_at=datetime.now()
)
db.add(new_retrieval_info)
db.commit()
logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}")
return "success to write data to database"
except Exception as e:
db.rollback()
logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}")
raise e
finally:
try:
db.close()
except Exception:
pass

View File

@@ -0,0 +1,183 @@
import asyncio
from dotenv import load_dotenv
import time
from datetime import datetime
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
# 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation_all,
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# 导入配置模块(而不是直接导入变量)
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.log.logging_utils import log_time
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import Memory_summary_generation
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
load_dotenv()
async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id: str = "wyl20251027", config_id: str = None) -> None:
"""
执行完整的知识提取流水线(使用新的 ExtractionOrchestrator
Args:
content: 对话内容
user_id: 用户ID
apply_id: 应用ID
group_id: 组ID
ref_id: 参考ID默认为 "wyl20251027"
config_id: 配置ID用于标记数据处理配置
"""
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}")
logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}")
logger.info(f"Using chunker strategy: {config_defs.SELECTED_CHUNKER_STRATEGY}")
logger.info(f"Using group ID: {config_defs.SELECTED_GROUP_ID}")
logger.info(f"Using embedding ID: {config_defs.SELECTED_EMBEDDING_ID}")
logger.info(f"Config ID: {config_id if config_id else 'None'}")
logger.info(f"LANGFUSE_ENABLED: {config_defs.LANGFUSE_ENABLED}")
logger.info(f"AGENTA_ENABLED: {config_defs.AGENTA_ENABLED}")
# Initialize timing log
log_file = "logs/time.log"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
pipeline_start = time.time()
# 初始化客户端
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# 获取 embedder 配置
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
neo4j_connector = Neo4jConnector()
# Step 1: 加载和分块数据
step_start = time.time()
chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
content=content,
ref_id=ref_id,
config_id=config_id,
)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# Step 2: 初始化并运行 ExtractionOrchestrator
step_start = time.time()
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
)
# 运行完整的提取流水线
# orchestrator.run returns a flat tuple of 7 values after deduplication
(
all_dialogue_nodes,
all_chunk_nodes,
all_statement_nodes,
all_entity_nodes,
all_statement_chunk_edges,
all_statement_entity_edges,
all_entity_entity_edges,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# Step 8: Save all data to Neo4j database using graph models
step_start = time.time()
# 运行索引创建
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
try:
await create_fulltext_indexes()
except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True)
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
finally:
await neo4j_connector.close()
log_time("Neo4j Database Save", time.time() - step_start, log_file)
# Step 9: Generate Memory summaries and save to local vector DB and Neo4j
step_start = time.time()
try:
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
)
# Save memory summaries to Neo4j as nodes
try:
ms_connector = Neo4jConnector()
await add_memory_summary_nodes(summaries, ms_connector)
# Link summaries to statements via chunks for summary→entity queries
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
try:
await ms_connector.close()
except Exception:
pass
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
finally:
log_time("Memory Summary (Local Vector DB & Neo4j)", time.time() - step_start, log_file)
# Log total pipeline time
total_time = time.time() - pipeline_start
log_time("TOTAL PIPELINE TIME", total_time, log_file)
# Add completion marker to log
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
logger.info(f"Timing details saved to: {log_file}")
if __name__ == "__main__":
content = "你好,我是张三,是张曼婷的新朋友。请问张曼婷喜欢什么?"
asyncio.run(write(content, ref_id="wyl20251027"))

View File

@@ -0,0 +1,19 @@
"""
LLM 工具模块
提供 LLM 和 Embedder 客户端的抽象基类和具体实现。
"""
from app.core.memory.llm_tools.llm_client import LLMClient
from app.core.memory.llm_tools.embedder_client import EmbedderClient
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.llm_tools.chunker_client import ChunkerClient
__all__ = [
"LLMClient",
"EmbedderClient",
"OpenAIClient",
"OpenAIEmbedderClient",
"ChunkerClient",
]

View File

@@ -0,0 +1,330 @@
from typing import Any, List
import re
import os
import asyncio
import json
import numpy as np
# Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from chonkie import (
SemanticChunker,
RecursiveChunker,
RecursiveRules,
LateChunker,
NeuralChunker,
SentenceChunker,
TokenChunker,
)
from app.core.memory.models.config_models import ChunkerConfig
from app.core.memory.models.message_models import DialogData, Chunk
try:
from app.core.memory.llm_tools.openai_client import OpenAIClient
except Exception:
# 在测试或无可用依赖(如 langfuse环境下允许惰性导入
OpenAIClient = Any
class LLMChunker:
"""基于LLM的智能分块策略"""
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
self.llm_client = llm_client
self.chunk_size = chunk_size
async def __call__(self, text: str) -> List[Any]:
# 使用LLM分析文本结构并进行智能分块
prompt = f"""
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
请以JSON格式返回结果包含chunks数组每个chunk有text字段。
文本内容:
{text[:5000]}
"""
messages = [
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
{"role": "user", "content": prompt}
]
try:
# 使用异步的 achat 方法
if hasattr(self.llm_client, 'achat'):
response = await self.llm_client.achat(messages)
else:
# 如果没有异步方法,使用同步方法并转换为异步
response = await asyncio.to_thread(self.llm_client.chat, messages)
# 检查响应格式并提取内容
if hasattr(response, 'choices') and len(response.choices) > 0:
content = response.choices[0].message.content
elif hasattr(response, 'content'):
content = response.content
else:
content = str(response)
# 解析LLM响应
if "```json" in content:
json_str = content.split("```json")[1].split("```")[0].strip()
elif "```" in content:
json_str = content.split("```")[1].split("```")[0].strip()
else:
json_str = content
result = json.loads(json_str)
class SimpleChunk:
def __init__(self, text, index):
self.text = text
self.start_index = index * 100 # 近似位置
self.end_index = (index + 1) * 100
return [SimpleChunk(chunk["text"], i) for i, chunk in enumerate(result.get("chunks", []))]
except Exception as e:
print(f"LLM分块失败: {e}")
# 失败时返回空列表,外层会处理回退方案
return []
class HybridChunker:
"""混合分块策略:先按结构分块,再按语义合并"""
def __init__(self, semantic_threshold: float = 0.8, base_chunk_size: int = 300):
self.semantic_threshold = semantic_threshold
self.base_chunk_size = base_chunk_size
self.base_chunker = TokenChunker(tokenizer="character", chunk_size=base_chunk_size)
self.semantic_chunker = SemanticChunker(threshold=semantic_threshold)
def __call__(self, text: str) -> List[Any]:
# 先用基础分块
base_chunks = self.base_chunker(text)
# 如果文本不长,直接返回基础分块
if len(base_chunks) <= 3:
return base_chunks
# 对基础分块进行语义合并
combined_text = " ".join([chunk.text for chunk in base_chunks])
return self.semantic_chunker(combined_text)
class ChunkerClient:
def __init__(self, chunker_config: ChunkerConfig, llm_client: OpenAIClient = None):
self.chunker_config = chunker_config
self.embedding_model = chunker_config.embedding_model
self.chunk_size = chunker_config.chunk_size
self.threshold = chunker_config.threshold
self.language = chunker_config.language
self.skip_window = chunker_config.skip_window
self.min_sentences = chunker_config.min_sentences
self.min_characters_per_chunk = chunker_config.min_characters_per_chunk
self.llm_client = llm_client
# 可选参数(从配置中安全获取,提供默认值)
self.chunk_overlap = getattr(chunker_config, 'chunk_overlap', 0)
self.min_sentences_per_chunk = getattr(chunker_config, 'min_sentences_per_chunk', 1)
self.min_characters_per_sentence = getattr(chunker_config, 'min_characters_per_sentence', 12)
self.delim = getattr(chunker_config, 'delim', [".", "!", "?", "\n"])
self.include_delim = getattr(chunker_config, 'include_delim', "prev")
self.tokenizer_or_token_counter = getattr(chunker_config, 'tokenizer_or_token_counter', "character")
# 初始化具体分块器策略
if chunker_config.chunker_strategy == "TokenChunker":
self.chunker = TokenChunker(
tokenizer=self.tokenizer_or_token_counter,
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
)
elif chunker_config.chunker_strategy == "SemanticChunker":
self.chunker = SemanticChunker(
embedding_model=self.embedding_model,
threshold=self.threshold,
chunk_size=self.chunk_size,
min_sentences=self.min_sentences,
)
elif chunker_config.chunker_strategy == "RecursiveChunker":
self.chunker = RecursiveChunker(
rules=RecursiveRules(),
min_characters_per_chunk=self.min_characters_per_chunk or 50,
chunk_size=self.chunk_size,
)
elif chunker_config.chunker_strategy == "LateChunker":
self.chunker = LateChunker(
embedding_model=self.embedding_model,
chunk_size=self.chunk_size,
rules=RecursiveRules(),
min_characters_per_chunk=self.min_characters_per_chunk,
)
elif chunker_config.chunker_strategy == "NeuralChunker":
self.chunker = NeuralChunker(
model=self.embedding_model,
min_characters_per_chunk=self.min_characters_per_chunk,
)
elif chunker_config.chunker_strategy == "LLMChunker":
if not llm_client:
raise ValueError("LLMChunker requires an LLM client")
self.chunker = LLMChunker(llm_client, self.chunk_size)
elif chunker_config.chunker_strategy == "HybridChunker":
self.chunker = HybridChunker(
semantic_threshold=self.threshold,
base_chunk_size=self.chunk_size,
)
elif chunker_config.chunker_strategy == "SentenceChunker":
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
# 为了兼容不同版本,这里仅传递广泛支持的参数
self.chunker = SentenceChunker(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
min_sentences_per_chunk=self.min_sentences_per_chunk,
min_characters_per_sentence=self.min_characters_per_sentence,
delim=self.delim,
include_delim=self.include_delim,
)
else:
raise ValueError(f"Unknown chunker strategy: {chunker_config.chunker_strategy}")
async def generate_chunks(self, dialogue: DialogData):
"""
生成分块,支持异步操作
"""
try:
# 预处理文本:确保对话标记格式统一
content = dialogue.content
content = content.replace('AI', 'AI:').replace('用户:', '用户:') # 统一冒号
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
# 同步分块器
chunks = self.chunker(content)
else:
# 异步分块器如LLMChunker
chunks = await self.chunker(content)
# 过滤空块和过小的块
valid_chunks = []
for c in chunks:
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
valid_chunks.append(c)
dialogue.chunks = [
Chunk(
content=c.text if hasattr(c, 'text') else str(c),
metadata={
"start_index": getattr(c, "start_index", None),
"end_index": getattr(c, "end_index", None),
"chunker_strategy": self.chunker_config.chunker_strategy,
},
)
for c in valid_chunks
]
return dialogue
except Exception as e:
print(f"分块失败: {e}")
# 改进的后备方案:尝试按对话回合分割
try:
# 简单的按对话分割
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
class SimpleChunk:
def __init__(self, text, start_index, end_index):
self.text = text
self.start_index = start_index
self.end_index = end_index
chunks = []
current_chunk = ""
current_start = 0
for match in matches:
speaker, ct = match[0], match[1].strip()
turn_text = f"{speaker} {ct}"
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
current_chunk = turn_text
current_start = dialogue.content.find(turn_text, current_start)
else:
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
dialogue.chunks = [
Chunk(
content=c.text,
metadata={
"start_index": c.start_index,
"end_index": c.end_index,
"chunker_strategy": "DialogueTurnFallback",
},
)
for c in chunks
]
except Exception:
# 最后的手段:单一大块
dialogue.chunks = [Chunk(
content=dialogue.content,
metadata={"chunker_strategy": "SingleChunkFallback"},
)]
return dialogue
def evaluate_chunking(self, dialogue: DialogData) -> dict:
"""
评估分块质量
"""
if not getattr(dialogue, 'chunks', None):
return {}
chunks = dialogue.chunks
total_chars = sum(len(chunk.content) for chunk in chunks)
avg_chunk_size = total_chars / len(chunks)
# 计算各种指标
chunk_sizes = [len(chunk.content) for chunk in chunks]
metrics = {
"strategy": self.chunker_config.chunker_strategy,
"num_chunks": len(chunks),
"total_characters": total_chars,
"avg_chunk_size": avg_chunk_size,
"min_chunk_size": min(chunk_sizes),
"max_chunk_size": max(chunk_sizes),
"chunk_size_std": np.std(chunk_sizes) if len(chunk_sizes) > 1 else 0,
"coverage_ratio": total_chars / len(dialogue.content) if dialogue.content else 0,
}
return metrics
def save_chunking_results(self, dialogue: DialogData, output_path: str):
"""
保存分块结果到文件,文件名包含策略名称
"""
strategy_name = self.chunker_config.chunker_strategy
# 在文件名中添加策略名称
base_name, ext = os.path.splitext(output_path)
strategy_output_path = f"{base_name}_{strategy_name}{ext}"
with open(strategy_output_path, 'w', encoding='utf-8') as f:
f.write(f"=== Chunking Strategy: {strategy_name} ===\n")
f.write(f"Total chunks: {len(dialogue.chunks)}\n")
f.write(f"Total characters: {sum(len(chunk.content) for chunk in dialogue.chunks)}\n")
f.write("=" * 60 + "\n\n")
for i, chunk in enumerate(dialogue.chunks):
f.write(f"Chunk {i+1}:\n")
f.write(f"Size: {len(chunk.content)} characters\n")
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")
f.write(f"Content: {chunk.content}\n")
f.write("-" * 40 + "\n\n")
print(f"Chunking results saved to: {strategy_output_path}")
return strategy_output_path

View File

@@ -0,0 +1,176 @@
"""
Embedder 客户端抽象基类
提供统一的嵌入向量生成接口,支持重试机制和错误处理。
"""
from abc import ABC, abstractmethod
from typing import List, Optional
import asyncio
import logging
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
)
from app.core.models.base import RedBearModelConfig
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
class EmbedderClientException(BusinessException):
"""Embedder 客户端异常"""
def __init__(self, message: str, code: str = BizCode.EMBEDDING_ERROR):
super().__init__(message, code=code)
class EmbedderClient(ABC):
"""
Embedder 客户端抽象基类
提供统一的嵌入向量生成接口,包括:
- 批量文本嵌入response
- 自动重试机制
- 错误处理
"""
def __init__(self, model_config: RedBearModelConfig):
"""
初始化 Embedder 客户端
Args:
model_config: 模型配置包含模型名称、提供商、API密钥等信息
"""
self.config = model_config
self.model_name = model_config.model_name
self.provider = model_config.provider
self.api_key = model_config.api_key
self.base_url = model_config.base_url
self.max_retries = model_config.max_retries
self.timeout = model_config.timeout
logger.info(
f"初始化 Embedder 客户端: provider={self.provider}, "
f"model={self.model_name}, max_retries={self.max_retries}"
)
@abstractmethod
async def response(
self,
messages: List[str],
**kwargs
) -> List[List[float]]:
"""
生成嵌入向量
Args:
messages: 文本列表
**kwargs: 额外参数
Returns:
嵌入向量列表,每个向量是一个浮点数列表
Raises:
EmbedderClientException: 嵌入向量生成失败
"""
pass
def _create_retry_decorator(self):
"""
创建重试装饰器
Returns:
配置好的 tenacity retry 装饰器
"""
return retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type((
asyncio.TimeoutError,
ConnectionError,
Exception, # 可以根据需要细化异常类型
)),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
async def response_with_retry(
self,
messages: List[str],
**kwargs
) -> List[List[float]]:
"""
带重试机制的嵌入向量生成接口
Args:
messages: 文本列表
**kwargs: 额外参数
Returns:
嵌入向量列表
Raises:
EmbedderClientException: 重试失败后抛出
"""
retry_decorator = self._create_retry_decorator()
@retry_decorator
async def _response_with_retry():
try:
return await self.response(messages, **kwargs)
except Exception as e:
logger.error(f"嵌入向量生成失败: {e}")
raise EmbedderClientException(f"嵌入向量生成失败: {e}") from e
return await _response_with_retry()
async def embed_single(self, text: str, **kwargs) -> List[float]:
"""
为单个文本生成嵌入向量
Args:
text: 单个文本
**kwargs: 额外参数
Returns:
嵌入向量(浮点数列表)
Raises:
EmbedderClientException: 嵌入向量生成失败
"""
embeddings = await self.response_with_retry([text], **kwargs)
return embeddings[0] if embeddings else []
async def embed_batch(
self,
texts: List[str],
batch_size: int = 100,
**kwargs
) -> List[List[float]]:
"""
批量生成嵌入向量(支持大批量文本)
Args:
texts: 文本列表
batch_size: 每批处理的文本数量
**kwargs: 额外参数
Returns:
嵌入向量列表
Raises:
EmbedderClientException: 嵌入向量生成失败
"""
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_embeddings = await self.response_with_retry(batch, **kwargs)
all_embeddings.extend(batch_embeddings)
return all_embeddings

View File

@@ -0,0 +1,187 @@
"""
LLM 客户端抽象基类
提供统一的 LLM 调用接口,支持重试机制和错误处理。
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
import asyncio
import logging
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
)
from app.core.models.base import RedBearModelConfig
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
class LLMClientException(BusinessException):
"""LLM 客户端异常"""
def __init__(self, message: str, code: str = BizCode.LLM_ERROR):
super().__init__(message, code=code)
class LLMClient(ABC):
"""
LLM 客户端抽象基类
提供统一的 LLM 调用接口,包括:
- 聊天接口chat
- 结构化输出接口response_structured
- 自动重试机制
- 错误处理
"""
def __init__(self, model_config: RedBearModelConfig):
"""
初始化 LLM 客户端
Args:
model_config: 模型配置包含模型名称、提供商、API密钥等信息
"""
self.config = model_config
self.model_name = self.config.model_name
self.provider = self.config.provider
self.api_key = self.config.api_key
self.base_url = self.config.base_url
self.max_retries = self.config.max_retries
self.timeout = self.config.timeout
logger.info(
f"初始化 LLM 客户端: provider={self.provider}, "
f"model={self.model_name}, max_retries={self.max_retries}"
)
@abstractmethod
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
"""
聊天接口
Args:
messages: 消息列表,每个消息包含 role 和 content
**kwargs: 额外参数
Returns:
LLM 响应内容
Raises:
LLMClientException: LLM 调用失败
"""
pass
@abstractmethod
async def response_structured(
self,
messages: List[Dict[str, str]],
response_model: type[BaseModel],
**kwargs
) -> BaseModel:
"""
结构化输出接口
Args:
messages: 消息列表
response_model: 期望的响应模型类型Pydantic BaseModel
**kwargs: 额外参数
Returns:
解析后的 Pydantic 模型实例
Raises:
LLMClientException: LLM 调用或解析失败
"""
pass
def _create_retry_decorator(self):
"""
创建重试装饰器
Returns:
配置好的 tenacity retry 装饰器
"""
return retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type((
asyncio.TimeoutError,
ConnectionError,
Exception, # 可以根据需要细化异常类型
)),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
async def chat_with_retry(
self,
messages: List[Dict[str, str]],
**kwargs
) -> Any:
"""
带重试机制的聊天接口
Args:
messages: 消息列表
**kwargs: 额外参数
Returns:
LLM 响应内容
Raises:
LLMClientException: 重试失败后抛出
"""
retry_decorator = self._create_retry_decorator()
@retry_decorator
async def _chat_with_retry():
try:
return await self.chat(messages, **kwargs)
except Exception as e:
logger.error(f"LLM 调用失败: {e}")
raise LLMClientException(f"LLM 调用失败: {e}") from e
return await _chat_with_retry()
async def response_structured_with_retry(
self,
messages: List[Dict[str, str]],
response_model: type[BaseModel],
**kwargs
) -> BaseModel:
"""
带重试机制的结构化输出接口
Args:
messages: 消息列表
response_model: 期望的响应模型类型
**kwargs: 额外参数
Returns:
解析后的 Pydantic 模型实例
Raises:
LLMClientException: 重试失败后抛出
"""
retry_decorator = self._create_retry_decorator()
@retry_decorator
async def _response_structured_with_retry():
try:
return await self.response_structured(
messages,
response_model,
**kwargs
)
except Exception as e:
logger.error(f"LLM 结构化输出失败: {e}")
raise LLMClientException(f"LLM 结构化输出失败: {e}") from e
return await _response_structured_with_retry()

View File

@@ -0,0 +1,198 @@
"""
OpenAI LLM 客户端实现
基于 LangChain 和 RedBearLLM 的 OpenAI 客户端实现。
"""
import asyncio
from typing import List, Dict, Any
import json
import logging
from pydantic import BaseModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from app.core.models.base import RedBearModelConfig
from app.core.models.llm import RedBearLLM
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED
logger = logging.getLogger(__name__)
class OpenAIClient(LLMClient):
"""
OpenAI LLM 客户端实现
基于 LangChain 和 RedBearLLM 的实现,支持:
- 聊天接口
- 结构化输出
- Langfuse 追踪(可选)
"""
def __init__(self, model_config: RedBearModelConfig, type_: str = "chat"):
"""
初始化 OpenAI 客户端
Args:
model_config: 模型配置
type_: 模型类型,"chat""completion"
"""
super().__init__(model_config)
# 初始化 Langfuse 回调处理器(如果启用)
self.langfuse_handler = None
if LANGFUSE_ENABLED:
try:
from langfuse.langchain import CallbackHandler
self.langfuse_handler = CallbackHandler()
logger.info("Langfuse 追踪已启用")
except ImportError:
logger.warning("Langfuse 未安装,跳过追踪功能")
except Exception as e:
logger.warning(f"初始化 Langfuse 处理器失败: {e}")
# 初始化 RedBearLLM 客户端
self.client = RedBearLLM(
RedBearModelConfig(
model_name=self.model_name,
provider=self.provider,
api_key=self.api_key,
base_url=self.base_url,
max_retries=self.max_retries,
timeout=self.timeout,
),
type=type_
)
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
"""
聊天接口实现
Args:
messages: 消息列表
**kwargs: 额外参数
Returns:
LLM 响应内容
Raises:
LLMClientException: LLM 调用失败
"""
try:
template = """{messages}"""
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | self.client
# 添加 Langfuse 回调(如果可用)
config = {}
if self.langfuse_handler:
config["callbacks"] = [self.langfuse_handler]
response = await chain.ainvoke({"messages": messages}, config=config)
logger.debug(f"LLM 响应成功: {len(str(response))} 字符")
return response
except Exception as e:
logger.error(f"LLM 调用失败: {e}")
raise LLMClientException(f"LLM 调用失败: {e}") from e
async def response_structured(
self,
messages: List[Dict[str, str]],
response_model: type[BaseModel],
**kwargs
) -> BaseModel:
"""
结构化输出接口实现
Args:
messages: 消息列表
response_model: 期望的响应模型类型
**kwargs: 额外参数
Returns:
解析后的 Pydantic 模型实例
Raises:
LLMClientException: LLM 调用或解析失败
"""
try:
# 构建问题文本
question_text = "\n\n".join([
str(m.get("content", "")) for m in messages
])
# 准备配置(包含 Langfuse 回调)
config = {}
if self.langfuse_handler:
config["callbacks"] = [self.langfuse_handler]
# 方法 1: 使用 PydanticOutputParser
if PydanticOutputParser is not None:
try:
parser = PydanticOutputParser(pydantic_object=response_model)
format_instructions = parser.get_format_instructions()
prompt = ChatPromptTemplate.from_template(
"{question}\n{format_instructions}"
)
chain = prompt | self.client | parser
parsed = await chain.ainvoke(
{
"question": question_text,
"format_instructions": format_instructions,
},
config=config
)
logger.debug(f"使用 PydanticOutputParser 解析成功")
return parsed
except Exception as e:
logger.warning(
f"PydanticOutputParser 解析失败,尝试其他方法: {e}"
)
# 方法 2: 使用 LangChain 的 with_structured_output
template = """{question}"""
prompt = ChatPromptTemplate.from_template(template)
try:
with_so = getattr(self.client, "with_structured_output", None)
if callable(with_so):
structured_chain = prompt | with_so(response_model, strict=True)
parsed = await structured_chain.ainvoke(
{"question": question_text},
config=config
)
# 验证并返回结果
try:
return response_model.model_validate(parsed)
except Exception:
# 如果已经是 Pydantic 实例,直接返回
if hasattr(parsed, "model_dump"):
return parsed
# 尝试从 JSON 解析
return response_model.model_validate_json(json.dumps(parsed))
except Exception as e:
logger.error(f"结构化输出失败: {e}")
raise LLMClientException(f"结构化输出失败: {e}") from e
# 如果所有方法都失败,抛出异常
raise LLMClientException(
"无法生成结构化输出,所有解析方法均失败"
)
except LLMClientException:
raise
except Exception as e:
logger.error(f"结构化输出处理失败: {e}")
raise LLMClientException(f"结构化输出处理失败: {e}") from e

View File

@@ -0,0 +1,87 @@
"""
OpenAI Embedder 客户端实现
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
"""
from typing import List
import logging
from app.core.memory.llm_tools.embedder_client import (
EmbedderClient,
EmbedderClientException
)
from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings
logger = logging.getLogger(__name__)
class OpenAIEmbedderClient(EmbedderClient):
"""
OpenAI Embedder 客户端实现
基于 LangChain 和 RedBearEmbeddings 的实现,支持:
- 批量文本嵌入
- 自动重试机制
- 错误处理
"""
def __init__(self, model_config: RedBearModelConfig):
"""
初始化 OpenAI Embedder 客户端
Args:
model_config: 模型配置
"""
super().__init__(model_config)
# 初始化 RedBearEmbeddings 模型
self.model = RedBearEmbeddings(
RedBearModelConfig(
model_name=self.model_name,
provider=self.provider,
api_key=self.api_key,
base_url=self.base_url,
max_retries=self.max_retries,
timeout=self.timeout,
)
)
logger.info("OpenAI Embedder 客户端初始化完成")
async def response(
self,
messages: List[str],
**kwargs
) -> List[List[float]]:
"""
生成嵌入向量实现
Args:
messages: 文本列表
**kwargs: 额外参数
Returns:
嵌入向量列表
Raises:
EmbedderClientException: 嵌入向量生成失败
"""
try:
# 过滤空文本
texts: List[str] = [str(m) for m in messages if m is not None]
if not texts:
logger.warning("输入文本列表为空,返回空结果")
return []
# 生成嵌入向量
embeddings = await self.model.aembed_documents(texts)
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
return embeddings
except Exception as e:
logger.error(f"嵌入向量生成失败: {e}")
raise EmbedderClientException(f"嵌入向量生成失败: {e}") from e

332
api/app/core/memory/main.py Normal file
View File

@@ -0,0 +1,332 @@
"""
MemSci 记忆系统主入口 - 重构版本
该模块是重构后的记忆系统主入口,使用新的模块化架构。
旧版本入口app/core/memory/src/main.py已删除。
主要功能:
1. 协调整个知识提取流水线
2. 支持试运行模式和正常运行模式
3. 使用重构后的 storage_services 模块
4. 提供统一的配置管理和日志记录
作者Lance77
日期2025-11-22
"""
# 必须在最开始禁用 LangSmith 追踪,避免速率限制错误
import os
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_TRACING"] = "false"
import asyncio
import time
from datetime import datetime
from typing import Optional
from dotenv import load_dotenv
# 导入重构后的模块
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.message_models import ConversationMessage, ConversationContext, DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
# 导入数据加载函数
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
get_chunked_dialogs_with_preprocessing,
get_chunked_dialogs_from_preprocessed,
)
# 导入配置模块(而不是直接导入变量)
from app.core.memory.utils.config import definitions as config_defs
from app.core.logging_config import get_memory_logger, log_time
load_dotenv()
logger = get_memory_logger(__name__)
async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
"""
记忆系统主流程 - 重构版本
该函数是重构后的主入口,使用新的模块化架构。
Args:
dialogue_text: 输入的对话文本(可选,用于试运行模式)
is_pilot_run: 是否为试运行模式
- True: 试运行模式,不保存到 Neo4j
- False: 正常运行模式,保存到 Neo4j
工作流程:
1. 初始化客户端和配置
2. 加载或准备数据
3. 执行知识提取流水线
4. 保存结果(正常模式)或输出结果(试运行模式)
"""
print("=" * 60)
print("MemSci 知识提取流水线 - 重构版本")
print("=" * 60)
print(f"运行模式: {'试运行不保存到Neo4j' if is_pilot_run else '正常运行保存到Neo4j'}")
print("Using chunker strategy:", config_defs.SELECTED_CHUNKER_STRATEGY)
print("Using group ID:", config_defs.SELECTED_GROUP_ID)
print("Using model ID:", config_defs.SELECTED_LLM_ID)
print("Using embedding model ID:", config_defs.SELECTED_EMBEDDING_ID)
print("LANGFUSE_ENABLED:", config_defs.LANGFUSE_ENABLED)
print("AGENTA_ENABLED:", config_defs.AGENTA_ENABLED)
print("=" * 60)
# 初始化日志
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pipeline Run Started: {timestamp} ({'Pilot Run' if is_pilot_run else 'Normal Run'}) ===\n")
pipeline_start = time.time()
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# 获取 embedder 配置并转换为 RedBearModelConfig 对象
from app.core.models.base import RedBearModelConfig
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 加载或准备数据
logger.info("Loading data...")
logger.info(f"[MAIN] dialogue_text type={type(dialogue_text)}, length={len(dialogue_text) if dialogue_text else 0}, is_pilot_run={is_pilot_run}")
logger.info(f"[MAIN] dialogue_text preview: {repr(dialogue_text)[:200] if dialogue_text else 'None'}")
logger.info(f"[MAIN] Condition check: dialogue_text={bool(dialogue_text)}, isinstance={isinstance(dialogue_text, str) if dialogue_text else False}, strip={bool(dialogue_text.strip()) if dialogue_text and isinstance(dialogue_text, str) else False}")
step_start = time.time()
if dialogue_text and isinstance(dialogue_text, str) and dialogue_text.strip():
# 试运行模式:处理前端传入的对话文本
logger.info("[MAIN] ✓ Using frontend dialogue text (pilot run mode)")
import re
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
# 创建对话上下文和对话数据
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=config_defs.SELECTED_GROUP_ID,
user_id=config_defs.SELECTED_USER_ID,
apply_id=config_defs.SELECTED_APPLY_ID,
metadata={"source": "pilot_run", "input_type": "frontend_text"}
)
# 对前端传入的对话进行分块处理
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
llm_client=llm_client,
)
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
else:
# 正常运行模式:从 testdata.json 文件加载
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
logger.info("Loading data from testdata.json...")
test_data_path = os.path.join(
os.path.dirname(__file__), "data", "testdata.json"
)
if not os.path.exists(test_data_path):
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
group_id=config_defs.SELECTED_GROUP_ID,
user_id=config_defs.SELECTED_USER_ID,
apply_id=config_defs.SELECTED_APPLY_ID,
indices=config_defs.SELECTED_TEST_DATA_INDICES,
input_data_path=test_data_path,
llm_client=llm_client,
skip_cleaning=True,
)
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
# 从 runtime.json 加载配置(已经过数据库覆写)
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
logger.info(f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}")
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=is_pilot_run, # 传递试运行模式标志
)
# 解包 extraction_result tuple
# extraction_result 是一个包含 7 个元素的 tuple:
# (dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes,
# statement_chunk_edges, statement_entity_edges, entity_edges)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# 步骤 5: 保存结果或输出结果
if is_pilot_run:
logger.info("Pilot run mode: Skipping Neo4j save")
print("\n试运行模式:跳过 Neo4j 保存,流水线处理完成。")
print("提取结果已生成,可在相关输出中查看。")
else:
logger.info("Normal mode: Saving to Neo4j...")
step_start = time.time()
# 创建索引和约束
try:
from app.repositories.neo4j.create_indexes import (
create_fulltext_indexes,
create_unique_constraints,
)
await create_fulltext_indexes()
await create_unique_constraints()
logger.info("Successfully created indexes and constraints")
except Exception as e:
logger.error(f"Error creating indexes/constraints: {e}")
# 保存数据到 Neo4j
try:
from app.repositories.neo4j.graph_saver import (
save_dialog_and_statements_to_neo4j,
)
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
statement_chunk_edges=statement_chunk_edges,
statement_entity_edges=statement_entity_edges,
entity_edges=entity_edges,
connector=neo4j_connector,
)
if success:
logger.info("Successfully saved all data to Neo4j")
print("\n✓ 成功保存所有数据到 Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
print("\n⚠ 部分数据保存到 Neo4j 失败")
except Exception as e:
logger.error(f"Error saving to Neo4j: {e}", exc_info=True)
print(f"\n✗ 保存到 Neo4j 失败: {e}")
log_time("Neo4j Database Save", time.time() - step_start, log_file)
# 步骤 6: 生成记忆摘要(可选)
try:
logger.info("Generating memory summaries...")
step_start = time.time()
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
Memory_summary_generation,
)
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import (
add_memory_summary_statement_edges,
)
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
)
if not is_pilot_run:
# 保存记忆摘要到 Neo4j
ms_connector = Neo4jConnector()
try:
await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
await ms_connector.close()
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
except Exception as e:
logger.error(f"Pipeline execution failed: {e}", exc_info=True)
print(f"\n✗ 流水线执行失败: {e}")
raise
finally:
# 清理资源
try:
await neo4j_connector.close()
except Exception:
pass
# 记录总时间
total_time = time.time() - pipeline_start
log_time("TOTAL PIPELINE TIME", total_time, log_file)
# 添加完成标记
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
logger.info(f"Timing details saved to: {log_file}")
print("\n" + "=" * 60)
print(f"✓ 流水线执行完成")
print(f"✓ 总耗时: {total_time:.2f}")
print(f"✓ 详细日志: {log_file}")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,115 @@
"""Data models for the Memory module.
This package contains all Pydantic models used in the memory system,
including models for messages, dialogues, statements, entities, triplets,
graph nodes/edges, configurations, and deduplication decisions.
"""
# Base response models
from app.core.memory.models.base_response import RobustLLMResponse
# Configuration models
from app.core.memory.models.config_models import (
LLMConfig,
ChunkerConfig,
PruningConfig,
TemporalSearchParams,
)
# Deduplication models
from app.core.memory.models.dedup_models import (
EntityDedupDecision,
EntityDisambDecision,
)
# Graph models (nodes and edges)
from app.core.memory.models.graph_models import (
# Edges
Edge,
ChunkEdge,
ChunkEntityEdge,
ChunkDialogEdge,
StatementChunkEdge,
StatementEntityEdge,
EntityEntityEdge,
# Nodes
Node,
DialogueNode,
StatementNode,
ChunkNode,
ExtractedEntityNode,
MemorySummaryNode,
)
# Message and dialogue models
from app.core.memory.models.message_models import (
ConversationMessage,
TemporalValidityRange,
Statement,
ConversationContext,
Chunk,
DialogData,
)
# Triplet and entity models
from app.core.memory.models.triplet_models import (
Entity,
Triplet,
TripletExtractionResponse,
)
# Variable configuration models
from app.core.memory.models.variate_config import (
StatementExtractionConfig,
ForgettingEngineConfig,
TripletExtractionConfig,
TemporalExtractionConfig,
DedupConfig,
ExtractionPipelineConfig,
)
__all__ = [
# Base response
"RobustLLMResponse",
# Configuration
"LLMConfig",
"ChunkerConfig",
"PruningConfig",
"TemporalSearchParams",
# Deduplication
"EntityDedupDecision",
"EntityDisambDecision",
# Graph edges
"Edge",
"ChunkEdge",
"ChunkEntityEdge",
"ChunkDialogEdge",
"StatementChunkEdge",
"StatementEntityEdge",
"EntityEntityEdge",
# Graph nodes
"Node",
"DialogueNode",
"StatementNode",
"ChunkNode",
"ExtractedEntityNode",
"MemorySummaryNode",
# Messages and dialogues
"ConversationMessage",
"TemporalValidityRange",
"Statement",
"ConversationContext",
"Chunk",
"DialogData",
# Triplets and entities
"Entity",
"Triplet",
"TripletExtractionResponse",
# Variable configuration
"StatementExtractionConfig",
"ForgettingEngineConfig",
"TripletExtractionConfig",
"TemporalExtractionConfig",
"DedupConfig",
"ExtractionPipelineConfig",
]

View File

@@ -0,0 +1,59 @@
"""Base classes for LLM response models with common validators.
This module provides reusable base classes for Pydantic models that handle
common LLM response patterns and edge cases.
Classes:
RobustLLMResponse: Base class for LLM response models with robust validation
"""
from typing import Any
from pydantic import BaseModel, ConfigDict, model_validator
class RobustLLMResponse(BaseModel):
"""Base class for LLM response models with robust validation.
This base class provides:
- Automatic handling of list-wrapped responses (e.g., [{"field": "value"}])
- Ignoring extra fields from LLM output
- Validation on assignment
Usage:
class MyResponse(RobustLLMResponse):
field1: str
field2: int
"""
model_config = ConfigDict(
extra="ignore", # Allow extra fields to be ignored (more forgiving)
validate_assignment=True # Validate on assignment
)
@model_validator(mode='before')
@classmethod
def handle_list_input(cls, data: Any) -> Any:
"""Handle cases where LLM returns a list instead of a dict.
Some LLMs may wrap the response in a list like [{"field": "value"}].
This validator extracts the first item if that happens.
Args:
data: The input data from the LLM
Returns:
The unwrapped data (dict)
Raises:
ValueError: If the input is invalid (empty list, wrong type, etc.)
"""
if isinstance(data, list):
if len(data) == 0:
raise ValueError("Received empty list from LLM")
# Extract first item from list
data = data[0]
if not isinstance(data, dict):
raise ValueError(f"Expected dict or list, got {type(data).__name__}")
return data

View File

@@ -0,0 +1,93 @@
"""Configuration models for Memory module components.
This module contains Pydantic models for configuring various components
of the memory system including LLM, chunking, pruning, and search.
Classes:
LLMConfig: Configuration for LLM client
ChunkerConfig: Configuration for dialogue chunking
PruningConfig: Configuration for semantic pruning
TemporalSearchParams: Parameters for temporal search queries
"""
from typing import Optional
from pydantic import BaseModel, Field
class LLMConfig(BaseModel):
"""Configuration for Large Language Model client.
Attributes:
llm_name: The name of the LLM model to use (e.g., 'gpt-4', 'claude-3')
api_base: Optional base URL for the API endpoint
max_retries: Maximum number of retries for failed API calls (default: 3)
"""
llm_name: str = Field(..., description="The name of the LLM model to use.")
api_base: Optional[str] = Field(None, description="The base URL for the API endpoint.")
max_retries: Optional[int] = Field(3, ge=0, description="The maximum number of retries for API calls.")
class ChunkerConfig(BaseModel):
"""Configuration for dialogue chunking strategy.
Attributes:
chunker_strategy: Name of the chunking strategy (e.g., 'RecursiveChunker', 'SemanticChunker')
embedding_model: Name of the embedding model to use for semantic chunking
chunk_size: Maximum size of each chunk in characters (default: 2048)
threshold: Similarity threshold for semantic chunking (0-1, default: 0.8)
language: Language of the text (default: 'zh' for Chinese)
skip_window: Window size for skip-and-merge strategy (default: 0)
min_sentences: Minimum number of sentences per chunk (default: 1)
min_characters_per_chunk: Minimum characters per chunk (default: 24)
"""
chunker_strategy: str = Field(..., description="The name of the chunker strategy to use.")
embedding_model: str = Field(..., description="The name of the embedding model to use.")
chunk_size: Optional[int] = Field(2048, ge=0, description="The size of each chunk.")
threshold: Optional[float] = Field(0.8, ge=0, le=1, description="The threshold for similarity.")
language: Optional[str] = Field("zh", description="The language of the text.")
skip_window: Optional[int] = Field(0, ge=0, description="The window for skip-and-merge.")
min_sentences: Optional[int] = Field(1, ge=0, description="The minimum number of sentences in each chunk.")
min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.")
class PruningConfig(BaseModel):
"""Configuration for semantic pruning of dialogue content.
Attributes:
pruning_switch: Enable or disable semantic pruning
pruning_scene: Scene type for pruning ('education', 'online_service', 'outbound')
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
"""
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
pruning_scene: str = Field(
"education",
description="Scene for pruning: one of 'education', 'online_service', 'outbound'.",
)
pruning_threshold: float = Field(
0.5, ge=0.0, le=0.9,
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
class TemporalSearchParams(BaseModel):
"""Parameters for temporal search queries in the knowledge graph.
Attributes:
group_id: Group ID to filter search results (default: 'test')
apply_id: Application ID to filter search results
user_id: User ID to filter search results
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
end_date: End date for temporal filtering (format: 'YYYY-MM-DD')
valid_date: Date when memory should be valid (format: 'YYYY-MM-DD')
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
limit: Maximum number of results to return (default: 3)
"""
group_id: Optional[str] = Field("test", description="The group ID to filter the search.")
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
start_date: Optional[str] = Field(None, description="The start date for the search.")
end_date: Optional[str] = Field(None, description="The end date for the search.")
valid_date: Optional[str] = Field(None, description="The valid date for the search.")
invalid_date: Optional[str] = Field(None, description="The invalid date for the search.")
limit: int = Field(default=3, description="The maximum number of results to return.")

View File

@@ -0,0 +1,52 @@
"""Models for entity deduplication and disambiguation decisions.
This module contains Pydantic models for structured LLM responses
during entity deduplication and disambiguation processes.
Classes:
EntityDedupDecision: Decision model for entity deduplication
EntityDisambDecision: Decision model for entity disambiguation
"""
from typing import Optional
from pydantic import BaseModel, Field
class EntityDedupDecision(BaseModel):
"""Structured decision returned by LLM for entity deduplication.
This model represents the LLM's decision on whether two entities
refer to the same real-world entity and should be merged.
Attributes:
same_entity: Whether the two entities refer to the same real-world entity
confidence: Model confidence in the decision (0.0 to 1.0)
canonical_idx: Index of the canonical entity to keep when merging (0 or 1, -1 if not applicable)
reason: Brief rationale for the decision (1-3 sentences, kept for audit)
"""
same_entity: bool = Field(..., description="Two entities refer to the same entity")
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence of the decision")
canonical_idx: int = Field(..., description="Index of canonical entity among the pair: 0 or 1; -1 if not applicable")
reason: str = Field(..., description="Short rationale, 1-3 sentences")
class EntityDisambDecision(BaseModel):
"""Structured disambiguation decision for same-name but different-type entities.
This model represents the LLM's decision on whether two entities with
the same name but different types should be merged or kept separate.
Attributes:
should_merge: Whether the two entities should be merged despite type difference
confidence: Model confidence in the decision (0.0 to 1.0)
canonical_idx: Index of the canonical entity to keep when merging (0 or 1, -1 if not applicable)
block_pair: If True, this pair should be blocked from fuzzy/auto merges
suggested_type: Optional unified type to apply when should_merge is True
reason: Brief rationale for audit and analysis (1-3 sentences)
"""
should_merge: bool = Field(..., description="Merge the pair despite type difference")
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence of the decision")
canonical_idx: int = Field(..., description="Index of canonical entity among the pair: 0 or 1; -1 if not applicable")
block_pair: bool = Field(False, description="Block this pair from fuzzy or heuristic merges")
suggested_type: Optional[str] = Field(None, description="Unified entity type when merging")
reason: str = Field(..., description="Short rationale, 1-3 sentences")

View File

@@ -0,0 +1,304 @@
"""Graph models for Neo4j knowledge graph nodes and edges.
This module contains Pydantic models representing nodes and edges
in the Neo4j knowledge graph, including dialogues, statements,
chunks, entities, and their relationships.
Classes:
Edge: Base class for all graph edges
ChunkEdge: Edge connecting chunks
ChunkEntityEdge: Edge connecting chunks to entities
ChunkDialogEdge: Edge connecting chunks to dialogues
StatementChunkEdge: Edge connecting statements to chunks
StatementEntityEdge: Edge connecting statements to entities
EntityEntityEdge: Edge connecting related entities
Node: Base class for all graph nodes
DialogueNode: Node representing a dialogue
StatementNode: Node representing a statement
ChunkNode: Node representing a conversation chunk
ExtractedEntityNode: Node representing an extracted entity
MemorySummaryNode: Node representing a memory summary
"""
from uuid import uuid4
from datetime import datetime, timezone
from typing import List, Optional
from pydantic import BaseModel, Field, field_validator
import re
from app.core.memory.utils.data.ontology import TemporalInfo
def parse_historical_datetime(v):
"""支持任意年份的日期解析包括历史日期如公元755年
Python datetime 支持公元1年到9999年的日期
此函数手动解析 ISO 8601 格式的日期字符串支持1-4位年份
Args:
v: 日期值(可以是 None、datetime 对象或字符串)
Returns:
datetime 对象或 None
"""
if v is None or isinstance(v, datetime):
return v
if isinstance(v, str):
# 匹配 ISO 8601 格式YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM]
# 支持1-4位年份
pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?'
match = re.match(pattern, v)
if match:
try:
year = int(match.group(1))
month = int(match.group(2))
day = int(match.group(3))
hour = int(match.group(4)) if match.group(4) else 0
minute = int(match.group(5)) if match.group(5) else 0
second = int(match.group(6)) if match.group(6) else 0
microsecond = 0
# 处理微秒
if match.group(7):
# 补齐或截断到6位
us_str = match.group(7).ljust(6, '0')[:6]
microsecond = int(us_str)
# 处理时区
tzinfo = None
if 'Z' in v or match.group(8):
tzinfo = timezone.utc
# 创建 datetime 对象
return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo)
except (ValueError, OverflowError):
# 日期值无效如月份13、日期32等
return None
# 如果不匹配模式,尝试使用 fromisoformat用于标准格式
try:
return datetime.fromisoformat(v.replace('Z', '+00:00'))
except Exception:
return None
return v
class Edge(BaseModel):
"""Base class for all graph edges in the knowledge graph.
Attributes:
id: Unique identifier for the edge
source: ID of the source node
target: ID of the target node
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
run_id: Unique identifier for the pipeline run that created this edge
created_at: Timestamp when the edge was created (system perspective)
expired_at: Optional timestamp when the edge expires (system perspective)
"""
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
source: str = Field(..., description="The ID of the source node.")
target: str = Field(..., description="The ID of the target node.")
group_id: str = Field(..., description="The group ID of the edge.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
class ChunkEdge(Edge):
"""Edge connecting two chunks in sequence."""
pass
class ChunkEntityEdge(Edge):
"""Edge connecting a chunk to an entity mentioned in it."""
pass
class ChunkDialogEdge(Edge):
"""Edge connecting a chunk to its parent dialog.
Attributes:
sequence_number: Order of this chunk within the dialog
"""
sequence_number: int = Field(..., description="Order of this chunk within the dialog")
class StatementChunkEdge(Edge):
"""Edge connecting a statement to its parent chunk."""
pass
class StatementEntityEdge(Edge):
"""Edge connecting a statement to entities extracted from it.
Attributes:
connect_strength: Classification of connection strength ('Strong' or 'Weak')
"""
connect_strength: str = Field(..., description="Strong VS Weak about this statement-entity edge")
class EntityEntityEdge(Edge):
"""Edge connecting related entities (from triplet relationships).
Attributes:
relation_type: Type of relationship as defined in ontology
relation_value: Optional value of the relation
statement: The statement text where this relationship was found
source_statement_id: ID of the statement where this relationship was extracted
valid_at: Optional start date of temporal validity
invalid_at: Optional end date of temporal validity
"""
relation_type: str = Field(..., description="Relation type as defined in ontology")
relation_value: Optional[str] = Field(None, description="Value of the relation")
statement: str = Field(..., description='The statement of the edge.')
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
@field_validator('valid_at', 'invalid_at', mode='before')
@classmethod
def validate_datetime(cls, v):
"""使用通用的历史日期解析函数"""
return parse_historical_datetime(v)
class Node(BaseModel):
"""Base class for all graph nodes in the knowledge graph.
Attributes:
id: Unique identifier for the node
name: Name of the node
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
run_id: Unique identifier for the pipeline run that created this node
created_at: Timestamp when the node was created (system perspective)
expired_at: Optional timestamp when the node expires (system perspective)
"""
id: str = Field(..., description="The unique identifier for the node.")
name: str = Field(..., description="The name of the node.")
group_id: str = Field(..., description="The group ID of the node.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
class DialogueNode(Node):
"""Node representing a dialogue in the knowledge graph.
Attributes:
ref_id: Reference identifier linking to external dialog system
content: Full dialogue content as text
dialog_embedding: Optional embedding vector for the entire dialogue
config_id: Configuration ID used to process this dialogue
"""
ref_id: str = Field(..., description="Reference identifier of the dialog")
content: str = Field(..., description="Dialogue content")
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
class StatementNode(Node):
"""Node representing a statement extracted from dialogue.
Attributes:
chunk_id: ID of the parent chunk this statement belongs to
stmt_type: Type of the statement (from ontology)
temporal_info: Temporal information extracted from the statement
statement: The actual statement text content
connect_strength: Classification of connection strength ('Strong' or 'Weak')
valid_at: Optional start date of temporal validity
invalid_at: Optional end date of temporal validity
statement_embedding: Optional embedding vector for the statement
chunk_embedding: Optional embedding vector for the parent chunk
config_id: Configuration ID used to process this statement
"""
chunk_id: str = Field(..., description="ID of the parent chunk")
stmt_type: str = Field(..., description="Type of the statement")
temporal_info: TemporalInfo = Field(..., description="Temporal information")
statement: str = Field(..., description="The statement text content")
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
@field_validator('valid_at', 'invalid_at', mode='before')
@classmethod
def validate_datetime(cls, v):
"""使用通用的历史日期解析函数"""
return parse_historical_datetime(v)
class ChunkNode(Node):
"""Node representing a chunk of conversation in the knowledge graph.
Attributes:
dialog_id: ID of the parent dialog
content: The text content of the chunk
chunk_embedding: Optional embedding vector for the chunk
sequence_number: Order of this chunk within the dialog
metadata: Additional chunk metadata as key-value pairs
"""
dialog_id: str = Field(..., description="ID of the parent dialog")
content: str = Field(..., description="The text content of the chunk")
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
sequence_number: int = Field(..., description="Order of this chunk within the dialog")
metadata: dict = Field(default_factory=dict, description="Additional chunk metadata")
class ExtractedEntityNode(Node):
"""Node representing an extracted entity in the knowledge graph.
Attributes:
entity_idx: Unique numeric identifier for the entity
statement_id: ID of the statement this entity was extracted from
entity_type: Type/category of the entity
description: Textual description of the entity
aliases: Optional list of alternative names for the entity
name_embedding: Optional embedding vector for the entity name
fact_summary: Summary of facts about this entity
connect_strength: Classification of connection strength ('Strong' or 'Weak')
config_id: Configuration ID used to process this entity
"""
entity_idx: int = Field(..., description="Unique identifier for the entity")
statement_id: str = Field(..., description="Statement this entity was extracted from")
entity_type: str = Field(..., description="Type of the entity")
description: str = Field(..., description="Entity description")
aliases: Optional[List[str]] = Field(default_factory=list, description="Entity aliases")
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
fact_summary: str = Field(..., description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
class MemorySummaryNode(Node):
"""Node representing a memory summary with vector embedding.
Attributes:
summary_id: Unique identifier for the summary
dialog_id: ID of the parent dialog
chunk_ids: List of chunk IDs used to generate this summary
content: Summary text content
summary_embedding: Optional embedding vector for the summary
metadata: Additional metadata for the summary
config_id: Configuration ID used to process this summary
"""
summary_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the summary")
dialog_id: str = Field(..., description="ID of the parent dialog")
chunk_ids: List[str] = Field(default_factory=list, description="List of chunk IDs used in the summary")
content: str = Field(..., description="Summary text content")
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")

View File

@@ -0,0 +1,247 @@
"""Models for dialogue messages, conversations, and statements.
This module contains Pydantic models for representing dialogue data,
including messages, conversation context, chunks, and statements.
Classes:
ConversationMessage: Single message in a conversation
TemporalValidityRange: Temporal validity range for statements
Statement: Statement extracted from dialogue with metadata
ConversationContext: Full conversation history
Chunk: Chunk of conversation text
DialogData: Complete dialogue data structure
"""
from typing import List, Dict, Any, Optional
from pydantic import BaseModel, Field
from uuid import uuid4
from datetime import datetime
from app.core.memory.utils.data.ontology import StatementType, TemporalInfo, RelevenceInfo
from app.core.memory.models.triplet_models import TripletExtractionResponse, Triplet
class ConversationMessage(BaseModel):
"""Represents a single message in a conversation.
Attributes:
role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant)
msg: Text content of the message
"""
role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').")
msg: str = Field(..., description="The text content of the message.")
class TemporalValidityRange(BaseModel):
"""Represents the temporal validity range of a statement.
Attributes:
valid_at: Start date of validity in 'YYYY-MM-DD' format (None if not specified)
invalid_at: End date of validity in 'YYYY-MM-DD' format (None if not specified)
"""
valid_at: Optional[str] = Field(
None,
description="The start date of the statement's validity, in 'YYYY-MM-DD' format or 'None'.",
)
invalid_at: Optional[str] = Field(
None,
description="The end date of the statement's validity, in 'YYYY-MM-DD' format or 'None'.",
)
class Statement(BaseModel):
"""Represents a statement extracted from dialogue with metadata.
Attributes:
id: Unique identifier for the statement
chunk_id: ID of the parent chunk this statement belongs to
group_id: Optional group ID for multi-tenancy
statement: The actual statement text content
statement_embedding: Optional embedding vector for the statement
stmt_type: Type of the statement (from ontology)
temporal_info: Temporal information extracted from the statement
relevence_info: Relevance classification (RELEVANT or IRRELEVANT)
connect_strength: Optional connection strength ('Strong' or 'Weak')
temporal_validity: Optional temporal validity range
triplet_extraction_info: Optional triplet extraction results
"""
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
statement: str = Field(..., description="The text content of the statement.")
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
stmt_type: StatementType = Field(..., description="The type of the statement.")
temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.")
relevence_info: RelevenceInfo = Field(RelevenceInfo.RELEVANT, description="The relevence information of the statement.")
connect_strength: Optional[str] = Field(None, description="Strong VS Weak about this entity")
temporal_validity: Optional[TemporalValidityRange] = Field(
None, description="The temporal validity range of the statement."
)
triplet_extraction_info: Optional[TripletExtractionResponse] = Field(
None, description="The triplet extraction information of the statement."
)
class ConversationContext(BaseModel):
"""Represents the full conversation history.
Attributes:
msgs: List of messages in the conversation
Properties:
content: Formatted string representation of the conversation
"""
msgs: List[ConversationMessage] = Field(..., description="A list of messages in the conversation.")
@property
def content(self) -> str:
"""Get the content of the conversation as a formatted string.
Returns:
String with format "role: message" for each message, joined by newlines
"""
return "\n".join([f"{msg.role}: {msg.msg}" for msg in self.msgs])
class Chunk(BaseModel):
"""A chunk of text from the conversation context.
Attributes:
id: Unique identifier for the chunk
text: List of messages in the chunk
content: The content of the chunk as a formatted string
statements: List of statements extracted from this chunk
chunk_embedding: Optional embedding vector for the chunk
metadata: Additional metadata as key-value pairs
"""
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.")
text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.")
content: str = Field(..., description="The content of the chunk as a string.")
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
@classmethod
def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None):
"""Create a chunk from a list of messages.
Args:
messages: List of conversation messages
metadata: Optional metadata dictionary
Returns:
Chunk instance with formatted content
"""
if metadata is None:
metadata = {}
# Generate content from messages
content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages])
return cls(text=messages, content=content, metadata=metadata)
class DialogData(BaseModel):
"""Represents the complete data structure for a dialog record.
Attributes:
id: Unique identifier for the dialog
context: Full conversation context
dialog_embedding: Optional embedding vector for the entire dialog
ref_id: Reference ID linking to external dialog system
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
created_at: Timestamp when the dialog was created
expired_at: Timestamp when the dialog expires (default: far future)
metadata: Additional metadata as key-value pairs
chunks: List of chunks from the conversation
config_id: Configuration ID used to process this dialog
Properties:
content: Formatted string representation of the dialog
"""
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the dialog.")
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
group_id: str = Field(default=..., description="Group ID of dialogue data")
user_id: str = Field(..., description="USER ID of dialogue data")
apply_id: str = Field(..., description="APPLY ID of dialogue data")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the dialog.")
chunks: List[Chunk] = Field(default_factory=list, description="A list of chunks from the conversation context.")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialog (integer or string)")
@property
def content(self) -> str:
"""Get the content of the dialog as a formatted string.
Returns:
String representation of the conversation context
"""
return self.context.content
def get_statement_chunk(self, statement_id: str) -> Optional[Chunk]:
"""Find the chunk containing a specific statement.
Args:
statement_id: ID of the statement to find
Returns:
Chunk containing the statement, or None if not found
"""
for chunk in self.chunks:
for statement in chunk.statements:
if statement.id == statement_id:
return chunk
return None
def get_all_statements(self) -> List[Statement]:
"""Get all statements from all chunks.
Returns:
List of all statements in the dialog
"""
all_statements = []
for chunk in self.chunks:
all_statements.extend(chunk.statements)
return all_statements
def get_statement_by_id(self, statement_id: str) -> Optional[Statement]:
"""Find a specific statement by its ID.
Args:
statement_id: ID of the statement to find
Returns:
Statement with the given ID, or None if not found
"""
for chunk in self.chunks:
for statement in chunk.statements:
if statement.id == statement_id:
return statement
return None
def get_triplets_for_statement(self, statement_id: str) -> List[Triplet]:
"""Get all triplets extracted from a specific statement.
Args:
statement_id: ID of the statement
Returns:
List of triplets from the statement, or empty list if none found
"""
statement = self.get_statement_by_id(statement_id)
if statement and statement.triplet_extraction_info:
return statement.triplet_extraction_info.triplets
return []
def assign_group_id_to_statements(self) -> None:
"""Assign this dialog's group_id to all statements in all chunks.
This method updates statements that don't have a group_id set.
"""
for chunk in self.chunks:
for statement in chunk.statements:
if statement.group_id is None:
statement.group_id = self.group_id

View File

@@ -0,0 +1,85 @@
"""Models for knowledge triplets and entities.
This module contains Pydantic models for representing extracted knowledge
in the form of entities and triplets (subject-predicate-object relationships).
Classes:
Entity: Represents an extracted entity
Triplet: Represents a knowledge triplet (subject-predicate-object)
TripletExtractionResponse: Response model containing extracted triplets and entities
"""
from typing import List, Optional
from pydantic import BaseModel, Field, ConfigDict
from uuid import uuid4
class Entity(BaseModel):
"""Represents an extracted entity from dialogue.
Attributes:
id: Unique string identifier for the entity
entity_idx: Numeric index for the entity
name: Name of the entity
name_embedding: Optional embedding vector for the entity name
type: Type/category of the entity (e.g., 'Person', 'Organization')
description: Textual description of the entity
Config:
extra: Ignore extra fields from LLM output
"""
model_config = ConfigDict(extra='ignore')
id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the entity.")
entity_idx: int = Field(..., description="Unique identifier for the entity")
name: str = Field(..., description="Name of the entity")
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
type: str = Field(..., description="Type/category of the entity")
description: str = Field(..., description="Description of the entity")
class Triplet(BaseModel):
"""Represents an extracted knowledge triplet (subject-predicate-object).
A triplet represents a relationship between two entities, forming
the basic unit of knowledge in the knowledge graph.
Attributes:
id: Unique string identifier for the triplet
statement_id: Optional ID of the parent statement (set programmatically)
subject_name: Name of the subject entity
subject_id: Numeric ID of the subject entity
predicate: Relationship/predicate between subject and object
object_name: Name of the object entity
object_id: Numeric ID of the object entity
value: Optional additional value or context for the relationship
Config:
extra: Ignore extra fields from LLM output
"""
model_config = ConfigDict(extra='ignore')
id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the triplet.")
statement_id: Optional[str] = Field(None, description="ID of the parent statement this triplet was extracted from.")
subject_name: str = Field(..., description="Name of the subject entity")
subject_id: int = Field(..., description="ID of the subject entity")
predicate: str = Field(..., description="Relationship/predicate between subject and object")
object_name: str = Field(..., description="Name of the object entity")
object_id: int = Field(..., description="ID of the object entity")
value: Optional[str] = Field(None, description="Additional value or context")
class TripletExtractionResponse(BaseModel):
"""Response model for triplet extraction from LLM.
This model represents the structured output from the LLM when
extracting knowledge triplets and entities from statements.
Attributes:
triplets: List of extracted knowledge triplets
entities: List of extracted entities
Config:
extra: Ignore extra fields from LLM output
"""
model_config = ConfigDict(extra='ignore')
triplets: List[Triplet] = Field(default_factory=list, description="List of extracted triplets")
entities: List[Entity] = Field(default_factory=list, description="List of extracted entities")

View File

@@ -0,0 +1,151 @@
"""Variable configuration models for extraction pipeline components.
This module contains Pydantic models for configuring various aspects
of the extraction pipeline, including statement extraction, triplet extraction,
temporal extraction, deduplication, and forgetting mechanisms.
Classes:
StatementExtractionConfig: Configuration for statement extraction
ForgettingEngineConfig: Configuration for forgetting engine
TripletExtractionConfig: Configuration for triplet extraction
TemporalExtractionConfig: Configuration for temporal extraction
DedupConfig: Configuration for entity deduplication
ExtractionPipelineConfig: Combined configuration for entire pipeline
"""
from typing import Optional
from pydantic import BaseModel, Field
class StatementExtractionConfig(BaseModel):
"""Configuration for statement extraction behavior.
Attributes:
statement_granularity: Granularity level (1-3):
- 1: Split sentences into different statements
- 2: Sentence-level statements
- 3: Combine sentences, shorten long statements
temperature: LLM temperature for statement extraction (0-2, default: 0.1)
include_dialogue_context: Whether to include full dialogue context
max_dialogue_context_chars: Maximum characters from dialogue context (default: 2000)
"""
statement_granularity: Optional[int] = Field(None, ge=1, le=3, description="Granularity of statements to extract, level 1 to 3")
temperature: Optional[float] = Field(0.1, ge=0, le=2, description="LLM temperature for statement extraction")
include_dialogue_context: bool = Field(True, description="Whether to include full dialogue context in extraction")
max_dialogue_context_chars: Optional[int] = Field(2000, ge=100, description="Maximum number of characters to include from dialogue context")
class ForgettingEngineConfig(BaseModel):
"""Configuration for the forgetting engine.
The forgetting engine implements a memory decay mechanism based on
time and memory strength parameters.
Attributes:
offset: Minimum retention level (0-1, prevents complete forgetting, default: 0.1)
lambda_time: Lambda parameter controlling time decay effect (default: 0.1)
lambda_mem: Lambda parameter controlling memory strength effect (default: 1.0)
"""
offset: float = Field(0.1, ge=0.0, le=1.0, description="Minimum retention level (prevents complete forgetting).")
lambda_time: float = Field(0.1, gt=0.0, description="Lambda parameter controlling time effect.")
lambda_mem: float = Field(1.0, gt=0.0, description="Lambda parameter controlling memory strength effect.")
class TripletExtractionConfig(BaseModel):
"""Configuration for triplet extraction behavior.
Attributes:
temperature: LLM temperature for triplet extraction (0-2, default: 0.1)
enable_entity_normalization: Whether to normalize entity names (default: True)
confidence_threshold: Minimum confidence for extracted triplets (0-1, default: 0.7)
"""
temperature: Optional[float] = Field(0.1, ge=0, le=2, description="LLM temperature for triplet extraction")
enable_entity_normalization: bool = Field(True, description="Whether to normalize entity names")
confidence_threshold: Optional[float] = Field(0.7, ge=0, le=1, description="Minimum confidence threshold for extracted triplets")
class TemporalExtractionConfig(BaseModel):
"""Configuration for temporal extraction behavior.
Attributes:
temperature: LLM temperature for temporal extraction (0-2, default: 0.1)
"""
temperature: Optional[float] = Field(0.1, ge=0, le=2, description="LLM temperature for temporal extraction")
class DedupConfig(BaseModel):
"""Configuration for entity deduplication behavior.
This configuration controls the multi-stage deduplication process,
including fuzzy matching, LLM-based deduplication, and disambiguation.
Attributes:
enable_llm_dedup_blockwise: Enable blockwise LLM-driven deduplication (default: False)
enable_llm_disambiguation: Enable LLM disambiguation for same-name different-type entities (default: False)
enable_llm_fallback_only_on_borderline: Only trigger LLM when borderline pairs exist (default: True)
fuzzy_name_threshold_strict: Strict threshold for name similarity (0-1, default: 0.90)
fuzzy_type_threshold_strict: Strict threshold for type similarity (0-1, default: 0.75)
fuzzy_overall_threshold: Overall similarity threshold to merge (0-1, default: 0.82)
fuzzy_unknown_type_name_threshold: Name threshold when entity type is UNKNOWN (0-1, default: 0.92)
fuzzy_unknown_type_type_threshold: Type threshold when entity type is UNKNOWN (0-1, default: 0.50)
name_weight: Weight of name similarity in overall score (0-1, default: 0.50)
desc_weight: Weight of description similarity in overall score (0-1, default: 0.30)
type_weight: Weight of type similarity in overall score (0-1, default: 0.20)
context_bonus: Bonus when entities co-occur in same statements (0-0.2, default: 0.03)
llm_fallback_floor: Lower bound for borderline score (0-1, default: 0.76)
llm_fallback_ceiling: Upper bound for borderline score (0-1, default: 0.82)
llm_block_size: Entities per block for LLM dedup (1-500, default: 50)
llm_block_concurrency: Concurrent blocks processed by LLM (1-64, default: 4)
llm_pair_concurrency: Concurrent pairwise decisions per block (1-64, default: 4)
llm_max_rounds: Maximum LLM iterative dedup rounds (1-10, default: 3)
"""
# LLM deduplication toggles
enable_llm_dedup_blockwise: bool = Field(False, description="Toggle blockwise LLM-driven deduplication")
enable_llm_disambiguation: bool = Field(False, description="Toggle LLM-driven disambiguation for same-name different-type entities")
enable_llm_fallback_only_on_borderline: bool = Field(True, description="Trigger LLM dedup only when borderline pairs are detected in fuzzy stage")
# Fuzzy match thresholds
fuzzy_name_threshold_strict: float = Field(0.90, ge=0, le=1, description="Strict threshold for name similarity")
fuzzy_type_threshold_strict: float = Field(0.75, ge=0, le=1, description="Strict threshold for type similarity")
fuzzy_overall_threshold: float = Field(0.82, ge=0, le=1, description="Overall similarity threshold to merge")
# Specialized thresholds when type is UNKNOWN
fuzzy_unknown_type_name_threshold: float = Field(0.92, ge=0, le=1, description="Name threshold when any entity type is UNKNOWN")
fuzzy_unknown_type_type_threshold: float = Field(0.50, ge=0, le=1, description="Type threshold when any entity type is UNKNOWN")
# Weighted scoring components for overall similarity
name_weight: float = Field(0.50, ge=0, le=1, description="Weight of name similarity in overall score")
desc_weight: float = Field(0.30, ge=0, le=1, description="Weight of description similarity in overall score")
type_weight: float = Field(0.20, ge=0, le=1, description="Weight of type similarity in overall score")
context_bonus: float = Field(0.03, ge=0, le=0.2, description="Bonus added to score when entities co-occur in same statements")
# Borderline range for LLM fallback triggering
llm_fallback_floor: float = Field(0.76, ge=0, le=1, description="Lower bound of overall score to consider as borderline for LLM fallback")
llm_fallback_ceiling: float = Field(0.82, ge=0, le=1, description="Upper bound (below merge threshold) of overall score for LLM fallback")
# LLM iterative dedup parameters
llm_block_size: int = Field(50, ge=1, le=500, description="Entities per block for LLM dedup")
llm_block_concurrency: int = Field(4, ge=1, le=64, description="Concurrent blocks processed by LLM")
llm_pair_concurrency: int = Field(4, ge=1, le=64, description="Concurrent pairwise decisions per block")
llm_max_rounds: int = Field(3, ge=1, le=10, description="Maximum LLM iterative dedup rounds")
class ExtractionPipelineConfig(BaseModel):
"""Configuration for the entire extraction pipeline.
This model combines all configuration components for the complete
extraction pipeline, including statement extraction, triplet extraction,
temporal extraction, deduplication, and forgetting mechanisms.
Attributes:
statement_extraction: Configuration for statement extraction
triplet_extraction: Configuration for triplet extraction
temporal_extraction: Configuration for temporal extraction
deduplication: Configuration for entity deduplication
forgetting_engine: Configuration for forgetting engine
"""
statement_extraction: StatementExtractionConfig = Field(default_factory=StatementExtractionConfig)
triplet_extraction: TripletExtractionConfig = Field(default_factory=TripletExtractionConfig)
temporal_extraction: TemporalExtractionConfig = Field(default_factory=TemporalExtractionConfig)
deduplication: DedupConfig = Field(default_factory=DedupConfig)
forgetting_engine: ForgettingEngineConfig = Field(default_factory=ForgettingEngineConfig)

View File

View File

@@ -0,0 +1,330 @@
from typing import Any, List
import re
import os
import asyncio
import json
import numpy as np
# Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from chonkie import (
SemanticChunker,
RecursiveChunker,
RecursiveRules,
LateChunker,
NeuralChunker,
SentenceChunker,
TokenChunker,
)
from app.core.memory.models.config_models import ChunkerConfig
from app.core.memory.models.message_models import DialogData, Chunk
try:
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
except Exception:
# 在测试或无可用依赖(如 langfuse环境下允许惰性导入
OpenAIClient = Any
class LLMChunker:
"""基于LLM的智能分块策略"""
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
self.llm_client = llm_client
self.chunk_size = chunk_size
async def __call__(self, text: str) -> List[Any]:
# 使用LLM分析文本结构并进行智能分块
prompt = f"""
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
请以JSON格式返回结果包含chunks数组每个chunk有text字段。
文本内容:
{text[:5000]}
"""
messages = [
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
{"role": "user", "content": prompt}
]
try:
# 使用异步的 achat 方法
if hasattr(self.llm_client, 'achat'):
response = await self.llm_client.achat(messages)
else:
# 如果没有异步方法,使用同步方法并转换为异步
response = await asyncio.to_thread(self.llm_client.chat, messages)
# 检查响应格式并提取内容
if hasattr(response, 'choices') and len(response.choices) > 0:
content = response.choices[0].message.content
elif hasattr(response, 'content'):
content = response.content
else:
content = str(response)
# 解析LLM响应
if "```json" in content:
json_str = content.split("```json")[1].split("```")[0].strip()
elif "```" in content:
json_str = content.split("```")[1].split("```")[0].strip()
else:
json_str = content
result = json.loads(json_str)
class SimpleChunk:
def __init__(self, text, index):
self.text = text
self.start_index = index * 100 # 近似位置
self.end_index = (index + 1) * 100
return [SimpleChunk(chunk["text"], i) for i, chunk in enumerate(result.get("chunks", []))]
except Exception as e:
print(f"LLM分块失败: {e}")
# 失败时返回空列表,外层会处理回退方案
return []
class HybridChunker:
"""混合分块策略:先按结构分块,再按语义合并"""
def __init__(self, semantic_threshold: float = 0.8, base_chunk_size: int = 300):
self.semantic_threshold = semantic_threshold
self.base_chunk_size = base_chunk_size
self.base_chunker = TokenChunker(tokenizer="character", chunk_size=base_chunk_size)
self.semantic_chunker = SemanticChunker(threshold=semantic_threshold)
def __call__(self, text: str) -> List[Any]:
# 先用基础分块
base_chunks = self.base_chunker(text)
# 如果文本不长,直接返回基础分块
if len(base_chunks) <= 3:
return base_chunks
# 对基础分块进行语义合并
combined_text = " ".join([chunk.text for chunk in base_chunks])
return self.semantic_chunker(combined_text)
class ChunkerClient:
def __init__(self, chunker_config: ChunkerConfig, llm_client: OpenAIClient = None):
self.chunker_config = chunker_config
self.embedding_model = chunker_config.embedding_model
self.chunk_size = chunker_config.chunk_size
self.threshold = chunker_config.threshold
self.language = chunker_config.language
self.skip_window = chunker_config.skip_window
self.min_sentences = chunker_config.min_sentences
self.min_characters_per_chunk = chunker_config.min_characters_per_chunk
self.llm_client = llm_client
# 可选参数(从配置中安全获取,提供默认值)
self.chunk_overlap = getattr(chunker_config, 'chunk_overlap', 0)
self.min_sentences_per_chunk = getattr(chunker_config, 'min_sentences_per_chunk', 1)
self.min_characters_per_sentence = getattr(chunker_config, 'min_characters_per_sentence', 12)
self.delim = getattr(chunker_config, 'delim', [".", "!", "?", "\n"])
self.include_delim = getattr(chunker_config, 'include_delim', "prev")
self.tokenizer_or_token_counter = getattr(chunker_config, 'tokenizer_or_token_counter', "character")
# 初始化具体分块器策略
if chunker_config.chunker_strategy == "TokenChunker":
self.chunker = TokenChunker(
tokenizer=self.tokenizer_or_token_counter,
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
)
elif chunker_config.chunker_strategy == "SemanticChunker":
self.chunker = SemanticChunker(
embedding_model=self.embedding_model,
threshold=self.threshold,
chunk_size=self.chunk_size,
min_sentences=self.min_sentences,
)
elif chunker_config.chunker_strategy == "RecursiveChunker":
self.chunker = RecursiveChunker(
rules=RecursiveRules(),
min_characters_per_chunk=self.min_characters_per_chunk or 50,
chunk_size=self.chunk_size,
)
elif chunker_config.chunker_strategy == "LateChunker":
self.chunker = LateChunker(
embedding_model=self.embedding_model,
chunk_size=self.chunk_size,
rules=RecursiveRules(),
min_characters_per_chunk=self.min_characters_per_chunk,
)
elif chunker_config.chunker_strategy == "NeuralChunker":
self.chunker = NeuralChunker(
model=self.embedding_model,
min_characters_per_chunk=self.min_characters_per_chunk,
)
elif chunker_config.chunker_strategy == "LLMChunker":
if not llm_client:
raise ValueError("LLMChunker requires an LLM client")
self.chunker = LLMChunker(llm_client, self.chunk_size)
elif chunker_config.chunker_strategy == "HybridChunker":
self.chunker = HybridChunker(
semantic_threshold=self.threshold,
base_chunk_size=self.chunk_size,
)
elif chunker_config.chunker_strategy == "SentenceChunker":
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
# 为了兼容不同版本,这里仅传递广泛支持的参数
self.chunker = SentenceChunker(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
min_sentences_per_chunk=self.min_sentences_per_chunk,
min_characters_per_sentence=self.min_characters_per_sentence,
delim=self.delim,
include_delim=self.include_delim,
)
else:
raise ValueError(f"Unknown chunker strategy: {chunker_config.chunker_strategy}")
async def generate_chunks(self, dialogue: DialogData):
"""
生成分块,支持异步操作
"""
try:
# 预处理文本:确保对话标记格式统一
content = dialogue.content
content = content.replace('AI', 'AI:').replace('用户:', '用户:') # 统一冒号
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
# 同步分块器
chunks = self.chunker(content)
else:
# 异步分块器如LLMChunker
chunks = await self.chunker(content)
# 过滤空块和过小的块
valid_chunks = []
for c in chunks:
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
valid_chunks.append(c)
dialogue.chunks = [
Chunk(
content=c.text if hasattr(c, 'text') else str(c),
metadata={
"start_index": getattr(c, "start_index", None),
"end_index": getattr(c, "end_index", None),
"chunker_strategy": self.chunker_config.chunker_strategy,
},
)
for c in valid_chunks
]
return dialogue
except Exception as e:
print(f"分块失败: {e}")
# 改进的后备方案:尝试按对话回合分割
try:
# 简单的按对话分割
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
class SimpleChunk:
def __init__(self, text, start_index, end_index):
self.text = text
self.start_index = start_index
self.end_index = end_index
chunks = []
current_chunk = ""
current_start = 0
for match in matches:
speaker, ct = match[0], match[1].strip()
turn_text = f"{speaker} {ct}"
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
current_chunk = turn_text
current_start = dialogue.content.find(turn_text, current_start)
else:
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
dialogue.chunks = [
Chunk(
content=c.text,
metadata={
"start_index": c.start_index,
"end_index": c.end_index,
"chunker_strategy": "DialogueTurnFallback",
},
)
for c in chunks
]
except Exception:
# 最后的手段:单一大块
dialogue.chunks = [Chunk(
content=dialogue.content,
metadata={"chunker_strategy": "SingleChunkFallback"},
)]
return dialogue
def evaluate_chunking(self, dialogue: DialogData) -> dict:
"""
评估分块质量
"""
if not getattr(dialogue, 'chunks', None):
return {}
chunks = dialogue.chunks
total_chars = sum(len(chunk.content) for chunk in chunks)
avg_chunk_size = total_chars / len(chunks)
# 计算各种指标
chunk_sizes = [len(chunk.content) for chunk in chunks]
metrics = {
"strategy": self.chunker_config.chunker_strategy,
"num_chunks": len(chunks),
"total_characters": total_chars,
"avg_chunk_size": avg_chunk_size,
"min_chunk_size": min(chunk_sizes),
"max_chunk_size": max(chunk_sizes),
"chunk_size_std": np.std(chunk_sizes) if len(chunk_sizes) > 1 else 0,
"coverage_ratio": total_chars / len(dialogue.content) if dialogue.content else 0,
}
return metrics
def save_chunking_results(self, dialogue: DialogData, output_path: str):
"""
保存分块结果到文件,文件名包含策略名称
"""
strategy_name = self.chunker_config.chunker_strategy
# 在文件名中添加策略名称
base_name, ext = os.path.splitext(output_path)
strategy_output_path = f"{base_name}_{strategy_name}{ext}"
with open(strategy_output_path, 'w', encoding='utf-8') as f:
f.write(f"=== Chunking Strategy: {strategy_name} ===\n")
f.write(f"Total chunks: {len(dialogue.chunks)}\n")
f.write(f"Total characters: {sum(len(chunk.content) for chunk in dialogue.chunks)}\n")
f.write("=" * 60 + "\n\n")
for i, chunk in enumerate(dialogue.chunks):
f.write(f"Chunk {i+1}:\n")
f.write(f"Size: {len(chunk.content)} characters\n")
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")
f.write(f"Content: {chunk.content}\n")
f.write("-" * 40 + "\n\n")
print(f"Chunking results saved to: {strategy_output_path}")
return strategy_output_path

View File

@@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from typing import List
from app.core.models.base import RedBearModelConfig
class EmbedderClient(ABC):
def __init__(self, model_config: RedBearModelConfig):
self.config = model_config
self.model_name = model_config.model_name
self.provider = model_config.provider
self.api_key = model_config.api_key
self.base_url = model_config.base_url
self.max_retries = model_config.max_retries
# self.dimension = model_config.dimension
@abstractmethod
async def response(
self,
messages: List[str],
) -> List[str]:
pass

View File

@@ -0,0 +1,37 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from pydantic import BaseModel
from app.core.memory.models.config_models import LLMConfig
"""
model_name: str
provider: str
api_key: str
base_url: Optional[str] = None
timeout: float = 30.0 # 请求超时时间(秒)
max_retries: int = 3 # 最大重试次数
concurrency: int = 5 # 并发限流
extra_params: Dict[str, Any] = {}
"""
from app.core.models.base import RedBearModelConfig
class LLMClient(ABC):
def __init__(self, model_config: RedBearModelConfig):
self.config = model_config
self.model_name = self.config.model_name
self.provider = self.config.provider
self.api_key = self.config.api_key
self.base_url = self.config.base_url
self.max_retries = self.config.max_retries
@abstractmethod
def chat(self, messages: List[Dict[str, str]]) -> Any:
pass
@abstractmethod
async def response_structured(
self,
messages: List[Dict[str, str]],
response_model: type[BaseModel],
) -> type[BaseModel]:
pass

View File

@@ -0,0 +1,224 @@
import asyncio
from typing import List, Dict, Any
import json
from pydantic import BaseModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from app.core.models.base import RedBearModelConfig
from app.core.models.llm import RedBearLLM
from app.core.memory.src.llm_tools.llm_client import LLMClient
# from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED
LANGFUSE_ENABLED=False
class OpenAIClient(LLMClient):
def __init__(self, model_config: RedBearModelConfig, type_: str = "chat"):
super().__init__(model_config)
# Initialize Langfuse callback handler if enabled
self.langfuse_handler = None
if LANGFUSE_ENABLED:
try:
from langfuse.langchain import CallbackHandler
self.langfuse_handler = CallbackHandler()
except ImportError:
# Langfuse not installed, continue without tracing
pass
except Exception as e:
# Log error but don't fail initialization
import logging
logging.warning(f"Failed to initialize Langfuse handler: {e}")
# Initialize RedBearLLM client
self.client = RedBearLLM(RedBearModelConfig(
model_name=self.model_name,
provider=self.provider,
api_key=self.api_key,
base_url=self.base_url,
max_retries=self.max_retries,
), type=type_)
async def chat(self, messages: List[Dict[str, str]]) -> Any:
template = """{messages}"""
# ChatPromptTemplate
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | self.client
# Add Langfuse callback if available
config = {}
if self.langfuse_handler:
config["callbacks"] = [self.langfuse_handler]
response = await chain.ainvoke({"messages": messages}, config=config)
# print(f"OpenAIClient response ======>:\n {response}")
return response
async def response_structured(
self,
messages: List[Dict[str, str]],
response_model: type[BaseModel],
) -> type[BaseModel]:
# Build a simple prompt pipeline that sends messages to the underlying LLM
question_text = "\n\n".join([str(m.get("content", "")) for m in messages])
# Prepare config with Langfuse callback if available
config = {}
if self.langfuse_handler:
config["callbacks"] = [self.langfuse_handler]
# Primary: enforce schema with PydanticOutputParser if available
if PydanticOutputParser is not None:
try:
import logging
logger = logging.getLogger(__name__)
# 使用正确的属性路径self.config.timeout从LLMClient基类继承
# logger.info(f"开始LLM结构化输出请求 (模型: {self.model_name}, 超时: {self.config.timeout}秒)")
parser = PydanticOutputParser(pydantic_object=response_model)
format_instructions = parser.get_format_instructions()
prompt = ChatPromptTemplate.from_template("{question}\n{format_instructions}")
chain = prompt | self.client | parser
parsed = await chain.ainvoke({
"question": question_text,
"format_instructions": format_instructions,
})
# logger.info(f"LLM结构化输出请求成功完成")
return parsed
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"PydanticOutputParser失败尝试备用方法: {str(e)}")
# Fall through to alternative structured methods
pass
# Fallback path: create plain prompt for other structured methods
template = """{question}"""
prompt = ChatPromptTemplate.from_template(template)
# Try LangChain structured output if available on the underlying client
try:
with_so = getattr(self.client, "with_structured_output", None)
if callable(with_so):
try:
structured_chain = prompt | with_so(response_model, strict=True)
parsed = await structured_chain.ainvoke({"question": question_text}, config=config)
# parsed may already be a pydantic model or a dict
try:
return response_model.model_validate(parsed)
except Exception:
try:
# If it's already a pydantic instance (LangChain returns model), return it
if hasattr(parsed, "model_dump"):
return parsed
return response_model.model_validate_json(json.dumps(parsed))
except Exception:
# Fall through to manual parsing below
pass
except NotImplementedError:
# The underlying model doesn't support structured output, fall through
import logging
logger = logging.getLogger(__name__)
logger.warning(
f"Model {self.model_name} doesn't support with_structured_output, falling back to manual parsing")
pass
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Structured output attempt failed: {e}, falling back to manual parsing")
# Final fallback: manual parsing with plain LLM response
try:
import logging
logger = logging.getLogger(__name__)
logger.info(f"Using manual parsing fallback for model {self.model_name}")
# Create a prompt that asks for JSON output
json_prompt = ChatPromptTemplate.from_template(
"{question}\n\n"
"Please respond with a valid JSON object that matches this schema:\n"
"{schema}\n\n"
"Response (JSON only):"
)
# Get the schema from the response model
schema = response_model.model_json_schema()
chain = json_prompt | self.client
response = await chain.ainvoke({
"question": question_text,
"schema": json.dumps(schema, indent=2)
}, config=config)
# Extract JSON from response
response_text = str(response.content if hasattr(response, 'content') else response)
# Try to find JSON in the response
import re
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
json_str = json_match.group(0)
try:
parsed_dict = json.loads(json_str)
return response_model.model_validate(parsed_dict)
except json.JSONDecodeError:
pass
# If JSON parsing fails, try to create a minimal valid response
logger.warning(f"Failed to parse JSON from LLM response, creating minimal response")
# Create a minimal response based on the schema
return self._create_minimal_response(response_model)
except Exception as fallback_error:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Manual parsing fallback also failed: {fallback_error}")
# Return minimal response as last resort
return self._create_minimal_response(response_model)
def _create_minimal_response(self, response_model: type[BaseModel]) -> BaseModel:
"""Create a minimal valid response based on the model schema."""
try:
minimal_response = {}
for field_name, field_info in response_model.model_fields.items():
# Check if field has a default value
if hasattr(field_info, 'default') and field_info.default is not None:
minimal_response[field_name] = field_info.default
else:
# Create default based on field type
field_type = field_info.annotation
# Handle nested BaseModel
if hasattr(field_type, '__bases__') and BaseModel in field_type.__bases__:
minimal_response[field_name] = self._create_minimal_response(field_type)
elif field_type == str:
minimal_response[field_name] = "信息不足,无法回答"
elif field_type == int:
minimal_response[field_name] = 0
elif field_type == float:
minimal_response[field_name] = 0.0
elif field_type == bool:
minimal_response[field_name] = False
elif field_type == list:
minimal_response[field_name] = []
elif field_type == dict:
minimal_response[field_name] = {}
else:
minimal_response[field_name] = None
return response_model.model_validate(minimal_response)
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Failed to create minimal response: {e}")
# Last resort: try to create with just required fields
try:
return response_model()
except Exception:
# If even that fails, raise the original error
raise ValueError(f"Unable to create minimal response for {response_model.__name__}") from e

View File

@@ -0,0 +1,26 @@
from typing import List
from app.core.memory.src.llm_tools.embedder_client import EmbedderClient
from app.core.models.base import RedBearModelConfig
# from app.models.models_model import ModelType
from app.core.models.embedding import RedBearEmbeddings
class OpenAIEmbedderClient(EmbedderClient):
def __init__(self, model_config: RedBearModelConfig):
super().__init__(model_config)
async def response(
self,
messages: List[str],
) -> List[List[float]]:
texts: List[str] = [str(m) for m in messages if m is not None]
model = RedBearEmbeddings(RedBearModelConfig(
model_name=self.model_name,
provider=self.provider,
api_key=self.api_key,
base_url=self.base_url,
))
embeddings = await model.aembed_documents(texts)
return embeddings

View File

@@ -0,0 +1,980 @@
import argparse
import asyncio
import json
import os
import time
from typing import List, Dict, Any, Optional
from dotenv import load_dotenv
from datetime import datetime
import math
from app.core.logging_config import get_memory_logger
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import (
search_graph_by_embedding, search_graph,
search_graph_by_temporal, search_graph_by_keyword_temporal,
search_graph_by_chunk_id
)
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.config_models import TemporalSearchParams
from app.core.memory.utils.config.config_utils import get_embedder_config, get_pipeline_config
from app.core.memory.utils.data.time_utils import normalize_date_safe
from app.core.memory.models.variate_config import ForgettingEngineConfig
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
from app.core.memory.utils.data.text_utils import extract_plain_query
from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.llm.llm_utils import get_reranker_client
load_dotenv()
logger = get_memory_logger(__name__)
def _parse_datetime(value: Any) -> Optional[datetime]:
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
if value is None:
return None
if isinstance(value, datetime):
return value
if isinstance(value, str):
s = value.strip()
if not s:
return None
try:
return datetime.fromisoformat(s)
except Exception:
return None
return None
def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") -> List[Dict[str, Any]]:
"""Normalize scores using z-score normalization followed by sigmoid transformation."""
if not results:
return results
# Extract scores, ensuring they are numeric and not None
scores = []
for item in results:
if score_field in item:
score = item.get(score_field)
if score is not None and isinstance(score, (int, float)):
scores.append(float(score))
else:
scores.append(0.0) # Default for None or non-numeric values
if not scores:
return results
if len(scores) == 1:
# Single score, set to 1.0
for item in results:
if score_field in item:
item[f"normalized_{score_field}"] = 1.0
return results
# Calculate mean and standard deviation
mean_score = sum(scores) / len(scores)
variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
std_dev = math.sqrt(variance)
if std_dev == 0:
# All scores are the same, set them to 1.0
for item in results:
if score_field in item:
item[f"normalized_{score_field}"] = 1.0
else:
for item in results:
if score_field in item:
score = item[score_field]
# Handle None or non-numeric scores
if score is None or not isinstance(score, (int, float)):
score = 0.0
# Calculate z-score
z_score = (score - mean_score) / std_dev
# Transform to positive range using sigmoid function
normalized = 1 / (1 + math.exp(-z_score))
item[f"normalized_{score_field}"] = normalized
return results
def rerank_hybrid_results(
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10
) -> Dict[str, List[Dict[str, Any]]]:
"""
Rerank hybrid search results by combining BM25 and embedding scores.
Args:
keyword_results: Results from keyword/BM25 search
embedding_results: Results from embedding search
alpha: Weight for BM25 scores (1-alpha for embedding scores)
limit: Maximum number of results to return per category
Returns:
Reranked results with combined scores
"""
reranked = {}
for category in ["statements", "chunks", "entities","summaries"]:
keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, [])
# Normalize scores within each search type
keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score")
# Create a combined pool of unique items
combined_items = {}
# Add keyword results with BM25 scores
for item in keyword_items:
item_id = item.get("id") or item.get("uuid")
if item_id:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
combined_items[item_id]["embedding_score"] = 0 # Default
# Add or update with embedding results
for item in embedding_items:
item_id = item.get("id") or item.get("uuid")
if item_id:
if item_id in combined_items:
# Update existing item with embedding score
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
# New item from embedding search only
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0 # Default
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# Calculate combined scores and rank
for item_id, item in combined_items.items():
bm25_score = item.get("bm25_score", 0)
embedding_score = item.get("embedding_score", 0)
# Combined score: weighted average of normalized scores
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
item["combined_score"] = combined_score
# Keep original score for reference
if "score" not in item and bm25_score > 0:
item["score"] = bm25_score
elif "score" not in item and embedding_score > 0:
item["score"] = embedding_score
# Sort by combined score and limit results
sorted_items = sorted(
combined_items.values(),
key=lambda x: x.get("combined_score", 0),
reverse=True
)[:limit]
reranked[category] = sorted_items
return reranked
def rerank_with_forgetting_curve(
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None,
now: datetime | None = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Rerank hybrid results with a forgetting curve applied to combined scores.
The forgetting curve reduces scores for older memories or weaker connections.
Args:
keyword_results: Results from keyword/BM25 search
embedding_results: Results from embedding search
alpha: Weight for BM25 scores (1-alpha for embedding scores)
limit: Maximum number of results to return per category
forgetting_config: Configuration for the forgetting engine
now: Optional current time override for testing
Returns:
Reranked results with combined and final scores (after forgetting)
"""
engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
now_dt = now or datetime.now()
reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities","summaries"]:
keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, [])
# Normalize scores within each search type
keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score")
combined_items: Dict[str, Dict[str, Any]] = {}
# Combine two result sets by ID
for src_items, is_embedding in (
(keyword_items, False), (embedding_items, True)
):
for item in src_items:
item_id = item.get("id") or item.get("uuid")
if not item_id:
continue
existing = combined_items.get(item_id)
if not existing:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0
combined_items[item_id]["embedding_score"] = 0
# Update normalized score from the right source
if is_embedding:
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# Calculate scores and apply forgetting weights
for item_id, item in combined_items.items():
bm25_score = float(item.get("bm25_score", 0) or 0)
embedding_score = float(item.get("embedding_score", 0) or 0)
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# Estimate time elapsed in days
dt = _parse_datetime(item.get("created_at"))
if dt is None:
time_elapsed_days = 0.0
else:
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
# Memory strength (currently set to default value)
memory_strength = 1.0
forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days, memory_strength=memory_strength
)
# print(f"Forgetting weight for {item_id}: {forgetting_weight}")
# print(f"Time elapsed days for {item_id}: {time_elapsed_days}")
final_score = combined_score * forgetting_weight
item["combined_score"] = final_score
sorted_items = sorted(
combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
)[:limit]
reranked[category] = sorted_items
return reranked
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = "search_log.txt"):
"""Log search query information to file"""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Ensure the query text is plain and clean before logging
cleaned_query = extract_plain_query(query_text)
log_entry = {
"timestamp": timestamp,
# "query": query_text,
"query": cleaned_query,
"search_type": search_type,
"group_id": group_id,
"limit": limit,
"include": include
}
# Append to log file
with open(log_file, "a", encoding="utf-8") as f:
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
logger.info(f"Search logged: {query_text} ({search_type})")
def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
"""Remove specified keys recursively from dict/list structures (in place)."""
try:
if isinstance(obj, dict):
for k in keys_to_remove:
if k in obj:
obj.pop(k, None)
for v in list(obj.values()):
_remove_keys_recursive(v, keys_to_remove)
elif isinstance(obj, list):
for item in obj:
_remove_keys_recursive(item, keys_to_remove)
except Exception:
# Be defensive: never fail search because of sanitization
pass
return obj
def apply_reranker_placeholder(
results: Dict[str, List[Dict[str, Any]]],
query_text: str,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Placeholder for a cross-encoder reranker.
If config enables reranker, annotate items with a final_score equal to combined_score
and keep ordering. This is a no-op reranker to be replaced later.
"""
try:
rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
except Exception as e:
logger.debug(f"Failed to load reranker config: {e}")
rc = {}
if not rc or not rc.get("enabled", False):
return results
top_k = int(rc.get("top_k", 100))
model_name = rc.get("model", "placeholder")
for cat, items in results.items():
head = items[:top_k]
for it in head:
base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
it["final_score"] = base
it["reranker_model"] = model_name
# Keep overall order by final_score if present, otherwise combined/score
results[cat] = sorted(
items,
key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
reverse=True,
)
return results
async def apply_llm_reranker(
results: Dict[str, List[Dict[str, Any]]],
query_text: str,
reranker_client: Optional[Any] = None,
llm_weight: Optional[float] = None,
top_k: Optional[int] = None,
batch_size: Optional[int] = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Apply LLM-based reranking to search results.
Args:
results: Search results organized by category
query_text: Original search query
reranker_client: Optional pre-initialized reranker client
llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
top_k: Maximum number of items to rerank per category
batch_size: Number of items to process concurrently
Returns:
Reranked results with final_score and reranker_model fields
"""
# Load reranker configuration from runtime.json
try:
rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
except Exception as e:
logger.debug(f"Failed to load reranker config: {e}")
rc = {}
# Check if reranking is enabled
enabled = rc.get("enabled", False)
if not enabled:
logger.debug("LLM reranking is disabled in configuration")
return results
# Load configuration parameters with defaults
llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
top_k = top_k if top_k is not None else rc.get("top_k", 20)
batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
# Initialize reranker client if not provided
if reranker_client is None:
try:
reranker_client = get_reranker_client()
except Exception as e:
logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
return results
# Get model name for metadata
model_name = getattr(reranker_client, 'model_name', 'unknown')
# Process each category
reranked_results = {}
for category in ["statements", "chunks", "entities", "summaries"]:
items = results.get(category, [])
if not items:
reranked_results[category] = []
continue
# Select top K items by combined_score for reranking
sorted_items = sorted(
items,
key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
reverse=True
)
top_items = sorted_items[:top_k]
remaining_items = sorted_items[top_k:]
# Extract text content from each item
def extract_text(item: Dict[str, Any]) -> str:
"""Extract text content from a result item."""
# Try different text fields based on category
text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
return str(text).strip()
# Batch items for concurrent processing
batches = []
for i in range(0, len(top_items), batch_size):
batch = top_items[i:i + batch_size]
batches.append(batch)
# Process batches concurrently
async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Process a batch of items with LLM relevance scoring."""
scored_batch = []
for item in batch:
item_text = extract_text(item)
# Skip items with no text
if not item_text:
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score
item_copy["llm_relevance_score"] = 0.0
item_copy["reranker_model"] = model_name
scored_batch.append(item_copy)
continue
# Create relevance scoring prompt
prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
Query: {query_text}
Result: {item_text}
Respond with only a number between 0.0 and 1.0, where:
- 0.0 means completely irrelevant
- 1.0 means perfectly relevant
Relevance score:"""
# Send request to LLM
try:
messages = [{"role": "user", "content": prompt}]
response = await reranker_client.chat(messages)
# Parse LLM response to extract relevance score
response_text = str(response.content if hasattr(response, 'content') else response).strip()
# Try to extract a float from the response
try:
# Remove any non-numeric characters except decimal point
import re
score_match = re.search(r'(\d+\.?\d*)', response_text)
if score_match:
llm_score = float(score_match.group(1))
# Clamp to [0.0, 1.0]
llm_score = max(0.0, min(1.0, llm_score))
else:
raise ValueError("No numeric score found in response")
except (ValueError, AttributeError) as e:
logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
llm_score = None
# Calculate final score
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
if llm_score is not None:
final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
item_copy["llm_relevance_score"] = llm_score
else:
# Use combined_score as fallback
final_score = combined_score
item_copy["llm_relevance_score"] = combined_score
item_copy["final_score"] = final_score
item_copy["reranker_model"] = model_name
scored_batch.append(item_copy)
except Exception as e:
logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score
item_copy["llm_relevance_score"] = combined_score
item_copy["reranker_model"] = model_name
scored_batch.append(item_copy)
return scored_batch
# Process all batches concurrently
try:
batch_tasks = [process_batch(batch) for batch in batches]
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
# Merge batch results
scored_items = []
for result in batch_results:
if isinstance(result, Exception):
logger.warning(f"Batch processing failed: {result}")
continue
scored_items.extend(result)
# Add remaining items (not in top K) with their combined_score as final_score
for item in remaining_items:
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score
item_copy["reranker_model"] = model_name
scored_items.append(item_copy)
# Sort all items by final_score in descending order
scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
reranked_results[category] = scored_items
except Exception as e:
logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
# Return original items with combined_score as final_score
for item in items:
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item["final_score"] = combined_score
item["reranker_model"] = model_name
reranked_results[category] = items
return reranked_results
async def run_hybrid_search(
query_text: str,
search_type: str,
group_id: str | None,
limit: int,
include: List[str],
output_path: str | None,
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
):
"""
Run search with specified type: 'keyword', 'embedding', or 'hybrid'
"""
# Start overall timing
search_start_time = time.time()
latency_metrics = {}
# Clean and normalize the incoming query before use/logging
query_text = extract_plain_query(query_text)
# Validate query is not empty after cleaning
if not query_text or not query_text.strip():
logger.warning(f"Empty query after cleaning, returning empty results")
return {
"keyword_search": {},
"embedding_search": {},
"reranked_results": {},
"combined_summary": {
"total_keyword_results": 0,
"total_embedding_results": 0,
"total_reranked_results": 0,
"search_query": "",
"search_timestamp": datetime.now().isoformat(),
"error": "Empty query"
}
}
# Log the search query
log_search_query(query_text, search_type, group_id, limit, include)
connector = Neo4jConnector()
results = {}
try:
keyword_task = None
embedding_task = None
if search_type in ["keyword", "hybrid"]:
# Keyword-based search
logger.info("Starting keyword search...")
keyword_start = time.time()
keyword_task = asyncio.create_task(
search_graph(
connector=connector,
q=query_text,
group_id=group_id,
limit=limit,
include=include
)
)
if search_type in ["embedding", "hybrid"]:
# Embedding-based search
logger.info("Starting embedding search...")
embedding_start = time.time()
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time()
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
rb_config = RedBearModelConfig(
model_name=embedder_config_dict["model_name"],
provider=embedder_config_dict["provider"],
api_key=embedder_config_dict["api_key"],
base_url=embedder_config_dict["base_url"],
type="llm"
)
config_load_time = time.time() - config_load_start
logger.info(f"Config loading took {config_load_time:.4f}s")
# Init embedder
embedder_init_start = time.time()
embedder = OpenAIEmbedderClient(model_config=rb_config)
embedder_init_time = time.time() - embedder_init_start
logger.info(f"Embedder init took {embedder_init_time:.4f}s")
embedding_task = asyncio.create_task(
search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=query_text,
group_id=group_id,
limit=limit,
include=include,
)
)
if keyword_task:
keyword_results = await keyword_task
keyword_latency = time.time() - keyword_start
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
logger.info(f"Keyword search completed in {keyword_latency:.4f}s")
if search_type == "keyword":
results = keyword_results
else:
results["keyword_search"] = keyword_results
if embedding_task:
embedding_results = await embedding_task
embedding_latency = time.time() - embedding_start
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
logger.info(f"Embedding search completed in {embedding_latency:.4f}s")
if search_type == "embedding":
results = embedding_results
else:
results["embedding_search"] = embedding_results
# Merge and rank results for hybrid search
if search_type == "hybrid":
results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"search_query": query_text,
"search_timestamp": datetime.now().isoformat()
}
# Apply reranking (optionally with forgetting curve)
rerank_start = time.time()
if use_forgetting_rerank:
# Load forgetting parameters from pipeline config
try:
pc = get_pipeline_config()
forgetting_cfg = pc.forgetting_engine
except Exception as e:
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
forgetting_cfg = ForgettingEngineConfig()
reranked_results = rerank_with_forgetting_curve(
keyword_results=keyword_results,
embedding_results=embedding_results,
alpha=rerank_alpha,
limit=limit,
forgetting_config=forgetting_cfg,
)
else:
reranked_results = rerank_hybrid_results(
keyword_results=keyword_results,
embedding_results=embedding_results,
alpha=rerank_alpha, # Configurable weight for BM25 vs embedding
limit=limit
)
rerank_latency = time.time() - rerank_start
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
logger.info(f"Reranking completed in {rerank_latency:.4f}s")
# Optional: apply reranker placeholder if enabled via config
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
# Apply LLM reranking if enabled
llm_rerank_applied = False
if use_llm_rerank:
try:
reranked_results = await apply_llm_reranker(
results=reranked_results,
query_text=query_text,
)
llm_rerank_applied = True
logger.info("LLM reranking applied successfully")
except Exception as e:
logger.warning(f"LLM reranking failed: {e}, using previous scores")
results["reranked_results"] = reranked_results
results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
"search_query": query_text,
"search_timestamp": datetime.now().isoformat(),
"reranking_alpha": rerank_alpha,
"forgetting_rerank": use_forgetting_rerank,
"llm_rerank": llm_rerank_applied,
}
# Calculate total latency
total_latency = time.time() - search_start_time
latency_metrics["total_latency"] = round(total_latency, 4)
# Add latency metrics to results
if "combined_summary" in results:
results["combined_summary"]["latency_metrics"] = latency_metrics
else:
results["latency_metrics"] = latency_metrics
logger.info(f"Total search completed in {total_latency:.4f}s")
logger.info(f"Latency breakdown: {latency_metrics}")
# Sanitize results: drop large/unused fields
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs
# print(json.dumps(results, ensure_ascii=False, indent=2, default=str))
# Save to file
output_path = output_path or "search_results.json"
out_dir = os.path.dirname(output_path)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2, default=str)
logger.info(f"Search results saved to: {output_path}")
# Log search completion with result count
if search_type == "hybrid":
result_counts = {
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()}
}
else:
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
completion_log = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"query": query_text,
"search_type": search_type,
"status": "completed",
"result_counts": result_counts,
"output_file": output_path,
"latency_metrics": latency_metrics
}
with open("search_log.txt", "a", encoding="utf-8") as f:
f.write(json.dumps(completion_log, ensure_ascii=False) + "\n")
return results
finally:
await connector.close()
async def search_by_temporal(
group_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
):
"""
Temporal search across Statements.
- Matches statements created between start_date and end_date
- Optionally filters by group_id
- Returns up to 'limit' statements
"""
connector = Neo4jConnector()
if start_date:
start_date = normalize_date_safe(start_date)
if end_date:
end_date = normalize_date_safe(end_date)
params = TemporalSearchParams.model_validate({
"group_id": group_id,
"apply_id": apply_id,
"user_id": user_id,
"start_date": start_date,
"end_date": end_date,
"valid_date": valid_date,
"invalid_date": invalid_date,
"limit": limit,
})
statements = await search_graph_by_temporal(
connector=connector,
group_id=params.group_id,
apply_id=params.apply_id,
user_id=params.user_id,
start_date=params.start_date,
end_date=params.end_date,
valid_date=params.valid_date,
invalid_date=params.invalid_date,
limit=params.limit
)
return {"statements": statements}
async def search_by_keyword_temporal(
query_text: str,
group_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
):
"""
Temporal keyword search across Statements.
"""
connector = Neo4jConnector()
if start_date:
start_date = normalize_date_safe(start_date)
if end_date:
end_date = normalize_date_safe(end_date)
if valid_date:
valid_date = normalize_date_safe(valid_date)
if invalid_date:
invalid_date = normalize_date_safe(invalid_date)
params = TemporalSearchParams.model_validate({
"group_id": group_id,
"apply_id": apply_id,
"user_id": user_id,
"start_date": start_date,
"end_date": end_date,
"valid_date": valid_date,
"invalid_date": invalid_date,
"limit": limit,
})
statements = await search_graph_by_keyword_temporal(
connector=connector,
query_text=query_text,
group_id=params.group_id,
apply_id=params.apply_id,
user_id=params.user_id,
start_date=params.start_date,
end_date=params.end_date,
valid_date=params.valid_date,
invalid_date=params.invalid_date,
limit=params.limit
)
return {"statements": statements}
async def search_chunk_by_chunk_id(
chunk_id: str,
group_id: Optional[str] = "test",
limit: int = 1,
):
"""
Search for Chunks by chunk_id.
"""
connector = Neo4jConnector()
chunks = await search_graph_by_chunk_id(
connector=connector,
chunk_id=chunk_id,
group_id=group_id,
limit=limit
)
return {"chunks": chunks}
def main():
"""Main entry point for the hybrid graph search CLI.
Parses command line arguments and executes search with specified parameters.
Supports keyword, embedding, and hybrid search modes.
"""
parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
parser.add_argument(
"--query", "-q", required=True, help="Free-text query to search"
)
parser.add_argument(
"--search-type",
"-t",
choices=["keyword", "embedding", "hybrid"],
default="hybrid",
help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
)
parser.add_argument(
"--embedding-name",
"-m",
default="openai/nomic-embed-text:v1.5",
help="Embedding config name from config.json (default: openai/nomic-embed-text:v1.5)",
)
parser.add_argument(
"--group-id",
"-g",
default=None,
help="Optional group_id to filter results (default: None)",
)
parser.add_argument(
"--limit",
"-k",
type=int,
default=5,
help="Max number of results per type (default: 5)",
)
parser.add_argument(
"--include",
"-i",
nargs="+",
default=["statements", "chunks", "entities", "summaries"],
choices=["statements", "chunks", "entities", "summaries"],
help="Which targets to search for embedding search (default: statements chunks entities summaries)"
)
parser.add_argument(
"--output",
"-o",
default="search_results.json",
help="Path to save the search results JSON (default: search_results.json)",
)
parser.add_argument(
"--rerank-alpha",
"-a",
type=float,
default=0.6,
help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
)
parser.add_argument(
"--forgetting-rerank",
action="store_true",
help="Apply forgetting curve during reranking for hybrid search.",
)
parser.add_argument(
"--llm-rerank",
action="store_true",
help="Apply LLM-based reranking for hybrid search.",
)
args = parser.parse_args()
asyncio.run(
run_hybrid_search(
query_text=args.query,
search_type=args.search_type,
group_id=args.group_id,
limit=args.limit,
include=args.include,
output_path=args.output,
rerank_alpha=args.rerank_alpha,
use_forgetting_rerank=args.forgetting_rerank,
use_llm_rerank=args.llm_rerank,
)
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,8 @@
"""
存储服务模块
包含三大引擎:
1. 萃取引擎Extraction Engine- 知识提取、预处理、去重消歧
2. 遗忘引擎Forgetting Engine- 记忆遗忘机制
3. 自我反思引擎Reflection Engine- 自我反思和优化
"""

View File

@@ -0,0 +1,8 @@
"""
萃取引擎Extraction Engine
负责从对话数据中提取结构化知识,包括:
- 数据预处理
- 知识提取(分块、陈述句、三元组、时间信息、嵌入向量)
- 去重消歧
"""

View File

@@ -0,0 +1,13 @@
"""
数据预处理模块 - 负责对话数据的清洗、转换和预处理
包含:
- data_preprocessor: 数据预处理器 - 读取、清洗和转换对话数据
- data_pruning: 语义剪枝器 - 过滤与场景不相关的内容
- data_chunker: 数据分块器 - 将对话分割成可处理的片段
"""
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
__all__ = ['DataPreprocessor', 'SemanticPruner']

View File

@@ -0,0 +1,54 @@
"""
数据分块器 - 将对话分割成可处理的片段
功能:
- 支持多种分块策略递归分块、语义分块、LLM分块等
- 根据对话长度和内容特征进行智能分块
- 保持对话上下文的连贯性
注意:此模块当前为占位符,具体实现将在后续任务中完成。
分块功能目前在 app/core/memory/llm_tools/chunker_client.py 中实现。
"""
from typing import List, Optional
from app.core.memory.models.message_models import DialogData, Chunk
class DataChunker:
"""数据分块器 - 将长对话分割成多个可处理的片段"""
def __init__(self, chunker_strategy: str = "RecursiveChunker"):
"""
初始化数据分块器
Args:
chunker_strategy: 分块策略名称
"""
self.chunker_strategy = chunker_strategy
async def chunk_dialog(self, dialog: DialogData) -> List[Chunk]:
"""
将对话分割成多个块
Args:
dialog: 对话数据
Returns:
分块列表
Note:
当前此功能在 app/core/memory/llm_tools/chunker_client.py 中实现
"""
raise NotImplementedError("数据分块功能将在后续任务中实现")
async def chunk_dialogs(self, dialogs: List[DialogData]) -> List[DialogData]:
"""
批量处理多个对话的分块
Args:
dialogs: 对话数据列表
Returns:
包含分块信息的对话数据列表
"""
raise NotImplementedError("数据分块功能将在后续任务中实现")

View File

@@ -0,0 +1,785 @@
"""
数据预处理器 - 支持多种格式的对话数据读取、清洗和预处理
功能:
- 支持多种文件格式JSON、CSV、Excel、TXT
- 自动检测文件编码
- 清洗和标准化对话数据
- 转换为 DialogData 对象
"""
import json
import csv
import pandas as pd
import re
import os
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
from datetime import datetime
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
class DataPreprocessor:
"""数据预处理器类,支持多种格式的对话数据读取、清洗和预处理。"""
def __init__(self, input_file_path: str = None, output_file_path: str = None):
"""
初始化数据预处理器。
Args:
input_file_path: 输入文件路径可选可后续通过set_input_path设置
output_file_path: 输出文件路径可选可后续通过set_output_path设置
注意:您可以通过以下方式指定输入输出路径:
1. 初始化时传入参数
2. 调用set_input_path()和set_output_path()方法
3. 在preprocess()方法中直接传入路径参数
"""
self.input_file_path = input_file_path or r"src\extracted_statements.txt"
self.output_file_path = output_file_path or r"src\data_preprocessing\out-file\extracted_statements-pre.txt"
self.supported_formats = ['.json', '.csv', '.txt', '.xlsx', '.tsv']
def set_input_path(self, input_path: str) -> None:
"""
设置输入文件路径。
Args:
input_path: 输入文件的完整路径
"""
self.input_file_path = input_path
def set_output_path(self, output_path: str) -> None:
"""
设置输出文件路径。
Args:
output_path: 输出文件的完整路径
"""
self.output_file_path = output_path
def get_file_format(self, file_path: str) -> str:
"""
获取文件格式。
Args:
file_path: 文件路径
Returns:
文件扩展名(小写)
"""
return Path(file_path).suffix.lower()
def _detect_encoding(self, file_path: str) -> str:
"""
检测文件编码,使用多种方法确保准确性。
Args:
file_path: 文件路径
Returns:
检测到的编码格式
"""
# 常见编码列表,按优先级排序
encodings_to_try = ['utf-8', 'gbk', 'gb2312', 'utf-16', 'latin-1']
# 首先尝试使用chardet检测
try:
import chardet
with open(file_path, 'rb') as f:
raw_data = f.read(10000) # 读取前10KB进行检测
result = chardet.detect(raw_data)
detected_encoding = result.get('encoding')
confidence = result.get('confidence', 0)
# 如果检测置信度较高,使用检测结果
if detected_encoding and confidence > 0.7:
return detected_encoding
except ImportError:
print("警告: chardet库未安装使用备用编码检测方法")
except Exception as e:
print(f"chardet检测失败: {e},使用备用方法")
# 备用方法:尝试不同编码读取文件开头
for encoding in encodings_to_try:
try:
with open(file_path, 'r', encoding=encoding) as f:
f.read(1000) # 尝试读取前1000个字符
return encoding
except (UnicodeDecodeError, UnicodeError):
continue
# 如果所有编码都失败返回utf-8作为最后选择
return 'utf-8'
def _read_json(self, data_path: str) -> List[Dict[str, Any]]:
"""
读取JSON格式的对话数据支持标准JSON和JSONL格式。
Args:
data_path: JSON文件路径
Returns:
解析后的数据列表
"""
encoding = self._detect_encoding(data_path)
content = None
# 尝试使用检测到的编码读取文件
encodings_to_try = [encoding, 'utf-8', 'gbk', 'gb2312', 'latin-1']
for enc in encodings_to_try:
try:
with open(data_path, 'r', encoding=enc) as f:
content = f.read().strip()
print(f"成功使用编码 {enc} 读取文件")
break
except (UnicodeDecodeError, UnicodeError) as e:
print(f"编码 {enc} 读取失败: {e}")
continue
if content is None:
raise ValueError(f"无法使用任何编码读取文件: {data_path}")
try:
# 尝试解析为标准JSON
try:
data = json.loads(content)
if isinstance(data, dict):
return [data]
elif isinstance(data, list):
return data
else:
raise ValueError(f"不支持的JSON数据结构: {type(data)}")
except json.JSONDecodeError as e:
# 如果标准JSON解析失败尝试JSONL格式每行一个JSON对象
print(f"标准JSON解析失败: {e}尝试JSONL格式...")
data_list = []
lines = content.split('\n')
for line_num, line in enumerate(lines, 1):
line = line.strip()
if line: # 跳过空行
try:
json_obj = json.loads(line)
data_list.append(json_obj)
except json.JSONDecodeError as line_error:
# 如果是单行巨大JSON数组可能需要特殊处理
if line_num == 1 and len(lines) == 1:
print(f"检测到单行大型JSON尝试分块解析...")
# 对于超大单行JSON尝试使用json.JSONDecoder进行流式解析
try:
decoder = json.JSONDecoder()
idx = 0
while idx < len(line):
line = line[idx:].lstrip()
if not line:
break
try:
obj, end_idx = decoder.raw_decode(line)
if isinstance(obj, list):
data_list.extend(obj)
elif isinstance(obj, dict):
data_list.append(obj)
idx += end_idx
except json.JSONDecodeError:
break
except Exception as decode_error:
print(f"分块解析也失败: {decode_error}")
else:
print(f"警告: 第{line_num}行JSON解析失败: {line_error}")
continue
return data_list
except Exception as e:
raise ValueError(f"读取JSON文件时发生错误: {e}")
def _read_csv(self, data_path: str) -> List[Dict[str, Any]]:
"""
读取CSV格式的对话数据。
Args:
data_path: CSV文件路径
Returns:
解析后的数据列表
"""
encoding = self._detect_encoding(data_path)
encodings_to_try = [encoding, 'utf-8', 'gbk', 'gb2312', 'latin-1']
for enc in encodings_to_try:
try:
# 尝试不同的分隔符
separators = [',', '\t', ';', '|']
df = None
for sep in separators:
try:
df = pd.read_csv(data_path, encoding=enc, sep=sep)
if len(df.columns) > 1: # 如果成功分割出多列,则认为找到了正确的分隔符
break
except Exception:
continue
if df is None:
df = pd.read_csv(data_path, encoding=enc)
print(f"成功使用编码 {enc} 读取CSV文件")
return df.to_dict('records')
except (UnicodeDecodeError, UnicodeError) as e:
print(f"编码 {enc} 读取CSV失败: {e}")
continue
except Exception as e:
if enc == encodings_to_try[-1]: # 最后一个编码也失败了
raise ValueError(f"读取CSV文件失败: {e}")
continue
raise ValueError(f"无法使用任何编码读取CSV文件: {data_path}")
def _read_excel(self, data_path: str) -> List[Dict[str, Any]]:
"""
读取Excel格式的对话数据。
Args:
data_path: Excel文件路径
Returns:
解析后的数据列表
"""
try:
df = pd.read_excel(data_path)
return df.to_dict('records')
except Exception as e:
raise ValueError(f"读取Excel文件失败: {e}")
def _read_text(self, data_path: str) -> List[Dict[str, Any]]:
"""
读取纯文本格式的对话数据。
Args:
data_path: 文本文件路径
Returns:
解析后的数据列表
"""
encoding = self._detect_encoding(data_path)
encodings_to_try = [encoding, 'utf-8', 'gbk', 'gb2312', 'latin-1']
content = None
# 尝试使用不同编码读取文件
for enc in encodings_to_try:
try:
with open(data_path, 'r', encoding=enc) as f:
content = f.read()
print(f"成功使用编码 {enc} 读取文本文件")
break
except (UnicodeDecodeError, UnicodeError) as e:
print(f"编码 {enc} 读取文本失败: {e}")
continue
if content is None:
raise ValueError(f"无法使用任何编码读取文本文件: {data_path}")
try:
# 尝试解析不同的文本格式
lines = content.strip().split('\n')
# 格式1: 每行一个对话轮次,格式为 "角色: 内容" 或 "角色:内容"
messages = []
for line in lines:
line = line.strip()
if not line:
continue
# 尝试匹配 "角色: 内容" 或 "角色:内容" 格式
match = re.match(r'^([^:]+)[:]\s*(.+)$', line)
if match:
role, msg = match.groups()
messages.append({'role': role.strip(), 'msg': msg.strip()})
else:
# 如果不匹配,则作为用户消息处理
messages.append({'role': 'User', 'msg': line})
if messages:
return [{'context': {'msgs': messages}}]
else:
# 如果没有解析出消息,则将整个文本作为一条消息
return [{'context': {'msgs': [{'role': 'User', 'msg': content}]}}]
except Exception as e:
raise ValueError(f"读取文本文件失败: {e}")
def read_data(self, data_path: str = None) -> List[Dict[str, Any]]:
"""
根据文件格式自动选择合适的读取方法。
Args:
data_path: 数据文件路径如果为None则使用初始化时设置的路径
Returns:
解析后的原始数据列表
"""
if data_path is None:
data_path = self.input_file_path
if not data_path:
raise ValueError("请指定输入文件路径")
if not os.path.exists(data_path):
raise FileNotFoundError(f"文件不存在: {data_path}")
file_format = self.get_file_format(data_path)
if file_format == '.json':
return self._read_json(data_path)
elif file_format == '.csv':
return self._read_csv(data_path)
elif file_format in ['.xlsx', '.xls']:
return self._read_excel(data_path)
elif file_format in ['.txt', '.tsv']:
return self._read_text(data_path)
else:
raise ValueError(f"不支持的文件格式: {file_format}。支持的格式: {self.supported_formats}")
def _clean_text(self, text: str) -> str:
"""
增强的文本清洗函数。
"""
if not text or not isinstance(text, str):
return ""
# 1. 移除消息中的角色标识(支持英文冒号":"与中文冒号""
text = re.sub(r'^(用户|AI|user|ai|assistant|bot|助手|机器人)[:]\s*', '', text, flags=re.IGNORECASE)
# 2. 移除URL链接
text = re.sub(r'https?://[^\s]+', '', text)
text = re.sub(r'www\.[^\s]+', '', text)
# 3. 移除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 4. 移除乱码和控制字符
text = re.sub(r'[<5B>]+', '', text)
text = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', text)
# 5. 标点符号规范化
# 将连续的感叹号(中英文)替换为一个句号
text = re.sub(r'[!]+', '', text)
# 将连续的句点/省略号(中英文)替换为一个句号
text = re.sub(r'(…{1,}|\.{2,}|。{2,})', '', text)
# 将英文句点统一为中文句号(避免残留英文句点影响断句)
text = re.sub(r'\.', '', text)
# 将连续的逗号(中英文)规范为一个中文逗号
text = re.sub(r'[,]{2,}', '', text)
# 将英文逗号统一为中文逗号
text = re.sub(r',', '', text)
# 6. 规范化空白字符
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def _parse_message_content(self, content: str) -> List[Dict[str, str]]:
"""
增强的消息内容解析。
"""
messages = []
# 先清洗内容
cleaned_content = self._clean_text(content)
if not cleaned_content:
return messages
# 检查是否为有效消息(至少包含中文或英文单词)
if not re.search(r'[\u4e00-\u9fff\w]', cleaned_content):
return messages
# 根据内容特征判断角色(更智能的角色识别)
if re.search(r'(你好|嗨|早上好|晚上好|请问|谢谢|抱歉)', cleaned_content):
role = 'User'
elif re.search(r'(很高兴|建议|推荐|可以帮助|请提供)', cleaned_content):
role = 'Assistant'
else:
role = 'User' # 默认
messages.append({'role': role, 'msg': cleaned_content})
return messages
def _filter_empty_messages(self, messages: List[ConversationMessage]) -> List[ConversationMessage]:
"""
更严格的空消息过滤。
"""
filtered = []
for msg in messages:
# 检查消息是否有效
if (msg.msg and
isinstance(msg.msg, str) and
len(msg.msg.strip()) >= 2 and # 至少2个字符
re.search(r'[\u4e00-\u9fff\w]', msg.msg)): # 包含有效字符
filtered.append(msg)
return filtered
def _normalize_role(self, role: str) -> str:
"""
标准化角色名称。
Args:
role: 原始角色名称
Returns:
标准化后的角色名称
"""
if not role or not isinstance(role, str):
return "User"
role = role.strip().lower()
# 用户角色的各种表示
user_roles = ['user', 'human', '用户', '人类', 'customer', '客户', 'u']
# AI角色的各种表示
ai_roles = ['assistant', 'ai', 'bot', 'chatbot', '助手', '机器人', 'system', 'a']
if role in user_roles:
return "User"
elif role in ai_roles:
return "Assistant"
else:
return "User" # 默认为用户
def clean_data(self, raw_data: List[Dict[str, Any]], skip_cleaning: bool = True) -> List[DialogData]:
"""
清洗原始数据并转换为DialogData对象。
Args:
raw_data: 原始数据列表
skip_cleaning: 是否跳过数据清洗直接转换为DialogData对象默认False
Returns:
清洗后的DialogData对象列表
"""
if skip_cleaning:
print("跳过数据清洗步骤,直接转换数据...")
return self._convert_to_dialog_data(raw_data)
cleaned_dialogs = []
for i, item in enumerate(raw_data):
conv_date: Optional[str] = None
try:
# 提取对话消息
messages = []
# 处理不同的数据结构
if 'content' in item and isinstance(item['content'], list):
# 新格式dialog_release_zh.json格式content是字符串数组
content_list = item['content']
for j, content_text in enumerate(content_list):
# 交替分配角色偶数索引为用户奇数索引为AI
role = 'User' if j % 2 == 0 else 'Assistant'
normalized_role = self._normalize_role(role)
# 清洗消息内容
cleaned_content = self._clean_text(str(content_text))
# 过滤空消息
if cleaned_content:
messages.append(ConversationMessage(role=normalized_role, msg=cleaned_content))
elif 'context' in item and isinstance(item['context'], dict) and 'msgs' in item['context']:
# 标准格式context是字典且包含msgs
raw_messages = item['context']['msgs']
elif 'context' in item and isinstance(item['context'], str):
# testdata.json格式context是字符串需要解析对话内容
context_text = item['context']
# 从context文本中解析绝对日期并存入conv_date格式YYYY-MM-DD
m = re.search(r"(\d{4})年(\d{1,2})月(\d{1,2})日", context_text)
if m:
y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3))
conv_date = f"{y:04d}-{mo:02d}-{d:02d}"
else:
m = re.search(r"(\d{4})[-/](\d{1,2})[-/](\d{1,2})", context_text)
if m:
y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3))
conv_date = f"{y:04d}-{mo:02d}-{d:02d}"
messages = self._parse_context_string(context_text)
elif 'messages' in item:
# 另一种常见格式
raw_messages = item['messages']
elif 'conversation' in item:
# 对话格式
raw_messages = item['conversation']
else:
# 尝试直接解析
raw_messages = [item] if 'role' in item and 'msg' in item else []
# 如果messages还是空的说明需要处理raw_messages
if not messages and 'raw_messages' in locals():
# 清洗每条消息
for msg_data in raw_messages:
if isinstance(msg_data, dict):
role = self._normalize_role(msg_data.get('role', 'User'))
content = msg_data.get('msg', msg_data.get('content', msg_data.get('message', '')))
# 清洗消息内容
cleaned_content = self._clean_text(str(content))
# 过滤空消息
if cleaned_content:
messages.append(ConversationMessage(role=role, msg=cleaned_content))
# 过滤空对话
if not messages:
continue
# 去重相邻的重复消息
deduplicated_messages = []
for msg in messages:
if not deduplicated_messages or (
deduplicated_messages[-1].role != msg.role or
deduplicated_messages[-1].msg != msg.msg
):
deduplicated_messages.append(msg)
# 创建DialogData对象
context = ConversationContext(msgs=deduplicated_messages)
# 获取对话ID优先使用dialog_id然后是ref_id、id最后生成默认ID
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
# 获取group_id如果不存在则生成默认值
group_id = item.get('group_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}')
# 构建元数据,附加解析到的会话日期
metadata = {
**item.get('metadata', {}),
'document_id': str(item.get('document_id', 'unknown')) if item.get('document_id') is not None else 'unknown',
'original_format': 'dialog_release_zh' if 'content' in item and isinstance(item['content'], list) else 'testdata'
}
if conv_date:
metadata['conversation_date'] = conv_date
metadata['publication_date'] = conv_date
dialog_data = DialogData(
context=context,
ref_id=dialog_id,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
metadata=metadata
)
cleaned_dialogs.append(dialog_data)
except Exception as e:
print(f"警告: 处理第{i+1}条数据时出错: {e}")
continue
return cleaned_dialogs
def _convert_to_dialog_data(self, raw_data: List[Dict[str, Any]]) -> List[DialogData]:
"""
直接将原始数据转换为DialogData对象不进行清洗。
Args:
raw_data: 原始数据列表
Returns:
DialogData对象列表
"""
dialog_list = []
for i, item in enumerate(raw_data):
try:
messages = []
# 处理不同的数据结构
if 'content' in item and isinstance(item['content'], list):
content_list = item['content']
for j, content_text in enumerate(content_list):
role = 'User' if j % 2 == 0 else 'Assistant'
if content_text:
messages.append(ConversationMessage(role=role, msg=str(content_text)))
elif 'context' in item and isinstance(item['context'], dict) and 'msgs' in item['context']:
raw_messages = item['context']['msgs']
for msg_data in raw_messages:
if isinstance(msg_data, dict):
role = msg_data.get('role', 'User')
content = msg_data.get('msg', msg_data.get('content', msg_data.get('message', '')))
if content:
messages.append(ConversationMessage(role=role, msg=str(content)))
elif 'context' in item and isinstance(item['context'], str):
# 尝试解析结构化对话,如果失败则作为单条用户消息处理
messages = self._parse_context_string(item['context'])
if not messages:
# 如果没有解析出结构化消息将整个context作为用户消息
context_text = item['context'].strip()
if context_text:
messages.append(ConversationMessage(role='User', msg=context_text))
elif 'messages' in item:
raw_messages = item['messages']
for msg_data in raw_messages:
if isinstance(msg_data, dict):
role = msg_data.get('role', 'User')
content = msg_data.get('msg', msg_data.get('content', msg_data.get('message', '')))
if content:
messages.append(ConversationMessage(role=role, msg=str(content)))
if not messages:
continue
context = ConversationContext(msgs=messages)
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
group_id = item.get('group_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}')
metadata = {
**item.get('metadata', {}),
'document_id': str(item.get('document_id', 'unknown')) if item.get('document_id') is not None else 'unknown',
'original_format': 'raw'
}
dialog_data = DialogData(
context=context,
ref_id=dialog_id,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
metadata=metadata
)
dialog_list.append(dialog_data)
except Exception as e:
print(f"警告: 转换第{i+1}条数据时出错: {e}")
continue
return dialog_list
def _parse_context_string(self, context_text: str) -> List[ConversationMessage]:
"""
解析context字符串中的对话内容。
Args:
context_text: 包含对话的字符串
Returns:
解析后的ConversationMessage列表
"""
messages = []
# 使用正则表达式匹配对话模式
# 匹配 "User: 内容" / "用户: 内容" 或 "Assistant: 内容" / "AI: 内容" 格式
pattern = r'(User|用户|Assistant|AI|user|assistant)[:]\s*([^\n]+(?:\n(?!(?:User|用户|Assistant|AI|user|assistant)[:])[^\n]*)*?)'
matches = re.findall(pattern, context_text, re.MULTILINE | re.DOTALL | re.IGNORECASE)
for role, content in matches:
# 标准化角色名称
normalized_role = self._normalize_role(role)
# 清洗消息内容
cleaned_content = self._clean_text(content.strip())
# 过滤空消息
if cleaned_content:
messages.append(ConversationMessage(role=normalized_role, msg=cleaned_content))
return messages
def save_data(self, dialog_data_list: List[DialogData], output_path: str = None) -> None:
"""
保存处理后的数据。
Args:
dialog_data_list: DialogData对象列表
output_path: 输出文件路径如果为None则使用初始化时设置的路径
"""
if output_path is None:
output_path = self.output_file_path
if not output_path:
raise ValueError("请指定输出文件路径")
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 转换为可序列化的格式
serializable_data = []
for dialog in dialog_data_list:
serializable_data.append({
'id': dialog.id,
'ref_id': dialog.ref_id,
'created_at': dialog.created_at.isoformat(),
'context': {
'msgs': [{'role': msg.role, 'msg': msg.msg} for msg in dialog.context.msgs]
},
'metadata': dialog.metadata,
'chunks': []
})
# 保存为JSON格式
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(serializable_data, f, ensure_ascii=False, indent=2)
print(f"数据已保存到: {output_path}")
def preprocess(self, input_path: str = None, output_path: str = None, skip_cleaning: bool = True, indices: Optional[List[int]] = None) -> List[DialogData]:
"""
完整的数据预处理流程。
Args:
input_path: 输入文件路径(可选)
output_path: 输出文件路径(可选)
skip_cleaning: 是否跳过数据清洗步骤默认False
indices: 要处理的数据索引列表(可选)
Returns:
处理后的DialogData对象列表
"""
print("开始数据预处理...")
# 读取原始数据
print("正在读取数据...")
raw_data = self.read_data(input_path)
print(f"成功读取 {len(raw_data)} 条原始数据")
# 根据索引筛选数据
if indices:
selected = [raw_data[i] for i in indices if 0 <= i < len(raw_data)]
if selected:
raw_data = selected
print(f"根据索引 {indices} 筛选后,保留 {len(raw_data)} 条数据")
else:
print(f"警告: 提供的索引 {indices} 筛选为空,处理全部 {len(raw_data)} 条数据")
# 清洗数据
if skip_cleaning:
print("跳过数据清洗步骤...")
cleaned_data = self.clean_data(raw_data, skip_cleaning=True)
else:
print("正在清洗数据...")
cleaned_data = self.clean_data(raw_data, skip_cleaning=False)
print(f"处理完成,得到 {len(cleaned_data)} 条有效对话")
# 保存数据(如果指定了输出路径)
if output_path or self.output_file_path:
print("正在保存数据...")
self.save_data(cleaned_data, output_path)
print("数据预处理完成!")
return cleaned_data

View File

@@ -0,0 +1,573 @@
"""
语义剪枝器 - 在预处理与分块之间过滤与场景不相关内容
功能:
- 对话级一次性抽取判定相关性
- 仅对"不相关对话"的消息按比例删除
- 重要信息(时间、编号、金额、联系方式、地址等)优先保留
"""
import os
import hashlib
import json
import re
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, Field
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
from app.core.memory.models.config_models import PruningConfig
from app.core.memory.utils.config.config_utils import get_pruning_config
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
class DialogExtractionResponse(BaseModel):
"""对话级一次性抽取的结构化返回,用于加速剪枝。
- is_related对话与场景的相关性判定。
- times / ids / amounts / contacts / addresses / keywords重要信息片段用来在不相关对话中保留关键消息。
"""
is_related: bool = Field(...)
times: List[str] = Field(default_factory=list)
ids: List[str] = Field(default_factory=list)
amounts: List[str] = Field(default_factory=list)
contacts: List[str] = Field(default_factory=list)
addresses: List[str] = Field(default_factory=list)
keywords: List[str] = Field(default_factory=list)
class SemanticPruner:
"""语义剪枝:在预处理与分块之间过滤与场景不相关内容。
采用对话级一次性抽取判定相关性;仅对"不相关对话"的消息按比例删除,
重要信息(时间、编号、金额、联系方式、地址等)优先保留。
"""
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None):
cfg_dict = get_pruning_config() if config is None else config.model_dump()
self.config = PruningConfig.model_validate(cfg_dict)
self.llm_client = llm_client
# Load Jinja2 template
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
# 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染
self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {}
# 运行日志:收集关键终端输出,便于写入 JSON
self.run_logs: List[str] = []
# 采用顺序处理,移除并发配置以简化与稳定执行
def _is_important_message(self, message: ConversationMessage) -> bool:
"""基于启发式规则识别重要信息消息,优先保留。
- 含日期/时间如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。
- 关键词:"时间""日期""编号""订单""流水""金额""""""电话""手机号""邮箱""地址"
"""
import re
text = message.msg.strip()
if not text:
return False
patterns = [
r"\b\d{4}-\d{1,2}-\d{1,2}\b",
r"\b\d{1,2}:\d{2}\b",
r"\d{4}\d{1,2}月\d{1,2}日",
r"上午|下午|AM|PM",
r"订单号|工单|申请号|编号|ID|账号|账户",
r"电话|手机号|微信|QQ|邮箱",
r"地址|地点",
r"金额|费用|价格|¥|¥|\d+元",
r"时间|日期|有效期|截止",
]
for p in patterns:
if re.search(p, text, flags=re.IGNORECASE):
return True
return False
def _importance_score(self, message: ConversationMessage) -> int:
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
简单启发:匹配到的类别越多、越关键分值越高。
"""
import re
text = message.msg.strip()
score = 0
weights = [
(r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3),
(r"\b\d{1,2}:\d{2}\b", 2),
(r"\d{4}\d{1,2}月\d{1,2}日", 3),
(r"订单号|工单|申请号|编号|ID|账号|账户", 4),
(r"电话|手机号|微信|QQ|邮箱", 3),
(r"地址|地点", 2),
(r"金额|费用|价格|¥|¥|\d+元", 4),
(r"时间|日期|有效期|截止", 2),
]
for p, w in weights:
if re.search(p, text, flags=re.IGNORECASE):
score += w
return score
def _is_filler_message(self, message: ConversationMessage) -> bool:
"""检测典型寒暄/口头禅/确认类短消息用于跳过LLM分类以加速。
满足以下之一视为填充消息:
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体;
- 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。
"""
import re
t = message.msg.strip()
if not t:
return True
# 常见填充语
fillers = [
"你好", "您好", "在吗", "", "嗯嗯", "", "好的", "", "", "可以", "不可以", "谢谢",
"拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", ""
]
if t in fillers:
return True
# 长度与字符类型判断
if len(t) <= 8:
# 非数字、无关键实体的短文本
if not re.search(r"[0-9]", t) and not self._is_important_message(message):
# 主要是标点或简单确认词
if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers:
return True
return False
async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse:
"""对话级一次性抽取:从整段对话中提取重要信息并判定相关性。
- 仅使用 LLM 结构化输出;
"""
# 缓存命中则直接返回(场景+内容作为键)
cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest()
if cache_key in self._dialog_extract_cache:
return self._dialog_extract_cache[cache_key]
rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text)
log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene})
log_prompt_rendering("pruning-extract", rendered)
# 强制使用 LLM移除正则回退
if not self.llm_client:
raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。")
messages = [
{"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"},
{"role": "user", "content": rendered},
]
try:
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
self._dialog_extract_cache[cache_key] = ex
return ex
except Exception as e:
raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
"""判断消息是否包含任意抽取到的重要片段。"""
if not tokens:
return False
t = message.msg
return any(tok and (tok in t) for tok in tokens)
async def prune_dialog(self, dialog: DialogData) -> DialogData:
"""单对话剪枝:使用一次性对话抽取,避免逐条消息 LLM 调用。
流程:
- 对整段对话进行抽取与相关性判定;若相关则不剪;
- 若不相关:用抽取到的重要片段 + 简单启发识别重要消息,按比例删除不相关消息,优先删除不重要,再删除重要(但重要最多按比例)。
- 删除策略:不重要消息按出现顺序删除(确定性、无随机)。
"""
if not self.config.pruning_switch:
return dialog
proportion = float(self.config.pruning_threshold)
extraction = await self._extract_dialog_important(dialog.content)
if extraction.is_related:
# 相关对话不剪枝
return dialog
# 在不相关对话中,识别重要/不重要消息
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
msgs = dialog.context.msgs
imp_unrel_msgs: List[ConversationMessage] = []
unimp_unrel_msgs: List[ConversationMessage] = []
for m in msgs:
if self._msg_matches_tokens(m, tokens) or self._is_important_message(m):
imp_unrel_msgs.append(m)
else:
unimp_unrel_msgs.append(m)
# 计算总删除目标数量
total_unrel = len(msgs)
delete_target = int(total_unrel * proportion)
if proportion > 0 and total_unrel > 0 and delete_target == 0:
delete_target = 1
imp_del_cap = min(int(len(imp_unrel_msgs) * proportion), len(imp_unrel_msgs))
unimp_del_cap = len(unimp_unrel_msgs)
max_capacity = max(0, len(msgs) - 1)
max_deletable = min(imp_del_cap + unimp_del_cap, max_capacity)
delete_target = min(delete_target, max_deletable)
# 删除配额分配
del_unimp = min(delete_target, unimp_del_cap)
rem = delete_target - del_unimp
del_imp = min(rem, imp_del_cap)
# 选取删除集合
unimp_delete_ids = []
imp_delete_ids = []
if del_unimp > 0:
# 按出现顺序选取前 del_unimp 条不重要消息进行删除(确定性、可复现)
unimp_delete_ids = [id(m) for m in unimp_unrel_msgs[:del_unimp]]
if del_imp > 0:
imp_sorted = sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))
imp_delete_ids = [id(m) for m in imp_sorted[:del_imp]]
# 统计实际删除数量(重要/不重要)
actual_unimp_deleted = 0
actual_imp_deleted = 0
kept_msgs = []
delete_targets = set(unimp_delete_ids) | set(imp_delete_ids)
for m in msgs:
mid = id(m)
if mid in delete_targets:
if mid in set(unimp_delete_ids) and actual_unimp_deleted < del_unimp:
actual_unimp_deleted += 1
continue
if mid in set(imp_delete_ids) and actual_imp_deleted < del_imp:
actual_imp_deleted += 1
continue
kept_msgs.append(m)
if not kept_msgs and msgs:
kept_msgs = [msgs[0]]
deleted_total = actual_unimp_deleted + actual_imp_deleted
self._log(
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} 删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
)
dialog.context = ConversationContext(msgs=kept_msgs)
return dialog
async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]:
"""数据集层面:全局消息级剪枝,保留所有对话。
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。
- 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。
- 保证每段对话至少保留1条消息不会删除整段对话。
"""
# 如果剪枝功能关闭,直接返回原始数据集。
if not self.config.pruning_switch:
return dialogs
# 阈值保护最高0.9
proportion = float(self.config.pruning_threshold)
if proportion > 0.9:
print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9已自动调整为0.9")
proportion = 0.9
if proportion < 0.0:
proportion = 0.0
evaluated_dialogs = [] # list of dicts: {dialog, is_related}
self._log(
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}"
)
# 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存)
evaluated_dialogs = []
for idx, dd in enumerate(dialogs):
try:
ex = await self._extract_dialog_important(dd.content)
evaluated_dialogs.append({
"dialog": dd,
"is_related": bool(ex.is_related),
"index": idx,
"extraction": ex
})
except Exception:
evaluated_dialogs.append({
"dialog": dd,
"is_related": True,
"index": idx,
"extraction": None
})
# 统计相关 / 不相关对话
not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]]
related_dialogs = [d for d in evaluated_dialogs if d["is_related"]]
self._log(
f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}"
)
# 简洁打印第几段对话相关/不相关索引基于1
def _fmt_indices(items, cap: int = 10):
inds = [i["index"] + 1 for i in items]
if len(inds) <= cap:
return inds
# 超过上限时只打印前cap个并标注总数
return inds[:cap] + ["...", f"{len(inds)}"]
rel_inds = _fmt_indices(related_dialogs)
nrel_inds = _fmt_indices(not_related_dialogs)
self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}")
result: List[DialogData] = []
if not_related_dialogs:
# 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM
per_dialog_info = {}
total_unrelated = 0
total_capacity = 0
for d in not_related_dialogs:
dd = d["dialog"]
extraction = d.get("extraction")
if extraction is None:
extraction = await self._extract_dialog_important(dd.content)
# 合并所有重要标记
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
msgs = dd.context.msgs
# 分类消息
imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)]
unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs]
# 重要消息按重要性排序
imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))]
info = {
"dialog": dd,
"total_msgs": len(msgs),
"unrelated_count": len(msgs),
"imp_ids_sorted": imp_sorted_ids,
"unimp_ids": [id(m) for m in unimp_unrel_msgs],
}
per_dialog_info[d["index"]] = info
total_unrelated += info["unrelated_count"]
# 全局删除配额:比例作用于全部不相关消息(重要+不重要)
global_delete = int(total_unrelated * proportion)
if proportion > 0 and total_unrelated > 0 and global_delete == 0:
global_delete = 1
# 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例)且至少保留1条消息
capacities = []
for d in not_related_dialogs:
idx = d["index"]
info = per_dialog_info[idx]
# 统计重要数量
imp_count = len(info["imp_ids_sorted"])
unimp_count = len(info["unimp_ids"])
imp_cap = int(imp_count * proportion)
cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1))
capacities.append(cap)
total_capacity = sum(capacities)
if global_delete > total_capacity:
print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。")
global_delete = total_capacity
# 配额分配:按不相关消息占比分配到各对话,但不超过各自容量
alloc = []
for i, d in enumerate(not_related_dialogs):
idx = d["index"]
info = per_dialog_info[idx]
share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0
alloc.append(min(share, capacities[i]))
allocated = sum(alloc)
rem = global_delete - allocated
turn = 0
while rem > 0 and turn < 100000:
progressed = False
for i in range(len(not_related_dialogs)):
if rem <= 0:
break
if alloc[i] < capacities[i]:
alloc[i] += 1
rem -= 1
progressed = True
if not progressed:
break
turn += 1
# 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先)
total_deleted_confirm = 0
for d in evaluated_dialogs:
dd = d["dialog"]
msgs = dd.context.msgs
original = len(msgs)
if d["is_related"]:
result.append(dd)
continue
idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None)
if idx_in_unrel is None:
result.append(dd)
continue
quota = alloc[idx_in_unrel]
info = per_dialog_info[d["index"]]
# 计算本对话重要最多可删数量
imp_count = len(info["imp_ids_sorted"])
imp_del_cap = int(imp_count * proportion)
# 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条)
unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))])
del_unimp = min(quota, len(unimp_delete_ids))
rem_quota = quota - del_unimp
# 再从重要里选低分优先的删除ID不超过 imp_del_cap
imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)])
deleted_here = 0
actual_unimp_deleted = 0
actual_imp_deleted = 0
kept = []
for m in msgs:
mid = id(m)
if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp:
actual_unimp_deleted += 1
deleted_here += 1
continue
if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids):
actual_imp_deleted += 1
deleted_here += 1
continue
kept.append(m)
if not kept and msgs:
kept = [msgs[0]]
dd.context.msgs = kept
total_deleted_confirm += deleted_here
self._log(
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}"
)
result.append(dd)
self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。")
else:
# 全部相关:不执行剪枝
result = [d["dialog"] for d in evaluated_dialogs]
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
# 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成)
try:
from app.core.config import settings
settings.ensure_memory_output_dir()
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
# 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
payload = self._parse_logs_to_structured(sanitized_logs)
with open(log_output_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
except Exception as e:
self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}")
# Safety: avoid empty dataset
if not result:
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
return dialogs
return result
def _log(self, msg: str) -> None:
"""记录日志并打印到终端。"""
try:
self.run_logs.append(msg)
except Exception:
# 任何异常都不影响打印
pass
print(msg)
def _sanitize_log_line(self, line: str) -> str:
"""移除行首的方括号标签前缀,例如 [剪枝-数据集] 或 [剪枝-对话]。"""
try:
return re.sub(r"^\[[^\]]+\]\s*", "", line)
except Exception:
return line
def _parse_logs_to_structured(self, logs: List[str]) -> dict:
"""将已去前缀的日志列表解析为结构化 JSON便于数据对接。"""
summary = {
"scene": self.config.pruning_scene,
"dialog_total": None,
"deletion_ratio": None,
"enabled": None,
"related_count": None,
"unrelated_count": None,
"related_indices": [],
"unrelated_indices": [],
"total_deleted_messages": None,
"remaining_dialogs": None,
}
dialogs = []
# 解析函数
def parse_int(value: str) -> Optional[int]:
try:
return int(value)
except Exception:
return None
def parse_float(value: str) -> Optional[float]:
try:
return float(value)
except Exception:
return None
def parse_indices(s: str) -> List[int]:
s = s.strip()
if not s:
return []
parts = [p.strip() for p in s.split(",") if p.strip()]
out: List[int] = []
for p in parts:
try:
out.append(int(p))
except Exception:
pass
return out
# 正则
re_header = re.compile(r"对话总数=(\d+)\s+场景=([^\s]+)\s+删除比例=([0-9.]+)\s+开关=(True|False)")
re_counts = re.compile(r"相关对话数=(\d+)\s+不相关对话数=(\d+)")
re_indices = re.compile(r"相关对话:第\[(.*?)\]段;不相关对话:第\[(.*?)\]段")
re_dialog = re.compile(r"对话\s+(\d+)\s+总消息=(\d+)\s+分配删除=(\d+)\s+实删=(\d+)\s+保留=(\d+)")
re_total_del = re.compile(r"总删除\s+(\d+)\s+条")
re_remaining = re.compile(r"剩余对话数=(\d+)")
for line in logs:
# 第一行:总览
m = re_header.search(line)
if m:
summary["dialog_total"] = parse_int(m.group(1))
# 顶层 scene 依配置,这里不覆盖,但也可校验 m.group(2)
summary["deletion_ratio"] = parse_float(m.group(3))
summary["enabled"] = True if m.group(4) == "True" else False
continue
# 第二行:相关/不相关数量
m = re_counts.search(line)
if m:
summary["related_count"] = parse_int(m.group(1))
summary["unrelated_count"] = parse_int(m.group(2))
continue
# 第三行:相关/不相关索引
m = re_indices.search(line)
if m:
summary["related_indices"] = parse_indices(m.group(1))
summary["unrelated_indices"] = parse_indices(m.group(2))
continue
# 对话级统计
m = re_dialog.search(line)
if m:
dialogs.append({
"index": parse_int(m.group(1)),
"total_messages": parse_int(m.group(2)),
"quota_delete": parse_int(m.group(3)),
"actual_deleted": parse_int(m.group(4)),
"kept": parse_int(m.group(5)),
})
continue
# 全局删除总数
m = re_total_del.search(line)
if m:
summary["total_deleted_messages"] = parse_int(m.group(1))
continue
# 剩余对话数
m = re_remaining.search(line)
if m:
summary["remaining_dialogs"] = parse_int(m.group(1))
continue
return {
"scene": summary["scene"],
"timestamp": datetime.now().isoformat(),
"summary": {k: v for k, v in summary.items() if k != "scene"},
"dialogs": dialogs,
}

View File

@@ -0,0 +1,41 @@
"""
去重消歧模块
提供实体去重和消歧功能,包括:
- 基础去重和消歧(精确匹配、模糊匹配)
- LLM 实体去重
- 第二层去重(与 Neo4j 数据库联合去重)
- 两阶段去重(完整的去重流程)
"""
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
deduplicate_entities_and_edges,
accurate_match,
fuzzy_match,
LLM_decision,
LLM_disamb_decision,
)
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import (
llm_dedup_entities,
llm_dedup_entities_iterative_blocks,
llm_disambiguate_pairs_iterative,
)
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
second_layer_dedup_and_merge_with_neo4j,
)
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
dedup_layers_and_merge_and_return,
)
__all__ = [
"deduplicate_entities_and_edges",
"accurate_match",
"fuzzy_match",
"LLM_decision",
"LLM_disamb_decision",
"llm_dedup_entities",
"llm_dedup_entities_iterative_blocks",
"llm_disambiguate_pairs_iterative",
"second_layer_dedup_and_merge_with_neo4j",
"dedup_layers_and_merge_and_return",
]

View File

@@ -0,0 +1,784 @@
"""
去重功能函数
"""
from app.core.memory.models.variate_config import DedupConfig
from typing import List, Dict, Tuple
from app.core.memory.models.graph_models import(
StatementEntityEdge,
EntityEntityEdge,
ExtractedEntityNode
)
import os
from datetime import datetime
import difflib # 提供字符串相似度计算工具
import asyncio
import importlib
import re
# 模块级属性融合工具函数(统一行为)
def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
# 强弱连接合并
can_strength = (getattr(canonical, "connect_strength", "") or "").lower()
inc_strength = (getattr(ent, "connect_strength", "") or "").lower()
pair = {can_strength, inc_strength} - {""}
if pair:
if "both" in pair or pair == {"strong", "weak"}:
canonical.connect_strength = "both"
elif pair == {"strong"}:
canonical.connect_strength = "strong"
elif pair == {"weak"}:
canonical.connect_strength = "weak"
else:
canonical.connect_strength = next(iter(pair))
# 别名合并(去重保序)
try:
existing = getattr(canonical, "aliases", []) or []
incoming = getattr(ent, "aliases", []) or []
seen = set()
merged_list: List[str] = []
for x in existing + incoming:
xn = (x or "").strip()
if xn and xn not in seen:
seen.add(xn)
merged_list.append(x)
canonical.aliases = merged_list
except Exception:
pass
# 描述与事实摘要(保留更长者)
try:
desc_a = getattr(canonical, "description", "") or ""
desc_b = getattr(ent, "description", "") or ""
if len(desc_b) > len(desc_a):
canonical.description = desc_b
# 合并事实摘要:统一保留一个“实体: name”行来源行去重保序
fact_a = getattr(canonical, "fact_summary", "") or ""
fact_b = getattr(ent, "fact_summary", "") or ""
def _extract_sources(txt: str) -> List[str]:
sources: List[str] = []
if not txt:
return sources
for line in str(txt).splitlines():
ln = line.strip()
# 支持“来源:”或“来源:”前缀
m = re.match(r"^来源[:]\s*(.+)$", ln)
if m:
content = m.group(1).strip()
if content:
sources.append(content)
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
if not sources and txt.strip():
sources.append(txt.strip())
return sources
try:
src_a = _extract_sources(fact_a)
src_b = _extract_sources(fact_b)
seen = set()
merged_sources: List[str] = []
for s in src_a + src_b:
if s and s not in seen:
seen.add(s)
merged_sources.append(s)
if merged_sources:
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
elif fact_b and not fact_a:
canonical.fact_summary = fact_b
except Exception:
# 兜底:若解析失败,保留较长文本
if len(fact_b) > len(fact_a):
canonical.fact_summary = fact_b
except Exception:
pass
# 名称向量补全
try:
emb_a = getattr(canonical, "name_embedding", []) or []
emb_b = getattr(ent, "name_embedding", []) or []
if not emb_a and emb_b:
canonical.name_embedding = emb_b
except Exception:
pass
# 时间范围合并
try:
# 统一使用 created_at / expired_at
if getattr(ent, "created_at", None) and getattr(canonical, "created_at", None) and ent.created_at < canonical.created_at:
canonical.created_at = ent.created_at
if getattr(ent, "expired_at", None) and getattr(canonical, "expired_at", None):
if canonical.expired_at is None:
canonical.expired_at = ent.expired_at
elif ent.expired_at and ent.expired_at > canonical.expired_at:
canonical.expired_at = ent.expired_at
except Exception:
pass
def accurate_match(
entity_nodes: List[ExtractedEntityNode]
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
"""
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。
返回: (deduped_entities, id_redirect, exact_merge_map)
"""
exact_merge_map: Dict[str, Dict] = {}
canonical_map: Dict[str, ExtractedEntityNode] = {}
id_redirect: Dict[str, str] = {}
# 1) 构建规范实体映射(按名称+类型+group 精确匹配)
for ent in entity_nodes:
name_norm = (getattr(ent, "name", "") or "").strip()
type_norm = (getattr(ent, "entity_type", "") or "").strip()
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}"
# 为避免跨业务组误并,明确以 group_id 为范围边界
if key not in canonical_map:
canonical_map[key] = ent
id_redirect[getattr(ent, "id")] = getattr(ent, "id")
continue
canonical = canonical_map[key]
# 执行精确属性与强弱合并,并建立重定向
_merge_attribute(canonical, ent)
id_redirect[getattr(ent, "id")] = getattr(canonical, "id")
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
try:
k = f"{getattr(canonical, 'group_id')}|{(getattr(canonical, 'name') or '').strip()}|{(getattr(canonical, 'entity_type') or '').strip()}"
if k not in exact_merge_map:
exact_merge_map[k] = {
"canonical_id": getattr(canonical, "id"),
"group_id": getattr(canonical, "group_id"),
"name": getattr(canonical, "name"),
"entity_type": getattr(canonical, "entity_type"),
"merged_ids": set(),
}
exact_merge_map[k]["merged_ids"].add(getattr(ent, "id"))
except Exception:
pass
deduped_entities = list(canonical_map.values())
return deduped_entities, id_redirect, exact_merge_map
def fuzzy_match(
deduped_entities: List[ExtractedEntityNode],
statement_entity_edges: List[StatementEntityEdge],
id_redirect: Dict[str, str],
config: DedupConfig | None = None,
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]:
"""
模糊匹配:在精确匹配之后,基于名称/类型相似度与上下文共现,进一步融合高相似实体。
返回: (updated_entities, updated_redirect, fuzzy_merge_records)
"""
fuzzy_merge_records: List[str] = []
def _normalize_text(s: str) -> str:
try:
return re.sub(r"\s+", " ", re.sub(r"[^\w\u4e00-\u9fff]+", " ", (s or "").lower())).strip()
except Exception:
return str(s).lower().strip()
def _tokenize(s: str) -> List[str]:
norm = _normalize_text(s)
tokens = re.findall(r"[\u4e00-\u9fff]+|[a-z0-9]+", norm)
return tokens
def _jaccard(a_tokens: List[str], b_tokens: List[str]) -> float:
try:
set_a, set_b = set(a_tokens), set(b_tokens)
if not set_a and not set_b:
return 0.0
inter = len(set_a & set_b)
union = len(set_a | set_b)
return inter / union if union > 0 else 0.0
except Exception:
return 0.0
def _cosine(a: List[float], b: List[float]) -> float:
try:
if not a or not b or len(a) != len(b):
return 0.0
dot = sum(x * y for x, y in zip(a, b))
na = sum(x * x for x in a) ** 0.5
nb = sum(y * y for y in b) ** 0.5
if na == 0 or nb == 0:
return 0.0
return dot / (na * nb)
except Exception:
return 0.0
def _name_similarity(e1: ExtractedEntityNode, e2: ExtractedEntityNode):
emb_sim = _cosine(getattr(e1, "name_embedding", []) or [], getattr(e2, "name_embedding", []) or [])
tokens1 = set(_tokenize(getattr(e1, "name", "") or ""))
tokens2 = set(_tokenize(getattr(e2, "name", "") or ""))
aliases1 = getattr(e1, "aliases", []) or []
aliases2 = getattr(e2, "aliases", []) or []
alias_tokens1 = set(tokens1)
alias_tokens2 = set(tokens2)
for a in aliases1:
alias_tokens1 |= set(_tokenize(a))
for a in aliases2:
alias_tokens2 |= set(_tokenize(a))
j_primary = _jaccard(list(tokens1), list(tokens2))
j_alias = _jaccard(list(alias_tokens1), list(alias_tokens2))
s_name = 0.6 * emb_sim + 0.2 * j_primary + 0.2 * j_alias
return s_name, emb_sim, j_primary, j_alias
def _desc_similarity(e1: ExtractedEntityNode, e2: ExtractedEntityNode):
"""
计算实体描述的相似度Jaccard + SequenceMatcher
返回: (相似度得分, Jaccard 相似度(词重合), SequenceMatcher 相似度(序列相似))
"""
d1 = getattr(e1, "description", "") or ""
d2 = getattr(e2, "description", "") or ""
if not d1 and not d2:
return 0.0, 0.0, 0.0
t1 = _tokenize(d1)
t2 = _tokenize(d2)
j = _jaccard(t1, t2)
try:
seq = difflib.SequenceMatcher(None, _normalize_text(d1), _normalize_text(d2)).ratio()
except Exception:
seq = 0.0
# 平衡词重合与序列相似(更鲁棒)
s_desc = 0.5 * j + 0.5 * seq
return s_desc, j, seq
def _canonicalize_type(t: str) -> str: # 扩展类型同义归一
t = (t or "").strip()
if not t:
return ""
t_up = t.upper()
TYPE_ALIASES = {
"PERSON": {"人物", "", "个人", "人名", "PERSON", "PEOPLE", "INDIVIDUAL"},
"ORG": {"组织", "ORG"},
"COMPANY": {"公司", "企业", "COMPANY"},
"INSTITUTION": {"机构", "INSTITUTION"},
"LOCATION": {"地点", "位置", "LOCATION"},
"CITY": {"城市", "CITY"},
"COUNTRY": {"国家", "COUNTRY"},
"EVENT": {"事件", "EVENT"},
# 扩展活动与技能近义,统一到 ACTIVITY便于本地模糊匹配
"ACTIVITY": {"活动", "技术活动", "技能", "ACTIVITY", "SKILL"},
"PRODUCT": {"产品", "商品", "物品", "OBJECT", "PRODUCT"},
"TOOL": {"工具", "TOOL"},
"SOFTWARE": {"软件", "SOFTWARE"},
"FOOD": {"食品", "食物", "FOOD"},
"INGREDIENT": {"食材", "配料", "原料", "INGREDIENT"},
"SWEETMEATS": {"甜点", "甜品", "甜食", "SWEETMEATS"},
# 统一本地与 LLM 阶段:将 EQUIPMENT/装备 映射为 APPLIANCE
"APPLIANCE": {"设备", "器材", "摄影器材", "摄影设备", "电器", "烤箱", "装备","镜头", "EQUIPMENT", "APPLIANCE"},
"ART": {"艺术", "艺术形式", "ART"},
"FLOWER": {"花卉", "鲜花", "FLOWER"},
"PLANT": {"植物", "PLANT"},
"AGENT": {"AI助手", "助手", "人工智能助手", "智能助手", "智能体", "Agent", "AGENTA"},
"ROLE": {"角色", "ROLE"},
"SCENE_ELEMENT": {"场景元素", "SCENE_ELEMENT"},
"UNKNOWN": {"UNKNOWN", "未知", "不明"},
}
for canon, aliases in TYPE_ALIASES.items():
if t_up in {a.upper() for a in aliases}:
return canon
return t_up
def _type_similarity(t1: str, t2: str) -> float:
import difflib
c1 = _canonicalize_type(t1)
c2 = _canonicalize_type(t2)
if not c1 or not c2:
return 0.0
if c1 == c2:
return 0.5 if c1 == "UNKNOWN" else 1.0
if c1 == "UNKNOWN" or c2 == "UNKNOWN":
return 0.5
sim_table = {
("ORG", "COMPANY"): 0.9, ("COMPANY", "ORG"): 0.9,
("ORG", "INSTITUTION"): 0.85, ("INSTITUTION", "ORG"): 0.85,
("LOCATION", "CITY"): 0.9, ("CITY", "LOCATION"): 0.9,
("LOCATION", "COUNTRY"): 0.9, ("COUNTRY", "LOCATION"): 0.9,
("EVENT", "ACTIVITY"): 0.8, ("ACTIVITY", "EVENT"): 0.8,
("PRODUCT", "TOOL"): 0.8, ("TOOL", "PRODUCT"): 0.8,
("PRODUCT", "SOFTWARE"): 0.8, ("SOFTWARE", "PRODUCT"): 0.8,
("FOOD", "SWEETMEATS"): 0.8, ("SWEETMEATS", "FOOD"): 0.8,
("INGREDIENT", "FOOD"): 0.85, ("FOOD", "INGREDIENT"): 0.85,
("APPLIANCE", "TOOL"): 0.8, ("TOOL", "APPLIANCE"): 0.8,
("APPLIANCE", "PRODUCT"): 0.7, ("PRODUCT", "APPLIANCE"): 0.7,
("FLOWER", "PLANT"): 0.9, ("PLANT", "FLOWER"): 0.9,
("AGENT", "SOFTWARE"): 0.85, ("SOFTWARE", "AGENT"): 0.85,
("AGENT", "PRODUCT"): 0.7, ("PRODUCT", "AGENT"): 0.7,
("AGENT", "ROLE"): 0.9, ("ROLE", "AGENT"): 0.9,
("SCENE_ELEMENT", "PRODUCT"): 0.6, ("PRODUCT", "SCENE_ELEMENT"): 0.6,
}
base = sim_table.get((c1, c2), 0.0)
if base:
return base
t1n = (t1 or "").strip().lower()
t2n = (t2 or "").strip().lower()
seq_ratio = difflib.SequenceMatcher(None, t1n, t2n).ratio()
return seq_ratio * 0.6
# 阈值与权重设定(从配置读取;若无配置则使用 DedupConfig 的默认值)
_defaults = DedupConfig()
T_NAME_STRICT = (config.fuzzy_name_threshold_strict if config is not None else _defaults.fuzzy_name_threshold_strict)
T_TYPE_STRICT = (config.fuzzy_type_threshold_strict if config is not None else _defaults.fuzzy_type_threshold_strict)
T_OVERALL = (config.fuzzy_overall_threshold if config is not None else _defaults.fuzzy_overall_threshold)
UNKNOWN_NAME_T = (config.fuzzy_unknown_type_name_threshold if config is not None else _defaults.fuzzy_unknown_type_name_threshold)
UNKNOWN_TYPE_T = (config.fuzzy_unknown_type_type_threshold if config is not None else _defaults.fuzzy_unknown_type_type_threshold)
W_NAME = (config.name_weight if config is not None else _defaults.name_weight)
W_DESC = (config.desc_weight if config is not None else _defaults.desc_weight)
W_TYPE = (config.type_weight if config is not None else _defaults.type_weight)
CTX_BONUS = (config.context_bonus if config is not None else _defaults.context_bonus) # 上下文共现加分
FALL_FLOOR = (config.llm_fallback_floor if config is not None else _defaults.llm_fallback_floor)
FALL_CEIL = (config.llm_fallback_ceiling if config is not None else _defaults.llm_fallback_ceiling)
i = 0
while i < len(deduped_entities):
a = deduped_entities[i]
j = i + 1
while j < len(deduped_entities):
b = deduped_entities[j]
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
j += 1
continue
# 上下文共现
try:
sources_a = {e.source for e in statement_entity_edges if getattr(e, "target", None) == getattr(a, "id", None)}
sources_b = {e.source for e in statement_entity_edges if getattr(e, "target", None) == getattr(b, "id", None)}
co_ctx = bool(sources_a & sources_b)
except Exception:
co_ctx = False
s_name, emb_sim, j_primary, j_alias = _name_similarity(a, b)
s_desc, j_desc, seq_desc = _desc_similarity(a, b)
s_type = _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None))
unknown_present = (
str(getattr(a, "entity_type", "")).upper() == "UNKNOWN"
or str(getattr(b, "entity_type", "")).upper() == "UNKNOWN"
)
tn = UNKNOWN_NAME_T if unknown_present else T_NAME_STRICT
tn = min(tn, 0.88) if co_ctx else tn
type_threshold = UNKNOWN_TYPE_T if unknown_present else T_TYPE_STRICT
tover = T_OVERALL
a_cs = (getattr(a, "connect_strength", "") or "").lower()
b_cs = (getattr(b, "connect_strength", "") or "").lower()
if a_cs in ("strong", "both") or b_cs in ("strong", "both"):
tover = 0.80
# 综合评分:名称、描述、类型加权 + 上下文加分
overall = W_NAME * s_name + W_DESC * s_desc + W_TYPE * s_type + (CTX_BONUS if co_ctx else 0.0)
if s_name >= tn and s_type >= type_threshold and overall >= tover:
_merge_attribute(a, b)
try:
fuzzy_merge_records.append(
f"[模糊] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | s_name={s_name:.3f}, s_desc={s_desc:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, ctx={co_ctx}"
)
except Exception:
pass
# 用于处理合并实体后Statement节点下方无挂载边的情况 后续考虑将其代码逻辑统一由关系去重消歧管理
# 建立 ID 重定向:将合并实体 b 的 ID 指向规范实体 a 的 ID
try:
canonical_id = id_redirect.get(getattr(a, "id", None), getattr(a, "id", None))
losing_id = getattr(b, "id", None)
if losing_id and canonical_id:
id_redirect[losing_id] = canonical_id
# 扁平化可能的重定向链:凡是映射到 b.id 的,统一指向 a.id
for k, v in list(id_redirect.items()):
if v == losing_id:
id_redirect[k] = canonical_id
except Exception:
pass
deduped_entities.pop(j)
continue
else:
try:
if s_name >= tn and s_type >= type_threshold and (FALL_FLOOR <= overall < tover) and (overall <= FALL_CEIL):
fuzzy_merge_records.append(
f"[边界] {a.id}<->{b.id} ({a.group_id}|{a.name}|{a.entity_type} ~ {b.group_id}|{b.name}|{b.entity_type}) | s_name={s_name:.3f}, s_desc={s_desc:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, ctx={co_ctx}"
)
except Exception:
pass
j += 1
i += 1
return deduped_entities, id_redirect, fuzzy_merge_records
async def LLM_decision( # 决策中包含去重和消歧的功能
deduped_entities: List[ExtractedEntityNode],
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
id_redirect: Dict[str, str],
config: DedupConfig | None = None,
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]:
"""
基于迭代分块并发的 LLM 判定,生成实体重定向并在本地应用融合。
返回 (updated_entities, updated_redirect, llm_records)。
- 仅在配置 enable_llm_dedup_blockwise 为 True 时启用;
若未提供配置,则使用 DedupConfig 的默认值作为回退。
- 内部调用 llm_dedup_entities_iterative_blocks 获取 pairwise 的重定向映射。
- 将映射应用到 deduped_entities 与 id_redirect并记录融合日志。
"""
llm_records: List[str] = []
try:
# 优先使用运行时配置;若未提供配置,使用模型默认值,不再回退到环境变量
enable_switch = (
bool(config.enable_llm_dedup_blockwise) if config is not None else DedupConfig().enable_llm_dedup_blockwise
)
if not enable_switch:
return deduped_entities, id_redirect, llm_records
# 从配置读取 LLM 迭代参数;若无配置则使用 DedupConfig 的默认值
_defaults = DedupConfig()
block_size = (config.llm_block_size if config is not None else _defaults.llm_block_size)
block_concurrency = (config.llm_block_concurrency if config is not None else _defaults.llm_block_concurrency)
pair_concurrency = (config.llm_pair_concurrency if config is not None else _defaults.llm_pair_concurrency)
max_rounds = (config.llm_max_rounds if config is not None else _defaults.llm_max_rounds)
# 动态导入 llm 客户端(统一从 app.core.memory.utils.llm_utils 获取)
try:
llm_utils_mod = importlib.import_module("app.core.memory.utils.llm_utils")
get_llm_client_fn = getattr(llm_utils_mod, "get_llm_client")
except Exception:
get_llm_client_fn = lambda: None
try:
llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm")
llm_fn = getattr(llm_mod, "llm_dedup_entities_iterative_blocks")
except Exception:
raise RuntimeError("LLM 模块加载失败deduplication.entity_dedup_llm 缺少 llm_dedup_entities_iterative_blocks")
# 获取 LLM 客户端,若环境未配置或抛错则回退为 None
try:
llm_client = get_llm_client_fn()
except Exception:
llm_client = None
llm_redirect, llm_records = await llm_fn(
entity_nodes=deduped_entities,
statement_entity_edges=statement_entity_edges,
entity_entity_edges=entity_entity_edges,
llm_client=llm_client,
block_size=block_size,
block_concurrency=block_concurrency,
pair_concurrency=pair_concurrency,
max_rounds=max_rounds,
)
except Exception as e:
# 记录错误,不中断主流程
llm_records.append(f"[LLM错误] 迭代分块执行失败: {e}")
return deduped_entities, id_redirect, llm_records
# 若存在 LLM 的重定向,应用到实体与映射
# 确保实体集合与 id_redirect 完整反映 LLM 的合并结果;否则后续边重定向不会指向规范 ID实体仍然重复
if llm_redirect:
entity_by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in deduped_entities}
for losing_id, canonical_id in list(llm_redirect.items()):
if losing_id == canonical_id:
continue
a = entity_by_id.get(canonical_id)
b = entity_by_id.get(losing_id)
if not a or not b: # 若不存在 a 或 b可能已在精确或模糊阶段合并在之前阶段合并之后不会再处理但是处于审计的目的会记录
continue
_merge_attribute(a, b)
# ID 重定向
try:
id_redirect[b.id] = a.id
for k, v in list(id_redirect.items()):
if v == b.id:
id_redirect[k] = a.id
except Exception:
pass
# 记录 LLM 融合日志
try:
llm_records.append(
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
)
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
except Exception:
pass
# 移除 losing 实体
try:
if b in deduped_entities:
deduped_entities.remove(b)
entity_by_id.pop(b.id, None)
except Exception:
pass
return deduped_entities, id_redirect, llm_records
async def LLM_disamb_decision(
deduped_entities: List[ExtractedEntityNode],
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
id_redirect: Dict[str, str],
config: DedupConfig | None = None,
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], set[tuple[str, str]], List[str]]:
"""
预消歧阶段对“同名但类型不同”的实体对调用LLM进行消歧
产出:需阻断的实体对(blocked_pairs)与必要的合并(merge_redirect)。
返回 (updated_entities, updated_redirect, blocked_pairs, disamb_records)。
- 仅在配置开关 enable_llm_disambiguation 为 True 时启用;否则返回空阻断列表。
"""
disamb_records: List[str] = []
blocked_pairs: set[tuple[str, str]] = set()
try:
enable_switch = (
config.enable_llm_disambiguation
if config is not None
else DedupConfig().enable_llm_disambiguation
)
if not bool(enable_switch):
return deduped_entities, id_redirect, blocked_pairs, disamb_records
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import llm_disambiguate_pairs_iterative
from app.core.memory.utils.config import definitions as config_defs
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative(
entity_nodes=deduped_entities,
statement_entity_edges=statement_entity_edges,
entity_entity_edges=entity_entity_edges,
llm_client=llm_client,
)
# 应用LLM消歧的合并建议
if merge_redirect:
entity_by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in deduped_entities}
for losing_id, canonical_id in list(merge_redirect.items()):
if losing_id == canonical_id:
continue
a = entity_by_id.get(canonical_id)
b = entity_by_id.get(losing_id)
if not a or not b:
continue
_merge_attribute(a, b)
id_redirect[b.id] = a.id
for k, v in list(id_redirect.items()):
if v == b.id:
id_redirect[k] = a.id
try:
disamb_records.append(
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
)
except Exception:
pass
try:
if b in deduped_entities:
deduped_entities.remove(b)
entity_by_id.pop(b.id, None)
except Exception:
pass
# 保存阻断对
try:
blocked_pairs = {tuple(sorted(p)) for p in (block_list or [])}
except Exception:
blocked_pairs = set()
except Exception as e:
disamb_records.append(f"[DISAMB错误] 消歧执行失败: {e}")
return deduped_entities, id_redirect, blocked_pairs, disamb_records
return deduped_entities, id_redirect, blocked_pairs, disamb_records
async def deduplicate_entities_and_edges(
entity_nodes: List[ExtractedEntityNode],
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
report_stage: str = "第一层去重消歧",
report_append: bool = False,
report_stage_notes: List[str] | None = None,
dedup_config: DedupConfig | None = None,
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
"""
主流程依次执行精确匹配、模糊匹配与可选LLM 决策融合,随后对边做重定向与去重。之后再处理边,是关系去重和消歧
返回:去重后的实体、语句→实体边、实体↔实体边。
"""
local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化保留为了之后对于LLM决策追溯
# 1) 精确匹配
deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes)
# 1.5) LLM 决策消歧:阻断同名不同类型的高相似对,并应用必要的合并
deduped_entities, id_redirect, blocked_pairs, disamb_records = await LLM_disamb_decision(
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
)
# 2) 模糊匹配(本地规则)
deduped_entities, id_redirect, fuzzy_merge_records = fuzzy_match(
deduped_entities, statement_entity_edges, id_redirect, config=dedup_config
)
# 3) LLM 决策(仅按配置开关)
try:
enable_switch = (
dedup_config.enable_llm_dedup_blockwise
if dedup_config is not None
else DedupConfig().enable_llm_dedup_blockwise
)
should_trigger_llm = bool(enable_switch)
# 将触发信息写入阶段备注,便于输出报告审计
if report_stage_notes is None:
report_stage_notes = []
report_stage_notes.append(f"LLM触发: {'' if should_trigger_llm else ''}")
except Exception:
should_trigger_llm = False
if should_trigger_llm:
deduped_entities, id_redirect, llm_decision_records = await LLM_decision(
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
)
else:
llm_decision_records = []
# 累加 LLM 记录 把 LLM_decision 返回的日志 llm_decision_records 追加到 local_llm_records
try:
local_llm_records.extend(llm_decision_records or [])
except Exception:
pass
# 在主流程这里 这里是之后关系去重和消歧的地方,方法可以写在其他地方
# 此处统一对边进行处理,使用累积的 id_redirect 把边的 source/target 改成规范ID
# 4) 边重定向与去重
# 4.1 语句→实体边:重复时优先保留 strong
stmt_ent_map: Dict[str, StatementEntityEdge] = {}
for edge in statement_entity_edges:
new_target = id_redirect.get(edge.target, edge.target)
edge.target = new_target
key = f"{edge.source}_{edge.target}"
if key not in stmt_ent_map:
stmt_ent_map[key] = edge
else:
existing = stmt_ent_map[key]
old_strength = getattr(existing, "connect_strength", "")
new_strength = getattr(edge, "connect_strength", "")
if old_strength != "strong" and new_strength == "strong":
stmt_ent_map[key] = edge
# 4.2 实体↔实体边:按 source_target 去重(无强弱属性)
ent_ent_map: Dict[str, EntityEntityEdge] = {}
for edge in entity_entity_edges:
new_source = id_redirect.get(edge.source, edge.source)
new_target = id_redirect.get(edge.target, edge.target)
edge.source = new_source
edge.target = new_target
key = f"{edge.source}_{edge.target}"
if key not in ent_ent_map:
ent_ent_map[key] = edge
_write_dedup_fusion_report(
exact_merge_map=exact_merge_map,
fuzzy_merge_records=fuzzy_merge_records,
local_llm_records=local_llm_records,
disamb_records=disamb_records,
stage_label=report_stage,
append=report_append,
stage_notes=report_stage_notes,
)
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values())
# 独立模块:去重融合报告写入(与实体/边的计算解耦)
def _write_dedup_fusion_report(
exact_merge_map: Dict[str, Dict],
fuzzy_merge_records: List[str],
local_llm_records: List[str],
disamb_records: List[str] | None = None,
stage_label: str | None = None,
append: bool = False,
stage_notes: List[str] | None = None,
):
try:
# 使用全局配置的输出路径
from app.core.config import settings
settings.ensure_memory_output_dir()
out_path = settings.get_memory_output_path("dedup_entity_output.txt")
report_lines: List[str] = []
if not append:
report_lines.append(f"去重融合报告 - {datetime.now().isoformat()}")
report_lines.append("")
if stage_label:
# 追加写入时,在阶段标题前增加一个空行以增强分隔
if append:
report_lines.append("")
report_lines.append(f"=== {stage_label} ===")
report_lines.append("")
# 阶段注释:在标题下追加,如候选数、是否跳过等
if stage_notes:
for note in stage_notes:
try:
report_lines.append(str(note))
except Exception:
pass
report_lines.append("")
# 精确
report_lines.append("精确匹配去重:")
aggregated_exact_lines: List[str] = []
try:
for k, info in (exact_merge_map or {}).items():
merged_ids = sorted(list(info.get("merged_ids", set())))
if merged_ids:
aggregated_exact_lines.append(
f"[精确] 键 {k} 规范实体 {info.get('canonical_id')} 名称 '{info.get('name')}' 类型 {info.get('entity_type')} <- 合并实体IDs {', '.join(merged_ids)}"
)
except Exception:
pass
report_lines.extend(aggregated_exact_lines if aggregated_exact_lines else ["无合并项"])
report_lines.append("")
# 消歧
report_lines.append("LLM 决策消歧:")
try:
# 仅展示阻断项,过滤掉合并与合并应用
disamb_block_only = [
line for line in (disamb_records or [])
if str(line).startswith("[DISAMB阻断]") or str(line).startswith("[DISAMB异常阻断]")
]
except Exception:
disamb_block_only = disamb_records or []
report_lines.extend(disamb_block_only if disamb_block_only else ["未执行或无阻断/合并项"])
report_lines.append("")
# 模糊
report_lines.append("模糊匹配去重:")
report_lines.extend(fuzzy_merge_records if fuzzy_merge_records else ["未执行或无合并项"])
report_lines.append("")
# LLM
report_lines.append("LLM 决策去重:")
try:
# 仅保留 LLM 的“去重判定”记录,排除“合并指令/融合落地”
def _is_llm_dedup_record(s: str) -> bool:
try:
text = str(s)
return "[LLM去重]" in text
except Exception:
return False
llm_dedup_only = [
line for line in (local_llm_records or [])
if _is_llm_dedup_record(str(line))
]
# 同名类型相似的 LLM 去重记录可能来源于消歧阶段,将其也纳入展示
try:
llm_dedup_only.extend([
line for line in (disamb_records or [])
if _is_llm_dedup_record(str(line))
])
except Exception:
pass
except Exception:
llm_dedup_only = []
# 输出前移除块前缀(如 "[LLM块0] "),并对重复记录去重(保序)
try:
import re as _re
def _strip_block_prefix(s: str) -> str:
try:
return _re.sub(r"^\[LLM块\d+\]\s*", "", str(s))
except Exception:
return str(s)
stripped = [ _strip_block_prefix(line) for line in (llm_dedup_only or []) ]
seen = set()
deduped_ordered = []
for line in stripped:
if line not in seen:
seen.add(line)
deduped_ordered.append(line)
llm_dedup_only = deduped_ordered
except Exception:
pass
report_lines.extend(llm_dedup_only if llm_dedup_only else ["未执行或无合并项"])
with open(out_path, ("a" if append else "w"), encoding="utf-8") as f:
f.write("\n".join(report_lines) + "\n")
except Exception:
# 静默失败,避免影响主流程
pass

View File

@@ -0,0 +1,689 @@
"""
用于实体去重基于LLM的决策
提供“LLM判定逻辑”的核心实现与并发控制。
"""
import asyncio
import difflib
from typing import List, Tuple, Dict
import anyio
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge
from app.core.memory.models.dedup_models import EntityDedupDecision, EntityDisambDecision
from app.core.memory.utils.prompt.prompt_utils import render_entity_dedup_prompt
# --- 类型同义归并与相似度 ---
_TYPE_ALIASES_UPPER: Dict[str, set[str]] = {
# 设备/器材类近义:统一到 EQUIPMENT
"EQUIPMENT": {s.upper() for s in {"设备", "器材", "摄影器材", "装备", "工具", "APPLIANCE", "TOOL"}},
# 活动/技能近义:统一到 ACTIVITY放宽“技术活动/技能”的同类判断
"ACTIVITY": {s.upper() for s in {"活动", "技术活动", "技能", "ACTIVITY", "SKILL"}},
# 常见类别,按需扩展
"PERSON": {s.upper() for s in {"人物", "", "个人", "人名", "PERSON"}},
"LOCATION": {s.upper() for s in {"地点", "位置", "LOCATION", "城市", "CITY", "国家", "COUNTRY"}},
"SOFTWARE": {s.upper() for s in {"软件", "SOFTWARE"}},
"EVENT": {s.upper() for s in {"事件", "EVENT"}},
}
def _canonicalize_type(t: str | None) -> str:
u = (str(t or "").strip().upper())
if not u or u == "UNKNOWN":
return "UNKNOWN"
for canon, aliases in _TYPE_ALIASES_UPPER.items():
if u in aliases:
return canon
return u # 未知类型直接返回自身(保守兼容)
def _type_similarity(t1: str | None, t2: str | None) -> float:
c1, c2 = _canonicalize_type(t1), _canonicalize_type(t2)
if c1 == c2:
return 1.0
if c1 == "UNKNOWN" or c2 == "UNKNOWN":
return 0.6 # 任一未知,给中等相似度,允许模型结合描述判断
return 0.0
def _simple_type_ok(t1: str | None, t2: str | None) -> bool:
"""类型门控:
- 允许同类(含近义归并后同类)或任一 UNKNOWN/空;
- 其余不同类不放行(例如 PERSON vs EQUIPMENT
"""
c1, c2 = _canonicalize_type(t1), _canonicalize_type(t2)
if c1 == "UNKNOWN" or c2 == "UNKNOWN":
return True
return c1 == c2
def _name_embed_sim(a: List[float] | None, b: List[float] | None) -> float: # 计算实体名称嵌入向量的余弦相似度
a = a or []
b = b or []
if not a or not b or len(a) != len(b):
return 0.0
try:
dot = sum(x * y for x, y in zip(a, b))
na = (sum(x * x for x in a)) ** 0.5
nb = (sum(y * y for y in b)) ** 0.5
if na > 0 and nb > 0:
return dot / (na * nb)
except Exception:
pass
return 0.0
def _name_text_sim(name1: str, name2: str) -> float: # 计算实体名称文本的字符串相似度
name1 = (name1 or "").strip().lower()
name2 = (name2 or "").strip().lower()
if not name1 or not name2:
return 0.0
return difflib.SequenceMatcher(None, name1, name2).ratio()
def _co_occurrence(statement_edges: List[StatementEntityEdge], a_id: str, b_id: str) -> bool: # 判断两个实体是否在同一陈述中 “同现”
try:
sources_a = {e.source for e in statement_edges if getattr(e, "target", None) == a_id}
sources_b = {e.source for e in statement_edges if getattr(e, "target", None) == b_id}
return bool(sources_a & sources_b)
except Exception:
return False
def _relation_statements(entity_edges: List[EntityEntityEdge], a_id: str, b_id: str) -> List[str]: # 提取两个实体间的所有关联语句
stmts: List[str] = []
for e in entity_edges:
if (getattr(e, "source", None) == a_id and getattr(e, "target", None) == b_id) or (
getattr(e, "source", None) == b_id and getattr(e, "target", None) == a_id
):
s_text = getattr(e, "statement", None) or ""
r_type = getattr(e, "relation_type", None) or ""
if s_text or r_type:
stmts.append(f"{r_type}: {s_text}".strip(': '))
return stmts
def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: # 选择 “规范实体”(合并时保留的实体)
# 0 for a, 1 for b
# 1. 第一优先级:按“连接强度”排序(连接强度越高,实体越可靠)
cs_a = (getattr(a, "connect_strength", "") or "").lower()
cs_b = (getattr(b, "connect_strength", "") or "").lower()
prio = {"strong": 3, "both": 3, "weak": 1, "": 0}
if prio.get(cs_a, 0) != prio.get(cs_b, 0):
return 0 if prio.get(cs_a, 0) > prio.get(cs_b, 0) else 1
# pick longer description/fact_summary
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
desc_a = (getattr(a, "description", "") or "")
desc_b = (getattr(b, "description", "") or "")
fact_a = (getattr(a, "fact_summary", "") or "")
fact_b = (getattr(b, "fact_summary", "") or "")
score_a = len(desc_a) + len(fact_a)
score_b = len(desc_b) + len(fact_b)
if score_a != score_b:
return 0 if score_a >= score_b else 1
return 0
# _judge_pair单对实体的 LLM 判断) 已经有分块迭代的函数内容是否还需要单对LLM判断--这是已经创建的工具服务于分块迭代的函数
async def _judge_pair(
llm_client: OpenAIClient,
a: ExtractedEntityNode,
b: ExtractedEntityNode,
statement_edges: List[StatementEntityEdge],
entity_edges: List[EntityEntityEdge],
) -> Tuple[EntityDedupDecision, Dict]:
# 1. 计算实体名称的核心相似度指标
name_text_sim = _name_text_sim(getattr(a, "name", ""), getattr(b, "name", ""))
name_embed_sim = _name_embed_sim(getattr(a, "name_embedding", []), getattr(b, "name_embedding", []))
# 2. 判断名称是否存在“包含关系”(如“苹果公司”包含“苹果”)
name_contains = False
try:
n1 = (getattr(a, "name", "") or "").strip().lower()
n2 = (getattr(b, "name", "") or "").strip().lower()
name_contains = bool(n1 and n2 and (n1 in n2 or n2 in n1))
except Exception:
pass
# 3. 构建LLM判断的“上下文信息”规则层计算的所有特征 判断上下文特征有助于实体消歧首先判断的类型关系
ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim,
"name_embed_sim": name_embed_sim,
"name_contains": name_contains,
"co_occurrence": _co_occurrence(statement_edges, getattr(a, "id", None), getattr(b, "id", None)),
"relation_statements": _relation_statements(entity_edges, getattr(a, "id", None), getattr(b, "id", None)),
}
entity_a = {
"name": getattr(a, "name", None),
"entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None),
}
entity_b = {
"name": getattr(b, "name", None),
"entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None),
}
# 5. 渲染LLM提示词用工具函数填充模板包含实体信息、上下文、输出格式
prompt = render_entity_dedup_prompt(
entity_a=entity_a,
entity_b=entity_b,
context=ctx,
json_schema=EntityDedupDecision.model_json_schema(),
)
messages = [
{"role": "system", "content": "You judge whether two entities are the same. Return valid JSON only."},
{"role": "user", "content": prompt},
]
decision = await llm_client.response_structured(messages, EntityDedupDecision)
return decision, ctx
# 消歧场景同名不同类型下的LLM判断
async def _judge_pair_disamb(
llm_client: OpenAIClient,
a: ExtractedEntityNode,
b: ExtractedEntityNode,
statement_edges: List[StatementEntityEdge],
entity_edges: List[EntityEntityEdge],
) -> Tuple[EntityDisambDecision, Dict]:
name_text_sim = _name_text_sim(getattr(a, "name", ""), getattr(b, "name", ""))
name_embed_sim = _name_embed_sim(getattr(a, "name_embedding", []), getattr(b, "name_embedding", []))
name_contains = False
try:
n1 = (getattr(a, "name", "") or "").strip().lower()
n2 = (getattr(b, "name", "") or "").strip().lower()
name_contains = bool(n1 and n2 and (n1 in n2 or n2 in n1))
except Exception:
pass
ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim,
"name_embed_sim": name_embed_sim,
"name_contains": name_contains,
"co_occurrence": _co_occurrence(statement_edges, getattr(a, "id", None), getattr(b, "id", None)),
"relation_statements": _relation_statements(entity_edges, getattr(a, "id", None), getattr(b, "id", None)),
}
entity_a = {
"name": getattr(a, "name", None),
"entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None),
}
entity_b = {
"name": getattr(b, "name", None),
"entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None),
}
prompt = render_entity_dedup_prompt(
entity_a=entity_a,
entity_b=entity_b,
context=ctx,
json_schema=EntityDisambDecision.model_json_schema(),
disambiguation_mode=True,
)
messages = [
{"role": "system", "content": "You disambiguate same-name different-type entities. Return valid JSON only."},
{"role": "user", "content": prompt},
]
decision = await llm_client.response_structured(messages, EntityDisambDecision)
return decision, ctx
# llm_dedup_entities单轮实体去重
async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了保证高精度、可审计、可复用和行为一致性
# 对偶判断让每次决策只聚焦于一对实体,信息维度清晰,噪声更低,模型更容易给出稳定的“是否同一实体”与“规范方”选择。
# 考虑是否将其保留
entity_nodes: List[ExtractedEntityNode],
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
llm_client: OpenAIClient,
max_concurrency: int = 4,
auto_merge_threshold: float = 0.90,
co_ctx_threshold: float = 0.83,
) -> Tuple[Dict[str, str], List[str]]:
"""
Use LLM to assist fuzzy deduplication among candidate entity pairs and
produce an `id_redirect` mapping plus audit log records.
Parameters:
- entity_nodes: deduplication input entities
- statement_entity_edges: edges from statements to entities (for co-occurrence context)
- entity_entity_edges: relational edges between entities (for relation statements)
- llm_client: configured async client used to call the model
- max_concurrency: semaphore limit for concurrent LLM calls (default 4)
- auto_merge_threshold: confidence threshold to auto-merge without co-occurrence (default 0.90)
- co_ctx_threshold: slightly lower threshold when co-occurrence is detected (default 0.83)
Returns:
- id_redirect_updates: dict of losing_id -> canonical_id decided by LLM
- records: textual logs for decisions, errors, and non-merges
Notes:
- Candidate generation uses simple gates: same group, type compatible, and
name similarity or containment, optionally lowered threshold with co-occurrence.
- The higher-level pipeline should call this async function upstream, then
pass the resulting mapping and records into `deduplicate_entities_and_edges`
via `llm_redirect` and `llm_records` to apply merges synchronously before
edge redirection.
"""
# 1. 构建“候选实体对”用规则层筛选减少LLM调用量提高效率
# Build candidate pairs: simple gates
candidates: List[Tuple[int, int]] = []
for i in range(len(entity_nodes)):
a = entity_nodes[i]
for j in range(i + 1, len(entity_nodes)):
b = entity_nodes[j]
# 规则1必须属于同一组group_id相同不同组的实体不重复
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
continue
# 规则2类型必须兼容调用_simple_type_ok判断
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
continue
# 规则3名称相似度达标文本/嵌入相似度取最大值)
txt_sim = _name_text_sim(getattr(a, "name", ""), getattr(b, "name", ""))
emb_sim = _name_embed_sim(getattr(a, "name_embedding", []), getattr(b, "name_embedding", []))
# 规则4名称是否包含如“苹果公司”和“苹果”
contains = False
try:
n1 = (getattr(a, "name", "") or "").strip().lower()
n2 = (getattr(b, "name", "") or "").strip().lower()
contains = bool(n1 and n2 and (n1 in n2 or n2 in n1))
except Exception:
pass
# 规则5是否同现同现的实体更可能重复降低相似度阈值
co_ctx = _co_occurrence(statement_entity_edges, getattr(a, "id", None), getattr(b, "id", None))
sim = max(txt_sim, emb_sim)
# 候选对筛选条件:满足任一即加入(减少漏判)
if (sim >= 0.80) or (co_ctx and sim >= 0.75) or contains:
candidates.append((i, j))
# Use anyio for cross-compatibility with asyncio and trio
results = []
async with anyio.create_task_group() as tg:
result_list = [None] * len(candidates)
async def _wrapped(idx: int, i: int, j: int):
try:
result_list[idx] = await _judge_pair(llm_client, entity_nodes[i], entity_nodes[j], statement_entity_edges, entity_entity_edges)
except Exception as e:
result_list[idx] = e
# Limit concurrency using semaphore
sem = anyio.Semaphore(max_concurrency)
async def _limited_wrapped(idx: int, i: int, j: int):
async with sem:
await _wrapped(idx, i, j)
for idx, (i, j) in enumerate(candidates):
tg.start_soon(_limited_wrapped, idx, i, j)
results = result_list
id_redirect_updates: Dict[str, str] = {}
records: List[str] = []
for idx, res in enumerate(results):
if isinstance(res, Exception):
i, j = candidates[idx]
a = entity_nodes[i]
b = entity_nodes[j]
records.append(f"[LLM异常] pair ({a.id},{b.id}) -> {res}")
continue
decision, ctx = res
i, j = candidates[idx]
a = entity_nodes[i]
b = entity_nodes[j]
th = auto_merge_threshold if not ctx.get("co_occurrence") else co_ctx_threshold
if decision.same_entity and decision.confidence >= th:
canon_idx = decision.canonical_idx if decision.canonical_idx in (0, 1) else _choose_canonical(a, b)
canon = a if canon_idx == 0 else b
other = b if canon_idx == 0 else a
id_redirect_updates[getattr(other, "id")] = getattr(canon, "id")
records.append(
f"[LLM合并] 规范实体 {canon.id} 名称 '{getattr(canon, 'name', '')}' <- 合并实体 {other.id} 名称 '{getattr(other, 'name', '')}' | conf={decision.confidence:.3f}, th={th:.3f}, co_ctx={ctx.get('co_occurrence')}"
)
# 若类型相同且名称高度相似/包含关系,补充“同类名称相似”记录,格式与报告要求一致(名称后带类型)
try:
type_same = (getattr(a, "entity_type", None) == getattr(b, "entity_type", None)) and getattr(a, "entity_type", None) is not None
name_sim = max(float(ctx.get("name_text_sim", 0.0)), float(ctx.get("name_embed_sim", 0.0)))
name_contains = bool(ctx.get("name_contains", False))
if type_same and (name_sim >= 0.80 or name_contains):
name_a = (getattr(a, "name", "") or "").strip()
name_b = (getattr(b, "name", "") or "").strip()
type_a = getattr(a, "entity_type", "")
type_b = getattr(b, "entity_type", "")
records.append(
f"[LLM去重] 同类名称相似 {name_a}{type_a}|{name_b}{type_b} | conf={decision.confidence:.2f} | reason={decision.reason}"
)
except Exception:
pass
else:
records.append(
f"[LLM不合并] A={a.id} B={b.id} | same={decision.same_entity} conf={decision.confidence:.3f} co_ctx={ctx.get('co_occurrence')}"
)
return id_redirect_updates, records
# 迭代分块去重,这才是重点
async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
entity_nodes: List[ExtractedEntityNode], # 待去重实体列表需先经过精确去重LLM决策属于模糊匹配下
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
llm_client: OpenAIClient,
block_size: int = 50,
block_concurrency: int = 4,
pair_concurrency: int = 4,
max_rounds: int = 3,
auto_merge_threshold: float = 0.90,
co_ctx_threshold: float = 0.83,
shuffle_each_round: bool = True, # 每轮是否打乱实体顺序(避免同一块内实体重复,提高覆盖度)
) -> Tuple[Dict[str, str], List[str]]: # 返回全局ID映射、全局审计日志
"""
Iteratively deduplicate entities using LLM in block-wise concurrent rounds.
Process:
- Partition the input entities (post exact + local fuzzy stage) into blocks per round.
- Run LLM pairwise decisions concurrently *within each block*, and also run multiple blocks concurrently.
- Apply merges from all blocks, collapse to canonical set, re-partition, and repeat until no new merges or max_rounds reached.
Parameters:
- entity_nodes: entities to deduplicate (should already be exact/fuzzy merged candidates)
- statement_entity_edges: statement→entity edges for co-occurrence context
- entity_entity_edges: entity↔entity relational edges for relation statements context
- llm_client: initialized async client
- block_size: target number of entities per block (default 50)
- block_concurrency: how many blocks to process concurrently (default 4)
- pair_concurrency: concurrency for pairwise LLM calls inside each block (default 4)
- max_rounds: upper bound for iterative passes (default 3)
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition
Returns:
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
- records: textual logs including per-round/per-block summaries and per-pair decisions
"""
import asyncio
import random
# 初始化全局日志和全局ID映射存储所有轮次的结果
records: List[str] = []
global_redirect: Dict[str, str] = {}
# Helper: resolve final canonical id following redirect chain
# 辅助函数1_resolve递归解析实体的“最终规范ID”处理ID映射链如a→b→c返回c
def _resolve(id_: str) -> str:
while id_ in global_redirect and global_redirect[id_] != id_: # 若ID在映射中且未指向自身
id_ = global_redirect[id_] # 递归替换为映射的ID
return id_ # 返回最终规范ID
## 这里辅助函数没有看懂
# Helper: collapse nodes to canonical representatives per current global_redirect
# 辅助函数2_collapse_nodes根据全局ID映射“折叠”实体列表保留每个规范ID对应的实体
def _collapse_nodes(nodes: List[ExtractedEntityNode]) -> List[ExtractedEntityNode]:
by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in nodes} # 实体ID→实体的映射
keep: Dict[str, ExtractedEntityNode] = {} # 存储需保留的规范实体
for e in nodes:
cid = _resolve(e.id) # 解析e的最终规范ID
# 优先保留by_id中已存在的规范实体若有否则保留第一个遇到的实体
if cid in by_id:
keep[cid] = by_id[cid]
else:
keep[cid] = keep.get(cid, e)
return list(keep.values())
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
"""
按 group_id 分块,避免跨组实体在同一块,减少无效候选对
Args:
nodes: 实体节点列表
Returns:
分块后的实体列表
"""
groups: Dict[str, List[ExtractedEntityNode]] = {}
for e in nodes:
gid = getattr(e, "group_id", None)
groups.setdefault(str(gid), []).append(e)
blocks: List[List[ExtractedEntityNode]] = []
for gid, arr in groups.items():
if shuffle_each_round:
random.shuffle(arr)
# chunk into block_size
for i in range(0, len(arr), max(1, block_size)):
blocks.append(arr[i:i + max(1, block_size)])
return blocks
# Semaphore for block-level concurrency
# 初始化块级并发信号量(控制同时处理的块数量)
block_sem = asyncio.Semaphore(max(1, block_concurrency))
# 辅助函数4_run_one_block处理单个块的去重调用llm_dedup_entities
async def _run_one_block(block_idx: int, block_nodes: List[ExtractedEntityNode]):
async with block_sem:
# Delegate to existing pairwise function with limited concurrency per block
id_map, recs = await llm_dedup_entities(
entity_nodes=block_nodes,
statement_entity_edges=statement_entity_edges,
entity_entity_edges=entity_entity_edges,
llm_client=llm_client,
max_concurrency=pair_concurrency,
auto_merge_threshold=auto_merge_threshold,
co_ctx_threshold=co_ctx_threshold,
)
# Prefix block index in records for readability
prefixed = [f"[LLM块{block_idx}] {line}" for line in recs]
return id_map, prefixed
# Iterative rounds
# 核心:迭代分块去重(多轮处理)
current_nodes: List[ExtractedEntityNode] = list(entity_nodes)
round_idx = 1
while round_idx <= max(1, max_rounds):
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
# 步骤1折叠实体合并已确定的重复实体减少后续计算量
current_nodes = _collapse_nodes(current_nodes)
# 步骤2分块按group_id分块避免跨组处理
blocks = _partition_blocks(current_nodes)
if not blocks: # 无块可处理(实体已全部折叠),退出循环
break
# 步骤3记录当前轮次的基本信息轮次、块数、块大小
records.append(f"[LLM批次] 轮次 {round_idx} 预计处理块数 {len(blocks)} 每块大小≈{block_size}")
# Run all blocks concurrently with block-level semaphore
# 步骤4并发处理所有块创建块处理任务批量执行
results = [None] * len(blocks)
async with anyio.create_task_group() as tg:
async def _run_block_wrapper(idx: int, block: List[ExtractedEntityNode]):
try:
results[idx] = await _run_one_block(idx, block)
except Exception as e:
results[idx] = e
for i in range(len(blocks)):
tg.start_soon(_run_block_wrapper, i, blocks[i])
# Collect and normalize redirects from blocks
# 步骤5合并块结果到全局映射和日志
merged_this_round = 0
for bi, res in enumerate(results):
if isinstance(res, Exception):
records.append(f"[LLM块异常] 轮次 {round_idx}{bi} -> {res}")
continue
id_map, recs = res
records.extend(recs)
# Normalize with current global redirects
for losing, canon in id_map.items():
losing_final = _resolve(losing)
canon_final = _resolve(canon)
if losing_final == canon_final:
continue
# Apply mapping and ensure chain consistency
global_redirect[losing_final] = canon_final
merged_this_round += 1
records.append(f"[LLM批次] 轮次 {round_idx} 块数 {len(blocks)} 新合并 {merged_this_round}")
if merged_this_round == 0:
break
# Prepare nodes for next round: collapse canonical set
current_nodes = _collapse_nodes(current_nodes)
round_idx += 1
return global_redirect, records
# LLM 消歧:同名不同类型的实体对,输出合并建议与阻断对列表
async def llm_disambiguate_pairs_iterative(
entity_nodes: List[ExtractedEntityNode],
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
llm_client: OpenAIClient,
max_concurrency: int = 4,
merge_conf_threshold: float = 0.88,
block_conf_threshold: float = 0.60,
) -> Tuple[Dict[str, str], List[Tuple[str, str]], List[str]]:
"""
Disambiguate same-name different-type pairs using LLM.
Returns:
- merge_redirect: dict losing_id -> canonical_id for merges decided by LLM
- block_pairs: list of sorted (id1, id2) pairs to block from fuzzy/heuristic merges
- records: textual logs for audit
"""
records: List[str] = []
merge_redirect: Dict[str, str] = {}
block_pairs: List[Tuple[str, str]] = []
def _is_typed(t: str) -> bool:
t = (t or "").strip().upper()
return bool(t) and t not in {"UNKNOWN", "UNDEFINED", ""}
candidates: List[Tuple[int, int]] = []
n = len(entity_nodes)
for i in range(n):
for j in range(i + 1, n):
a = entity_nodes[i]
b = entity_nodes[j]
# 必须同组
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
continue
ta = getattr(a, "entity_type", None)
tb = getattr(b, "entity_type", None)
# 必须不同类型且两者均为已定义类型
if ta == tb:
continue
if not (_is_typed(ta) and _is_typed(tb)):
continue
# 严格“同名不同义”:名称需严格相同(大小写与首尾空格忽略)
try:
na = (getattr(a, "name", "") or "").strip().lower()
nb = (getattr(b, "name", "") or "").strip().lower()
except Exception:
na, nb = "", ""
if not na or not nb:
continue
if na == nb:
candidates.append((i, j))
if not candidates:
return merge_redirect, block_pairs, records
# Use anyio for cross-compatibility with asyncio and trio
judged = [None] * len(candidates)
async with anyio.create_task_group() as tg:
async def _wrapped(idx: int, i: int, j: int):
try:
judged[idx] = await _judge_pair_disamb(llm_client, entity_nodes[i], entity_nodes[j], statement_entity_edges, entity_entity_edges)
except Exception as e:
judged[idx] = e
# Limit concurrency using semaphore
sem = anyio.Semaphore(max_concurrency)
async def _limited_wrapped(idx: int, i: int, j: int):
async with sem:
await _wrapped(idx, i, j)
for idx, (i, j) in enumerate(candidates):
tg.start_soon(_limited_wrapped, idx, i, j)
for k, res in enumerate(judged):
i, j = candidates[k]
a = entity_nodes[i]
b = entity_nodes[j]
a_id = getattr(a, "id", None) or ""
b_id = getattr(b, "id", None) or ""
if isinstance(res, Exception):
records.append(f"[DISAMB错误] 对({a_id},{b_id})调用失败: {res}")
block_pairs.append(tuple(sorted((a_id, b_id))))
continue
decision, ctx = res
try:
if decision.should_merge and decision.confidence >= merge_conf_threshold:
can_idx = 0 if decision.canonical_idx == 0 else 1
canonical = a if can_idx == 0 else b
losing = b if can_idx == 0 else a
merge_redirect[getattr(losing, "id", "")] = getattr(canonical, "id", "")
records.append(
f"[DISAMB合并] {getattr(losing,'id','')} -> {getattr(canonical,'id','')} | conf={decision.confidence:.2f} | reason={decision.reason} | suggested_type={decision.suggested_type or ''}"
)
# 追加 LLM 决策去重记录以便下方报告展示到“LLM 决策去重”区块
records.append(
f"[LLM去重] 同名类型相似 {getattr(a,'name','')}{getattr(a,'entity_type','')}|{getattr(b,'name','')}{getattr(b,'entity_type','')} | conf={decision.confidence:.2f} | reason={decision.reason}"
)
else:
# Fallback同名且类型不同但语义高度相似且未要求阻断按“同名类型相似”进行合并
name_a = (getattr(a, "name", "") or "").strip().lower()
name_b = (getattr(b, "name", "") or "").strip().lower()
def _strength_rank(x: str) -> int:
s = (x or "").strip().lower()
return {"strong": 3, "both": 2, "weak": 1}.get(s, 0)
if (
name_a and name_b and name_a == name_b
and (not decision.block_pair)
and decision.confidence >= max(0.80, block_conf_threshold)
):
# 选择规范实体:优先使用 canonical_idx否则根据连接强度挑选更强者
if decision.canonical_idx in (0, 1):
canonical = a if decision.canonical_idx == 0 else b
losing = b if decision.canonical_idx == 0 else a
else:
sa = _strength_rank(getattr(a, "connect_strength", None))
sb = _strength_rank(getattr(b, "connect_strength", None))
canonical = a if sa >= sb else b
losing = b if sa >= sb else a
merge_redirect[getattr(losing, "id", "")] = getattr(canonical, "id", "")
# 消歧合并审计
records.append(
f"[DISAMB合并] {getattr(losing,'id','')} -> {getattr(canonical,'id','')} | conf={decision.confidence:.2f} | reason={decision.reason} | suggested_type={decision.suggested_type or ''}"
)
# 追加 LLM 决策去重记录(同名类型相似)
records.append(
f"[LLM去重] 同名类型相似 {getattr(a,'name','')}{getattr(a,'entity_type','')}|{getattr(b,'name','')}{getattr(b,'entity_type','')} | conf={decision.confidence:.2f} | reason={decision.reason}"
)
else:
if decision.block_pair or decision.confidence >= block_conf_threshold:
block_pairs.append(tuple(sorted((a_id, b_id))))
# 仅保留阻断条目在预筛选报告,包含实体名称与类型,便于人读
records.append(
f"[DISAMB阻断] {getattr(a,'name','')}{getattr(a,'entity_type','')}|{getattr(b,'name','')}{getattr(b,'entity_type','')} | conf={decision.confidence:.2f} | reason={decision.reason} || block_pair={decision.block_pair}"
)
except Exception:
block_pairs.append(tuple(sorted((a_id, b_id))))
# 异常情况也以阻断形式记录,包含名称便于定位
records.append(
f"[DISAMB异常阻断] {getattr(a,'name','')}{getattr(a,'entity_type','')}|{getattr(b,'name','')}{getattr(b,'entity_type','')} | ids=({a_id},{b_id})"
)
return merge_redirect, block_pairs, records

View File

@@ -0,0 +1,149 @@
# 导入 Python 的annotations特性允许在类型注解中使用尚未定义的类支持 “向前引用”),提升代码中类型注解的灵活性。
# 这是什么意思? 该类的属性的类型是这个类本身(递归定义)?
"""
这段代码是 “第二层去重消歧” 的核心实现,逻辑可分为四步:
1.从第一层去重消歧后的实体中提取核心信息,作为索引查询 Neo4j 中同组的候选实体;
2.对候选实体去重并转换为统一模型;
3.构建预重定向关系(第一层实体 ID→数据库实体 ID确保优先使用数据库 ID
4.合并数据库候选实体与第一层实体,调用去重函数完成最终融合,返回结果。
"""
from __future__ import annotations
from typing import List, Dict, Any, Tuple
from datetime import datetime
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互
from app.repositories.neo4j.graph_search import get_dedup_candidates_for_entities # 导入ge函数用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges, _write_dedup_fusion_report # 导入报告写入以在跳过时追加说明
from app.core.memory.models.variate_config import DedupConfig
def _parse_dt(val: Any) -> datetime: # 定义内部辅助函数_parse_dt用于将任意类型的输入值解析为datetime对象处理实体节点中的时间字段
if isinstance(val, datetime):
return val
if isinstance(val, str) and val:
try:
return datetime.fromisoformat(val) # 使用fromisoformat方法将 ISO 格式的字符串(如 "2023-10-01T12:00:00"解析为datetime对象
except Exception:
pass
# Fallback: now; upstream should provide real times
return datetime.now()
def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
"""
将 Neo4j 返回的数据库记录转换为 ExtractedEntityNode 模型对象
Args:
row: Neo4j 查询返回的记录字典
Returns:
ExtractedEntityNode: 实体节点对象
Note:
从数据库中查询到的内容是 JSON 格式的字符串,需要先解析为 Python 对象
"""
return ExtractedEntityNode(
id=row.get("id"),
name=row.get("name") or "",
group_id=row.get("group_id") or "",
user_id=row.get("user_id") or "",
apply_id=row.get("apply_id") or "",
created_at=_parse_dt(row.get("created_at")),
expired_at=_parse_dt(row.get("expired_at")) if row.get("expired_at") else None,
entity_idx=int(row.get("entity_idx") or 0),
statement_id=row.get("statement_id") or "",
entity_type=row.get("entity_type") or "",
description=row.get("description") or "",
aliases=row.get("aliases") or [],
name_embedding=row.get("name_embedding") or [],
fact_summary=row.get("fact_summary") or "",
connect_strength=row.get("connect_strength") or "",
)
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
connector: Neo4jConnector,
group_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
dedup_config: DedupConfig | None = None,
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
"""
第二层去重消歧:
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
- 返回融合后的实体与重定向后的边(边已指向规范 ID优先 DB ID
"""
if not entity_nodes:
return entity_nodes, statement_entity_edges, entity_entity_edges
# 构造批量行并检索候选(精确/别名 + CONTAINS 召回)
# 将第一层去重消歧的结果作为索引批量查询DB候选实体
incoming_rows: List[Dict[str, Any]] = [ # 定义 包含第一层实体的核心信息(用于数据库查询)
{"id": e.id, "name": e.name, "entity_type": e.entity_type} for e in entity_nodes # 对entity_nodes中的每个实体e提取id实体 ID、name名称、entity_type类型构造字典作为查询条件。
]
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体并将结果赋值给candidates_map等待异步操作完成
connector=connector, group_id=group_id,
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略若精确匹配无结果用包含关系召回候选与src\database\cypher_queries.py的307产生联动
)
# 拉平候选,转为模型(按 DB 节点优先)
db_candidate_rows: List[Dict[str, Any]] = [] # 存储去重后的数据库候选实体记录(行)
seen_db_ids: set[str] = set() # 集合,用于记录已处理的数据库实体 ID避免重复添加同一实体
for _, rows in candidates_map.items():
for r in rows:
rid = r.get("id")
if rid and rid not in seen_db_ids: # 如果rid存在且未被处理
seen_db_ids.add(rid) # 将rid加入seen_db_ids标记为已处理
db_candidate_rows.append(r) # 将该记录r添加到db_candidate_rows确保数据库实体唯一
db_candidate_models: List[ExtractedEntityNode] = []
for r in db_candidate_rows: # db_candidate_rows去重后的数据库候选实体记录
try:
m = _row_to_entity(r) # 调用_row_to_entity函数将数据库记录r转换为实体模型m
db_candidate_models.append(m) # m添加到db_candidate_models
except Exception:
# 忽略无法解析的记录
pass
# 若 DB 候选为空:跳过二层融合,直接返回第一层结果,并在报告中标注候选数
candidate_count = len(db_candidate_models)
if candidate_count == 0:
try:
_write_dedup_fusion_report(
exact_merge_map={},
fuzzy_merge_records=[],
local_llm_records=[],
disamb_records=[],
stage_label="第二层去重消歧",
append=True,
stage_notes=[f"候选数:{candidate_count}DB 为空则标注跳过)"],
)
except Exception:
# 报告写入失败不影响主流程
pass
return entity_nodes, statement_entity_edges, entity_entity_edges
# 联合集合DB 在前,确保规范 ID 优先使用 DB ID
# 将从 DB 检索到的候选实体与第一层去重消歧的实体合并,作为输入继续调用去重方法。
# 由于按顺序遍历,规范实体将优先选择位于前面的 DB 节点,因此无需显式预重定向。
union_entities: List[ExtractedEntityNode] = db_candidate_models + list(entity_nodes)
# 融合(内部执行精确/模糊/LLM 决策;随后再做边重定向与去重)
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges = await deduplicate_entities_and_edges(
union_entities,
statement_entity_edges,
entity_entity_edges,
report_stage="第二层去重消歧",
report_append=True,
dedup_config=dedup_config,
)
return fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges

View File

@@ -0,0 +1,106 @@
from __future__ import annotations
from typing import List, Tuple, Optional
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.utils.config.config_utils import get_pipeline_config
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import second_layer_dedup_and_merge_with_neo4j
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.models.graph_models import (
DialogueNode,
ChunkNode,
StatementNode,
ExtractedEntityNode,
StatementChunkEdge,
StatementEntityEdge,
EntityEntityEdge,
)
from app.core.memory.models.message_models import DialogData
async def dedup_layers_and_merge_and_return(
dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData],
pipeline_config: Optional[ExtractionPipelineConfig] = None,
connector: Optional[Neo4jConnector] = None,
) -> Tuple[
List[DialogueNode],
List[ChunkNode],
List[StatementNode],
List[ExtractedEntityNode],
List[StatementChunkEdge],
List[StatementEntityEdge],
List[EntityEntityEdge],
]:
"""
执行两层实体去重与融合:
- 第一层:精确/模糊/LLM 决策去重
- 第二层:与 Neo4j 同组实体联合去重与融合(依赖传入的 connector
返回融合后的实体与边,同时保留原始的对话、片段与语句节点与边。
"""
# 默认从 runtime.json 加载管线配置,避免回退到环境变量
if pipeline_config is None:
try:
pipeline_config = get_pipeline_config()
except Exception:
pipeline_config = None
# 先探测 group_id决定报告写入策略
group_id: Optional[str] = None
for dd in dialog_data_list:
group_id = getattr(dd, "group_id", None)
if group_id:
break
# 第一层去重消歧
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
entity_nodes,
statement_entity_edges,
entity_entity_edges,
report_stage="第一层去重消歧",
report_append=False,
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
)
# 初始化第二层融合结果为第一层结果
fused_entity_nodes = dedup_entity_nodes
fused_statement_entity_edges = dedup_statement_entity_edges
fused_entity_entity_edges = dedup_entity_entity_edges
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
try:
if group_id:
if connector:
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
connector=connector,
group_id=group_id,
entity_nodes=dedup_entity_nodes,
statement_entity_edges=dedup_statement_entity_edges,
entity_entity_edges=dedup_entity_entity_edges,
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
)
else:
print("Skip second-layer dedup: missing connector")
else:
print("Skip second-layer dedup: missing group_id")
except Exception as e:
print(f"Second-layer dedup failed: {e}")
return (
dialogue_nodes,
chunk_nodes,
statement_nodes,
fused_entity_nodes,
statement_chunk_edges,
fused_statement_entity_edges,
fused_entity_entity_edges,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,11 @@
"""
知识提取模块
包含以下提取器:
- DialogueChunker: 对话分块
- StatementExtractor: 陈述句提取
- TripletExtractor: 三元组提取
- TemporalExtractor: 时间信息提取
- EmbeddingGenerator: 嵌入向量生成
- MemorySummaryGenerator: 记忆摘要生成
"""

View File

@@ -0,0 +1,103 @@
import os
from typing import Optional
from app.core.logging_config import get_memory_logger
from app.core.memory.models.message_models import DialogData, Chunk
from app.core.memory.models.config_models import ChunkerConfig
from app.core.memory.llm_tools.chunker_client import ChunkerClient
from app.core.memory.utils.config.config_utils import get_chunker_config
logger = get_memory_logger(__name__)
class DialogueChunker:
"""A class that processes dialogues and fills them with chunks based on a specified strategy.
This class encapsulates the chunking process, allowing for easy configuration and application
of different chunking strategies to dialogue data.
"""
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None):
"""Initialize the DialogueChunker with a specific chunking strategy.
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
"""
self.chunker_strategy = chunker_strategy
chunker_config_dict = get_chunker_config(chunker_strategy)
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
# 对于 LLMChunker需要传入 llm_client
if self.chunker_config.chunker_strategy == "LLMChunker":
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
else:
self.chunker_client = ChunkerClient(self.chunker_config)
async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]:
"""Process a dialogue by generating chunks and adding them to the DialogData object.
Args:
dialogue: The DialogData object to process
Returns:
A list of Chunk objects
"""
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
# Defensive fallback: ensure at least one chunk is returned for non-empty content
try:
chunks = result_dialogue.chunks
except Exception:
chunks = []
if not chunks or len(chunks) == 0:
# If the dialogue has content, return a single fallback chunk built from messages
content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "")
if content_str and len(content_str.strip()) > 0:
fallback_chunk = Chunk.from_messages(
dialogue.context.msgs,
metadata={
"fallback": "single_chunk",
"chunker_strategy": self.chunker_config.chunker_strategy,
"source": "DialogueChunkerFallback",
},
)
return [fallback_chunk]
# No content: return empty list
return []
return chunks
def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str:
"""Save the chunking results to a file and return the output path.
Args:
dialogue: The processed DialogData object with chunks
output_path: Optional path to save the output (default: chunker_output_{strategy}.txt)
Returns:
The path where the output was saved
"""
if not output_path:
output_path = os.path.join(os.path.dirname(__file__), "..", "..",
f"chunker_output_{self.chunker_strategy.lower()}.txt")
output_lines = []
output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===")
output_lines.append(f"Dialogue ID: {dialogue.ref_id}")
output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages")
output_lines.append(f"Total characters: {len(dialogue.content)}")
output_lines.append(f"Generated {len(dialogue.chunks)} chunks:")
for i, chunk in enumerate(dialogue.chunks):
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
output_lines.append(f" Content preview: {chunk.content}...")
if chunk.metadata:
output_lines.append(f" Metadata: {chunk.metadata}")
with open(output_path, "w", encoding="utf-8") as f:
f.write("\n".join(output_lines))
logger.info(f"Chunking results saved to: {output_path}")
return output_path

Some files were not shown because too many files have changed in this diff Show More