Merge branch 'refs/heads/develop' into fix/memory_bug_fix

This commit is contained in:
lixinyue
2026-01-20 10:47:32 +08:00
107 changed files with 4845 additions and 6224 deletions

View File

@@ -3,6 +3,12 @@ Celery Worker 入口点
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
"""
from app.celery_app import celery_app
from app.core.logging_config import LoggingConfig, get_logger
# Initialize logging system for Celery worker
LoggingConfig.setup_logging()
logger = get_logger(__name__)
logger.info("Celery worker logging initialized")
# 导入任务模块以注册任务
import app.tasks

View File

@@ -267,7 +267,7 @@ async def generate_emotion_suggestions(
"""生成个性化情绪建议调用LLM并缓存
Args:
request: 包含 group_id、可选的 config_id 和 force_refresh
request: 包含 end_user_id
db: 数据库会话
current_user: 当前用户
@@ -275,47 +275,22 @@ async def generate_emotion_suggestions(
新生成的个性化情绪建议响应
"""
try:
# 验证 config_id如果提供
# 获取终端用户关联的配置
config_id = request.config_id
if config_id is None:
# 如果没有提供 config_id尝试获取用户关联的配置
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(request.group_id, db)
config_id = connected_config.get("memory_config_id")
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "无法获取用户关联的配置", str(e))
else:
# 如果提供了 config_id验证其有效性
from app.services.memory_config_service import MemoryConfigService
try:
config_service = MemoryConfigService(db)
config = config_service.get_config_by_id(config_id)
if not config:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", f"配置 {config_id} 不存在")
except Exception as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID验证失败", str(e))
api_logger.info(
f"用户 {current_user.username} 请求生成个性化情绪建议",
extra={
"group_id": request.group_id,
"config_id": config_id
"end_user_id": request.end_user_id
}
)
# 调用服务层生成建议
data = await emotion_service.generate_emotion_suggestions(
end_user_id=request.group_id,
end_user_id=request.end_user_id,
db=db
)
# 保存到缓存
await emotion_service.save_suggestions_cache(
end_user_id=request.group_id,
end_user_id=request.end_user_id,
suggestions_data=data,
db=db,
expires_hours=24
@@ -324,7 +299,7 @@ async def generate_emotion_suggestions(
api_logger.info(
"个性化建议生成成功",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
@@ -334,7 +309,7 @@ async def generate_emotion_suggestions(
except Exception as e:
api_logger.error(
f"生成个性化建议失败: {str(e)}",
extra={"group_id": request.group_id},
extra={"end_user_id": request.end_user_id},
exc_info=True
)
raise HTTPException(

View File

@@ -147,6 +147,7 @@ class Settings:
# Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))

View File

@@ -1,16 +0,0 @@
"""
LangGraph Graph package for memory agent.
This package provides the LangGraph workflow orchestrator with modular
node implementations, routing logic, and state management.
Package structure:
- read_graph: Main graph factory for read operations
- write_graph: Main graph factory for write operations
- nodes: LangGraph node implementations
- routing: State routing logic
- state: State management utilities
"""
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
__all__ = ['make_read_graph']

View File

@@ -4,7 +4,7 @@ LangGraph node implementations.
This module contains custom node implementations for the LangGraph workflow.
"""
from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
__all__ = ["ToolExecutionNode", "create_input_message"]
# from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
# from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
#
# __all__ = ["ToolExecutionNode", "create_input_message"]

View File

@@ -0,0 +1,16 @@
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
def content_input_node(state: ReadState) -> ReadState:
"""开始节点 - 提取内容并保持状态信息"""
content = state['messages'][0].content if state.get('messages') else ''
# 返回内容并保持所有状态信息
return {"data": content}
def content_input_write(state: WriteState) -> WriteState:
"""开始节点 - 提取内容并保持状态信息"""
content = state['messages'][0].content if state.get('messages') else ''
# 返回内容并保持所有状态信息
return {"data": content}

View File

@@ -1,150 +0,0 @@
"""
Input node for LangGraph workflow entry point.
This module provides the create_input_message function which processes initial
user input with multimodal support and creates the first tool call message.
"""
import logging
import re
import uuid
from datetime import datetime
from typing import Any, Dict
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
logger = logging.getLogger(__name__)
async def create_input_message(
state: Dict[str, Any],
tool_name: str,
session_id: str,
search_switch: str,
apply_id: str,
group_id: str,
multimodal_processor: MultimodalProcessor,
memory_config: MemoryConfig,
) -> Dict[str, Any]:
"""
Create initial tool call message from user input.
This function:
1. Extracts the last message content from state
2. Processes multimodal inputs (images/audio) using the multimodal processor
3. Generates a unique message ID
4. Extracts namespace from session_id
5. Handles verified_data extraction for backward compatibility
6. Returns AIMessage with complete tool_calls structure
Args:
state: LangGraph state dictionary containing messages
tool_name: Name of the tool to invoke (typically "Split_The_Problem")
session_id: Session identifier (format: "call_id_{namespace}")
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
multimodal_processor: Processor for handling image/audio inputs
memory_config: MemoryConfig object containing all configuration
Returns:
State update with AIMessage containing tool_call
Examples:
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
>>> result = await create_input_message(
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config
... )
>>> result["messages"][0].tool_calls[0]["name"]
'Split_The_Problem'
"""
messages = state.get("messages", [])
# Extract last message content
if messages:
last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1])
else:
logger.warning("[create_input_message] No messages in state, using empty string")
last_message = ""
logger.debug(f"[create_input_message] Original input: {last_message[:100]}...")
# Process multimodal input (images/audio)
try:
processed_content = await multimodal_processor.process_input(last_message)
if processed_content != last_message:
logger.info(
f"[create_input_message] Multimodal processing converted input "
f"from {len(last_message)} to {len(processed_content)} chars"
)
last_message = processed_content
except Exception as e:
logger.error(
f"[create_input_message] Multimodal processing failed: {e}",
exc_info=True
)
# Continue with original content
# Generate unique message ID
uuid_str = uuid.uuid4()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Extract namespace from session_id
# Expected format: "call_id_{namespace}" or similar
try:
namespace = str(session_id).split('_id_')[1]
except (IndexError, AttributeError):
logger.warning(
f"[create_input_message] Could not extract namespace from session_id: {session_id}"
)
namespace = "unknown"
# Handle verified_data extraction (backward compatibility)
# This regex-based extraction is kept for compatibility with existing data formats
if 'verified_data' in str(last_message):
try:
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
query_match = re.findall(r'"query": "(.*?)",', messages_last)
if query_match:
last_message = query_match[0]
logger.debug(
f"[create_input_message] Extracted query from verified_data: {last_message}"
)
except Exception as e:
logger.warning(
f"[create_input_message] Failed to extract query from verified_data: {e}"
)
# Construct tool call message
tool_call_id = f"{session_id}_{uuid_str}"
logger.info(
f"[create_input_message] Creating tool call for '{tool_name}' "
f"with ID: {tool_call_id}"
)
# Build tool arguments
tool_args = {
"sentence": last_message,
"sessionid": session_id,
"messages_id": str(uuid_str),
"search_switch": search_switch,
"apply_id": apply_id,
"group_id": group_id,
"memory_config": memory_config,
}
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": tool_name,
"args": tool_args,
"id": tool_call_id
}]
)
]
}

View File

@@ -0,0 +1,237 @@
import json
import time
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
)
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
db_session = next(get_db())
logger = get_agent_logger(__name__)
class ProblemNodeService(LLMServiceMixin):
"""问题处理节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
problem_service = ProblemNodeService()
async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点"""
# 从状态中获取数据
content = state.get('data', '')
group_id = state.get('group_id', '')
memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id)
system_prompt = await problem_service.template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
history=history,
sentence=content
)
try:
# 使用优化的LLM服务
structured = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
# 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
# 验证结构化响应
if not structured or not hasattr(structured, 'root'):
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
split_result = json.dumps([], ensure_ascii=False)
elif not structured.root:
logger.warning("Split_The_Problem: 结构化响应的root为空")
split_result = json.dumps([], ensure_ascii=False)
else:
split_result = json.dumps(
[item.model_dump() for item in structured.root],
ensure_ascii=False
)
split_result_dict = []
for index, item in enumerate(json.loads(split_result)):
split_data = {
"id": f"Q{index+1}",
"question": item['extended_question'],
"type": item['type'],
"reason": item['reason']
}
split_result_dict.append(split_data)
logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项")
result = {
"context": split_result,
"original": content,
"_intermediate": {
"type": "problem_split",
"title": "问题拆分",
"data": split_result_dict,
"original_query": content
}
}
except Exception as e:
logger.error(
f"Split_The_Problem failed: {e}",
exc_info=True
)
# 提供更详细的错误信息
error_details = {
"error_type": type(e).__name__,
"error_message": str(e),
"content_length": len(content),
"llm_model_id": memory_config.llm_model_id if memory_config else None
}
logger.error(f"Split_The_Problem error details: {error_details}")
# 创建默认的空结果
result = {
"context": json.dumps([], ensure_ascii=False),
"original": content,
"error": str(e),
"_intermediate": {
"type": "problem_split",
"title": "问题拆分",
"data": [],
"original_query": content,
"error": error_details
}
}
# 返回更新后的状态包含spit_context字段
return {"spit_data": result}
async def Problem_Extension(state: ReadState) -> ReadState:
"""问题扩展节点"""
# 获取原始数据和分解结果
start = time.time()
content = state.get('data', '')
data = state.get('spit_data', '')['context']
group_id = state.get('group_id', '')
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
memory_config = state.get('memory_config', None)
databasets = {}
try:
data = json.loads(data)
for i in data:
databasets[i['extended_question']] = i['type']
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.error(f"Problem_Extension: 数据解析失败: {e}")
# 使用空字典作为fallback
databasets = {}
data = []
history = await SessionService(store).get_history(group_id, group_id, group_id)
system_prompt = await problem_service.template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
history=history,
questions=databasets
)
try:
# 使用优化的LLM服务
response_content = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
# 验证结构化响应
if not response_content or not hasattr(response_content, 'root'):
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
aggregated_dict = {}
elif not response_content.root:
logger.warning("Problem_Extension: 结构化响应的root为空")
aggregated_dict = {}
else:
# Aggregate results by original question
aggregated_dict = {}
for item in response_content.root:
try:
key = getattr(item, "original_question", None) or (
item.get("original_question") if isinstance(item, dict) else None
)
value = getattr(item, "extended_question", None) or (
item.get("extended_question") if isinstance(item, dict) else None
)
if not key or not value:
logger.warning(f"Problem_Extension: 跳过无效项: key={key}, value={value}")
continue
aggregated_dict.setdefault(key, []).append(value)
except Exception as item_error:
logger.warning(f"Problem_Extension: 处理项目时出错: {item_error}")
continue
logger.info(f"Problem_Extension: 成功生成 {len(aggregated_dict)} 个扩展问题组")
except Exception as e:
logger.error(
f"LLM call failed for Problem_Extension: {e}",
exc_info=True
)
# 提供更详细的错误信息
error_details = {
"error_type": type(e).__name__,
"error_message": str(e),
"questions_count": len(databasets),
"llm_model_id": memory_config.llm_model_id if memory_config else None
}
logger.error(f"Problem_Extension error details: {error_details}")
aggregated_dict = {}
logger.info("Problem extension")
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
print(time.time() - start)
result = {
"context": aggregated_dict,
"original": data,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "problem_extension",
"title": "问题扩展",
"data": aggregated_dict,
"original_query": content,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return {"problem_extension": result}

View File

@@ -0,0 +1,417 @@
# ===== 标准库 =====
import asyncio
import json
import os
# ===== 第三方库 =====
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from app.core.logging_config import get_agent_logger
from app.db import get_db, get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
COUNTState,
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.langgraph_graph.tools.tool import (
create_hybrid_retrieval_tool_sync,
create_time_retrieval_tool,
extract_tool_message_content,
)
from app.core.rag.nlp.search import knowledge_retrieval
logger = get_agent_logger(__name__)
db = next(get_db())
async def rag_config(state):
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state,question):
kb_config = await rag_config(state)
group_id = state.get('group_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = question
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception :
retrieval_knowledge=[]
clean_content = ''
raw_results = ''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge,clean_content,cleaned_query,raw_results
async def llm_infomation(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None)
model_id = memory_config.llm_model_id
tenant_id = memory_config.tenant_id
# 使用现有的 memory_config 而不是重新查询数据库
# 或者使用线程安全的数据库访问
with get_db_context() as db:
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
return result_pydantic
async def clean_databases(data) -> str:
"""
简化的数据库搜索结果清理函数
Args:
data: 搜索结果数据
Returns:
清理后的内容字符串
"""
try:
# 解析JSON字符串
if isinstance(data, str):
try:
data = json.loads(data)
except json.JSONDecodeError:
return data
if not isinstance(data, dict):
return str(data)
# 获取结果数据
# with open("搜索结果.json","w",encoding='utf-8') as f:
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
results = data.get('results', data)
if not isinstance(results, dict):
return str(results)
# 收集所有内容
content_list = []
# 处理重排序结果
reranked = results.get('reranked_results', {})
if reranked:
for category in ['summaries', 'statements', 'chunks', 'entities']:
items = reranked.get(category, [])
if isinstance(items, list):
content_list.extend(items)
# 处理时间搜索结果
time_search = results.get('time_search', {})
if time_search:
if isinstance(time_search, dict):
statements = time_search.get('statements', time_search.get('time_search', []))
if isinstance(statements, list):
content_list.extend(statements)
elif isinstance(time_search, list):
content_list.extend(time_search)
# 提取文本内容
text_parts = []
for item in content_list:
if isinstance(item, dict):
text = item.get('statement') or item.get('content', '')
if text:
text_parts.append(text)
elif isinstance(item, str):
text_parts.append(item)
return '\n'.join(text_parts).strip()
except Exception as e:
logger.error(f"clean_databases failed: {e}", exc_info=True)
return str(data)
async def retrieve_nodes(state: ReadState) -> ReadState:
'''
模型信息
'''
problem_extension=state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '')
group_id=state.get('group_id', '')
memory_config = state.get('memory_config', None)
original=state.get('data', '')
problem_list=[]
for key,values in problem_extension.items():
for data in values:
problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# 创建异步任务处理单个问题
async def process_question_nodes(idx, question):
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": question,
"return_raw_results": True
}
if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
else:
clean_content, cleaned_query, raw_results = await SearchService().execute_hybrid_search(
**search_params, memory_config=memory_config
)
return {
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(problem_list)
}
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Return empty result for this question
return {
"Query_small": question,
"Result_small": "",
"_intermediate": {
"type": "search_result",
"query": question,
"raw_results": [],
"index": idx + 1,
"total": len(problem_list)
}
}
# 并发处理所有问题
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks)
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val, strict=False):
if j!=['']:
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve':dup_databases}
async def retrieve(state: ReadState) -> ReadState:
# 从state中获取group_id
import time
start=time.time()
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
group_id = state.get('group_id', '')
memory_config = state.get('memory_config', None)
original = state.get('data', '')
problem_list = []
for key, values in problem_extension.items():
for data in values:
problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
databases_anser = []
async def get_llm_info():
with get_db_context() as db: # 使用同步数据库上下文管理器
config_service = MemoryConfigService(db)
return await llm_infomation(state)
llm_config = await get_llm_info()
api_key_obj = llm_config.api_keys[0]
api_key = api_key_obj.api_key
api_base = api_key_obj.api_base
model_name = api_key_obj.model_name
llm = ChatOpenAI(
model=model_name,
api_key=api_key,
base_url=api_base,
temperature=0.2,
)
time_retrieval_tool = create_time_retrieval_tool(group_id)
search_params = { "group_id": group_id, "return_raw_results": True }
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent(
llm,
tools=[time_retrieval_tool,hybrid_retrieval],
system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的group_id是: {group_id}"
)
# 创建异步任务处理单个问题
import asyncio
# 在模块级别定义信号量,限制最大并发数
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
async def process_question(idx, question):
async with SEMAPHORE: # 限制并发
try:
if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
else:
cleaned_query = question
# 使用 asyncio 在线程池中运行同步的 agent.invoke
import asyncio
response = await asyncio.get_event_loop().run_in_executor(
None,
lambda: agent.invoke({"messages": question})
)
tool_results = extract_tool_message_content(response)
if tool_results == None:
raw_results = []
clean_content = ''
else:
raw_results = tool_results['content']
clean_content = await clean_databases(raw_results)
try:
raw_results = raw_results['results']
except Exception:
raw_results = []
return {
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(problem_list)
}
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Return empty result for this question
return {
"Query_small": question,
"Result_small": "",
"_intermediate": {
"type": "search_result",
"query": question,
"raw_results": [],
"index": idx + 1,
"total": len(problem_list)
}
}
# 并发处理所有问题
import asyncio
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks)
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val, strict=False):
if j != ['']:
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
# with open('retrieve_text.json', 'w') as f:
# json.dump(dup_databases, f, indent=4)
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve': dup_databases}

View File

@@ -0,0 +1,303 @@
import time
from app.core.logging_config import get_agent_logger, log_time
from app.db import get_db
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
)
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
)
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
logger = get_agent_logger(__name__)
db_session = next(get_db())
class SummaryNodeService(LLMServiceMixin):
"""总结节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
summary_service = SummaryNodeService()
async def summary_history(state: ReadState) -> ReadState:
group_id = state.get("group_id", '')
history = await SessionService(store).get_history(group_id, group_id, group_id)
return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
"""
增强的summary_llm函数包含更好的错误处理和数据验证
"""
data = state.get("data", '')
# 构建系统提示词
if str(search_mode) == "0":
system_prompt = await summary_service.template_service.render_template(
template_name=template_name,
operation_name=operation_name,
data=retrieve_info,
query=data
)
else:
system_prompt = await summary_service.template_service.render_template(
template_name=template_name,
operation_name=operation_name,
query=data,
history=history,
retrieve_info=retrieve_info
)
try:
# 使用优化的LLM服务进行结构化输出
structured = await summary_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=response_model,
fallback_value=None
)
# 验证结构化响应
if structured is None:
logger.warning(f"LLM返回None使用默认回答")
return "信息不足,无法回答"
# 根据操作类型提取答案
if operation_name == "summary":
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
else:
# 处理RetrieveSummaryResponse
if hasattr(structured, 'data') and structured.data:
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
else:
logger.warning(f"结构化响应缺少data字段")
aimessages = "信息不足,无法回答"
# 验证答案不为空
if not aimessages or aimessages.strip() == "":
aimessages = "信息不足,无法回答"
return aimessages
except Exception as e:
logger.error(f"结构化输出失败: {e}", exc_info=True)
# 尝试非结构化输出作为fallback
try:
logger.info("尝试非结构化输出作为fallback")
response = await summary_service.call_llm_simple(
state=state,
db_session=db_session,
system_prompt=system_prompt,
fallback_message="信息不足,无法回答"
)
if response and response.strip():
# 简单清理响应
cleaned_response = response.strip()
# 移除可能的JSON标记
if cleaned_response.startswith('```'):
lines = cleaned_response.split('\n')
cleaned_response = '\n'.join(lines[1:-1])
return cleaned_response
else:
return "信息不足,无法回答"
except Exception as fallback_error:
logger.error(f"Fallback也失败: {fallback_error}")
return "信息不足,无法回答"
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
data = state.get("data", '')
group_id = state.get("group_id", '')
await SessionService(store).save_session(
user_id=group_id,
query=data,
apply_id=group_id,
group_id=group_id,
ai_response=aimessages
)
await SessionService(store).cleanup_duplicates()
logger.info(f"sessionid: {aimessages} 写入成功")
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
storage_type=state.get("storage_type",'')
user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '')
input_summary = {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "input_summary",
"title": "快速答案",
"summary": aimessages,
"query": data,
"raw_results": raw_results,
"search_mode": "quick_search",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
retrieve={
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "retrieval_summary",
"title":"快速检索",
"summary": aimessages,
"query": data,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return input_summary,retrieve
async def Input_Summary(state: ReadState) -> ReadState:
start=time.time()
storage_type=state.get("storage_type",'')
memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '')
group_id=state.get("group_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state)
search_params = {
"group_id": group_id,
"question": data,
"return_raw_results": True
}
try:
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
except Exception as e:
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
retrieve_info, question, raw_results = "", data, []
try:
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
# 'input_summary',RetrieveSummaryResponse)
# logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
summary = summary_result[0]
except Exception as e:
logger.error( f"Input_Summary failed: {e}", exc_info=True )
summary= {
"status": "fail",
"summary_result": "信息不足,无法回答",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)
return {"summary":summary}
async def Retrieve_Summary(state: ReadState)-> ReadState:
retrieve=state.get("retrieve", '')
history = await summary_history( state)
import json
with open("检索.json","w",encoding='utf-8') as f:
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
retrieve=retrieve.get("Expansion_issue", [])
start=time.time()
retrieve_info_str=[]
for data in retrieve:
if data=='':
retrieve_info_str=''
else:
for key, value in data.items():
if key=='Answer_Small':
for i in value:
retrieve_info_str.append(i)
retrieve_info_str=list(set(retrieve_info_str))
retrieve_info_str='\n'.join(retrieve_info_str)
aimessages=await summary_llm(state,history,retrieve_info_str,
'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"Summary after retrieval: {aimessages}")
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval summary', duration)
# 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary":summary}
async def Summary(state: ReadState)-> ReadState:
start=time.time()
query = state.get("data", '')
verify=state.get("verify", '')
verify_expansion_issue=verify.get("verified_data", '')
retrieve_info_str=''
for data in verify_expansion_issue:
for key, value in data.items():
if key=='answer_small':
for i in value:
retrieve_info_str+=i+'\n'
history=await summary_history(state)
data = {
"query": query,
"history": history,
"retrieve_info": retrieve_info_str
}
aimessages=await summary_llm(state,history,data,
'summary_prompt.jinja2','summary',SummaryResponse,0)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
aimessages = '信息不足,无法回答'
try:
duration = time.time() - start
except Exception:
duration = 0.0
log_time('Retrieval summary', duration)
# 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary":summary}
async def Summary_fails(state: ReadState)-> ReadState:
storage_type=state.get("storage_type", '')
user_rag_memory_id=state.get("user_rag_memory_id", '')
result= {
"status": "success",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
return {"summary":result}

View File

@@ -1,234 +0,0 @@
"""
Tool execution node for LangGraph workflow.
This module provides the ToolExecutionNode class which wraps tool execution
with parameter transformation logic using the ParameterBuilder service.
"""
import logging
import time
from typing import Any, Callable, Dict
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_content_payload,
extract_tool_call_id,
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode
logger = logging.getLogger(__name__)
class ToolExecutionNode:
"""
Custom LangGraph node that wraps tool execution with parameter transformation.
This node extracts content from previous tool results, transforms parameters
based on tool type using ParameterBuilder, and invokes the tool with the
correct argument structure.
Attributes:
tool_node: LangGraph ToolNode wrapping the actual tool
id: Node identifier for message IDs
tool_name: Name of the tool being executed
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
memory_config: MemoryConfig object containing all configuration
"""
def __init__(
self,
tool: Callable,
node_id: str,
namespace: str,
search_switch: str,
apply_id: str,
group_id: str,
parameter_builder: ParameterBuilder,
storage_type: str,
user_rag_memory_id: str,
memory_config: MemoryConfig,
):
"""
Initialize the tool execution node.
Args:
tool: The tool function to execute
node_id: Identifier for this node (used in message IDs)
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
storage_type: Storage type for the workspace
user_rag_memory_id: User RAG memory identifier
memory_config: MemoryConfig object containing all configuration
"""
self.tool_node = ToolNode([tool])
self.id = node_id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.parameter_builder = parameter_builder
self.storage_type = storage_type
self.user_rag_memory_id = user_rag_memory_id
self.memory_config = memory_config
logger.info(
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
)
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute the tool with transformed parameters.
This method:
1. Extracts the last message from state
2. Extracts tool call ID using state extractors
3. Extracts content payload using state extractors
4. Builds tool arguments using parameter builder
5. Constructs AIMessage with tool_calls
6. Invokes the tool and returns the result
Args:
state: LangGraph state dictionary
Returns:
Updated state with tool result in messages
"""
messages = state.get("messages", [])
logger.debug( self.tool_name)
if not messages:
logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state")
return {"messages": [AIMessage(content="Error: No messages in state")]}
last_message = messages[-1]
logger.debug(
f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}"
)
try:
# Extract tool call ID using state extractors
tool_call_id = extract_tool_call_id(last_message)
logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}")
except ValueError as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}"
)
return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
try:
# Extract content payload using state extractors
content = extract_content_payload(last_message)
logger.debug(
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}"
)
# Log raw message content for debugging
if hasattr(last_message, 'content'):
raw = last_message.content
logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}")
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}",
exc_info=True
)
content = {}
try:
# Build tool arguments using parameter builder
tool_args = self.parameter_builder.build_tool_args(
tool_name=self.tool_name,
content=content,
tool_call_id=tool_call_id,
search_switch=self.search_switch,
apply_id=self.apply_id,
group_id=self.group_id,
memory_config=self.memory_config,
storage_type=self.storage_type,
user_rag_memory_id=self.user_rag_memory_id,
)
logger.debug(
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
)
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}",
exc_info=True
)
return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]}
# Construct tool input message
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": f"{self.id}_{tool_call_id}",
}]
)
]
}
try:
# Invoke the tool
result = await self.tool_node.ainvoke(tool_input)
logger.debug(
f"[ToolExecutionNode] {self.id} - Tool execution completed"
)
# Check for error in tool response
error_entry = None
if result and "messages" in result:
for msg in result["messages"]:
if hasattr(msg, 'content'):
try:
import json
content = msg.content
if isinstance(content, str):
parsed = json.loads(content)
if isinstance(parsed, dict) and "error" in parsed:
error_msg = parsed["error"]
logger.warning(
f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}"
)
error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id}
except (json.JSONDecodeError, TypeError):
pass
# Return result with error tracking if error was found
if error_entry:
result["errors"] = [error_entry]
return result
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
exc_info=True
)
# Track error in state and return error message
from langchain_core.messages import ToolMessage
error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id}
return {
"messages": [
ToolMessage(
content=f"Error executing tool: {str(e)}",
tool_call_id=f"{self.id}_{tool_call_id}"
)
],
"errors": [error_entry]
}

View File

@@ -0,0 +1,85 @@
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.memory.agent.models.verification_models import VerificationResult
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
)
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
db_session = next(get_db())
logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin):
"""验证节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState,messages_deal):
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
data = state.get('data', '')
Verify_result = {
"status": messages_deal.split_result,
"verified_data": messages_deal.expansion_issue,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "Data Verification",
"result": messages_deal.split_result,
"reason": messages_deal.reason,
"query": data,
"verified_count": len(messages_deal.expansion_issue),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return Verify_result
async def Verify(state: ReadState):
content = state.get('data', '')
group_id = state.get('group_id', '')
memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id)
retrieve = state.get("retrieve", '')
retrieve = retrieve.get("Expansion_issue", [])
messages = {
"Query": content,
"Expansion_issue": retrieve
}
system_prompt = await verification_service.template_service.render_template(
template_name='split_verify_prompt.jinja2',
operation_name='split_verify_prompt',
history=history,
sentence=messages
)
# 使用优化的LLM服务
structured = await verification_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=VerificationResult,
fallback_value={
"split_result": "fail",
"expansion_issue": [],
"reason": "验证失败"
}
)
result = await Verify_prompt(state, structured)
return {"verify": result}

View File

@@ -0,0 +1,50 @@
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.utils.write_tools import write
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
async def write_node(state: WriteState) -> WriteState:
"""
Write data to the database/file system.
Args:
ctx: FastMCP context for dependency injection
content: Data content to write
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
"""
content=state.get('data','')
group_id=state.get('group_id','')
memory_config=state.get('memory_config', '')
try:
result=await write(
content=content,
user_id=group_id,
apply_id=group_id,
group_id=group_id,
memory_config=memory_config,
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
write_result= {
"status": "success",
"data": content,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name,
}
return {"write_result":write_result}
except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True)
write_result= {
"status": "error",
"message": str(e),
}
return {"write_result": write_result}

View File

@@ -1,469 +1,177 @@
import json
import os
import re
import time
import warnings
#!/usr/bin/env python3
from contextlib import asynccontextmanager
from typing import Literal
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.nodes import (
ToolExecutionNode,
create_input_message,
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import END, START
from langchain_core.messages import HumanMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
logger = get_agent_logger(__name__)
warnings.filterwarnings("ignore", category=RuntimeWarning)
load_dotenv()
redishost=os.getenv("REDISHOST")
redisport=os.getenv('REDISPORT')
redisdb=os.getenv('REDISDB')
redispassword=os.getenv('REDISPASSWORD')
counter = COUNTState(limit=3)
# Update loop count in workflow
async def update_loop_count(state):
"""Update loop counter"""
current_count = state.get("loop_count", 0)
return {"loop_count": current_count + 1}
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
messages = state["messages"]
from app.db import get_db
from app.services.memory_config_service import MemoryConfigService
# Add boundary check
if not messages:
return END
counter.add(1) # Increment by 1
from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
Split_The_Problem,
Problem_Extension,
)
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
retrieve,
)
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
Input_Summary,
Retrieve_Summary,
Summary_fails,
Summary,
)
from app.core.memory.agent.langgraph_graph.nodes.verification_nodes import Verify
from app.core.memory.agent.langgraph_graph.routing.routers import (
Split_continue,
Retrieve_continue,
Verify_continue,
)
loop_count = counter.get_total()
logger.debug(f"[should_continue] Current loop count: {loop_count}")
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"Status tools: {status_tools}")
if "success" in status_tools:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # Maximum loop count is 3
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# Add default return value to avoid returning None
counter.reset()
return "Summary" # Default based on business requirements
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# Add default return value to avoid returning None
return 'Retrieve_Summary' # Default based on business logic
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
logger.debug(f"Split_continue state: {state}")
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '2':
return 'Input_Summary'
return 'Split_The_Problem' # Default case
class ProblemExtensionNode:
def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""):
self.tool_node = ToolNode([tool])
self.id = id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.storage_type = storage_type
self.user_rag_memory_id = user_rag_memory_id
async def __call__(self, state):
messages = state["messages"]
last_message = messages[-1] if messages else ""
logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}")
if self.tool_name == 'Input_Summary':
tool_call = re.findall("'id': '(.*?)'", str(last_message))[0]
else:
tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
# Try to extract actual content payload from previous tool result
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
extracted_payload = None
# Capture ToolMessage content field (supports single/double quotes), avoid greedy matching
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
if m:
extracted_payload = m.group(1)
else:
# Fallback: use raw string directly
extracted_payload = raw_msg
# Try to parse content as JSON first
try:
content = json.loads(extracted_payload)
except Exception:
# Try to extract JSON fragment from text and parse
parsed = None
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
for cand in candidates:
try:
parsed = json.loads(cand)
break
except Exception:
continue
# If still fails, use raw string as content
content = parsed if parsed is not None else extracted_payload
# Build correct parameters based on tool name
tool_args = {}
if self.tool_name == "Verify":
# Verify tool requires context and usermessages parameters
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Retrieve":
# Retrieve tool requires context and usermessages parameters
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary":
# Summary tool requires string type context parameter
if isinstance(content, dict):
# Convert dict to JSON string
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary_fails":
# Summary_fails tool requires string type context parameter
if isinstance(content, dict):
# Convert dict to JSON string
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == 'Input_Summary':
tool_args["context"] = str(last_message)
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_args["storage_type"] = getattr(self, 'storage_type', "")
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
elif self.tool_name == 'Retrieve_Summary':
# Retrieve_Summary expects dict directly, not JSON string
# content might be a JSON string, try to parse it
if isinstance(content, str):
try:
parsed_content = json.loads(content)
# Check if it has a "context" key
if isinstance(parsed_content, dict) and "context" in parsed_content:
tool_args["context"] = parsed_content["context"]
else:
tool_args["context"] = parsed_content
except json.JSONDecodeError:
# If parsing fails, wrap the string
tool_args["context"] = {"content": content}
elif isinstance(content, dict):
# Check if content has a "context" key that needs unwrapping
if "context" in content:
tool_args["context"] = content["context"]
else:
tool_args["context"] = content
else:
tool_args["context"] = {"content": str(content)}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
else:
# Other tools use context parameter
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": self.id + f"{tool_call}",
}]
)
]
}
result = await self.tool_node.ainvoke(tool_input)
result_text = str(result)
return {"messages": [AIMessage(content=result_text)]}
@asynccontextmanager
async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None):
"""
Create a read graph workflow for memory operations.
Args:
namespace: Namespace identifier
tools: MCP tools loaded from session
search_switch: Search mode switch ("0", "1", or "2")
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type (optional)
user_rag_memory_id: User RAG memory ID (optional)
"""
memory = InMemorySaver()
tool = [i.name for i in tools]
logger.info(f"Initializing read graph with tools: {tool}")
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
# Extract tool functions
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None)
Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None)
Verify_ = next((t for t in tools if t.name == "Verify"), None)
Summary_ = next((t for t in tools if t.name == "Summary"), None)
Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None)
Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None)
Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None)
# Instantiate services
parameter_builder = ParameterBuilder()
multimodal_processor = MultimodalProcessor()
# Create nodes using new modular components
Split_The_Problem_node = ToolNode([Split_The_Problem_])
Problem_Extension_node = ToolExecutionNode(
tool=Problem_Extension_,
node_id="Problem_Extension_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
async def make_read_graph():
"""创建并返回 LangGraph 工作流"""
try:
# Build workflow graph
workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem)
workflow.add_node("Problem_Extension", Problem_Extension)
workflow.add_node("Input_Summary", Input_Summary)
# workflow.add_node("Retrieve", retrieve_nodes)
workflow.add_node("Retrieve", retrieve)
workflow.add_node("Verify", Verify)
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary)
workflow.add_node("Summary_fails", Summary_fails)
# 添加边
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END)
# 编译工作流
graph = workflow.compile()
yield graph
except Exception as e:
print(f"创建工作流失败: {e}")
raise
finally:
print("工作流创建完成")
async def main():
"""主函数 - 运行工作流"""
message = "昨天有什么好看的电影"
group_id = '88a459f5_text09' # 组ID
storage_type = 'neo4j' # 存储类型
search_switch = '1' # 搜索开关
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=17, # 改为整数
service_name="MemoryAgentService"
)
import time
start=time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
print(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构
if 'Summary' in node_name:
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
summary = node_data['InputSummary']['summary_result']
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
summary = node_data['RetrieveSummary']['summary_result']
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
summary = node_data['summary']['summary_result']
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
summary = node_data['SummaryFails']['summary_result']
Retrieve_node = ToolExecutionNode(
tool=Retrieve_,
node_id="Retrieve_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
if spit_data and spit_data != [] and spit_data != {}:
_intermediate_outputs.append(spit_data)
# Problem_Extension 节点
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
if problem_extension and problem_extension != [] and problem_extension != {}:
_intermediate_outputs.append(problem_extension)
# Retrieve 节点
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
# Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n)
Verify_node = ToolExecutionNode(
tool=Verify_,
node_id="Verify_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Summary_node = ToolExecutionNode(
tool=Summary_,
node_id="Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
# Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n)
Summary_fails_node = ToolExecutionNode(
tool=Summary_fails_,
node_id="Summary_fails_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
# # 过滤掉空值
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
#
# # 优化搜索结果
# print("=== 开始优化搜索结果 ===")
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
# result=reorder_output_results(optimized_outputs)
# # 保存优化后的结果到文件
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
# import json
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
#
print(f"=== 最终摘要 ===")
print(summary)
except Exception as e:
import traceback
traceback.print_exc()
Retrieve_Summary_node = ToolExecutionNode(
tool=Retrieve_Summary_,
node_id="Retrieve_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
end=time.time()
print(100*'y')
print(f"总耗时: {end-start}s")
print(100*'y')
Input_Summary_node = ToolExecutionNode(
tool=Input_Summary_,
node_id="Input_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
async def content_input_node(state):
state_search_switch = state.get("search_switch", search_switch)
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
return await create_input_message(
state=state,
tool_name=tool_name,
session_id=f"{session_prefix}_{namespace}",
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
multimodal_processor=multimodal_processor,
memory_config=memory_config,
)
# Build workflow graph
workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem_node)
workflow.add_node("Problem_Extension", Problem_Extension_node)
workflow.add_node("Retrieve", Retrieve_node)
workflow.add_node("Verify", Verify_node)
workflow.add_node("Summary", Summary_node)
workflow.add_node("Summary_fails", Summary_fails_node)
workflow.add_node("Retrieve_Summary", Retrieve_Summary_node)
workflow.add_node("Input_Summary", Input_Summary_node)
# Add edges using imported routers
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
graph = workflow.compile(checkpointer=memory)
yield graph
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -1,13 +0,0 @@
"""LangGraph routing logic."""
from app.core.memory.agent.langgraph_graph.routing.routers import (
Verify_continue,
Retrieve_continue,
Split_continue,
)
__all__ = [
"Verify_continue",
"Retrieve_continue",
"Split_continue",
]

View File

@@ -1,123 +1,62 @@
"""
Routing functions for LangGraph conditional edges.
This module provides routing functions that determine the next node to execute
based on state values. All functions return Literal types for type safety.
"""
import logging
import re
from typing import Literal
from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
logger = logging.getLogger(__name__)
# Global counter for Verify routing
logger = get_agent_logger(__name__)
counter = COUNTState(limit=3)
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing after Verify node based on verification result.
This function checks the verification result in the last message and routes to:
- Summary: if verification succeeded
- content_input: if verification failed and retry limit not reached
- Summary_fails: if verification failed and retry limit reached
Determine routing based on search_switch value.
Args:
state: LangGraph state containing messages
state: State dictionary containing search_switch
Returns:
Next node name as Literal type
Next node to execute
"""
messages = state.get("messages", [])
# Boundary check
if not messages:
logger.warning("[Verify_continue] No messages in state, defaulting to Summary")
counter.reset()
return "Summary"
# Increment counter
counter.add(1)
logger.debug(f"Split_continue state: {state}")
search_switch = state.get('search_switch', '')
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '2':
return 'Input_Summary'
return 'Split_The_Problem' # 默认情况
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
search_switch = state.get('search_switch', '')
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
return 'Retrieve_Summary' # Default based on business logic
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
status=state.get('verify', '')['status']
loop_count = counter.get_total()
logger.debug(f"[Verify_continue] Current loop count: {loop_count}")
# Extract verification result from last message
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"[Verify_continue] Status tools: {status_tools}")
# Route based on verification result
if "success" in status_tools:
print(status)
if "success" in status:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # Max retry count is 2
elif "failed" in status:
if loop_count < 2: # Maximum loop count is 3
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# Default to Summary if status is unclear
counter.reset()
return "Summary"
def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing after Retrieve node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '0': Route to Verify (verification needed)
- search_switch == '1': Route to Retrieve_Summary (direct summary)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
search_switch = extract_search_switch(state)
logger.debug(f"[Retrieve_continue] search_switch: {search_switch}")
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# Default to Retrieve_Summary
logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary")
return 'Retrieve_Summary'
def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing after content_input node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '2': Route to Input_Summary (direct input summary)
- Otherwise: Route to Split_The_Problem (problem decomposition)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
logger.debug(f"[Split_continue] state keys: {state.keys()}")
search_switch = extract_search_switch(state)
logger.debug(f"[Split_continue] search_switch: {search_switch}")
if search_switch == '2':
return 'Input_Summary'
# Default to Split_The_Problem
return 'Split_The_Problem'
# else:
# # Add default return value to avoid returning None
# counter.reset()
# return "Summary" # Default based on business requirements

View File

@@ -1,13 +0,0 @@
"""LangGraph state management utilities."""
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_search_switch,
extract_tool_call_id,
extract_content_payload,
)
__all__ = [
"extract_search_switch",
"extract_tool_call_id",
"extract_content_payload",
]

View File

@@ -1,179 +0,0 @@
"""
State extraction utilities for type-safe access to LangGraph state values.
This module provides utility functions for extracting values from LangGraph state
dictionaries with proper error handling and sensible defaults.
"""
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
def extract_search_switch(state: dict) -> Optional[str]:
"""
Extract search_switch from state or messages.
"""
search_switch = state.get("search_switch")
if search_switch is not None:
return str(search_switch)
# Try to extract from messages
messages = state.get("messages", [])
if not messages:
return None
# 从最新的消息开始查找
for message in reversed(messages):
# 尝试从 tool_calls 中提取
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if isinstance(tool_call, dict):
# 从 tool_call 的 args 中提取
if "args" in tool_call and isinstance(tool_call["args"], dict):
search_switch = tool_call["args"].get("search_switch")
if search_switch is not None:
return str(search_switch)
# 直接从 tool_call 中提取
search_switch = tool_call.get("search_switch")
if search_switch is not None:
return str(search_switch)
# 尝试从 content 中提取(如果是 JSON 格式)
if hasattr(message, "content"):
try:
import json
if isinstance(message.content, str):
content_data = json.loads(message.content)
if isinstance(content_data, dict):
search_switch = content_data.get("search_switch")
if search_switch is not None:
return str(search_switch)
except (json.JSONDecodeError, ValueError):
pass
return None
def extract_tool_call_id(message: Any) -> str:
"""
Extract tool call ID from message using structured attributes.
This function extracts the tool call ID from a message object, handling both
direct attribute access and tool_calls list structures.
Args:
message: Message object (typically ToolMessage or AIMessage)
Returns:
Tool call ID as string
Raises:
ValueError: If tool call ID cannot be extracted
Examples:
>>> message = ToolMessage(content="...", tool_call_id="call_123")
>>> extract_tool_call_id(message)
'call_123'
"""
# Try direct attribute access for ToolMessage
if hasattr(message, "tool_call_id"):
tool_call_id = message.tool_call_id
if tool_call_id:
return str(tool_call_id)
# Try extracting from tool_calls list for AIMessage
if hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "id" in tool_call:
return str(tool_call["id"])
# Try extracting from id attribute
if hasattr(message, "id"):
message_id = message.id
if message_id:
return str(message_id)
# If all else fails, raise an error
raise ValueError(f"Could not extract tool call ID from message: {type(message)}")
def extract_content_payload(message: Any) -> Any:
"""
Extract content payload from ToolMessage, parsing JSON if needed.
This function extracts the content from a message and attempts to parse it as JSON
if it appears to be a JSON string. It handles various message formats and provides
sensible fallbacks.
Args:
message: Message object (typically ToolMessage)
Returns:
Parsed content (dict, list, or str)
Examples:
>>> message = ToolMessage(content='{"key": "value"}')
>>> extract_content_payload(message)
{'key': 'value'}
>>> message = ToolMessage(content='plain text')
>>> extract_content_payload(message)
'plain text'
"""
# Extract raw content
# For ToolMessages (responses from tools), extract from content
if hasattr(message, "content"):
raw_content = message.content
logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}")
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(raw_content, list):
for block in raw_content:
if isinstance(block, dict) and block.get('type') == 'text':
raw_content = block.get('text', '')
logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}")
break
# If content is empty and this is an AIMessage with tool_calls,
# extract from args (this handles the initial tool call from content_input)
if not raw_content and hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "args" in tool_call:
return tool_call["args"]
else:
raw_content = str(message)
# If content is already a dict or list, return it directly
if isinstance(raw_content, (dict, list)):
logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}")
return raw_content
# Try to parse as JSON
if isinstance(raw_content, str):
# First, try direct JSON parsing
try:
parsed = json.loads(raw_content)
logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
return parsed
except (json.JSONDecodeError, ValueError):
pass
# If that fails, try to extract JSON from the string
# This handles cases where the content is embedded in a larger string
import re
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
for candidate in json_candidates:
try:
parsed = json.loads(candidate)
logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
return parsed
except (json.JSONDecodeError, ValueError):
continue
# If all parsing attempts fail, return the raw content
logger.info(f"extract_content_payload: returning raw content (parsing failed)")
return raw_content

View File

@@ -0,0 +1,320 @@
import asyncio
import json
from datetime import datetime, timedelta
from langchain.tools import tool
from pydantic import BaseModel, Field
from app.core.memory.src.search import (
search_by_temporal,
search_by_keyword_temporal,
)
def extract_tool_message_content(response):
"""从agent响应中提取ToolMessage内容和工具名称"""
messages = response.get('messages', [])
for message in messages:
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
# 这是一个ToolMessage
tool_content = message.content
tool_name = None
# 尝试获取工具名称
if hasattr(message, 'name'):
tool_name = message.name
elif hasattr(message, 'tool_name'):
tool_name = message.tool_name
try:
# 解析JSON内容
parsed_content = json.loads(tool_content)
return {
'tool_name': tool_name,
'content': parsed_content
}
except json.JSONDecodeError:
# 如果不是JSON格式直接返回内容
return {
'tool_name': tool_name,
'content': tool_content
}
return None
class TimeRetrievalInput(BaseModel):
"""时间检索工具的输入模式"""
context: str = Field(description="用户输入的查询内容")
group_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(group_id: str):
"""
创建一个带有特定group_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
"""
def clean_temporal_result_fields(data):
"""
清理时间搜索结果中不需要的字段,并修改结构
Args:
data: 要清理的数据
Returns:
清理后的数据
"""
# 需要过滤的字段列表
fields_to_remove = {
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
'valid_at', 'invalid_at', 'statement_ids'
}
if isinstance(data, dict):
cleaned = {}
for key, value in data.items():
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
cleaned_value = clean_temporal_result_fields(value)
# 进一步将内部的 statements 改为 time_search
if 'statements' in cleaned_value:
cleaned['results'] = {
'time_search': cleaned_value['statements']
}
else:
cleaned['results'] = cleaned_value
elif key not in fields_to_remove:
cleaned[key] = clean_temporal_result_fields(value)
return cleaned
elif isinstance(data, list):
return [clean_temporal_result_fields(item) for item in data]
else:
return data
@tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str:
"""
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数:
- context: 查询上下文内容
- start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD
- group_id_param: 组ID可选用于覆盖默认组ID
- clean_output: 是否清理输出中的元数据字段
-end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
"""
async def _async_search():
# 使用传入的参数或默认值
actual_group_id = group_id_param or group_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 基本时间搜索
results = await search_by_temporal(
group_id=actual_group_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=10
)
# 清理结果中不需要的字段
if clean_output:
cleaned_results = clean_temporal_result_fields(results)
else:
cleaned_results = results
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search())
@tool
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
"""
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数:
- context: 查询内容
- days_back: 向前搜索的天数默认7天
- start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD
- clean_output: 是否清理输出中的元数据字段
- end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
"""
async def _async_search():
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
# 关键词时间搜索
results = await search_by_keyword_temporal(
query_text=context,
group_id=group_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=15
)
# 清理结果中不需要的字段
if clean_output:
cleaned_results = clean_temporal_result_fields(results)
else:
cleaned_results = results
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search())
return TimeRetrievalWithGroupId
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"""
创建混合检索工具使用run_hybrid_search进行混合检索优化输出格式并过滤不需要的字段
Args:
memory_config: 内存配置对象
**search_params: 搜索参数包含group_id, limit, include等
"""
def clean_result_fields(data):
"""
递归清理结果中不需要的字段
Args:
data: 要清理的数据(可能是字典、列表或其他类型)
Returns:
清理后的数据
"""
# 需要过滤的字段列表
fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
}
if isinstance(data, dict):
# 对字典进行清理
cleaned = {}
for key, value in data.items():
if key not in fields_to_remove:
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
return cleaned
elif isinstance(data, list):
# 对列表中的每个元素进行清理
return [clean_result_fields(item) for item in data]
else:
# 其他类型直接返回
return data
@tool
async def HybridSearch(
context: str,
search_type: str = "hybrid",
limit: int = 10,
group_id: str = None,
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
clean_output: bool = True # 新增:是否清理输出字段
) -> str:
"""
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
Args:
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
group_id: 组ID用于过滤搜索结果
rerank_alpha: 重排序权重参数
use_forgetting_rerank: 是否使用遗忘重排序
use_llm_rerank: 是否使用LLM重排序
clean_output: 是否清理输出中的元数据字段
"""
try:
# 导入run_hybrid_search函数
from app.core.memory.src.search import run_hybrid_search
# 合并参数,优先使用传入的参数
final_params = {
"query_text": context,
"search_type": search_type,
"group_id": group_id or search_params.get("group_id"),
"limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"output_path": None, # 不保存到文件
"memory_config": memory_config,
"rerank_alpha": rerank_alpha,
"use_forgetting_rerank": use_forgetting_rerank,
"use_llm_rerank": use_llm_rerank
}
# 执行混合检索
raw_results = await run_hybrid_search(**final_params)
# 清理结果中不需要的字段
if clean_output:
cleaned_results = clean_result_fields(raw_results)
else:
cleaned_results = raw_results
# 格式化返回结果
formatted_results = {
"search_query": context,
"search_type": search_type,
"results": cleaned_results
}
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
except Exception as e:
error_result = {
"error": f"混合检索失败: {str(e)}",
"search_query": context,
"search_type": search_type,
"timestamp": datetime.now().isoformat()
}
return json.dumps(error_result, ensure_ascii=False, indent=2)
return HybridSearch
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"""
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
Args:
memory_config: 内存配置对象
**search_params: 搜索参数
"""
@tool
def HybridSearchSync(
context: str,
search_type: str = "hybrid",
limit: int = 10,
group_id: str = None,
clean_output: bool = True
) -> str:
"""
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
Args:
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
group_id: 组ID用于过滤搜索结果
clean_output: 是否清理输出中的元数据字段
"""
async def _async_search():
# 创建异步工具并执行
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
return await async_tool.ainvoke({
"context": context,
"search_type": search_type,
"limit": limit,
"group_id": group_id,
"clean_output": clean_output
})
return asyncio.run(_async_search())
return HybridSearchSync

View File

@@ -1,30 +1,32 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@asynccontextmanager
async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig):
async def make_write_graph():
"""
Create a write graph workflow for memory operations.
Args:
user_id: User identifier
tools: MCP tools loaded from session
@@ -32,43 +34,8 @@ async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: Me
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
logger.info("Loading MCP tools: %s", [t.name for t in tools])
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
if not data_write_tool:
logger.error("Data_write tool not found", exc_info=True)
raise ValueError("Data_write tool not found")
write_node = ToolNode([data_write_tool])
async def call_model(state):
messages = state["messages"]
last_message = messages[-1]
content = last_message[1] if isinstance(last_message, tuple) else last_message.content
# Call Data_write directly with memory_config
write_params = {
"content": content,
"apply_id": apply_id,
"group_id": group_id,
"user_id": user_id,
"memory_config": memory_config,
}
logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}")
write_result = await data_write_tool.ainvoke(write_params)
if isinstance(write_result, dict):
result_content = write_result.get("data", str(write_result))
else:
result_content = str(write_result)
logger.info("Write content: %s", result_content)
return {"messages": [AIMessage(content=result_content)]}
workflow = StateGraph(WriteState)
workflow.add_node("content_input", call_model)
workflow.add_node("content_input", content_input_write)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "content_input")
workflow.add_edge("content_input", "save_neo4j")
@@ -76,5 +43,45 @@ async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: Me
graph = workflow.compile()
yield graph
async def main():
"""主函数 - 运行工作流"""
message = "今天周一"
group_id = 'new_2025test1103' # 组ID
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=17, # 改为整数
service_name="MemoryAgentService"
)
try:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j'==node_name:
massages=node_data
massages=massages.get('write_result')['status']
print(massages) # | 更新数据: {node_data}
except Exception as e:
import traceback
traceback.print_exc()
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -1,28 +0,0 @@
"""
MCP Server package for memory agent.
This package provides the FastMCP server implementation with context-based
dependency injection for tool functions.
Package structure:
- server: FastMCP server initialization and context setup
- tools: MCP tool implementations
- models: Pydantic response models
- services: Business logic services
"""
# from app.core.memory.agent.mcp_server.server import (
# mcp,
# initialize_context,
# main,
# get_context_resource
# )
# # Import tools to register them (but don't export them)
# from app.core.memory.agent.mcp_server import tools
# __all__ = [
# 'mcp',
# 'initialize_context',
# 'main',
# 'get_context_resource',
# ]

View File

@@ -1,11 +0,0 @@
"""
MCP Server Instance
This module contains the FastMCP server instance that is shared across all modules.
It's in a separate file to avoid circular import issues.
"""
from mcp.server.fastmcp import FastMCP
# Initialize FastMCP server instance
# This instance is shared across all tool modules
mcp = FastMCP('data_flow')

View File

@@ -1,159 +0,0 @@
"""
MCP Server initialization with FastMCP context setup.
This module initializes the FastMCP server and registers shared resources
in the context for dependency injection into tool functions.
"""
import os
import sys
from app.core.config import settings
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.services.search_service import SearchService
from app.core.memory.agent.mcp_server.services.session_service import SessionService
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import store
logger = get_agent_logger(__name__)
def get_context_resource(ctx, resource_name: str):
"""
Helper function to retrieve a resource from the FastMCP context.
Args:
ctx: FastMCP Context object (passed to tool functions)
resource_name: Name of the resource to retrieve
Returns:
The requested resource
Raises:
AttributeError: If the resource doesn't exist
Example:
@mcp.tool()
async def my_tool(ctx: Context):
template_service = get_context_resource(ctx, 'template_service')
llm_client = get_context_resource(ctx, 'llm_client')
"""
if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None:
raise RuntimeError("Context does not have fastmcp attribute")
if not hasattr(ctx.fastmcp, resource_name):
raise AttributeError(
f"Resource '{resource_name}' not found in context. "
f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}"
)
return getattr(ctx.fastmcp, resource_name)
def initialize_context():
"""
Initialize and register shared resources in FastMCP context.
This function sets up all shared resources that will be available
to tool functions via dependency injection through the context parameter.
Resources are stored as attributes on the FastMCP instance and can be
accessed via ctx.fastmcp in tool functions.
Resources registered:
- session_store: RedisSessionStore for session management
- llm_client: LLM client for structured API calls
- app_settings: Application settings (renamed to avoid conflict with FastMCP settings)
- template_service: Service for template rendering
- search_service: Service for hybrid search
- session_service: Service for session operations
"""
try:
# Register Redis session store
logger.info("Registering session_store in context")
mcp.session_store = store
# Note: LLM client is NOT loaded at server startup
# It should be loaded dynamically when needed, with config_id passed explicitly
# to make_write_graph or make_read_graph functions
logger.info("LLM client will be loaded dynamically with config_id when needed")
mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id
# Register application settings (renamed to avoid conflict with FastMCP's settings)
logger.info("Registering app_settings in context")
mcp.app_settings = settings
# Register template service
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
# logger.info(f"Registering template_service in context with root: {template_root}")
template_service = TemplateService(template_root)
mcp.template_service = template_service
# Register search service
# logger.info("Registering search_service in context")
search_service = SearchService()
mcp.search_service = search_service
# Register session service
# logger.info("Registering session_service in context")
session_service = SessionService(store)
mcp.session_service = session_service
# logger.info("All context resources registered successfully")
except Exception as e:
logger.error(f"Failed to initialize context: {e}", exc_info=True)
raise
def main():
"""
Main entry point for the MCP server.
Initializes context and starts the server with SSE transport.
"""
try:
logger.info("Starting MCP server initialization")
# Initialize context resources
initialize_context()
# Import and register tools (imports trigger tool registration)
from app.core.memory.agent.mcp_server.tools import ( # noqa: F401
data_tools,
problem_tools,
retrieval_tools,
summary_tools,
verification_tools,
)
# Tools are registered via imports above
# Get MCP port from environment (default: 8081)
mcp_port = int(os.getenv("MCP_PORT", "8081"))
logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
# Configure DNS rebinding protection for Docker container compatibility
from mcp.server.fastmcp.server import TransportSecuritySettings
# Disable DNS rebinding protection to allow Docker container hostnames
# This allows containers to connect using service names like 'mcp-server'
mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=False,
)
logger.info("DNS rebinding protection: disabled for Docker container compatibility")
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
# Run the server with SSE transport for HTTP connections
import uvicorn
app = mcp.sse_app()
uvicorn.run(app, host=settings.SERVER_IP, port=mcp_port, log_level="info")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,27 +0,0 @@
"""
MCP Tools module.
This module contains all MCP tool implementations organized by functionality.
Tools are organized into the following modules:
- problem_tools: Question segmentation and extension
- retrieval_tools: Database and context retrieval
- verification_tools: Data verification
- summary_tools: Summarization and summary retrieval
- data_tools: Data type differentiation and writing
"""
# Import all tool modules to register them with the MCP server
from . import problem_tools
from . import retrieval_tools
from . import verification_tools
from . import summary_tools
from . import data_tools
__all__ = [
'problem_tools',
'retrieval_tools',
'verification_tools',
'summary_tools',
'data_tools',
]

View File

@@ -1,155 +0,0 @@
"""
Data Tools for data type differentiation and writing.
This module contains MCP tools for distinguishing data types and writing data.
"""
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.models.retrieval_models import (
DistinguishTypeResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.write_tools import write
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@mcp.tool()
async def Data_type_differentiation(
ctx: Context,
context: str,
memory_config: MemoryConfig,
) -> dict:
"""
Distinguish the type of data (read or write).
Args:
ctx: FastMCP context for dependency injection
context: Text to analyze for type differentiation
memory_config: MemoryConfig object containing LLM configuration
Returns:
dict: Contains 'context' with the original text and 'type' field
"""
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
# Get LLM client from memory_config using factory pattern
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Render template
try:
system_prompt = await template_service.render_template(
template_name='distinguish_types_prompt.jinja2',
operation_name='status_typle',
user_query=context
)
except Exception as e:
logger.error(
f"Template rendering failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"type": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=DistinguishTypeResponse
)
result = structured.model_dump()
# Add context to result
result["context"] = context
return result
except Exception as e:
logger.error(
f"LLM call failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": f"LLM call failed: {str(e)}"
}
except Exception as e:
logger.error(
f"Data_type_differentiation failed: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": str(e)
}
@mcp.tool()
async def Data_write(
ctx: Context,
content: str,
user_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
) -> dict:
"""
Write data to the database/file system.
Args:
ctx: FastMCP context for dependency injection
content: Data content to write
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
"""
try:
# Ensure output directory exists
os.makedirs("data_output", exist_ok=True)
file_path = os.path.join("data_output", "user_data.csv")
# Write data - clients are constructed inside write() from memory_config
await write(
content=content,
user_id=user_id,
apply_id=apply_id,
group_id=group_id,
memory_config=memory_config,
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
return {
"status": "success",
"saved_to": file_path,
"data": content,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name,
}
except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True)
return {
"status": "error",
"message": str(e),
}

View File

@@ -1,304 +0,0 @@
"""
Problem Tools for question segmentation and extension.
This module contains MCP tools for breaking down and extending user questions.
LLM clients are constructed from MemoryConfig when needed.
"""
import json
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.models.problem_models import (
ProblemBreakdownResponse,
ProblemExtensionResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@mcp.tool()
async def Split_The_Problem(
ctx: Context,
sentence: str,
sessionid: str,
messages_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
) -> dict:
"""
Segment the dialogue or sentence into sub-problems.
Args:
ctx: FastMCP context for dependency injection
sentence: Original sentence to split
sessionid: Session identifier
messages_id: Message identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'context' (JSON string of split results) and 'original' sentence
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Extract user ID from session
user_id = session_service.resolve_user_id(sessionid)
# Get conversation history
history = await session_service.get_history(user_id, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Render template
try:
system_prompt = await template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
history=history,
sentence=sentence
)
except Exception as e:
logger.error(
f"Template rendering failed for Split_The_Problem: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemBreakdownResponse
)
# Handle RootModel response with .root attribute access
if structured is None:
# LLM returned None, use empty list as fallback
split_result = json.dumps([], ensure_ascii=False)
elif hasattr(structured, 'root') and structured.root is not None:
split_result = json.dumps(
[item.model_dump() for item in structured.root],
ensure_ascii=False
)
elif isinstance(structured, list):
# Fallback: treat structured itself as the list
split_result = json.dumps(
[item.model_dump() for item in structured],
ensure_ascii=False
)
else:
# Last resort: use empty list
split_result = json.dumps([], ensure_ascii=False)
except Exception as e:
logger.error(
f"LLM call failed for Split_The_Problem: {e}",
exc_info=True
)
split_result = json.dumps([], ensure_ascii=False)
logger.info("Problem splitting")
logger.info(f"Problem split result: {split_result}")
# Emit intermediate output for frontend
result = {
"context": split_result,
"original": sentence,
"_intermediate": {
"type": "problem_split",
"data": json.loads(split_result) if split_result else [],
"original_query": sentence
}
}
return result
except Exception as e:
logger.error(
f"Split_The_Problem failed: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Problem splitting', duration)
@mcp.tool()
async def Problem_Extension(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Extend the problem with additional sub-questions.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing split problem results
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'context' (aggregated questions) and 'original' question
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID from usermessages
from app.core.memory.agent.utils.messages_tool import Resolve_username
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Process context to extract questions
extent_quest, original = await Problem_Extension_messages_deal(context)
# Format questions for template rendering
questions_formatted = []
for msg in extent_quest:
if msg.get("role") == "user":
questions_formatted.append(msg.get("content", ""))
# Render template
try:
system_prompt = await template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
history=history,
questions=questions_formatted
)
except Exception as e:
logger.error(
f"Template rendering failed for Problem_Extension: {e}",
exc_info=True
)
return {
"context": {},
"original": original,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
response_content = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemExtensionResponse
)
# Aggregate results by original question
aggregated_dict = {}
for item in response_content.root:
key = getattr(item, "original_question", None) or (
item.get("original_question") if isinstance(item, dict) else None
)
value = getattr(item, "extended_question", None) or (
item.get("extended_question") if isinstance(item, dict) else None
)
if not key or not value:
continue
aggregated_dict.setdefault(key, []).append(value)
except Exception as e:
logger.error(
f"LLM call failed for Problem_Extension: {e}",
exc_info=True
)
aggregated_dict = {}
logger.info("Problem extension")
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
result = {
"context": aggregated_dict,
"original": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "problem_extension",
"data": aggregated_dict,
"original_query": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return result
except Exception as e:
logger.error(
f"Problem_Extension failed: {e}",
exc_info=True
)
return {
"context": {},
"original": context.get("original", ""),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Problem extension', duration)

View File

@@ -1,294 +0,0 @@
"""
Retrieval Tools for database and context retrieval.
This module contains MCP tools for retrieving data using hybrid search.
"""
import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.llm_tools import (
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
from app.core.rag.nlp.search import knowledge_retrieval
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from mcp.server.fastmcp import Context
load_dotenv()
logger = get_agent_logger(__name__)
@mcp.tool()
async def Retrieve(
ctx: Context,
context,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Retrieve data from the database using hybrid search.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary or string containing query information
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'context' with Query and Expansion_issue results
"""
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
start = time.time()
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}")
try:
# Extract services from context
search_service = get_context_resource(ctx, 'search_service')
databases_anser = []
# Handle both dict and string context
if isinstance(context, dict):
# Process dict context with extended questions
all_items = []
logger.info(f"Retrieve: context keys={list(context.keys())}")
content, original = await Retriev_messages_deal(context)
logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}")
logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'")
if not original:
logger.warning(f"Retrieve: original query is empty! context={context}")
# Extract all query items from content
# content is like {original_question: [extended_questions...], ...}
for key, values in content.items():
if isinstance(values, list):
all_items.extend(values)
elif isinstance(values, str):
all_items.append(values)
elif values is not None:
# Fallback: convert non-empty non-list values to string
all_items.append(str(values))
# Execute search for each question
for idx, question in enumerate(all_items):
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": question,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query=question
raw_results=clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results=''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
databases_anser.append({
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(all_items)
}
})
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Continue with empty result for this question
databases_anser.append({
"Query_small": question,
"Result_small": ""
})
# Build initial database data structure
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val, strict=False):
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
else:
# Handle string context (simple query)
query = str(context).strip()
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = query
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results = ''
cleaned_query = query
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
# Keep structure for Verify/Retrieve_Summary compatibility
dup_databases = {
"Query": cleaned_query,
"Expansion_issue": [{
"Query_small": cleaned_query,
"Answer_Small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": 1,
"total": 1
}
}]
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for query '{query}': {e}",
exc_info=True
)
# Return empty results on failure
dup_databases = {
"Query": query,
"Expansion_issue": []
}
logger.info(
f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
)
# Build result with intermediate outputs
result = {
"context": dup_databases,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
# Add intermediate outputs list if they exist
intermediate_outputs = dup_databases.get('_intermediate_outputs', [])
if intermediate_outputs:
result['_intermediates'] = intermediate_outputs
logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result")
else:
logger.warning("No intermediate outputs found in dup_databases")
return result
except Exception as e:
logger.error(
f"Retrieve failed: {e}",
exc_info=True
)
return {
"context": {
"Query": "",
"Expansion_issue": []
},
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval', duration)

View File

@@ -1,640 +0,0 @@
"""
Summary Tools for data summarization.
This module contains MCP tools for summarizing retrieved data and generating responses.
LLM clients are constructed from MemoryConfig when needed.
"""
import json
import os
import re
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import (
Resolve_username,
Summary_messages_deal,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from mcp.server.fastmcp import Context
load_dotenv()
logger = get_agent_logger(__name__)
@mcp.tool()
async def Summary(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Summarize the verified data.
Args:
ctx: FastMCP context for dependency injection
context: JSON string containing verified data
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Process context to extract answer and query
answer_small, query = await Summary_messages_deal(context)
start_time= time.time()
history = await session_service.get_history(sessionid, apply_id, group_id)
end_time=time.time()
logger.info(f"Retrieve_Summary-REDIS搜索{end_time - start_time}")
data = {
"query": query,
"history": history,
"retrieve_info": answer_small
}
except Exception as e:
logger.error(
f"Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='summary_prompt.jinja2',
operation_name='summary',
data=data,
query=query
)
except Exception as e:
logger.error(
f"Template rendering failed for Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=SummaryResponse
)
aimessages = structured.query_answer or ""
except Exception as e:
logger.error(
f"LLM call failed for Summary: {e}",
exc_info=True
)
aimessages = ""
try:
# Save session
if aimessages != "":
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}",
exc_info=True
)
return {
"status": "error",
"message": str(e)
}
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"Summary after verification: {aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Summary', duration)
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
@mcp.tool()
async def Retrieve_Summary(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Summarize data directly from retrieval results.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing Query and Expansion_issue from Retrieve
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}")
if isinstance(context, dict):
if "content" in context:
inner = context["content"]
# If it's a JSON string, parse it
if isinstance(inner, str):
try:
parsed = json.loads(inner)
logger.info("Retrieve_Summary: successfully parsed JSON")
except json.JSONDecodeError:
# Try unescaping first
try:
unescaped = inner.encode('utf-8').decode('unicode_escape')
parsed = json.loads(unescaped)
logger.info("Retrieve_Summary: parsed after unescaping")
except (json.JSONDecodeError, UnicodeDecodeError) as e:
logger.error(
f"Retrieve_Summary: parsing failed even after unescape: {e}"
)
context_dict = {"Query": "", "Expansion_issue": []}
parsed = None
if parsed:
# Check if parsed has 'context' wrapper
if isinstance(parsed, dict) and "context" in parsed:
context_dict = parsed["context"]
else:
context_dict = parsed
elif isinstance(inner, dict):
context_dict = inner
else:
context_dict = {"Query": "", "Expansion_issue": []}
elif "context" in context:
context_dict = context["context"] if isinstance(context["context"], dict) else context
else:
context_dict = context
else:
context_dict = {"Query": "", "Expansion_issue": []}
query = context_dict.get("Query", "")
expansion_issue = context_dict.get("Expansion_issue", [])
logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}")
logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}")
# Extract retrieve_info from expansion_issue
retrieve_info = []
for item in expansion_issue:
# Check for both Answer_Small and Answer_Small (typo) for backward compatibility
answer = None
if isinstance(item, dict):
if "Answer_Small" in item:
answer = item["Answer_Small"]
if answer is not None:
# Handle both string and list formats
if isinstance(answer, list):
# Join list of characters/strings into a single string
retrieve_info.append(''.join(str(x) for x in answer))
elif isinstance(answer, str):
retrieve_info.append(answer)
else:
retrieve_info.append(str(answer))
# Join all retrieve_info into a single string
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
start_time=time.time()
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
end_time=time.time()
logger.info(f"Retrieve_Summary-REDIS搜索{end_time - start_time}")
except Exception as e:
logger.error(
f"Retrieve_Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='retrieve_summary',
query=query,
history=history,
retrieve_info=retrieve_info_str
)
except Exception as e:
logger.error(
f"Template rendering failed for Retrieve_Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
# Handle case where structured response might be None or incomplete
if structured and hasattr(structured, 'data') and structured.data:
aimessages = structured.data.query_answer or ""
else:
logger.warning("Structured response is None or incomplete, using default message")
aimessages = "信息不足,无法回答"
# Check for insufficient information response
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
# Save session
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"Retrieve_Summary: LLM call failed: {e}",
exc_info=True
)
aimessages = ""
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"Summary after retrieval: {aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval summary', duration)
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "retrieval_summary",
"summary": aimessages,
"query": query,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
@mcp.tool()
async def Input_Summary(
ctx: Context,
context: str,
usermessages: str,
search_switch: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Generate a quick summary for direct input without verification.
Args:
ctx: FastMCP context for dependency injection
context: String containing the input sentence
usermessages: User messages identifier
search_switch: Search switch value for routing ('2' for summaries only)
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'query_answer' with the summary result
"""
start = time.time()
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
# Extract services from context
session_service = get_context_resource(ctx, "session_service")
search_service = get_context_resource(ctx, "search_service")
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
sessionid = sessionid.replace('call_id_', '')
start_time=time.time()
history = await session_service.get_history(
str(sessionid),
str(apply_id),
str(group_id)
)
end_time=time.time()
logger.info(f"Input_Summary-REDIS搜索{end_time - start_time}")
# Override with empty list for now (as in original)
# Log the raw context for debugging
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
# Extract sentence from context
# Context can be a string or might contain the sentence in various formats
try:
# Try to parse as JSON first
if isinstance(context, str) and (context.startswith('{') or context.startswith('[')):
try:
import json
context_dict = json.loads(context)
if isinstance(context_dict, dict):
query = context_dict.get('sentence', context_dict.get('content', context))
else:
query = context
except json.JSONDecodeError:
# Not valid JSON, try regex
match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context)
query = match.group(1) if match else context
else:
query = context
except Exception as e:
logger.warning(f"Failed to extract query from context: {e}")
query = context
# Clean query
query = str(query).strip().strip("\"'")
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
# Execute search based on search_switch and storage_type
try:
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
# Retrieval
if search_switch == '2':
search_params["include"] = ["summaries"]
if storage_type == "rag" and user_rag_memory_id:
raw_results = []
retrieve_info = ""
kb_config={
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id":os.getenv('reranker_id'),
"reranker_top_k": 10
}
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
retrieve_info = '\n\n'.join(retrieval_knowledge)
raw_results=[retrieve_info]
logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}")
except:
retrieve_info=''
raw_results=['']
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
logger.info("Input_Summary: Using summary for retrieval")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
except Exception as e:
logger.error(
f"Input_Summary: hybrid_search failed, using empty results: {e}",
exc_info=True
)
retrieve_info, question, raw_results = "", query, []
# Return retrieved information directly without LLM processing
# Use the raw retrieved info as the answer
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "input_summary",
"title": "快速答案",
"summary": aimessages,
"query": query,
"raw_results": raw_results,
"search_mode": "quick_search",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Input_Summary failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "信息不足,无法回答",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval', duration)
@mcp.tool()
async def Summary_fails(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Handle workflow failure when summary cannot be generated.
Args:
ctx: FastMCP context for dependency injection
context: Failure context string
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'query_answer' with failure message
"""
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Parse session ID from usermessages
usermessages_parts = usermessages.split('_')[1:]
sessionid = '_'.join(usermessages_parts[:-1])
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
logger.info("没有相关数据")
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
return {
"status": "success",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
except Exception as e:
logger.error(
f"Summary_fails failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}

View File

@@ -1,174 +0,0 @@
"""
Verification Tools for data verification.
This module contains MCP tools for verifying retrieved data.
"""
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import (
Resolve_username,
Retrieve_verify_tool_messages_deal,
Verify_messages_deal,
)
from app.core.memory.agent.utils.verify_tool import VerifyTool
from app.schemas.memory_config_schema import MemoryConfig
from jinja2 import Template
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@mcp.tool()
async def Verify(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Verify the retrieved data.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing query and expansion issues
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'verified_data' with verification results
"""
start = time.time()
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Load verification prompt template
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2'
# Read template file directly (VerifyTool expects raw template content)
from app.core.memory.agent.utils.messages_tool import read_template_file
system_prompt = await read_template_file(file_path)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
template = Template(system_prompt)
system_prompt = template.render(history=history, sentence=context)
# Process context to extract query and results
Query_small, Result_small, query = await Verify_messages_deal(context)
# Build query list for verification
query_list = []
for query_small, anser in zip(Query_small, Result_small, strict=False):
query_list.append({
'Query_small': query_small,
'Answer_Small': anser
})
messages = {
"Query": query,
"Expansion_issue": query_list
}
# Call verification workflow with LLM model ID from memory_config
verify_tool = VerifyTool(
system_prompt=system_prompt,
verify_data=messages,
llm_model_id=str(memory_config.llm_model_id)
)
verify_result = await verify_tool.verify()
# Parse LLM verification result with error handling
try:
messages_deal = await Retrieve_verify_tool_messages_deal(
verify_result,
history,
query
)
except Exception as e:
logger.error(
f"Retrieve_verify_tool_messages_deal parsing failed: {e}",
exc_info=True
)
# Fallback to avoid 500 errors
messages_deal = {
"data": {
"query": query,
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": history,
}
logger.info(f"Verification result: {messages_deal}")
# Emit intermediate output for frontend
return {
"status": "success",
"verified_data": messages_deal,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "Data Verification",
"result": messages_deal.get("split_result", "unknown"),
"reason": messages_deal.get("reason", ""),
"query": query,
"verified_count": len(query_list),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Verify failed: {e}",
exc_info=True
)
return {
"status": "error",
"message": str(e),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"verified_data": {
"data": {
"query": "",
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": [],
}
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Verification', duration)

View File

@@ -1,114 +0,0 @@
import os
import sys
import traceback
import requests
# from qcloud_cos import CosConfig, CosS3Client
# from qcloud_cos.cos_exception import CosClientError, CosServiceError
# from config.paths import BASE_DIR
BASE_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
class OSSUploader:
"""对象存储文件上传工具类"""
def __init__(self, env):
api = {
"test": "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon",
"prod": "https://lingqi.redbearai.com/api/user/file/common/upload/v2/anon"
}
self.api = api.get(env, "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon")
self.privacy = "false"
self.headers = {
"User-Agent": 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
'AppleWebKit/537.36 (KHTML, like Gecko)'
' Chrome/133.0.6833.84 Safari/537.36'
}
@staticmethod
def _generate_object_key(file_path, prefix='xhs_'):
"""
生成对象存储的Key
:param file_path: 本地文件路径
:param prefix: 存储前缀,用于分类存储
:return: 生成的对象Key
"""
# 文件md5值.后缀名
filename = os.path.basename(file_path)
filename = f"{filename}"
# 组合成完整的对象Key
return f"{prefix}{filename}"
def upload_image(self, file_name, prefix='jd_'):
"""
上传文件到COS并返回可访问的URL
:param file_url: 文件路径
:param file_name: 文件名称
:param media_type: 文件类型
:param prefix: 存储前缀,用于分类存储
:return: 文件访问URL
"""
# 检查文件是否存在
file_path = os.path.join(BASE_DIR, file_name)
# response = requests.get(url, headers=self.headers, stream=True)
# if response.status_code == 200:
# with open(file_path, "wb") as f:
# for chunk in response.iter_content(1024): # 分块写入,避免内存占用过大
# f.write(chunk)
# else:
# raise Exception(f"文件下载失败,{file_name}")
# 生成对象Key
object_key = self._generate_object_key(file_path, prefix +file_name.split('.')[-1])
try:
upload_response = requests.post(
self.api,
data={
"privacy": self.privacy,
"fileName": object_key,
}
)
if upload_response.status_code != 200:
raise Exception('上传接口请求失败')
resp = upload_response.json()
name = resp["data"]["name"]
file_url = resp["data"]["path"]
policy = resp["data"]["policy"]
with open(file_path, 'rb') as f:
oss_push_resp = requests.post(
policy["host"],
files={
"key": policy["dir"],
"OSSAccessKeyId": policy["accessid"],
"name": name,
"policy": policy["policy"],
"success_action_status": 200,
"signature": policy["signature"],
"file": f,
}
)
if oss_push_resp.status_code == 200:
return file_url
raise Exception("OSS上传失败")
except Exception:
raise Exception(f"上传失败: \n{traceback.format_exc()}")
finally:
print('success')
# os.remove(file_path)
if __name__ == '__main__':
cos_uploader = OSSUploader("prod")
url =cos_uploader.upload_image('./example01.jpg')
print(url)

View File

@@ -1,121 +0,0 @@
import asyncio
import re
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_, picture_model_requests,Picture_recognize, Voice_recognize
from app.core.memory.agent.utils.messages_tool import read_template_file
import requests
import json
import os
import time
# file_urls = [
# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_female2.wav",
# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav",
# ]
class Vico_recognition:
def __init__(self,file_urls):
self.api_key=''
self.backend_model_name=''
self.api_base=''
self.file_urls=file_urls
# 提交文件转写任务包含待转写文件url列表
async def submit_task(self) -> str:
self.api_key, self.backend_model_name, self.api_base =await Voice_recognize()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"X-DashScope-Async": "enable",
}
data = {
"model": self.backend_model_name,
"input": {"file_urls": self.file_urls},
"parameters": {
"channel_id": [0],
"vocabulary_id": "vocab-Xxxx",
},
}
# 录音文件转写服务url
service_url = (
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription"
)
response = requests.post(
service_url, headers=headers, data=json.dumps(data)
)
# 打印响应内容
if response.status_code == 200:
return response.json()["output"]["task_id"]
else:
print("task failed!")
print(response.json())
return None
async def download_transcription_result(self, transcription_url):
"""
Args:
transcription_url (str): 转写结果文件URL
Returns:
dict: 转写结果内容
"""
try:
response = requests.get(transcription_url)
response.raise_for_status()
return response.json()
except Exception as e:
print(f"下载转写结果失败: {e}")
return None
# 循环查询任务状态直到成功
async def wait_for_complete(self,task_id):
self.api_key, self.backend_model_name, self.api_base = await Voice_recognize()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"X-DashScope-Async": "enable",
}
pending = True
while pending:
# 查询任务状态服务url
service_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
response = requests.post(
service_url, headers=headers
)
if response.status_code == 200:
status = response.json()['output']['task_status']
if status == 'SUCCEEDED':
print("task succeeded!")
pending = False
return response.json()['output']['results']
elif status == 'RUNNING' or status == 'PENDING':
pass
else:
print("task failed!")
pending = False
else:
print("query failed!")
pending = False
time.sleep(0.1)
async def run(self):
self.api_key, self.backend_model_name, self.api_base = await Voice_recognize()
task_id=await self.submit_task()
result=await self.wait_for_complete(task_id)
result_context=[]
for i in result:
transcription_url=i['transcription_url']
print(f"转写URL: {transcription_url}")
# 下载并打印转写内容
content = await self.download_transcription_result(transcription_url)
if content:
content=json.dumps(content, indent=2, ensure_ascii=False)
context=re.findall(r'"text": "(.*?)"', content)
result_context.append(context[0])
result=''.join(result_context)
return (result)

View File

@@ -0,0 +1,277 @@
"""
优化的LLM服务类用于压缩和统一LLM调用
"""
import asyncio
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.logging_config import get_agent_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.llm_tools.openai_client import OpenAIClient
T = TypeVar('T', bound=BaseModel)
logger = get_agent_logger(__name__)
class OptimizedLLMService:
"""
优化的LLM服务类提供统一的LLM调用接口
特性:
1. 客户端复用 - 避免重复创建LLM客户端
2. 批量处理 - 支持并发处理多个请求
3. 错误处理 - 统一的错误处理和降级策略
4. 性能优化 - 缓存和连接池优化
"""
def __init__(self, db_session: Session):
self.db_session = db_session
self.client_factory = MemoryClientFactory(db_session)
self._client_cache: Dict[str, OpenAIClient] = {}
def _get_cached_client(self, llm_model_id: str) -> OpenAIClient:
"""获取缓存的LLM客户端避免重复创建"""
if llm_model_id not in self._client_cache:
self._client_cache[llm_model_id] = self.client_factory.get_llm_client(llm_model_id)
return self._client_cache[llm_model_id]
async def structured_response(
self,
llm_model_id: str,
system_prompt: str,
response_model: Type[T],
user_message: Optional[str] = None,
fallback_value: Optional[Any] = None
) -> T:
"""
统一的结构化响应接口
Args:
llm_model_id: LLM模型ID
system_prompt: 系统提示词
response_model: 响应模型类
user_message: 用户消息(可选)
fallback_value: 失败时的降级值
Returns:
结构化响应对象
"""
try:
llm_client = self._get_cached_client(llm_model_id)
messages = [{"role": "system", "content": system_prompt}]
if user_message:
messages.append({"role": "user", "content": user_message})
logger.debug(f"LLM调用: model={llm_model_id}, prompt_length={len(system_prompt)}")
structured = await llm_client.response_structured(
messages=messages,
response_model=response_model
)
if structured is None:
logger.warning(f"LLM返回None使用降级值")
return self._create_fallback_response(response_model, fallback_value)
return structured
except Exception as e:
logger.error(f"结构化响应失败: {e}", exc_info=True)
return self._create_fallback_response(response_model, fallback_value)
async def batch_structured_response(
self,
llm_model_id: str,
requests: List[Dict[str, Any]],
response_model: Type[T],
max_concurrent: int = 5
) -> List[T]:
"""
批量处理结构化响应
Args:
llm_model_id: LLM模型ID
requests: 请求列表每个请求包含system_prompt等参数
response_model: 响应模型类
max_concurrent: 最大并发数
Returns:
结构化响应列表
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def process_single_request(request: Dict[str, Any]) -> T:
async with semaphore:
return await self.structured_response(
llm_model_id=llm_model_id,
system_prompt=request.get('system_prompt', ''),
response_model=response_model,
user_message=request.get('user_message'),
fallback_value=request.get('fallback_value')
)
tasks = [process_single_request(req) for req in requests]
return await asyncio.gather(*tasks)
async def simple_response(
self,
llm_model_id: str,
system_prompt: str,
user_message: Optional[str] = None,
fallback_message: str = "信息不足,无法回答"
) -> str:
"""
简单的文本响应接口
Args:
llm_model_id: LLM模型ID
system_prompt: 系统提示词
user_message: 用户消息(可选)
fallback_message: 失败时的降级消息
Returns:
响应文本
"""
try:
llm_client = self._get_cached_client(llm_model_id)
messages = [{"role": "system", "content": system_prompt}]
if user_message:
messages.append({"role": "user", "content": user_message})
response = await llm_client.response(messages=messages)
if not response or not response.strip():
return fallback_message
return response.strip()
except Exception as e:
logger.error(f"简单响应失败: {e}", exc_info=True)
return fallback_message
def _create_fallback_response(self, response_model: Type[T], fallback_value: Optional[Any]) -> T:
"""创建降级响应"""
try:
if fallback_value is not None:
if isinstance(fallback_value, response_model):
return fallback_value
elif isinstance(fallback_value, dict):
return response_model(**fallback_value)
# 尝试创建空的响应模型
if hasattr(response_model, 'root'):
# RootModel类型
return response_model([])
else:
# 普通BaseModel类型
return response_model()
except Exception as e:
logger.error(f"创建降级响应失败: {e}")
# 最后的降级策略
if hasattr(response_model, 'root'):
return response_model([])
else:
return response_model()
def clear_cache(self):
"""清理客户端缓存"""
self._client_cache.clear()
class LLMServiceMixin:
"""
LLM服务混入类为节点提供便捷的LLM调用方法
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._llm_service: Optional[OptimizedLLMService] = None
def get_llm_service(self, db_session: Session) -> OptimizedLLMService:
"""获取LLM服务实例"""
if self._llm_service is None:
self._llm_service = OptimizedLLMService(db_session)
return self._llm_service
async def call_llm_structured(
self,
state: Dict[str, Any],
db_session: Session,
system_prompt: str,
response_model: Type[T],
user_message: Optional[str] = None,
fallback_value: Optional[Any] = None
) -> T:
"""
便捷的结构化LLM调用方法
Args:
state: 状态字典包含memory_config
db_session: 数据库会话
system_prompt: 系统提示词
response_model: 响应模型类
user_message: 用户消息(可选)
fallback_value: 失败时的降级值
Returns:
结构化响应对象
"""
memory_config = state.get('memory_config')
if not memory_config:
raise ValueError("State中缺少memory_config")
llm_model_id = memory_config.llm_model_id
if not llm_model_id:
raise ValueError("Memory config中缺少llm_model_id")
llm_service = self.get_llm_service(db_session)
return await llm_service.structured_response(
llm_model_id=llm_model_id,
system_prompt=system_prompt,
response_model=response_model,
user_message=user_message,
fallback_value=fallback_value
)
async def call_llm_simple(
self,
state: Dict[str, Any],
db_session: Session,
system_prompt: str,
user_message: Optional[str] = None,
fallback_message: str = "信息不足,无法回答"
) -> str:
"""
便捷的简单LLM调用方法
Args:
state: 状态字典包含memory_config
db_session: 数据库会话
system_prompt: 系统提示词
user_message: 用户消息(可选)
fallback_message: 失败时的降级消息
Returns:
响应文本
"""
memory_config = state.get('memory_config')
if not memory_config:
raise ValueError("State中缺少memory_config")
llm_model_id = memory_config.llm_model_id
if not llm_model_id:
raise ValueError("Memory config中缺少llm_model_id")
llm_service = self.get_llm_service(db_session)
return await llm_service.simple_response(
llm_model_id=llm_model_id,
system_prompt=system_prompt,
user_message=user_message,
fallback_message=fallback_message
)

View File

@@ -4,22 +4,19 @@ Parameter Builder for constructing tool call arguments.
This service provides tool-specific parameter transformation logic
to build correct arguments for each tool type.
"""
from typing import Any, Dict, Optional
from app.core.logging_config import get_agent_logger
from app.schemas.memory_config_schema import MemoryConfig
logger = get_agent_logger(__name__)
class ParameterBuilder:
"""Service for building tool call arguments based on tool type."""
def __init__(self):
"""Initialize the parameter builder."""
logger.info("ParameterBuilder initialized")
def build_tool_args(
self,
tool_name: str,
@@ -28,9 +25,8 @@ class ParameterBuilder:
search_switch: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Build tool arguments based on tool type.
@@ -49,7 +45,6 @@ class ParameterBuilder:
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
@@ -60,19 +55,18 @@ class ParameterBuilder:
base_args = {
"usermessages": tool_call_id,
"apply_id": apply_id,
"group_id": group_id,
"memory_config": memory_config,
"group_id": group_id
}
# Always add storage_type and user_rag_memory_id (with defaults if None)
base_args["storage_type"] = storage_type if storage_type is not None else ""
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
# Tool-specific argument construction
if tool_name in ["Verify", "Summary", "Summary_fails", "Retrieve_Summary", "Problem_Extension"]:
# These tools expect dict context
if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']:
# Verify expects dict context
return {
"context": content if isinstance(content, dict) else {"content": content},
"context": content if isinstance(content, dict) else {},
**base_args
}

View File

@@ -4,31 +4,21 @@ Search Service for executing hybrid search and processing results.
This service provides clean search result processing with content extraction
and deduplication.
"""
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import List, Tuple, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
logger = get_agent_logger(__name__)
class SearchService:
"""Service for executing hybrid search and processing results."""
def __init__(self, memory_config: "MemoryConfig" = None):
"""
Initialize the search service.
Args:
memory_config: Optional MemoryConfig for embedding model configuration.
If not provided, must be passed to execute_hybrid_search.
"""
self.memory_config = memory_config
def __init__(self):
"""Initialize the search service."""
logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict) -> str:
@@ -103,49 +93,40 @@ class SearchService:
self,
group_id: str,
question: str,
limit: int = 15,
limit: int = 5,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config: "MemoryConfig" = None,
memory_config = None
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search with two-stage ranking.
Stage 1: Filter by content relevance (BM25 + Embedding)
Stage 2: Rerank by activation values (ACTR)
Execute hybrid search and return clean content.
Args:
group_id: Group identifier for filtering
group_id: Group identifier for filtering results
question: Search query text
limit: Max results per category (default: 15)
search_type: "hybrid", "keyword", or "embedding" (default: "hybrid")
include: Result types (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: BM25 weight (default: 0.6)
activation_boost_factor: Activation impact on memory strength (default: 0.8)
output_path: JSON output path (default: "search_results.json")
return_raw_results: Return full metadata (default: False)
memory_config: MemoryConfig for embedding model
limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False)
memory_config: Memory configuration object (required)
Returns:
Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results)
Tuple of (clean_content, cleaned_query, raw_results)
raw_results is None if return_raw_results=False
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Use provided memory_config or fall back to instance config
config = memory_config or self.memory_config
if not config:
raise ValueError("memory_config is required for search - either pass it to __init__ or execute_hybrid_search")
# Clean query
cleaned_query = self.clean_query(question)
try:
# Execute search using memory_config
# Execute search
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
@@ -153,9 +134,8 @@ class SearchService:
limit=limit,
include=include,
output_path=output_path,
memory_config=config,
rerank_alpha=rerank_alpha,
activation_boost_factor=activation_boost_factor,
memory_config=memory_config,
rerank_alpha=rerank_alpha
)
# Extract results based on search type and include parameter

View File

@@ -3,12 +3,22 @@ Template Service for loading and rendering Jinja2 templates.
This service provides centralized template management with caching and error handling.
"""
import os
from functools import lru_cache
from typing import Optional
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
from app.core.logging_config import get_agent_logger, log_prompt_rendering
from jinja2 import (
Environment,
FileSystemLoader,
Template,
TemplateNotFound,
)
from app.core.logging_config import (
get_agent_logger,
log_prompt_rendering,
)
logger = get_agent_logger(__name__)

View File

@@ -1,7 +0,0 @@
"""Agent utilities."""
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
__all__ = [
"MultimodalProcessor",
]

View File

@@ -0,0 +1,56 @@
import asyncio
from typing import Dict, Optional
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
from app.db import get_db
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
class LLMClientPool:
"""LLM客户端连接池"""
def __init__(self, max_size: int = 5):
self.max_size = max_size
self.pools: Dict[str, asyncio.Queue] = {}
self.active_clients: Dict[str, int] = {}
async def get_client(self, llm_model_id: str):
"""获取LLM客户端"""
if llm_model_id not in self.pools:
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
self.active_clients[llm_model_id] = 0
pool = self.pools[llm_model_id]
try:
# 尝试从池中获取客户端
client = pool.get_nowait()
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
return client
except asyncio.QueueEmpty:
# 池为空,创建新客户端
if self.active_clients[llm_model_id] < self.max_size:
db_session = next(get_db())
client = get_llm_client_fast(llm_model_id, db_session)
self.active_clients[llm_model_id] += 1
logger.debug(f"创建新LLM客户端: {llm_model_id}")
return client
else:
# 等待可用客户端
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
return await pool.get()
async def return_client(self, llm_model_id: str, client):
"""归还LLM客户端到池中"""
if llm_model_id in self.pools:
try:
self.pools[llm_model_id].put_nowait(client)
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
except asyncio.QueueFull:
# 池已满,丢弃客户端
self.active_clients[llm_model_id] -= 1
logger.debug(f"池已满丢弃LLM客户端: {llm_model_id}")
# 全局客户端池
llm_client_pool = LLMClientPool()

View File

@@ -1,40 +1,12 @@
import asyncio
import json
import logging
import os
from collections import defaultdict
from typing import Annotated, TypedDict
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.config.config_utils import (
get_picture_config,
get_voice_config,
)
# Removed global variable imports - use dependency injection instead
from dotenv import load_dotenv
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from openai import OpenAI
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logger = logging.getLogger(__name__)
load_dotenv()
async def picture_model_requests(image_url):
'''
Args:
image_url:
Returns:
'''
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 '
system_prompt = await read_template_file(file_path)
result = await Picture_recognize(image_url,system_prompt)
return (result)
class WriteState(TypedDict):
'''
Langgrapg Writing TypedDict
@@ -44,39 +16,69 @@ class WriteState(TypedDict):
apply_id:str
group_id:str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
memory_config: object
write_result: dict
data:str
class ReadState(TypedDict):
'''
Langgrapg READING TypedDict
name:
id:user id
loop_count:Traverse times
search_switchtype
config_id: configuration id for filtering results
errors: list of errors that occurred during workflow execution
'''
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
name: str
id: str
loop_count:int
"""
LangGraph 工作流状态定义
Attributes:
messages: 消息列表,支持自动追加
loop_count: 遍历次数
search_switch: 搜索类型开关
group_id: 组标识
config_id: 配置ID用于过滤结果
data: 从content_input_node传递的内容数据
spit_data: 从Split_The_Problem传递的分解结果
tool_calls: 工具调用请求列表
tool_results: 工具执行结果列表
memory_config: 内存配置对象
"""
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
loop_count: int
search_switch: str
user_id: str
apply_id: str
group_id: str
config_id: str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果
problem_extension:dict
storage_type: str
user_rag_memory_id: str
llm_id: str
embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象
retrieve:dict
RetrieveSummary: dict
InputSummary: dict
verify: dict
SummaryFails: dict
summary: dict
class COUNTState:
'''
The number of times the workflow dialogue retrieval content has no correct message recall traversal
'''
"""
工作流对话检索内容计数器
用于记录工作流对话检索内容没有正确消息召回遍历的次数。
"""
def __init__(self, limit: int = 5):
"""
初始化计数器
Args:
limit: 最大计数限制默认为5
"""
self.total: int = 0 # 当前累加值
self.limit: int = limit # 最大上限
def add(self, value: int = 1):
"""累加数字,如果达到上限就保持最大值"""
def add(self, value: int = 1) -> None:
"""
累加数字,如果达到上限就保持最大值
Args:
value: 要累加的值默认为1
"""
self.total += value
print(f"[COUNTState] 当前值: {self.total}")
if self.total >= self.limit:
@@ -84,21 +86,19 @@ class COUNTState:
self.total = self.limit # 达到上限不再增加
def get_total(self) -> int:
"""获取当前累加值"""
"""
获取当前累加值
Returns:
当前累加值
"""
return self.total
def reset(self):
def reset(self) -> None:
"""手动重置累加值"""
self.total = 0
print("[COUNTState] 已重置为 0")
def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list)
for item in data:
grouped[item[query_key]].append(item[result_key])
return [{key: values} for key, values in grouped.items()]
def deduplicate_entries(entries):
seen = set()
deduped = []
@@ -109,70 +109,37 @@ def deduplicate_entries(entries):
deduped.append(entry)
return deduped
def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list)
for item in data:
grouped[item[query_key]].append(item[result_key])
return [{key: values} for key, values in grouped.items()]
async def Picture_recognize(image_path, PROMPT_TICKET_EXTRACTION, picture_model_name: str) -> str:
def convert_extended_question_to_question(data):
"""
Updated to eliminate global variables in favor of explicit parameters.
递归地将数据中的 extended_question 字段转换为 question 字段
Args:
image_path: Path to image file
PROMPT_TICKET_EXTRACTION: Extraction prompt
picture_model_name: Picture model name (required, no longer from global variables)
data: 要转换的数据(可能是字典、列表或其他类型)
Returns:
转换后的数据
"""
try:
model_config = get_picture_config(picture_model_name)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
return err
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
backend_model_name = model_config["llm_name"].split("/")[-1]
api_base=model_config['api_base']
logger.info(f"model_name: {backend_model_name}")
logger.info(f"api_key set: {'yes' if api_key else 'no'}")
logger.info(f"base_url: {model_config['api_base']}")
client = OpenAI(
api_key=api_key, base_url=api_base,
)
completion = client.chat.completions.create(
model=backend_model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url":image_path,
},
{"type": "text",
"text": PROMPT_TICKET_EXTRACTION}
]
}
])
picture_text = completion.choices[0].message.content
picture_text = picture_text.replace('```json', '').replace('```', '')
picture_text = json.loads(picture_text)
return (picture_text['statement'])
async def Voice_recognize(voice_model_name: str):
"""
Updated to eliminate global variables in favor of explicit parameters.
Args:
voice_model_name: Voice model name (required, no longer from global variables)
"""
try:
model_config = get_voice_config(voice_model_name)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
return err
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
backend_model_name = model_config["llm_name"].split("/")[-1]
api_base = model_config['api_base']
return api_key,backend_model_name,api_base
if isinstance(data, dict):
# 创建新字典来存储转换后的数据
converted = {}
for key, value in data.items():
if key == 'extended_question':
# 将 extended_question 转换为 question
converted['question'] = convert_extended_question_to_question(value)
else:
# 递归处理其他字段
converted[key] = convert_extended_question_to_question(value)
return converted
elif isinstance(data, list):
# 递归处理列表中的每个元素
return [convert_extended_question_to_question(item) for item in data]
else:
# 其他类型直接返回
return data

View File

@@ -1,33 +0,0 @@
import os
from app.core.config import settings
def get_mcp_server_config():
"""
Get the MCP server configuration.
Uses MCP_SERVER_URL environment variable if set (for Docker),
otherwise falls back to SERVER_IP and MCP_PORT (for local development).
"""
# Get MCP port from environment (default: 8081)
mcp_port = os.getenv("MCP_PORT", "8081")
# In Docker: MCP_SERVER_URL=http://mcp-server:8081
# In local dev: uses SERVER_IP (127.0.0.1 or localhost)
mcp_server_url = os.getenv("MCP_SERVER_URL")
if mcp_server_url:
# Docker environment: use full URL from environment
base_url = mcp_server_url
else:
# Local development: build URL from SERVER_IP and MCP_PORT
base_url = f"http://{settings.SERVER_IP}:{mcp_port}"
mcp_server_config = {
"data_flow": {
"url": f"{base_url}/sse",
"transport": "sse",
"timeout": 15000,
"sse_read_timeout": 15000,
}
}
return mcp_server_config

View File

@@ -1,260 +0,0 @@
import json
import logging
import re
from typing import Any, List
from app.core.logging_config import get_agent_logger
from langchain_core.messages import AnyMessage
logger = get_agent_logger(__name__)
def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]:
out = []
for m in msgs:
if hasattr(m, "content"):
out.append({"role": "user", "content": getattr(m, "content", "")})
elif isinstance(m, dict) and "role" in m and "content" in m:
out.append(m)
else:
out.append({"role": "user", "content": str(m)})
return out
def _extract_content(resp: Any) -> str:
"""Extract LLM content and sanitize to raw JSON/text.
- Supports both object and dict response shapes.
- Removes leading role labels (e.g., "Assistant:").
- Strips Markdown code fences like ```json ... ```.
- Attempts to isolate the first valid JSON array/object block when extra text is present.
"""
def _to_text(r: Any) -> str:
try:
# 对象形式: resp.choices[0].message.content
if hasattr(r, "choices") and getattr(r, "choices", None):
msg = r.choices[0].message
if hasattr(msg, "content"):
return msg.content
if isinstance(msg, dict) and "content" in msg:
return msg["content"]
# 字典形式: resp["choices"][0]["message"]["content"]
if isinstance(r, dict):
return r.get("choices", [{}])[0].get("message", {}).get("content", "")
except Exception:
pass
return str(r)
def _clean_text(text: str) -> str:
s = str(text).strip()
# 移除可能的角色前缀
s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s)
# 提取 ```json ... ``` 代码块
m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I)
if m:
s = m.group(1).strip()
# 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段
if not (s.startswith("{") or s.startswith("[")):
left = s.find("[")
right = s.rfind("]")
if left != -1 and right != -1 and right > left:
s = s[left:right + 1].strip()
else:
left = s.find("{")
right = s.rfind("}")
if left != -1 and right != -1 and right > left:
s = s[left:right + 1].strip()
return s
raw = _to_text(resp)
return _clean_text(raw)
def Resolve_username(usermessages):
'''
Extract username
Args:
usermessages: user name
Returns:
'''
usermessages = usermessages.split('_')[1:]
sessionid = '_'.join(usermessages[:-1])
return sessionid
# TODO: USE app.core.memory.src.utils.render_template instead
async def read_template_file(template_path: str) -> str:
"""
读取模板文件
Args:
template_path: 模板文件路径
Returns:
模板内容字符串
Note:
建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能
"""
try:
with open(template_path, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
logger.error(f"模板文件未找到: {template_path}")
raise
except IOError as e:
logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True)
raise
async def Problem_Extension_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
extent_quest = []
original = context.get('original', '')
messages = context.get('context', '')
# Handle empty or non-string messages
if not messages:
return extent_quest, original
if isinstance(messages, str):
try:
messages = json.loads(messages)
except json.JSONDecodeError:
# If JSON parsing fails, return empty list
return extent_quest, original
if isinstance(messages, list):
for message in messages:
question = message.get('question', '')
type = message.get('type', '')
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
return extent_quest, original
async def Retriev_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
logger.info(f"Retriev_messages_deal input: type={type(context)}, value={str(context)[:500]}")
if isinstance(context, dict):
logger.info(f"Retriev_messages_deal: context is dict with keys={list(context.keys())}")
if 'context' in context or 'original' in context:
content = context.get('context', {})
original = context.get('original', '')
logger.info(f"Retriev_messages_deal output: content_type={type(content)}, content={str(content)[:300]}, original='{original[:50] if original else ''}'")
return content, original
# Return empty defaults if context is not a dict or doesn't have expected keys
logger.warning(f"Retriev_messages_deal: context missing expected keys, returning empty defaults")
return {}, ''
async def Verify_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
query = context['context']['Query']
Query_small_list = context['context']['Expansion_issue']
Result_small = []
Query_small = []
for i in Query_small_list:
Result_small.append(i['Answer_Small'][0])
Query_small.append(i['Query_small'])
return Query_small, Result_small, query
async def Summary_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
query = re.findall(r'"query": (.*?),', messages)[0]
query = query.replace('[', '').replace(']', '').strip()
matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages)
answer_small_texts = []
for m in matches:
try:
parsed = json.loads(m)
for item in parsed:
answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', ''))
except Exception:
answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', ''))
return answer_small_texts, query
async def VerifyTool_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
content_messages = messages.split('"context":')[1].replace('""', '"')
messages = str(content_messages).split("name='Retrieve'")[0]
query = re.findall('"Query": "(.*?)"', messages)[0]
Query_small = re.findall('"Query_small": "(.*?)"', messages)
Result_small = re.findall('"Result_small": "(.*?)"', messages)
return Query_small, Result_small, query
async def Retrieve_Summary_messages_deal(context):
pass
async def Retrieve_verify_tool_messages_deal(context, history, query):
'''
Extract data
Args:
context:
Returns:
'''
results = []
# 统一转为字符串,避免 None 或非字符串导致正则报错
text = str(context)
blocks = re.findall(r'\{(.*?)\}', text, flags=re.S)
for block in blocks:
query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block)
answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block)
status = re.search(r'"status"\s*:\s*"([^"]*)"', block)
query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block)
results.append({
"query_small": query_small.group(1) if query_small else None,
"answer_small": answer_small.group(1) if answer_small else None,
# 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误
"status": status.group(1) if status else "",
"query_answer": query_answer.group(1) if query_answer else None
})
result = []
for r in results:
# 统一按字符串判定状态,兼容大小写和缺失情况
status_str = str(r.get('status', '')).strip().lower()
if status_str == 'false':
continue
else:
result.append(r)
split_result = 'failed' if not result else 'success'
result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "",
"history": history}
return result

View File

@@ -0,0 +1,194 @@
from typing import List, Dict, Any
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
async def read_template_file(template_path: str) -> str:
"""
读取模板文件
Args:
template_path: 模板文件路径
Returns:
模板内容字符串
Note:
建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能
"""
try:
with open(template_path, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
logger.error(f"模板文件未找到: {template_path}")
raise
except IOError as e:
logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True)
raise
def reorder_output_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
重新排序输出结果,将 retrieval_summary 类型的数据放到最后面
Args:
results: 原始输出结果列表
Returns:
重新排序后的结果列表
"""
retrieval_summaries = []
other_results = []
# 分离 retrieval_summary 和其他类型的结果
for result in results:
if 'summary' in result.get('type'):
retrieval_summaries.append(result)
else:
other_results.append(result)
# 将 retrieval_summary 放到最后
return other_results + retrieval_summaries
def optimize_search_results(intermediate_outputs):
"""
优化检索结果,合并多个搜索结果,过滤空结果,统一格式
Args:
intermediate_outputs: 原始的中间输出列表
Returns:
优化后的检索结果列表
"""
optimized_results = []
for item in intermediate_outputs:
if not item or item == [] or item == {}:
continue
# 检查是否是搜索结果类型
if isinstance(item, dict) and item.get('type') == 'search_result':
raw_results = item.get('raw_results', {})
# 如果 raw_results 为空,跳过
if not raw_results or raw_results == [] or raw_results == {}:
continue
# 创建优化后的结果结构
optimized_item = {
"type": "search_result",
"title": f"检索结果 ({item.get('index', 1)}/{item.get('total', 1)})",
"query": item.get('query', ''),
"raw_results": {},
"index": item.get('index', 1),
"total": item.get('total', 1)
}
# 合并所有搜索结果类型到一个 raw_results 中
merged_raw_results = {}
# 处理 time_search
if 'time_search' in raw_results and raw_results['time_search']:
merged_raw_results['time_search'] = raw_results['time_search']
# 处理 keyword_search
if 'keyword_search' in raw_results and raw_results['keyword_search']:
merged_raw_results['keyword_search'] = raw_results['keyword_search']
# 处理 embedding_search
if 'embedding_search' in raw_results and raw_results['embedding_search']:
merged_raw_results['embedding_search'] = raw_results['embedding_search']
# 处理 combined_summary
if 'combined_summary' in raw_results and raw_results['combined_summary']:
merged_raw_results['combined_summary'] = raw_results['combined_summary']
# 处理 reranked_results
if 'reranked_results' in raw_results and raw_results['reranked_results']:
merged_raw_results['reranked_results'] = raw_results['reranked_results']
# 如果合并后的结果不为空,添加到优化结果中
if merged_raw_results:
optimized_item['raw_results'] = merged_raw_results
optimized_results.append(optimized_item)
else:
# 非搜索结果类型,直接添加
optimized_results.append(item)
return optimized_results
def merge_multiple_search_results(intermediate_outputs):
"""
将多个搜索结果合并为一个统一的搜索结果
Args:
intermediate_outputs: 原始的中间输出列表
Returns:
合并后的结果列表
"""
search_results = []
other_results = []
# 分离搜索结果和其他结果
for item in intermediate_outputs:
if isinstance(item, dict) and item.get('type') == 'search_result':
raw_results = item.get('raw_results', {})
# 只保留有内容的搜索结果
if raw_results and raw_results != [] and raw_results != {}:
search_results.append(item)
else:
other_results.append(item)
# 如果没有搜索结果,返回原始结果
if not search_results:
return intermediate_outputs
# 如果只有一个搜索结果,优化格式后返回
if len(search_results) == 1:
optimized = optimize_search_results(search_results)
return other_results + optimized
# 合并多个搜索结果
merged_raw_results = {}
all_queries = []
for result in search_results:
query = result.get('query', '')
if query:
all_queries.append(query)
raw_results = result.get('raw_results', {})
# 合并各种搜索类型的结果
for search_type in ['time_search', 'keyword_search', 'embedding_search', 'combined_summary',
'reranked_results']:
if search_type in raw_results and raw_results[search_type]:
if search_type not in merged_raw_results:
merged_raw_results[search_type] = raw_results[search_type]
else:
# 如果是字典类型,需要合并
if isinstance(raw_results[search_type], dict) and isinstance(merged_raw_results[search_type], dict):
for key, value in raw_results[search_type].items():
if key not in merged_raw_results[search_type]:
merged_raw_results[search_type][key] = value
elif isinstance(value, list) and isinstance(merged_raw_results[search_type][key], list):
merged_raw_results[search_type][key].extend(value)
elif isinstance(raw_results[search_type], list):
if isinstance(merged_raw_results[search_type], list):
merged_raw_results[search_type].extend(raw_results[search_type])
else:
merged_raw_results[search_type] = raw_results[search_type]
# 创建合并后的结果
if merged_raw_results:
merged_result = {
"type": "search_result",
"title": f"合并检索结果 (共{len(search_results)}个查询)",
"query": " | ".join(all_queries),
"raw_results": merged_raw_results,
"index": 1,
"total": 1
}
return other_results + [merged_result]
return other_results

View File

@@ -1,38 +0,0 @@
# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# sys.path.insert(0, project_root)
# load_dotenv()
# async def llm_client_chat(messages: List[dict]) -> str:
# """使用 OpenAI 兼容接口进行对话,返回内容字符串。"""
# try:
# cfg = get_model_config(SELECTED_LLM_ID)
# rb_config = RedBearModelConfig(
# model_name=cfg["model_name"],
# provider=cfg["provider"],
# api_key=cfg["api_key"],
# base_url=cfg["base_url"],
# )
# client = OpenAIClient(model_config=rb_config, type_="chat")
# except Exception as e:
# logger.error(f"获取模型配置失败:{e}")
# err = f"获取模型配置失败:{str(e)}。请检查!!!"
# return err
# try:
# response = await client.chat(messages)
# print(f"model_tool's llm_client_chat response ======>:\n {response}")
# return _extract_content(response)
# # return _extract_content(result)
# except Exception as e:
# logger.error(f"LLM调用失败{str(e)}。请检查 model_name、api_key、api_base 是否正确。")
# return f"LLM调用失败{str(e)}。请检查 model_name、api_key、api_base 是否正确。"
# async def main(image_url):
# await llm_client_chat(image_url)
#
# # 运行主函数
# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav']))
#

View File

@@ -1,131 +0,0 @@
"""
Multimodal input processor for handling image and audio content.
This module provides utilities for detecting and processing multimodal inputs
(images and audio files) by converting them to text using appropriate models.
"""
import logging
from typing import List
from app.core.memory.agent.multimodal.speech_model import Vico_recognition
from app.core.memory.agent.utils.llm_tools import picture_model_requests
logger = logging.getLogger(__name__)
class MultimodalProcessor:
"""
Processor for handling multimodal inputs (images and audio).
This class detects image and audio file paths in input content and converts
them to text using appropriate recognition models.
"""
# Supported file extensions
IMAGE_EXTENSIONS = ['.jpg', '.png']
AUDIO_EXTENSIONS = [
'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov',
'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv'
]
def __init__(self):
"""Initialize the multimodal processor."""
pass
def is_image(self, content: str) -> bool:
"""
Check if content is an image file path.
Args:
content: Input string to check
Returns:
True if content ends with a supported image extension
Examples:
>>> processor = MultimodalProcessor()
>>> processor.is_image("photo.jpg")
True
>>> processor.is_image("document.pdf")
False
"""
if not isinstance(content, str):
return False
content_lower = content.lower()
return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS)
def is_audio(self, content: str) -> bool:
"""
Check if content is an audio file path.
Args:
content: Input string to check
Returns:
True if content ends with a supported audio extension
Examples:
>>> processor = MultimodalProcessor()
>>> processor.is_audio("recording.mp3")
True
>>> processor.is_audio("video.mp4")
True
>>> processor.is_audio("document.txt")
False
"""
if not isinstance(content, str):
return False
content_lower = content.lower()
return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS)
async def process_input(self, content: str) -> str:
"""
Process input content, converting images/audio to text if needed.
This method detects if the input is an image or audio file and converts
it to text using the appropriate recognition model. If processing fails
or the content is not multimodal, it returns the original content.
Args:
content: Input string (may be file path or regular text)
Returns:
Text content (original or converted from image/audio)
Examples:
>>> processor = MultimodalProcessor()
>>> await processor.process_input("photo.jpg")
"Recognized text from image..."
>>> await processor.process_input("Hello world")
"Hello world"
"""
if not isinstance(content, str):
logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}")
return str(content)
try:
# Check for image input
if self.is_image(content):
logger.info(f"[MultimodalProcessor] Detected image input: {content}")
result = await picture_model_requests(content)
logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...")
return result
# Check for audio input
if self.is_audio(content):
logger.info(f"[MultimodalProcessor] Detected audio input: {content}")
result = await Vico_recognition([content]).run()
logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...")
return result
except Exception as e:
logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True)
logger.info("[MultimodalProcessor] Falling back to original content")
return content
# Return original content if not multimodal
return content

View File

@@ -0,0 +1,56 @@
import time
import json
from collections import defaultdict
from typing import Dict, List
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
class ProblemExtensionMonitor:
"""Problem_Extension性能监控器"""
def __init__(self):
self.metrics = defaultdict(list)
self.slow_queries = []
self.error_count = 0
def record_execution(self, duration: float, question_count: int, success: bool):
"""记录执行指标"""
self.metrics['durations'].append(duration)
self.metrics['question_counts'].append(question_count)
if not success:
self.error_count += 1
# 记录慢查询超过10秒
if duration > 10.0:
self.slow_queries.append({
'duration': duration,
'question_count': question_count,
'timestamp': time.time()
})
def get_stats(self) -> Dict:
"""获取统计信息"""
durations = self.metrics['durations']
if not durations:
return {"message": "暂无数据"}
return {
"total_executions": len(durations),
"avg_duration": sum(durations) / len(durations),
"max_duration": max(durations),
"min_duration": min(durations),
"slow_queries_count": len(self.slow_queries),
"error_rate": self.error_count / len(durations) if durations else 0,
"recent_slow_queries": self.slow_queries[-5:] # 最近5个慢查询
}
def log_stats(self):
"""记录统计信息到日志"""
stats = self.get_stats()
logger.info(f"Problem_Extension性能统计: {json.dumps(stats, indent=2)}")
# 全局监控器实例
performance_monitor = ProblemExtensionMonitor()

View File

@@ -0,0 +1,81 @@
你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则:
角色:
- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。
- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。
- 如果历史信息或上下文与当前问题无关,可忽略。
---
### 历史信息参考
在生成扩展问题时,你可以参考以下历史数据(如果提供):
- 历史对话或任务的主题;
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
- 历史中已解答的问题(避免重复);
- 历史推理链(保持逻辑一致性)。
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
输入历史信息内容:{{history}}
## User Input
{% if questions is string %}
{{ questions }}
{% else %}
{% for question in questions %}
- {{ question }}
{% endfor %}
{% endif %}
需求:
- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。
- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。
- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。
- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。
- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。
- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。
- 子问题数量不超过4个。
- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}]
拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么
输出要求:
- 仅输出 JSON 数组,不要包含任何解释或代码块。
- 每个元素包含:
- `original_question`: 原始问题
- `extended_question`: 扩展后的问题
- `type`: 类型(事实检索/澄清/定义/比较/行动建议)
- `reason`: 生成该扩展问题的简短理由
- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。
示例:
输入:
[
"问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳",
]
输出:
[
{
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
"extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?",
"type": "多跳",
"reason": "输出原问题的关键要素"
},
{
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
"extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?",
"type": "多跳",
"reason": "输出原问题的关键要素"
}
]
**Output format**
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
The output language should always be the same as the input language.{{ json_schema }}

View File

@@ -1,13 +1,10 @@
# 角色
你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。
# 任务
根据提供的上下文信息回答用户的问题。
# 输入信息
- 历史对话:{{history}}
- 检索信息:{{retrieve_info}}
## User Query
{{query}}

View File

@@ -9,8 +9,8 @@
3. 判断Answer_Small和Query_Small之间分析出来的关系状态
4. 如果是True保留否则不要相对应的问题和回答
5. 输出,需要严格按照模版
输入:{{history}}
历史消息:{"history":{{sentence}}}
输入:{{sentence}}
历史消息:{"history":{{history}}}
### 第一步 获取用户的输入
获取用户的输入提取对应的Query_Small和Answer_Small
### 第二步 分析验证

View File

@@ -0,0 +1,169 @@
"""
Session Service for managing user sessions and conversation history.
This service provides clean Redis interactions with error handling and
session management utilities.
"""
from typing import List, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.redis_tool import RedisSessionStore
logger = get_agent_logger(__name__)
class SessionService:
"""Service for managing user sessions and conversation history."""
def __init__(self, store: RedisSessionStore):
"""
Initialize the session service.
Args:
store: Redis session store instance
"""
self.store = store
logger.info("SessionService initialized")
def resolve_user_id(self, session_string: str) -> str:
"""
Extract user ID from session string.
Handles formats like:
- 'call_id_user123' -> 'user123'
- 'prefix_id_user456_suffix' -> 'user456_suffix'
Args:
session_string: Session identifier string
Returns:
Extracted user ID
"""
try:
# Split by '_id_' and take everything after it
parts = session_string.split('_id_')
if len(parts) > 1:
return parts[1]
# Fallback: return original string
return session_string
except Exception as e:
logger.warning(
f"Failed to parse user ID from session string '{session_string}': {e}"
)
return session_string
async def get_history(
self,
user_id: str,
apply_id: str,
group_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
)
return []
return history
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
return []
async def save_session(
self,
user_id: str,
query: str,
apply_id: str,
group_id: str,
ai_response: str
) -> Optional[str]:
"""
Save conversation turn to Redis.
Args:
user_id: User identifier
query: User query/message
apply_id: Application identifier
group_id: Group identifier
ai_response: AI response/answer
Returns:
Session ID if successful, None on error
"""
try:
# Validate required fields
if not user_id:
logger.warning("Cannot save session: user_id is empty")
return None
if not query:
logger.warning("Cannot save session: query is empty")
return None
# Save session
session_id = self.store.save_session(
userid=user_id,
messages=query,
apply_id=apply_id,
group_id=group_id,
aimessages=ai_response
)
logger.info(f"Session saved successfully: {session_id}")
return session_id
except Exception as e:
logger.error(
f"Failed to save session for user {user_id}: {e}",
exc_info=True
)
return None
async def cleanup_duplicates(self) -> int:
"""
Remove duplicate session entries.
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- group_id
- messages
- aimessages
Returns:
Number of duplicate sessions deleted
"""
try:
deleted_count = self.store.delete_duplicate_sessions()
logger.info(f"Cleaned up {deleted_count} duplicate sessions")
return deleted_count
except Exception as e:
logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True)
return 0

View File

@@ -0,0 +1,117 @@
"""
Template Service for loading and rendering Jinja2 templates.
This service provides centralized template management with caching and error handling.
"""
# 标准库
import os
from functools import lru_cache
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
from app.core.logging_config import get_agent_logger, log_prompt_rendering
logger = get_agent_logger(__name__)
class TemplateRenderError(Exception):
"""Exception raised when template rendering fails."""
def __init__(self, template_name: str, error: Exception, variables: dict):
self.template_name = template_name
self.error = error
self.variables = variables
super().__init__(
f"Failed to render template '{template_name}': {str(error)}"
)
class TemplateService:
"""Service for loading and rendering Jinja2 templates with caching."""
def __init__(self, template_root: str):
"""
Initialize the template service.
Args:
template_root: Root directory containing template files
"""
self.template_root = template_root
self.env = Environment(
loader=FileSystemLoader(template_root),
autoescape=False # Disable autoescape for prompt templates
)
logger.info(f"TemplateService initialized with root: {template_root}")
@lru_cache(maxsize=128)
def _load_template(self, template_name: str) -> Template:
"""
Load a template from disk with caching.
Args:
template_name: Relative path to template file
Returns:
Loaded Jinja2 Template object
Raises:
TemplateNotFound: If template file doesn't exist
"""
try:
return self.env.get_template(template_name)
except TemplateNotFound as e:
expected_path = os.path.join(self.template_root, template_name)
logger.error(
f"Template not found: {template_name}. "
f"Expected path: {expected_path}"
)
raise
async def render_template(
self,
template_name: str,
operation_name: str,
**variables
) -> str:
"""
Load and render a Jinja2 template.
Args:
template_name: Relative path to template file
operation_name: Name for logging (e.g., "split_the_problem")
**variables: Template variables to render
Returns:
Rendered template string
Raises:
TemplateRenderError: If template loading or rendering fails
"""
try:
# Load template (cached)
template = self._load_template(template_name)
# Render template
rendered = template.render(**variables)
# Log rendered prompt
log_prompt_rendering(operation_name, rendered)
return rendered
except TemplateNotFound as e:
logger.error(
f"Template rendering failed for {operation_name} "
f"({template_name}): Template not found",
exc_info=True
)
raise TemplateRenderError(template_name, e, variables)
except Exception as e:
logger.error(
f"Template rendering failed for {operation_name} "
f"({template_name}): {e}",
exc_info=True
)
raise TemplateRenderError(template_name, e, variables)

View File

@@ -1,10 +1,9 @@
"""
Type classification utility for distinguishing read/write operations.
"""
from app.core.config import settings
from app.core.logging_config import get_agent_logger, log_prompt_rendering
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.agent.utils.messages_tools import read_template_file
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from jinja2 import Template

View File

@@ -1,49 +0,0 @@
import os
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy.orm import Session
import logging
import json
from app.db import get_db
from app.models.retrieval_info import RetrievalInfo
logger = logging.getLogger(__name__)
async def write_to_database(host_id: uuid.UUID, data: Any) -> str:
"""
将数据写入数据库
:param host_id: 宿主 ID
:param data: 要写入的数据
:return: 写入数据库的结果
"""
# 从数据库会话中获取会话
db: Session = next(get_db())
try:
if isinstance(data, (dict, list)):
serialized = json.dumps(data, ensure_ascii=False)
elif isinstance(data, str):
serialized = data
else:
serialized = str(data)
new_retrieval_info = RetrievalInfo(
# host_id=host_id,
host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"),
retrieve_info=serialized,
created_at=datetime.now()
)
db.add(new_retrieval_info)
db.commit()
logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}")
return "success to write data to database"
except Exception as e:
db.rollback()
logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}")
raise e
finally:
try:
db.close()
except Exception:
pass

View File

@@ -7,14 +7,12 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
import time
from datetime import datetime
from dotenv import load_dotenv
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context
@@ -23,7 +21,7 @@ from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
load_dotenv()

View File

@@ -34,5 +34,4 @@ class EmotionSuggestionsRequest(BaseModel):
class EmotionGenerateSuggestionsRequest(BaseModel):
"""生成个性化情绪建议请求"""
group_id: str = Field(..., description="ID")
config_id: Optional[int] = Field(None, description="配置ID用于指定LLM模型")
end_user_id: str = Field(..., description="终端用户ID")

File diff suppressed because it is too large Load Diff

View File

@@ -1,27 +1,27 @@
import asyncio
import trio
import json
import os
import re
import time
import uuid
from datetime import datetime, timezone
from math import ceil
from typing import Any, Dict, List, Optional
import re
import redis
import requests
import trio
# Import a unified Celery instance
from app.celery_app import celery_app
from app.core.config import settings
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory,
@@ -486,6 +486,10 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
Raises:
Exception on failure
"""
from app.core.logging_config import get_logger
logger = get_logger(__name__)
logger.info(f"[CELERY WRITE] Starting write task - group_id={group_id}, config_id={config_id}, storage_type={storage_type}")
start_time = time.time()
# Resolve config_id if None
@@ -506,8 +510,14 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
async def _run() -> str:
db = next(get_db())
try:
logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory")
service = MemoryAgentService()
return await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
result = await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
return result
except Exception as e:
logger.error(f"[CELERY WRITE] Write failed: {e}", exc_info=True)
raise
finally:
db.close()
@@ -532,6 +542,8 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
return {
"status": "SUCCESS",
"result": result,
@@ -548,6 +560,9 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
detailed_error = "; ".join(error_messages)
else:
detailed_error = str(e)
logger.error(f"[CELERY WRITE] Task failed - elapsed_time={elapsed_time:.2f}s, error={detailed_error}", exc_info=True)
return {
"status": "FAILURE",
"error": detailed_error,

View File

@@ -1,32 +1,5 @@
version: '3.9'
services:
# MCP Server - standalone service
mcp-server:
image: redbear-mem-open:latest
container_name: mcp-server
ports:
- "8081:8081" # MCP server port
env_file:
- .env
environment:
- SERVER_IP=0.0.0.0 # Bind to all interfaces
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: python -m app.core.memory.agent.mcp_server.server
healthcheck:
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8081/sse')"]
interval: 10s
timeout: 5s
retries: 5
start_period: 30s
restart: unless-stopped
networks:
- default
- celery
# FastAPI application - connects to MCP server
# FastAPI application
api:
image: redbear-mem-open:latest
container_name: api
@@ -35,37 +8,31 @@ services:
env_file:
- .env
environment:
- MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name
- SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces
- SERVER_IP=0.0.0.0
# 如果代码里必须要 MCP_SERVER_URL可以先注释或指向占位
# - MCP_SERVER_URL=
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-level debug
depends_on:
mcp-server:
condition: service_healthy
restart: unless-stopped
networks:
- default
- celery
# Celery worker - connects to MCP server
# Celery worker
worker:
image: redbear-mem-open:latest
container_name: worker
env_file:
- .env
environment:
- MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker --loglevel=info
depends_on:
mcp-server:
condition: service_healthy
restart: unless-stopped
networks:
- celery
networks:
celery:
celery:

View File

@@ -18,174 +18,180 @@ import type { TestParams } from '@/views/MemoryConversation'
import type { EndUser } from '@/views/UserMemoryDetail/types'
import { handleSSE, type SSEMessage } from '@/utils/stream'
// 记忆对话
// Memory conversation
export const readService = (query: TestParams) => {
return request.post('/memory/read_service', query)
}
/****************** 记忆看板 相关接口 *******************************/
// 记忆看板-记忆总量
/****************** Memory Dashboard APIs *******************************/
// Memory Dashboard - Total memory count
export const getTotalMemoryCount = () => {
return request.get(`/dashboard/total_memory_count`)
}
// 记忆看板-知识库类型分布
// Memory Dashboard - Knowledge base type distribution
export const getKbTypes = () => {
return request.get(`/memory/stats/types`)
}
// 记忆看板-热门记忆标签
// Memory Dashboard - Hot memory tags
export const getHotMemoryTags = () => {
return request.get(`/memory-storage/analytics/hot_memory_tags`)
}
// 记忆看板-最近活动统计
// Memory Dashboard - Recent activity statistics
export const getRecentActivityStats = () => {
return request.get(`/memory-storage/analytics/recent_activity_stats`)
}
// 记忆看板-记忆增长趋势
// Memory Dashboard - Memory growth trend
export const getMemoryIncrement = (limit: number) => {
return request.get(`/dashboard/memory_increment`, { limit })
}
// 记忆看板-API调用趋势
// Memory Dashboard - API call trend
export const getApiTrend = () => {
return request.get(`/dashboard/api_increment`)
}
// 记忆看板-总数据
// Memory Dashboard - Total data
export const getDashboardData = () => {
return request.get(`/dashboard/dashboard_data`)
}
/*************** end 记忆看板 相关接口 ******************************/
/*************** end Memory Dashboard APIs ******************************/
/****************** 用户记忆 相关接口 *******************************/
/****************** User Memory APIs *******************************/
export const userMemoryListUrl = '/dashboard/end_users'
export const getUserMemoryList = () => {
return request.get(userMemoryListUrl)
}
// 用户记忆-用户记忆总量
// User Memory - Total end users
export const getTotalEndUsers = () => {
return request.get(`/dashboard/total_end_users`)
}
// 用户记忆-用户详情
// User Memory - User profile
export const getUserProfile = (end_user_id: string) => {
return request.get(`/memory/analytics/user_profile`, { end_user_id })
}
// 用户记忆-记忆洞察
// User Memory - Memory insight
export const getMemoryInsightReport = (end_user_id: string) => {
return request.get(`/memory-storage/analytics/memory_insight/report`, { end_user_id })
}
// 用户记忆-用户摘要
// User Memory - User summary
export const getUserSummary = (end_user_id: string) => {
return request.get(`/memory-storage/analytics/user_summary`, { end_user_id })
}
// 记忆分类
// Memory classification
export const getNodeStatistics = (end_user_id: string) => {
return request.get(`/memory-storage/analytics/node_statistics`, { end_user_id })
}
// 基本信息
// Basic information
export const getEndUserProfile = (end_user_id: string) => {
return request.get(`/memory-storage/read_end_user/profile`, { end_user_id })
}
export const updatedEndUserProfile = (values: EndUser) => {
return request.post(`/memory-storage/updated_end_user/profile`, values)
}
// 用户记忆-关系网络
// User Memory - Relationship network
export const getMemorySearchEdges = (end_user_id: string) => {
return request.get(`/memory-storage/analytics/graph_data`, { end_user_id })
}
// 用户记忆-用户兴趣分布
// User Memory - User interest distribution
export const getHotMemoryTagsByUser = (end_user_id: string) => {
return request.get(`/memory/analytics/hot_memory_tags/by_user`, { end_user_id })
}
// 用户记忆-记忆总量
// User Memory - Total memory count
export const getTotalMemoryCountByUser = (end_user_id: string) => {
return request.get(`/memory-storage/search`, { end_user_id })
}
// RAG 用户记忆-记忆总量
// RAG User Memory - Total memory count
export const getTotalRagMemoryCountByUser = (end_user_id: string) => {
return request.get(`/dashboard/current_user_rag_total_num`, { end_user_id })
}
// RAG 用户记忆-用户摘要
// RAG User Memory - User summary
export const getChunkSummaryTag = (end_user_id: string) => {
return request.get(`/dashboard/chunk_summary_tag`, { end_user_id })
}
// RAG 用户记忆-记忆洞察
// RAG User Memory - Memory insight
export const getChunkInsight = (end_user_id: string) => {
return request.get(`/dashboard/chunk_insight`, { end_user_id })
}
// RAG 用户记忆-存储内容
// RAG User Memory - Storage content
export const getRagContent = (end_user_id: string) => {
return request.get(`/dashboard/rag_content`, { end_user_id, limit: 20 })
}
// 情感分布分析
// Emotion distribution analysis
export const getWordCloud = (group_id: string) => {
return request.post(`/memory/emotion-memory/wordcloud`, { group_id, limit: 20 })
}
// 高频情绪关键词
// High-frequency emotion keywords
export const getEmotionTags = (group_id: string) => {
return request.post(`/memory/emotion-memory/tags`, { group_id, limit: 20 })
}
// 情绪健康指数
// Emotion health index
export const getEmotionHealth = (group_id: string) => {
return request.post(`/memory/emotion-memory/health`, { group_id, limit: 20 })
}
// 个性化建议
// Personalized suggestions
export const getEmotionSuggestions = (group_id: string) => {
return request.post(`/memory/emotion-memory/suggestions`, { group_id, limit: 20 })
}
export const generateSuggestions = (end_user_id: string) => {
return request.post(`/memory/emotion-memory/generate_suggestions`, { end_user_id })
}
export const analyticsRefresh = (end_user_id: string) => {
return request.post('/memory-storage/analytics/generate_cache', { end_user_id })
}
// 遗忘
// Forgetting stats
export const getForgetStats = (group_id: string) => {
return request.get(`/memory/forget-memory/stats`, { group_id })
}
// 隐性记忆-偏好
// Implicit Memory - Preferences
export const getImplicitPreferences = (end_user_id: string) => {
return request.get(`/memory/implicit-memory/preferences/${end_user_id}`)
}
// 隐性记忆-核心特质
// Implicit Memory - Core traits
export const getImplicitPortrait = (end_user_id: string) => {
return request.get(`/memory/implicit-memory/portrait/${end_user_id}`)
}
// 隐性记忆-兴趣领域分布
// Implicit Memory - Interest areas distribution
export const getImplicitInterestAreas = (end_user_id: string) => {
return request.get(`/memory/implicit-memory/interest-areas/${end_user_id}`)
}
// 隐性记忆-用户习惯分析
// Implicit Memory - User habits analysis
export const getImplicitHabits = (end_user_id: string) => {
return request.get(`/memory/implicit-memory/habits/${end_user_id}`)
}
// 短期记忆
export const generateProfile = (end_user_id: string) => {
return request.post(`/memory/implicit-memory/generate_profile`, { end_user_id })
}
// Short-term memory
export const getShortTerm = (end_user_id: string) => {
return request.get(`/memory/short/short_term`, { end_user_id })
}
// 感知记忆-视觉记忆
// Perceptual Memory - Visual memory
export const getPerceptualLastVisual = (end_user: string) => {
return request.get(`/memory/perceptual/${end_user}/last_visual`)
}
// 感知记忆-音频记忆
// Perceptual Memory - Audio memory
export const getPerceptualLastListen = (end_user: string) => {
return request.get(`/memory/perceptual/${end_user}/last_listen`)
}
// 感知记忆-文本记忆
// Perceptual Memory - Text memory
export const getPerceptualLastText = (end_user: string) => {
return request.get(`/memory/perceptual/${end_user}/last_text`)
}
// 感知记忆-感知记忆时间线
// Perceptual Memory - Perceptual memory timeline
export const getPerceptualTimeline = (end_user: string) => {
return request.get(`/memory/perceptual/${end_user}/timeline`)
}
// 情景记忆-总览
// Episodic Memory - Overview
export const getEpisodicOverview = (data: { end_user_id: string; time_range: string; episodic_type: string; } ) => {
return request.post(`/memory/episodic-memory/overview`, data)
}
export const getEpisodicDetail = (data: { end_user_id: string; summary_id: string; } ) => {
return request.post(`/memory/episodic-memory/details`, data)
}
// 关系演化
// Relationship evolution
export const getRelationshipEvolution = (data: { id: string; label: string; } ) => {
return request.get(`/memory-storage/memory_space/relationship_evolution`, data)
}
// 共同记忆时间线
// Shared memory timeline
export const getTimelineMemories = (data: { id: string; label: string; }) => {
return request.get(`/memory-storage/memory_space/timeline_memories`, data)
}
@@ -207,72 +213,72 @@ export const getConversationDetail = (end_user: string, conversation_id: string)
export const forgetTrigger = (data: { max_merge_batch_size: number; min_days_since_access: number; end_user_id: string;}) => {
return request.post(`/memory/forget-memory/trigger`, data)
}
/*************** end 用户记忆 相关接口 ******************************/
/*************** end User Memory APIs ******************************/
/****************** 记忆管理 相关接口 *******************************/
// 记忆管理-获取所有配置
/****************** Memory Management APIs *******************************/
// Memory Management - Get all configurations
export const memoryConfigListUrl = '/memory-storage/read_all_config'
export const getMemoryConfigList = () => {
return request.get(memoryConfigListUrl)
}
// 记忆管理-创建配置
// Memory Management - Create configuration
export const createMemoryConfig = (values: MemoryFormData) => {
return request.post('/memory-storage/create_config', values)
}
// 记忆管理-更新配置
// Memory Management - Update configuration
export const updateMemoryConfig = (values: MemoryFormData) => {
return request.post('/memory-storage/update_config', values)
}
// 记忆管理-删除配置
// Memory Management - Delete configuration
export const deleteMemoryConfig = (config_id: number) => {
return request.delete(`/memory-storage/delete_config?config_id=${config_id}`)
}
// 遗忘引擎-获取配置
// Forgetting Engine - Get configuration
export const getMemoryForgetConfig = (config_id: number | string) => {
return request.get('/memory/forget-memory/read_config', { config_id })
}
// 遗忘引擎-更新配置
// Forgetting Engine - Update configuration
export const updateMemoryForgetConfig = (values: ForgetConfigForm) => {
return request.post('/memory/forget-memory/update_config', values)
}
// 记忆萃取引擎-获取配置
// Memory Extraction Engine - Get configuration
export const getMemoryExtractionConfig = (config_id: number | string) => {
return request.get('/memory-storage/read_config_extracted', { config_id: config_id })
}
// 记忆萃取引擎-更新配置
// Memory Extraction Engine - Update configuration
export const updateMemoryExtractionConfig = (values: ExtractionConfigForm) => {
return request.post('/memory-storage/update_config_extracted', values)
}
// 记忆萃取引擎-试运行
// Memory Extraction Engine - Pilot run
export const pilotRunMemoryExtractionConfig = (values: { config_id: number | string; dialogue_text: string; }, onMessage?: (data: SSEMessage[]) => void) => {
return handleSSE('/memory-storage/pilot_run', values, onMessage)
}
// 情绪引擎-获取配置
// Emotion Engine - Get configuration
export const getMemoryEmotionConfig = (config_id: number | string) => {
return request.get('/memory/emotion/read_config', { config_id: config_id })
}
// 情绪引擎-更新配置
// Emotion Engine - Update configuration
export const updateMemoryEmotionConfig = (values: EmotionConfig) => {
return request.post('/memory/emotion/updated_config', values)
}
// 反思引擎-获取配置
// Reflection Engine - Get configuration
export const getMemoryReflectionConfig = (config_id: number | string) => {
return request.get('/memory/reflection/configs', { config_id: config_id })
}
// 反思引擎-更新配置
// Reflection Engine - Update configuration
export const updateMemoryReflectionConfig = (values: SelfReflectionEngineConfig) => {
return request.post('/memory/reflection/save', values)
}
// 反思引擎-试运行
// Reflection Engine - Pilot run
export const pilotRunMemoryReflectionConfig = (values: { config_id: number | string; language_type: string; }) => {
return request.get('/memory/reflection/run', values)
}
/*************** end 记忆管理 相关接口 ******************************/
/*************** end Memory Management APIs ******************************/
/****************** API参数 相关接口 *******************************/
/****************** API Parameters APIs *******************************/
export const getMemoryApi = () => {
return request.get('/memory/docs/api')
}
/*************** end API参数 相关接口 ******************************/
/*************** end API Parameters APIs ******************************/

View File

@@ -1,10 +1,13 @@
import { useEffect, useState, type FC, type Key } from 'react';
import { Select } from 'antd'
import type { SelectProps, DefaultOptionType } from 'antd/es/select'
import { Select } from 'antd';
import type { SelectProps, DefaultOptionType } from 'antd/es/select';
import { useTranslation } from 'react-i18next';
import { request } from '@/utils/request';
// 定义API响应类型
interface OptionType {
[key: string]: Key | string | number;
}
interface ApiResponse<T> {
items?: T[];
}
@@ -20,19 +23,16 @@ interface CustomSelectProps extends Omit<SelectProps, 'filterOption'> {
format?: (items: OptionType[]) => OptionType[];
showSearch?: boolean;
optionFilterProp?: string;
// 其他SelectProps属性
onChange?: SelectProps<Key, DefaultOptionType>['onChange'];
value?: SelectProps<Key, DefaultOptionType>['value'];
disabled?: boolean;
style?: React.CSSProperties;
className?: string;
filterOption?: (inputValue: string, option?: DefaultOptionType) => boolean;
}
interface OptionType {
[key: string]: Key | string | number;
}
const defaultFilterOption = (inputValue: string, option?: DefaultOptionType): boolean => {
if (!option || !inputValue) return true;
const label = String(option.children || option.label || '');
return label.toLowerCase().includes(inputValue.toLowerCase());
};
const CustomSelect: FC<CustomSelectProps> = ({
onChange,
url,
params,
valueKey = 'value',
@@ -42,42 +42,37 @@ const CustomSelect: FC<CustomSelectProps> = ({
allTitle,
format,
showSearch = false,
optionFilterProp = 'label',
filterOption,
...props
}) => {
const { t } = useTranslation();
const [options, setOptions] = useState<OptionType[]>([]);
// 默认模糊搜索函数
const defaultFilterOption = (inputValue: string, option?: DefaultOptionType) => {
if (!option || !inputValue) return true;
const label = String(option.children || option.label || '');
return label.toLowerCase().includes(inputValue.toLowerCase());
};
// 组件挂载时获取初始数据
const [options, setOptions] = useState<OptionType[]>([]);
useEffect(() => {
request.get<ApiResponse<OptionType>>(url, params).then((res) => {
const data = res;
setOptions(Array.isArray(data) ? data || [] : Array.isArray(data?.items) ? data.items || [] : []);
const data = Array.isArray(res) ? res : res?.items || [];
setOptions(data);
});
}, []);
}, [url, params]);
const displayOptions = format ? format(options) : options;
return (
<Select
placeholder={placeholder ? placeholder : t('common.select')}
onChange={onChange}
<Select
placeholder={placeholder || t('common.select')}
defaultValue={hasAll ? null : undefined}
showSearch={showSearch}
filterOption={filterOption || defaultFilterOption}
{...props}
>
{hasAll && (<Select.Option>{allTitle || t('common.all')}</Select.Option>)}
{(format ? format(options) : options)?.map(option => (
{hasAll && <Select.Option value={null}>{allTitle || t('common.all')}</Select.Option>}
{displayOptions.map((option) => (
<Select.Option key={option[valueKey]} value={option[valueKey]}>
{String(option[labelKey])}
</Select.Option>
))}
</Select>
);
}
};
export default CustomSelect;

View File

@@ -1967,6 +1967,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
value: 'Value',
addCase: 'Add Condition',
addVariable: 'Add Variables',
output: 'Output Variable'
},
clear: 'Clear',

View File

@@ -658,8 +658,8 @@ export const zh = {
priority: '结构化整合',
addTool: '添加工具',
tool: '工具',
variableConfig: '配置变量'
},
// 角色管理相关翻译
role: {
roleManagement: '角色管理',
roleId: '角色ID',
@@ -2061,6 +2061,7 @@ export const zh = {
value: '值',
addCase: '添加条件',
addVariable: '添加变量',
output: '输出变量'
},
clear: '清空',

View File

@@ -332,21 +332,6 @@
}
]
},
{
"id": 19,
"parent": 0,
"code": "member",
"label": "成员管理",
"i18nKey": "menu.memberManagement",
"path": "/member",
"enable": true,
"display": true,
"level": 1,
"sort": 0,
"icon": null,
"iconActive": null,
"subs": null
},
{
"id": 10,
"parent": 0,
@@ -377,6 +362,21 @@
"iconActive": null,
"subs": null
},
{
"id": 19,
"parent": 0,
"code": "member",
"label": "成员管理",
"i18nKey": "menu.memberManagement",
"path": "/member",
"enable": true,
"display": true,
"level": 1,
"sort": 0,
"icon": null,
"iconActive": null,
"subs": null
},
{
"id": 12,
"parent": 0,

View File

@@ -1,8 +1,47 @@
import { message } from 'antd';
import i18n from '@/i18n'
import { cookieUtils } from './request'
import { refreshToken } from '@/api/user'
import { clearAuthData } from './auth'
const API_PREFIX = '/api'
// Token refresh state
let isRefreshing = false;
let refreshPromise: Promise<string> | null = null;
// Refresh token function for SSE
const refreshTokenForSSE = async (): Promise<string> => {
if (isRefreshing && refreshPromise) {
return refreshPromise;
}
isRefreshing = true;
refreshPromise = (async () => {
try {
const refresh_token = cookieUtils.get('refreshToken');
if (!refresh_token) {
throw new Error(i18n.t('common.refreshTokenNotExist'));
}
const response: any = await refreshToken();
const newToken = response.access_token;
cookieUtils.set('authToken', newToken);
return newToken;
} catch (error) {
clearAuthData();
message.warning(i18n.t('common.loginExpired'));
if (!window.location.hash.includes('#/login')) {
window.location.href = `/#/login`;
}
throw error;
} finally {
isRefreshing = false;
refreshPromise = null;
}
})();
return refreshPromise;
};
export interface SSEMessage {
event?: string
data?: string | object
@@ -66,62 +105,66 @@ function parseDataContent(dataContent: string): string | object {
}
}
const makeSSERequest = async (url: string, data: any, token: string, config = { headers: {} }) => {
return fetch(`${API_PREFIX}${url}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${token}`,
...config.headers,
},
body: JSON.stringify(data)
});
};
export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMessage[]) => void, config = { headers: {} }) => {
try {
const token = cookieUtils.get('authToken');
const response = await fetch(`${API_PREFIX}${url}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${token}`,
...config.headers,
},
body: JSON.stringify(data)
});
let token = cookieUtils.get('authToken');
let response = await makeSSERequest(url, data, token || '', config);
const { status } = response
switch(status) {
switch (response.status) {
case 401:
if (url?.includes('/public')) {
return message.warning(i18n.t('common.publicApiCannotRefreshToken'));
}
window.location.href = `/#/login`;
break;
default:
if (!response.body) throw new Error('No response body');
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = ''; // 添加缓冲区来处理不完整的消息
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value, { stream: true });
buffer += chunk;
// 处理完整的事件
const events = buffer.split('\n\n');
buffer = events.pop() || ''; // 保留最后一个可能不完整的事件
for (const event of events) {
if (event.trim() && onMessage) {
onMessage(parseSSEToJSON(event) ?? {});
}
}
}
// 处理剩余的缓冲区内容
if (buffer.trim() && onMessage) {
onMessage(parseSSEToJSON(buffer) ?? {});
try {
const newToken = await refreshTokenForSSE();
response = await makeSSERequest(url, data, newToken, config);
} catch (refreshError) {
return;
}
break;
}
if (!response.body) throw new Error('No response body');
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = ''; // 添加缓冲区来处理不完整的消息
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value, { stream: true });
buffer += chunk;
// 处理完整的事件
const events = buffer.split('\n\n');
buffer = events.pop() || ''; // 保留最后一个可能不完整的事件
for (const event of events) {
if (event.trim() && onMessage) {
onMessage(parseSSEToJSON(event) ?? {});
}
}
}
// 处理剩余的缓冲区内容
if (buffer.trim() && onMessage) {
onMessage(parseSSEToJSON(buffer) ?? {});
}
} catch (error) {
console.error('Request failed:', error);
throw error;
}
}
};

View File

@@ -13,26 +13,25 @@ import type {
Config,
ModelConfig,
AgentRef,
KnowledgeBase,
KnowledgeConfig,
Variable,
MemoryConfig,
AiPromptModalRef,
Source,
ToolOption
ChatVariableConfigModalRef
} from './types'
import type { Variable } from './components/VariableList/types'
import type { KnowledgeConfig } from './components/Knowledge/types'
import type { Model } from '@/views/ModelManagement/types'
import { getModelList } from '@/api/models';
import { saveAgentConfig } from '@/api/application'
import Knowledge from './components/Knowledge'
import VariableList from './components/VariableList'
import Knowledge from './components/Knowledge/Knowledge'
import VariableList from './components/VariableList/VariableList'
import { getApplicationConfig } from '@/api/application'
import { getKnowledgeBaseList } from '@/api/knowledgeBase'
import { memoryConfigListUrl } from '@/api/memory'
import CustomSelect from '@/components/CustomSelect'
import aiPrompt from '@/assets/images/application/aiPrompt.png'
import AiPromptModal from './components/AiPromptModal'
import ToolList from './components/ToolList'
import ToolList from './components/ToolList/ToolList'
import ChatVariableConfigModal from './components/ChatVariableConfigModal';
const DescWrapper: FC<{desc: string, className?: string}> = ({desc, className}) => {
return (
@@ -66,7 +65,7 @@ const SwitchWrapper: FC<{ title: string, desc?: string, name: string | string[];
</div>
)
}
const SelectWrapper: FC<{ title: string, desc: string, name: string, url: string }> = ({ title, desc, name, url }) => {
const SelectWrapper: FC<{ title: string, desc: string, name: string | string[], url: string }> = ({ title, desc, name, url }) => {
const { t } = useTranslation();
return (
<>
@@ -77,6 +76,7 @@ const SelectWrapper: FC<{ title: string, desc: string, name: string, url: string
className="rb:mb-0!"
>
<CustomSelect
placeholder={t('common.pleaseSelect')}
url={url}
hasAll={false}
valueKey='config_id'
@@ -99,54 +99,22 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
const [modelList, setModelList] = useState<Model[]>([])
const [defaultModel, setDefaultModel] = useState<Model | null>(null)
const [chatList, setChatList] = useState<ChatData[]>([])
const [formData, setFormData] = useState<{
default_model_config_id?: string,
model_parameters?: Config['model_parameters'],
tools: ToolOption[],
} | null>(null)
const values = Form.useWatch<{
memoryEnabled: boolean;
memory_content?: string | number;
} & Config>([], form)
const [knowledgeConfig, setKnowledgeConfig] = useState<KnowledgeConfig>({ knowledge_bases: [] })
const [variableList, setVariableList] = useState<Variable[]>([])
const values = Form.useWatch<Config>([], form)
const [isSave, setIsSave] = useState(false)
const initialized = useRef(false)
const [toolList, setToolList] = useState<ToolOption[]>([])
// 初始化完成标记
useEffect(() => {
if (data && values && formData) {
if (data) {
initialized.current = true
}
}, [data, values, formData])
}, [data])
useEffect(() => {
if (!initialized.current) return
if (isSave) return
setIsSave(true)
}, [knowledgeConfig])
useEffect(() => {
if (!initialized.current) return
if (isSave) return
setIsSave(true)
}, [variableList])
useEffect(() => {
if (!initialized.current) return
if (isSave) return
setIsSave(true)
}, [formData])
useEffect(() => {
if (!initialized.current) return
if (isSave) return
setIsSave(true)
}, [values])
useEffect(() => {
if (!initialized.current) return
if (isSave) return
setIsSave(true)
}, [toolList])
useEffect(() => {
getModels()
@@ -157,68 +125,19 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
setLoading(true)
getApplicationConfig(id as string).then(res => {
const response = res as Config
setData({
...response,
tools: Array.isArray(response.tools) ? response.tools : []
})
const { memory, tools } = response
let allTools = Array.isArray(response.tools) ? response.tools : []
form.setFieldsValue({
...response,
memoryEnabled: memory?.enabled || false,
memory_content: memory?.memory_content ? Number(memory?.memory_content) : undefined,
tools: Array.isArray(tools) ? tools : []
tools: allTools
})
setFormData({
default_model_config_id: response.default_model_config_id,
model_parameters: response.model_parameters || {},
tools: Array.isArray(tools) ? tools : []
setData({
...response,
tools: allTools
})
if (response?.knowledge_retrieval?.knowledge_bases?.length) {
getDefaultKnowledgeList(response)
}
if (response?.tools?.length) {
setToolList(response?.tools)
}
}).finally(() => {
setLoading(false)
})
}
const getDefaultKnowledgeList = (data: Config) => {
if (!data || !data.knowledge_retrieval || !data.knowledge_retrieval?.knowledge_bases?.length) {
return
}
const initialList = [...(data?.knowledge_retrieval?.knowledge_bases || [])]
getKnowledgeBaseList(undefined, {
kb_ids: initialList.map(vo => vo.kb_id).join(','),
page: 1,
pagesize: 100,
})
.then(res => {
const list = res.items || []
const knowledge_bases: KnowledgeBase[] = list.map(item => {
const filterItem = initialList.find(vo => vo.kb_id === item.id)
return {
...item,
...filterItem
}
})
setKnowledgeConfig(prev => ({
...prev,
knowledge_bases: [...knowledge_bases]
}))
setData((prev) => {
prev = prev as Config
const knowledge_retrieval: KnowledgeConfig = {
...(prev?.knowledge_retrieval || {}),
knowledge_bases: [...knowledge_bases]
}
return {
...(prev || {}),
knowledge_retrieval
}
})
})
}
const refresh = (vo: ModelConfig, type: Source) => {
if (type === 'model') {
@@ -227,15 +146,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
default_model_config_id,
model_parameters: {...rest}
})
setFormData((prevState) => {
const prev = prevState as Config
return {
...(prev || {}),
default_model_config_id,
model_parameters: {...rest}
};
})
if (default_model_config_id === formData?.default_model_config_id) {
if (default_model_config_id === values?.default_model_config_id) {
setChatList([{
label: vo.label || '',
model_config_id: default_model_config_id || '',
@@ -279,24 +190,20 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
// 保存Agent配置
const handleSave = (flag = true) => {
if (!isSave || !data) return Promise.resolve()
const { memoryEnabled, memory_content, ...rest } = values
const { knowledge_bases = [], ...knowledgeRest } = knowledgeConfig || {}
const { memory, knowledge_retrieval, tools, ...rest } = values
const { knowledge_bases = [], ...knowledgeRest } = knowledge_retrieval || {}
const { memory_content } = memory || {}
// 从原数据中获取memory的其他必要属性
const originalMemory = data.memory || ({} as MemoryConfig)
const params: Config = {
...data,
...rest,
...(formData || {}),
memory: {
...originalMemory,
enabled: memoryEnabled,
...memory,
memory_content: memory_content ? String(memory_content) : '',
max_history: originalMemory.max_history || '',
},
variables: variableList || [],
knowledge_retrieval: knowledge_bases.length > 0 ? {
...data.knowledge_retrieval,
...knowledgeRest,
@@ -305,14 +212,12 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
...(item.config || {})
}))
} as KnowledgeConfig : null,
tools: toolList.map(vo => ({
tools: tools.map(vo => ({
tool_id: vo.tool_id,
operation: vo.operation,
enabled: vo.enabled
}))
}
console.log('params', rest, params)
return new Promise((resolve, reject) => {
saveAgentConfig(data.app_id, params)
@@ -338,8 +243,8 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
modelConfigModalRef.current?.handleOpen('chat')
}
useEffect(() => {
if (formData?.default_model_config_id && modelList.length > 0) {
const filterValue = modelList.find(item => item.id === formData.default_model_config_id)
if (values?.default_model_config_id && modelList.length > 0) {
const filterValue = modelList.find(item => item.id === values.default_model_config_id)
setDefaultModel(filterValue as Model | null)
setChatList([{
label: filterValue?.name || '',
@@ -348,7 +253,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
list: []
}])
}
}, [modelList, formData?.default_model_config_id])
}, [modelList, values?.default_model_config_id])
useImperativeHandle(ref, () => ({
handleSave
@@ -360,8 +265,31 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
}
const updatePrompt = (value: string) => {
form.setFieldValue('system_prompt', value)
const variables = value.match(/\{\{([^}]+)\}\}/g)?.map(match => match.slice(2, -2)) || []
const uniqueVariables = [...new Set(variables)]
const newVariableList: Variable[] = uniqueVariables.map((name, index) => ({
index,
type: 'text',
name,
display_name: name,
required: false
}))
updateVariableList(newVariableList)
}
const updateVariableList = (list: Variable[]) => {
form.setFieldValue('variables', [...list])
setChatVariables([...list])
}
const chatVariableConfigModalRef = useRef<ChatVariableConfigModalRef>(null)
const [chatVariables, setChatVariables] = useState<Variable[]>([])
const handleOpenVariableConfig = () => {
chatVariableConfigModalRef.current?.handleOpen(chatVariables)
}
const handleSaveChatVariable = (values: Variable[]) => {
setChatVariables(values)
}
console.log('values', values)
return (
<>
{loading && <Spin fullscreen></Spin>}
@@ -379,8 +307,9 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
</Space>
</div>
<Form form={form}>
<Form.Item name="default_model_config_id" hidden noStyle></Form.Item>
<Form.Item name="model_parameters" hidden noStyle></Form.Item>
<Space size={16} direction="vertical" style={{ width: '100%' }}>
{/* 提示词 */}
<Card title={t('application.promptConfiguration')}>
<div className="rb:flex rb:items-center rb:justify-between rb:mb-2.75">
<div className="rb:font-medium rb:leading-5">
@@ -406,36 +335,31 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
</Form.Item>
</Card>
{/* 知识库 */}
<Knowledge
data={data?.knowledge_retrieval || { knowledge_bases: [] }}
onUpdate={setKnowledgeConfig}
/>
<Form.Item name="knowledge_retrieval" noStyle>
<Knowledge />
</Form.Item>
{/* 记忆配置 */}
<Card title={t('application.memoryConfiguration')}>
<Space size={24} direction='vertical' style={{ width: '100%' }}>
<SwitchWrapper title="dialogueHistoricalMemory" desc="dialogueHistoricalMemoryDesc" name="memoryEnabled" />
<SwitchWrapper title="dialogueHistoricalMemory" desc="dialogueHistoricalMemoryDesc" name={['memory', 'enabled']} />
<SelectWrapper
title="selectMemoryContent"
desc="selectMemoryContentDesc"
name="memory_content"
name={['memory', 'memory_content']}
url={memoryConfigListUrl}
/>
</Space>
</Card>
{/* 变量配置 */}
<VariableList
data={data?.variables}
onUpdate={setVariableList}
/>
<Form.Item name="variables">
<VariableList />
</Form.Item>
{/* 工具配置 */}
<ToolList
data={data?.tools || []}
onUpdate={setToolList}
/>
<Form.Item name="tools">
<ToolList />
</Form.Item>
</Space>
</Form>
</Col>
@@ -444,6 +368,9 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
{t('application.debuggingAndPreview')}
<Space size={10}>
<Button type="primary" ghost onClick={handleOpenVariableConfig}>
{t('application.variableConfig')}
</Button>
<Button type="primary" ghost onClick={handleAddModel}>
+ {t('application.addModel')}
</Button>
@@ -463,7 +390,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
<ModelConfigModal
modelList={modelList}
data={formData as Config}
data={values}
ref={modelConfigModalRef}
refresh={refresh}
/>
@@ -472,6 +399,10 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
defaultModel={defaultModel}
refresh={updatePrompt}
/>
<ChatVariableConfigModal
ref={chatVariableConfigModalRef}
refresh={handleSaveChatVariable}
/>
</>
);
});

View File

@@ -3,18 +3,21 @@ import RbCard from '@/components/RbCard/Card'
interface CardProps {
title?: string | ReactNode;
subTitle?: string | ReactNode;
children: ReactNode;
extra?: ReactNode;
}
const Card: FC<CardProps> = ({
title,
subTitle,
children,
extra,
}) => {
return (
<RbCard
title={title}
subTitle={subTitle}
extra={extra}
headerType="borderL"
headerClassName="rb:before:bg-[#155EEF]! rb:before:h-[19px]"

View File

@@ -0,0 +1,101 @@
import { forwardRef, useImperativeHandle, useState } from 'react';
import { Form, Input, InputNumber } from 'antd';
import { useTranslation } from 'react-i18next';
import type { ChatVariableConfigModalRef } from '../types'
import type { Variable } from './VariableList/types'
import RbModal from '@/components/RbModal'
interface VariableEditModalProps {
refresh: (values: Variable[]) => void;
}
const ChatVariableConfigModal = forwardRef<ChatVariableConfigModalRef, VariableEditModalProps>(({
refresh,
}, ref) => {
const { t } = useTranslation();
const [visible, setVisible] = useState(false);
const [form] = Form.useForm<{variables: Variable[]}>();
const [loading, setLoading] = useState(false)
const [initialValues, setInitialValues] = useState<Variable[]>([])
// 封装取消方法,添加关闭弹窗逻辑
const handleClose = () => {
setVisible(false);
form.resetFields();
setLoading(false)
};
const handleOpen = (values: Variable[]) => {
console.log('values', values)
setVisible(true);
form.setFieldsValue({variables: values})
setInitialValues([...values])
};
// 封装保存方法,添加提交逻辑
const handleSave = () => {
form.validateFields().then((values) => {
refresh([
...(values?.variables ?? []),
])
handleClose()
})
}
// 暴露给父组件的方法
useImperativeHandle(ref, () => ({
handleOpen,
handleClose
}));
console.log(form.getFieldValue('variables'))
return (
<RbModal
title={t('application.variableConfig')}
open={visible}
onCancel={handleClose}
okText={t('common.save')}
onOk={handleSave}
confirmLoading={loading}
>
<Form
form={form}
layout="horizontal"
scrollToFirstError={{ behavior: 'instant', block: 'end', focus: true }}
>
<Form.List name="variables">
{(fields) => (
<>
{fields.map(({ name }, index) => {
const field = initialValues[index]
return (
<Form.Item
key={name}
name={[name, 'value']}
label={`${field.name}·${field.display_name}`}
rules={[
{ required: field.required, message: t('common.pleaseEnter') },
]}
>
{
field.type === 'text' && <Input placeholder={t('common.pleaseEnter')} />
}
{
field.type === 'number' && <InputNumber placeholder={t('common.pleaseEnter')} className="rb:w-full!" onChange={(value) => form.setFieldValue(['variables', name, 'value'], value)} />
}
{
field.type === 'paragraph' && <Input.TextArea placeholder={t('common.pleaseEnter')} />
}
</Form.Item>
)
})}
</>
)}
</Form.List>
</Form>
</RbModal>
);
});
export default ChatVariableConfigModal;

View File

@@ -2,7 +2,6 @@ import { type FC, useRef, useState, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import { Space, Button, List } from 'antd'
import knowledgeEmpty from '@/assets/images/application/knowledgeEmpty.svg'
import Card from './Card'
import type {
KnowledgeConfigForm,
KnowledgeConfig,
@@ -11,14 +10,16 @@ import type {
KnowledgeModalRef,
KnowledgeConfigModalRef,
KnowledgeGlobalConfigModalRef,
} from '../types'
} from './types'
import Empty from '@/components/Empty'
import KnowledgeListModal from './KnowledgeListModal'
import KnowledgeConfigModal from './KnowledgeConfigModal'
import KnowledgeGlobalConfigModal from './KnowledgeGlobalConfigModal'
import Tag from '@/components/Tag'
import { getKnowledgeBaseList } from '@/api/knowledgeBase'
import Card from '../Card'
const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) => void}> = ({data, onUpdate}) => {
const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfig) => void}> = ({value = {knowledge_bases: []}, onChange}) => {
const { t } = useTranslation()
const knowledgeModalRef = useRef<KnowledgeModalRef>(null)
const knowledgeConfigModalRef = useRef<KnowledgeConfigModalRef>(null)
@@ -27,12 +28,31 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig)
const [editConfig, setEditConfig] = useState<KnowledgeConfig>({} as KnowledgeConfig)
useEffect(() => {
if (data) {
setEditConfig({ ...(data || {}) })
const knowledge_bases = [...(data.knowledge_bases || [])]
setKnowledgeList(knowledge_bases)
if (value && JSON.stringify(value) !== JSON.stringify(editConfig)) {
setEditConfig({ ...(value || {}) })
const knowledge_bases = [...(value.knowledge_bases || [])]
// 检查是否有knowledge_bases缺少name字段
const basesWithoutName = knowledge_bases.filter(base => !base.name)
if (basesWithoutName.length > 0) {
// 调用接口获取完整的知识库信息
getKnowledgeBaseList().then(res => {
const fullBases = knowledge_bases.map(base => {
if (!base.name) {
const fullBase = res.items.find((item: any) => item.id === base.kb_id)
return fullBase ? { ...base, ...fullBase } : base
}
return base
})
setKnowledgeList(fullBases)
}).catch(() => {
setKnowledgeList(knowledge_bases)
})
} else {
setKnowledgeList(knowledge_bases)
}
}
}, [data])
}, [value])
const handleKnowledgeConfig = () => {
knowledgeGlobalConfigModalRef.current?.handleOpen()
@@ -43,7 +63,7 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig)
const handleDeleteKnowledge = (id: string) => {
const list = knowledgeList.filter(item => item.id !== id)
setKnowledgeList([...list])
onUpdate({
onChange && onChange({
...editConfig,
knowledge_bases: [...list],
})
@@ -65,7 +85,7 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig)
list = [...values as KnowledgeBase[]]
}
setKnowledgeList([...list])
onUpdate({
onChange && onChange({
...editConfig,
knowledge_bases: [...list],
})
@@ -77,14 +97,14 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig)
config: {...values as KnowledgeConfigForm}
}
setKnowledgeList([...list])
onUpdate({
onChange && onChange({
...editConfig,
knowledge_bases: [...list],
})
} else if (type === 'rerankerConfig') {
const rerankerValues = values as RerankerConfig
setEditConfig(prev => ({ ...prev, ...rerankerValues }))
onUpdate({
onChange && onChange({
...editConfig,
...rerankerValues,
reranker_id: rerankerValues.rerank_model ? rerankerValues.reranker_id : undefined,
@@ -93,55 +113,54 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig)
}
}
return (
<Card
<Card
title={t('application.knowledgeBaseAssociation')}
extra={
<Button style={{padding: '0 8px', height: '24px'}} onClick={() => handleKnowledgeConfig()}>{t('application.globalConfig')}</Button>
<Space>
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleKnowledgeConfig}>{t('workflow.config.knowledge-retrieval.recallConfig')}</Button>
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleAddKnowledge}>+</Button>
</Space>
}
>
<div className="rb:flex rb:items-center rb:justify-between rb:mb-3">
<div className="rb:font-medium rb:leading-5">{t('application.associatedKnowledgeBase')}</div>
<Button style={{padding: '0 8px', height: '24px'}} onClick={handleAddKnowledge}>+{t('application.addKnowledgeBase')}</Button>
</div>
{knowledgeList.length === 0
? <Empty url={knowledgeEmpty} size={88} subTitle={t('application.knowledgeEmpty')} />
:
<List
grid={{ gutter: 12, column: 1 }}
dataSource={knowledgeList}
renderItem={(item) => (
<List.Item>
<div key={item.id} className="rb:flex rb:items-center rb:justify-between rb:p-[12px_16px] rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg">
<div className="rb:font-medium rb:leading-4">
{item.name}
<Tag color={item.status === 1 ? 'success' : item.status === 0 ? 'default' : 'error'} className="rb:ml-2">
{item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')}
</Tag>
<div className="rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-5">{t('application.contains', {include_count: item.doc_num})}</div>
renderItem={(item) => {
if (!item.id) return null
return (
<List.Item>
<div key={item.id} className="rb:flex rb:items-center rb:justify-between rb:p-[12px_16px] rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg">
<div className="rb:font-medium rb:leading-4">
{item.name}
<Tag color={item.status === 1 ? 'success' : item.status === 0 ? 'default' : 'error'} className="rb:ml-2">
{item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')}
</Tag>
<div className="rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-5">{t('application.contains', {include_count: item.doc_num})}</div>
</div>
<Space size={12}>
<div
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/editBorder.svg')] rb:hover:bg-[url('@/assets/images/editBg.svg')]"
onClick={() => handleEditKnowledge(item)}
></div>
<div
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/deleteBorder.svg')] rb:hover:bg-[url('@/assets/images/deleteBg.svg')]"
onClick={() => handleDeleteKnowledge(item.id)}
></div>
</Space>
</div>
<Space size={12}>
<div
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/editBorder.svg')] rb:hover:bg-[url('@/assets/images/editBg.svg')]"
onClick={() => handleEditKnowledge(item)}
></div>
<div
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/deleteBorder.svg')] rb:hover:bg-[url('@/assets/images/deleteBg.svg')]"
onClick={() => handleDeleteKnowledge(item.id)}
></div>
</Space>
</div>
</List.Item>
)}
</List.Item>
)
}}
/>
}
{/* 全局设置 */}
<KnowledgeGlobalConfigModal
data={editConfig}
ref={knowledgeGlobalConfigModalRef}
refresh={refresh}
/>
{/* 知识库列表 */}
<KnowledgeListModal
ref={knowledgeModalRef}
selectedList={knowledgeList}

View File

@@ -2,7 +2,7 @@ import { forwardRef, useEffect, useImperativeHandle, useState } from 'react';
import { Form, Select, InputNumber } from 'antd';
import { useTranslation } from 'react-i18next';
import type { KnowledgeConfigModalRef, KnowledgeBase, KnowledgeConfigForm } from '../types'
import type { KnowledgeConfigModalRef, KnowledgeBase, KnowledgeConfigForm, RetrieveType } from './types'
import RbModal from '@/components/RbModal'
import RbSlider from '@/components/RbSlider'
import { formatDateTime } from '@/utils/format';
@@ -12,7 +12,7 @@ const FormItem = Form.Item;
interface KnowledgeConfigModalProps {
refresh: (values: KnowledgeConfigForm, type: 'knowledgeConfig') => void;
}
const retrieveTypes = ['participle', 'semantic', 'hybrid']
const retrieveTypes: RetrieveType[] = ['participle', 'semantic', 'hybrid']
const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfigModalProps>(({
refresh,
@@ -33,8 +33,11 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
const handleOpen = (data: KnowledgeBase) => {
form.setFieldsValue({
retrieve_type: retrieveTypes[0],
retrieve_type: data?.config?.retrieve_type || retrieveTypes[0],
kb_id: data.id,
top_k: data?.config?.top_k || 5,
similarity_threshold: data?.config?.similarity_threshold || 0.5,
vector_similarity_weight: data?.config?.vector_similarity_weight || 0.5,
...(data || {}),
...(data?.config || {}),
})
@@ -62,12 +65,10 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
useEffect(() => {
if (values?.retrieve_type) {
const initialValues = Object.keys(values).map(key => {
return {
[key as keyof KnowledgeConfigForm]: (key === 'kb_id' || key === 'retrieve_type') ? values[key] : undefined
}
})
form.resetFields(initialValues)
const fieldsToReset = Object.keys(values).filter(key =>
key !== 'kb_id' && key !== 'retrieve_type'
) as (keyof KnowledgeConfigForm)[];
form.resetFields(fieldsToReset);
}
}, [values?.retrieve_type])
@@ -84,12 +85,12 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
layout="vertical"
>
{data && (
<div className="rb:mb-[24px] rb:flex rb:items-center rb:justify-between rb:border rb:rounded-[8px] rb:p-[17px_16px] rb:cursor-pointer rb:bg-[#F0F3F8] rb:border-[#DFE4ED] rb:text-[#212332]">
<div className="rb:text-[16px] rb:leading-[22px]">
<div className="rb:mb-6 rb:flex rb:items-center rb:justify-between rb:border rb:rounded-lg rb:p-[17px_16px] rb:cursor-pointer rb:bg-[#F0F3F8] rb:border-[#DFE4ED] rb:text-[#212332]">
<div className="rb:text-[16px] rb:leading-5.5">
{data.name}
<div className="rb:text-[12px] rb:leading-[16px] rb:text-[#5B6167] rb:mt-[8px]">{t('application.contains', {include_count: data.doc_num})}</div>
<div className="rb:text-[12px] rb:leading-4 rb:text-[#5B6167] rb:mt-2">{t('application.contains', {include_count: data.doc_num})}</div>
</div>
<div className="rb:text-[12px] rb:leading-[16px] rb:text-[#5B6167]">{formatDateTime(data.updated_at, 'YYYY-MM-DD HH:mm:ss')}</div>
<div className="rb:text-[12px] rb:leading-4 rb:text-[#5B6167]">{formatDateTime(data.updated_at, 'YYYY-MM-DD HH:mm:ss')}</div>
</div>
)}
<FormItem name="kb_id" hidden />
@@ -114,8 +115,14 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
label={t('application.top_k')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
extra={t('application.top_k_desc')}
initialValue={5}
>
<InputNumber style={{ width: '100%' }} />
<InputNumber
style={{ width: '100%' }}
min={1}
max={20}
onChange={(value) => form.setFieldValue('top_k', value)}
/>
</FormItem>
{/* 语义相似度阈值 similarity_threshold */}
{values?.retrieve_type === 'semantic' && (
@@ -123,6 +130,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
name="similarity_threshold"
label={t('application.similarity_threshold')}
extra={t('application.similarity_threshold_desc')}
initialValue={0.5}
>
<RbSlider
max={1.0}
@@ -137,6 +145,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
name="vector_similarity_weight"
label={t('application.vector_similarity_weight')}
extra={t('application.vector_similarity_weight_desc')}
initialValue={0.5}
>
<RbSlider
max={1.0}
@@ -152,6 +161,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
name="similarity_threshold"
label={t('application.similarity_threshold')}
extra={t('application.similarity_threshold_desc1')}
initialValue={0.5}
>
<RbSlider
max={1.0}
@@ -163,6 +173,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
name="vector_similarity_weight"
label={t('application.vector_similarity_weight')}
extra={t('application.vector_similarity_weight_desc1')}
initialValue={0.5}
>
<RbSlider
max={1.0}

View File

@@ -2,7 +2,7 @@ import { forwardRef, useImperativeHandle, useState, useEffect } from 'react';
import { Form, InputNumber, Switch } from 'antd';
import { useTranslation } from 'react-i18next';
import type { RerankerConfig, KnowledgeGlobalConfigModalRef } from '../types'
import type { RerankerConfig, KnowledgeGlobalConfigModalRef } from './types'
import RbModal from '@/components/RbModal'
import CustomSelect from '@/components/CustomSelect'
import { getModelListUrl } from '@/api/models'
@@ -71,18 +71,18 @@ const KnowledgeGlobalConfigModal = forwardRef<KnowledgeGlobalConfigModalRef, Kno
form={form}
layout="vertical"
>
<div className="rb:text-[#5B6167] rb:mb-[24px]">{t('application.globalConfigDesc')}</div>
<div className="rb:text-[#5B6167] rb:mb-6">{t('application.globalConfigDesc')}</div>
{/* 结果重排 */}
<div className="rb:flex rb:items-center rb:justify-between rb:my-[24px]">
<div className="rb:text-[14px] rb:font-medium rb:leading-[20px]">
<div className="rb:flex rb:items-center rb:justify-between rb:my-6">
<div className="rb:text-[14px] rb:font-medium rb:leading-5">
{t('application.rerankModel')}
<div className="rb:mt-[4px] rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-[16px]">{t('application.rerankModelDesc')}</div>
<div className="rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4">{t('application.rerankModelDesc')}</div>
</div>
<FormItem
name="rerank_model"
valuePropName="checked"
className="rb:mb-[0px]!"
className="rb:mb-0!"
>
<Switch />
</FormItem>
@@ -110,7 +110,12 @@ const KnowledgeGlobalConfigModal = forwardRef<KnowledgeGlobalConfigModalRef, Kno
rules={[{ required: true, message: t('common.pleaseEnter') }]}
extra={t('application.reranker_top_k_desc')}
>
<InputNumber style={{ width: '100%' }} min={1} max={20} />
<InputNumber
style={{ width: '100%' }}
min={1}
max={20}
onChange={(value) => form.setFieldValue('reranker_top_k', value)}
/>
</FormItem>
</>}
</Form>

View File

@@ -2,7 +2,7 @@ import { forwardRef, useEffect, useImperativeHandle, useState } from 'react';
import { Space, List } from 'antd';
import { useTranslation } from 'react-i18next';
import clsx from 'clsx'
import type { KnowledgeModalRef, KnowledgeBase } from '../types'
import type { KnowledgeModalRef, KnowledgeBase } from './types'
import type { KnowledgeBaseListItem } from '@/views/KnowledgeBase/types'
import RbModal from '@/components/RbModal'
import { getKnowledgeBaseList } from '@/api/knowledgeBase'
@@ -39,12 +39,13 @@ const KnowledgeListModal = forwardRef<KnowledgeModalRef, KnowledgeModalProps>(({
setQuery({})
setSelectedIds([])
setSelectedRows([])
getList()
};
useEffect(() => {
getList()
}, [query.keywords])
if (visible) {
getList()
}
}, [query.keywords, visible])
const getList = () => {
getKnowledgeBaseList(undefined, {
...query,
@@ -124,15 +125,15 @@ const KnowledgeListModal = forwardRef<KnowledgeModalRef, KnowledgeModalProps>(({
dataSource={filterList}
renderItem={(item: KnowledgeBase) => (
<List.Item>
<div key={item.id} className={clsx("rb:flex rb:items-center rb:justify-between rb:border rb:rounded-[8px] rb:p-[17px_16px] rb:cursor-pointer rb:hover:bg-[#F0F3F8]", {
<div key={item.id} className={clsx("rb:flex rb:items-center rb:justify-between rb:border rb:rounded-lg rb:p-[17px_16px] rb:cursor-pointer rb:hover:bg-[#F0F3F8]", {
"rb:bg-[rgba(21,94,239,0.06)] rb:border-[#155EEF] rb:text-[#155EEF]": selectedIds.includes(item.id),
"rb:border-[#DFE4ED] rb:text-[#212332]": !selectedIds.includes(item.id),
})} onClick={() => handleSelect(item)}>
<div className="rb:text-[16px] rb:leading-[22px]">
<div className="rb:text-[16px] rb:leading-5.5">
{item.name}
<div className="rb:text-[12px] rb:leading-[16px] rb:text-[#5B6167] rb:mt-[8px]">{t('application.contains', {include_count: item.doc_num})}</div>
<div className="rb:text-[12px] rb:leading-4 rb:text-[#5B6167] rb:mt-2">{t('application.contains', {include_count: item.doc_num})}</div>
</div>
<div className="rb:text-[12px] rb:leading-[16px] rb:text-[#5B6167]">{formatDateTime(item.created_at, 'YYYY-MM-DD HH:mm:ss')}</div>
<div className="rb:text-[12px] rb:leading-4 rb:text-[#5B6167]">{formatDateTime(item.created_at, 'YYYY-MM-DD HH:mm:ss')}</div>
</div>
</List.Item>
)}

View File

@@ -0,0 +1,30 @@
import type { KnowledgeBaseListItem } from '@/views/KnowledgeBase/types'
export interface RerankerConfig {
rerank_model?: boolean | undefined;
reranker_id?: string | undefined;
reranker_top_k?: number | undefined;
}
export type RetrieveType = 'participle' | 'semantic' | 'hybrid'
export interface KnowledgeConfigForm {
kb_id?: string;
similarity_threshold?: number;
vector_similarity_weight?: number;
top_k?: number;
retrieve_type?: RetrieveType;
}
export interface KnowledgeBase extends KnowledgeBaseListItem, KnowledgeConfigForm {
config?: KnowledgeConfigForm
}
export interface KnowledgeConfig extends RerankerConfig {
knowledge_bases: KnowledgeBase[];
}
export interface KnowledgeConfigModalRef {
handleOpen: (data: KnowledgeBase) => void;
}
export interface KnowledgeGlobalConfigModalRef {
handleOpen: () => void;
}
export interface KnowledgeModalRef {
handleOpen: (config?: KnowledgeConfig[]) => void;
}

View File

@@ -1,22 +1,22 @@
import { type FC, useRef, useState, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import { Space, Button, List, Switch } from 'antd'
import Card from './Card'
import Card from '../Card'
import type {
ToolModalRef,
ToolOption
} from '../types'
} from './types'
import Empty from '@/components/Empty'
import ToolModal from './ToolModal'
import { getToolMethods, getToolDetail } from '@/api/tools'
const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => void}> = ({data, onUpdate}) => {
const ToolList: FC<{ value?: ToolOption[]; onChange?: (config: ToolOption[]) => void}> = ({value, onChange}) => {
const { t } = useTranslation()
const toolModalRef = useRef<ToolModalRef>(null)
const [toolList, setToolList] = useState<ToolOption[]>([])
useEffect(() => {
if (data) {
const processedData = data.map(async (item) => {
if (value) {
const processedData = value.map(async (item) => {
if (!item.label && item.tool_id) {
try {
const [toolDetail, methods] = await Promise.all([
@@ -77,7 +77,7 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi
Promise.all(processedData).then(setToolList)
}
}, [data])
}, [value])
const handleAddTool = () => {
toolModalRef.current?.handleOpen()
@@ -85,12 +85,12 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi
const updateTools = (tool: ToolOption) => {
const list = [...toolList, tool]
setToolList(list)
onUpdate(list)
onChange && onChange(list)
}
const handleDeleteTool = (index: number) => {
const list = toolList.filter((_item, idx) => idx !== index)
setToolList([...list])
onUpdate(list)
onChange && onChange(list)
}
const handleChangeEnabled = (index: number) => {
const list = toolList.map((item, idx) => {
@@ -103,7 +103,7 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi
return item
})
setToolList([...list])
onUpdate(list)
onChange && onChange(list)
}
return (
<Card
@@ -112,7 +112,6 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleAddTool}>+{t('application.addTool')}</Button>
}
>
{toolList.length === 0
? <Empty size={88} />
:

View File

@@ -0,0 +1,26 @@
export interface ToolOption {
value?: string | number | null;
label?: React.ReactNode;
description?: string;
children?: ToolOption[];
isLeaf?: boolean;
method_id?: string;
operation?: string;
parameters?: Parameter[];
tool_id?: string;
enabled?: boolean;
}
export interface Parameter {
name: string;
type: string;
description: string;
required: boolean;
default: any;
enum: null | string[];
minimum: number;
maximum: number;
pattern: null | string;
}
export interface ToolModalRef {
handleOpen: () => void;
}

View File

@@ -1,131 +0,0 @@
import { type FC, useRef, useState, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import { Space, Button, Switch } from 'antd'
import variablesEmpty from '@/assets/images/application/variablesEmpty.svg'
import Card from './Card'
import Table from '@/components/Table';
import type { Variable, VariableEditModalRef } from '../types'
import Empty from '@/components/Empty'
import VariableEditModal from './VariableEditModal'
interface VariableListProps {
data?: Variable[];
onUpdate: (data: Variable[]) => void;
}
const VariableList: FC<VariableListProps> = ({data = [], onUpdate}) => {
const { t } = useTranslation()
const variableEditModalRef = useRef<VariableEditModalRef>(null)
const [variableList, setVariableList] = useState<Variable[]>([])
const [maxIndex, setMaxIndex] = useState(0)
useEffect(() => {
if (!data || data.length === 0) return
const list = data.map((item, index) => ({
...item,
index
}))
setVariableList(list)
onUpdate(list)
setMaxIndex(list.length)
}, [data])
const handleAddVariable = () => {
variableEditModalRef.current?.handleOpen()
}
const handleSaveVariable = (value: Variable) => {
if (value.index !== undefined && value.index >= 0) {
const index = variableList.findIndex(item => item.index === value.index)
if (index !== -1) {
const newData = [...variableList]
newData[index] = value
setVariableList([...newData])
onUpdate([...newData])
}
} else {
const list = [...variableList, {
index: maxIndex + 1,
...value
}]
setVariableList(list)
onUpdate([...list])
setMaxIndex(maxIndex + 1)
}
}
const handleDeleteVariable = (index: number) => {
const list = variableList.filter((_, i) => i !== index)
setVariableList(list)
onUpdate([...list])
}
return (
<Card title={t('application.variableConfiguration')}>
<div className="rb:flex rb:items-center rb:justify-between rb:mb-[11px]">
<div className="rb:font-medium rb:leading-[20px]">
{t('application.VariableManagement')}
<span className="rb:font-regular rb:text-[12px] rb:text-[#5B6167]"> ({t('application.VariableManagementDesc')})</span>
</div>
<Button style={{padding: '0 8px', height: '24px'}} onClick={handleAddVariable}>+{t('application.addVariables')}</Button>
</div>
{/* List */}
{variableList.length > 0
? (
<div className="rb:mt-[12px]">
<Table
rowKey="index"
pagination={false}
columns={[
{
title: t('application.variableType'),
dataIndex: 'type',
key: 'type',
render: (type) => t(`application.${type}`)
},
{
title: t('application.variableKey'),
dataIndex: 'name',
key: 'name',
},
{
title: t('application.variableName'),
dataIndex: 'display_name',
key: 'display_name',
},
{
title: t('application.optional'),
dataIndex: 'required',
key: 'required',
render: (required) => <Switch checked={!required} disabled />
},
{
title: t('common.operation'),
key: 'action',
render: (_, record, index: number) => (
<Space size="middle">
<Button
type="link"
onClick={() => variableEditModalRef.current?.handleOpen(record as Variable)}
>
{t('common.edit')}
</Button>
<Button type="link" danger onClick={() => handleDeleteVariable(index)}>
{t('common.delete')}
</Button>
</Space>
),
},
]}
initialData={variableList as unknown as Record<string, unknown>[]}
emptySize={88}
/>
</div>
)
: <Empty url={variablesEmpty} size={88} subTitle={t('application.variablesEmpty')} />
}
<VariableEditModal
ref={variableEditModalRef}
refreshTable={handleSaveVariable}
/>
</Card>
)
}
export default VariableList

View File

@@ -2,7 +2,7 @@ import { forwardRef, useImperativeHandle, useState } from 'react';
import { Form, Input } from 'antd';
import { useTranslation } from 'react-i18next';
import type { ApiExtensionModalData, ApiExtensionModalRef } from '../types'
import type { ApiExtensionModalData, ApiExtensionModalRef } from './types'
import RbModal from '@/components/RbModal'
const FormItem = Form.Item;

View File

@@ -2,7 +2,7 @@ import { forwardRef, useImperativeHandle, useState, useRef } from 'react';
import { Form, Input, Select, InputNumber, Checkbox, Tag, Divider, Button } from 'antd';
import { useTranslation } from 'react-i18next';
import type { ApiExtensionModalRef, Variable, VariableEditModalRef } from '../types'
import type { ApiExtensionModalRef, Variable, VariableEditModalRef } from './types'
import RbModal from '@/components/RbModal'
import SortableList from '@/components/SortableList'
import ApiExtensionModal from './ApiExtensionModal'
@@ -137,7 +137,14 @@ const VariableEditModal = forwardRef<VariableEditModalRef, VariableEditModalProp
{ pattern: /^[a-zA-Z_][a-zA-Z0-9_]*$/, message: t('application.invalidVariableName') },
]}
>
<Input placeholder={t('common.enter')} />
<Input
placeholder={t('common.enter')}
onBlur={(e) => {
if (!form.getFieldValue('display_name')) {
form.setFieldValue('display_name', e.target.value)
}
}}
/>
</FormItem>
{/* 显示名称 */}
<FormItem

View File

@@ -0,0 +1,110 @@
import { type FC, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import { Space, Button, Switch, Form } from 'antd'
import variablesEmpty from '@/assets/images/application/variablesEmpty.svg'
import Card from '../Card'
import Table from '@/components/Table';
import type { Variable, VariableEditModalRef } from './types'
import Empty from '@/components/Empty'
import VariableEditModal from './VariableEditModal'
interface VariableListProps {
value?: Variable[];
onChange?: (value: Variable[]) => void;
}
const VariableList: FC<VariableListProps> = ({value = [], onChange}) => {
const { t } = useTranslation()
const variableEditModalRef = useRef<VariableEditModalRef>(null)
const handleAddVariable = () => {
variableEditModalRef.current?.handleOpen()
}
const handleSaveVariable = (variable: Variable) => {
const newList = [...(value || [])]
if (variable.index !== undefined && variable.index >= 0) {
const index = newList.findIndex(item => item.index === variable.index)
if (index !== -1) {
newList[index] = variable
}
} else {
newList.push({ ...variable, index: Date.now() })
}
onChange?.(newList)
}
return (
<Card
title={<>
{t('application.variableConfiguration')}
<span className="rb:font-regular rb:text-[12px] rb:text-[#5B6167]"> ({t('application.VariableManagementDesc')})</span>
</>}
extra={<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleAddVariable}>+ {t('application.addVariables')}</Button>}
>
<Form.List name="variables" initialValue={value}>
{(fields, { remove }) => {
return (
<>
{fields.length > 0 ? (
<div className="rb:mt-3">
<Table
rowKey="index"
pagination={false}
columns={[
{
title: t('application.variableType'),
dataIndex: 'type',
key: 'type',
render: (type) => t(`application.${type}`)
},
{
title: t('application.variableKey'),
dataIndex: 'name',
key: 'name',
},
{
title: t('application.variableName'),
dataIndex: 'display_name',
key: 'display_name',
},
{
title: t('application.optional'),
dataIndex: 'required',
key: 'required',
render: (required) => <Switch checked={!required} disabled />
},
{
title: t('common.operation'),
key: 'action',
render: (_, record, index: number) => (
<Space size="middle">
<Button
type="link"
onClick={() => variableEditModalRef.current?.handleOpen(record as Variable)}
>
{t('common.edit')}
</Button>
<Button type="link" danger onClick={() => remove(index)}>
{t('common.delete')}
</Button>
</Space>
),
},
]}
initialData={value as unknown as Record<string, unknown>[]}
emptySize={88}
/>
</div>
) : (
<Empty url={variablesEmpty} size={88} subTitle={t('application.variablesEmpty')} />
)}
</>
)
}}
</Form.List>
<VariableEditModal
ref={variableEditModalRef}
refreshTable={handleSaveVariable}
/>
</Card>
)
}
export default VariableList

View File

@@ -0,0 +1,28 @@
export interface Variable {
index?: number;
name: string;
display_name: string;
type: string;
required: boolean;
max_length?: number;
description?: string;
key?: string;
default_value?: string;
options?: string[];
api_extension?: string;
hidden?: boolean;
value?: any;
}
export interface VariableEditModalRef {
handleOpen: (values?: Variable) => void;
}
export interface ApiExtensionModalData {
name: string;
apiEndpoint: string;
apiKey: string;
}
export interface ApiExtensionModalRef {
handleOpen: () => void;
}

View File

@@ -1,4 +1,6 @@
import type { KnowledgeBaseListItem } from '@/views/KnowledgeBase/types'
import type { KnowledgeConfig } from './components/Knowledge/types'
import type { Variable } from './components/VariableList/types'
import type { ToolOption } from './components/ToolList/types'
import type { ChatItem } from '@/components/Chat/types'
import type { GraphRef } from '@/views/Workflow/types';
import type { ApiKey } from '@/views/ApiKeyManagement/types'
@@ -14,55 +16,6 @@ export interface ModelConfig {
n: number;
stop?: string;
}
/*************** 知识库相关 ******************/
export interface RerankerConfig {
rerank_model?: boolean | undefined;
reranker_id?: string | undefined;
reranker_top_k?: number | undefined;
}
export interface KnowledgeConfigForm {
kb_id?: string;
similarity_threshold?: number;
vector_similarity_weight?: number;
top_k?: number;
retrieve_type?: 'participle' | 'semantic' | 'hybrid';
}
export interface KnowledgeBase extends KnowledgeBaseListItem, KnowledgeConfigForm {
config?: KnowledgeConfigForm
}
export interface KnowledgeConfig extends RerankerConfig {
knowledge_bases: KnowledgeBase[];
}
export interface KnowledgeConfigModalRef {
handleOpen: (data: KnowledgeBase) => void;
}
export interface KnowledgeGlobalConfigModalRef {
handleOpen: () => void;
}
/*********** end 知识库相关 ******************/
/*************** 变量相关 ******************/
export interface Variable {
index?: number;
name: string;
display_name: string;
type: string;
required: boolean;
max_length?: number;
description?: string;
key: string;
default_value?: string;
options?: string[];
api_extension?: string;
hidden?: boolean;
}
export interface VariableEditModalRef {
handleOpen: (values?: Variable) => void;
}
/*********** end 变量相关 ******************/
export interface MemoryConfig {
enabled: boolean;
memory_content?: string;
@@ -131,17 +84,6 @@ export interface ModelConfigModalData {
export interface AiPromptModalRef {
handleOpen: () => void;
}
export interface KnowledgeModalRef {
handleOpen: (config?: KnowledgeConfig[]) => void;
}
export interface ApiExtensionModalData {
name: string;
apiEndpoint: string;
apiKey: string;
}
export interface ApiExtensionModalRef {
handleOpen: () => void;
}
export interface ChatData {
label?: string;
model_config_id?: string;
@@ -206,30 +148,6 @@ export interface AiPromptForm {
message?: string;
current_prompt?: string;
}
export interface ToolModalRef {
handleOpen: () => void;
}
export interface ToolOption {
value?: string | number | null;
label?: React.ReactNode;
description?: string;
children?: ToolOption[];
isLeaf?: boolean;
method_id?: string;
operation?: string;
parameters?: Parameter[];
tool_id?: string;
enabled?: boolean;
}
export interface Parameter {
name: string;
type: string;
description: string;
required: boolean;
default: any;
enum: null | string[];
minimum: number;
maximum: number;
pattern: null | string;
export interface ChatVariableConfigModalRef {
handleOpen: (values: Variable[]) => void;
}

View File

@@ -121,7 +121,7 @@ const EmotionTags: FC = () => {
})}
</div>
</div>
: <Empty size={88} className="rb:h-full" />
: <Empty size={88} className="rb:h-full rb:mb-4" />
}
</RbCard>
)

View File

@@ -1,4 +1,4 @@
import { type FC, useEffect, useState } from 'react'
import { useEffect, useState, forwardRef, useImperativeHandle } from 'react'
import { useTranslation } from 'react-i18next'
import { useParams } from 'react-router-dom'
import { Skeleton, Space, Progress } from 'antd';
@@ -20,7 +20,7 @@ interface HabitsItem {
specific_examples: string[];
}
const Habits: FC = () => {
const Habits = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => {
const { t } = useTranslation()
const { id } = useParams()
const [loading, setLoading] = useState<boolean>(false)
@@ -43,6 +43,9 @@ const Habits: FC = () => {
setLoading(false)
})
}
useImperativeHandle(ref, () => ({
handleRefresh: getData
}));
return (
<>
@@ -80,5 +83,5 @@ const Habits: FC = () => {
</RbCard>
</>
)
}
})
export default Habits

View File

@@ -1,4 +1,4 @@
import { type FC, useEffect, useState } from 'react'
import { useEffect, useState, forwardRef, useImperativeHandle } from 'react'
import { useTranslation } from 'react-i18next'
import { useParams } from 'react-router-dom'
import { Skeleton, Progress } from 'antd';
@@ -23,7 +23,7 @@ interface InterestAreasItem {
art: Item;
}
const InterestAreas: FC = () => {
const InterestAreas = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => {
const { t } = useTranslation()
const { id } = useParams()
const [loading, setLoading] = useState<boolean>(false)
@@ -47,6 +47,9 @@ const InterestAreas: FC = () => {
})
}
useImperativeHandle(ref, () => ({
handleRefresh: getData
}));
return (
<RbCard
title={t('implicitDetail.interestAreas')}
@@ -70,5 +73,5 @@ const InterestAreas: FC = () => {
}
</RbCard>
)
}
})
export default InterestAreas

View File

@@ -1,4 +1,4 @@
import { type FC, useEffect, useState } from 'react'
import { useEffect, useState, forwardRef, useImperativeHandle } from 'react'
import { useTranslation } from 'react-i18next'
import { useParams } from 'react-router-dom'
import { Skeleton, Progress } from 'antd';
@@ -25,7 +25,7 @@ interface PortraitItem {
literature: Item;
}
const Portrait: FC = () => {
const Portrait = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => {
const { t } = useTranslation()
const { id } = useParams()
const [loading, setLoading] = useState<boolean>(false)
@@ -49,6 +49,9 @@ const Portrait: FC = () => {
})
}
useImperativeHandle(ref, () => ({
handleRefresh: getData
}));
return (
<RbCard
title={t('implicitDetail.portrait')}
@@ -73,5 +76,5 @@ const Portrait: FC = () => {
}
</RbCard>
)
}
})
export default Portrait

View File

@@ -1,4 +1,4 @@
import { type FC, useEffect, useState, useRef, useMemo } from 'react'
import { useEffect, useState, useRef, useMemo, forwardRef, useImperativeHandle } from 'react'
import { useTranslation } from 'react-i18next'
import { useParams } from 'react-router-dom'
import { Row, Col, Skeleton } from 'antd'
@@ -31,7 +31,7 @@ const generateCategoryColors = (categories: string[]) => {
return colors
}
const Preferences: FC = () => {
const Preferences = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => {
const { t } = useTranslation()
const { id } = useParams()
const chartRef = useRef<HTMLDivElement>(null)
@@ -138,6 +138,9 @@ const Preferences: FC = () => {
return selectedWord !== null && data[selectedWord].tag_name ? <>{data[selectedWord].tag_name}{t('implicitDetail.preferencesDetail')}</> : ''
}, [selectedWord, data, t])
useImperativeHandle(ref, () => ({
handleRefresh: getData
}));
return (
<>
<div className="rb:bg-[rgba(21,94,239,0.12)] rb:px-4 rb:py-2.5 rb:font-medium rb:leading-5 rb:mb-4 rb:mt-6 rb:rounded-md">{t('forgetDetail.overviewTitle')}</div>
@@ -184,6 +187,6 @@ const Preferences: FC = () => {
</Row>
</>
)
}
})
export default Preferences

View File

@@ -1,4 +1,4 @@
import { type FC, useEffect, useState } from 'react'
import { useEffect, useState, forwardRef, useImperativeHandle } from 'react'
import { useTranslation } from 'react-i18next'
import { useParams } from 'react-router-dom'
@@ -18,7 +18,7 @@ interface Suggestions {
actionable_steps: string[];
}>;
}
const Suggestions: FC = () => {
const Suggestions = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => {
const { t } = useTranslation()
const { id } = useParams()
const [suggestions, setSuggestions] = useState<Suggestions | null>(null)
@@ -37,6 +37,9 @@ const Suggestions: FC = () => {
})
}
useImperativeHandle(ref, () => ({
handleRefresh: getSuggestionData
}));
return (
<RbCard
title={t('statementDetail.suggestions')}
@@ -64,6 +67,6 @@ const Suggestions: FC = () => {
}
</RbCard>
)
}
})
export default Suggestions

View File

@@ -1,34 +1,57 @@
import { type FC } from 'react'
import { forwardRef, useImperativeHandle, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import { Row, Col } from 'antd'
import { useParams } from 'react-router-dom'
import Preferences from '../components/Preferences'
import Portrait from '../components/Portrait'
import InterestAreas from '../components/InterestAreas'
import Habits from '../components/Habits'
import {
generateProfile,
} from '@/api/memory'
const ImplicitDetail: FC = () => {
const ImplicitDetail = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => {
const { t } = useTranslation()
const { id } = useParams()
const preferencesRef = useRef<{ handleRefresh: () => void; }>(null)
const portraitRef = useRef<{ handleRefresh: () => void; }>(null)
const interestAreasRef = useRef<{ handleRefresh: () => void; }>(null)
const habitsRef = useRef<{ handleRefresh: () => void; }>(null)
const handleRefresh = () => {
if (!id) return
generateProfile(id)
.then(() => {
preferencesRef.current?.handleRefresh()
portraitRef.current?.handleRefresh()
interestAreasRef.current?.handleRefresh()
habitsRef.current?.handleRefresh()
})
}
useImperativeHandle(ref, () => ({
handleRefresh
}));
return (
<div className="rb:h-full rb:max-w-266 rb:mx-auto">
<div className="rb:text-[#5B6167] rb:leading-5 rb:mt-3">{t('implicitDetail.title')}</div>
<Preferences />
<Preferences ref={preferencesRef} />
<div className="rb:bg-[rgba(21,94,239,0.12)] rb:px-3 rb:py-2.5 rb:font-medium rb:leading-5 rb:mb-4 rb:mt-6 rb:rounded-md">{t('implicitDetail.portraitTitle')}</div>
<div className="rb:my-3 rb:text-[#5B6167] rb:leading-5">{t('implicitDetail.portraitSubTitle')}</div>
<Row gutter={[16, 16]} className="rb:mt-4">
<Col span={12}>
<Portrait />
<Portrait ref={portraitRef} />
</Col>
<Col span={12}>
<InterestAreas />
<InterestAreas ref={interestAreasRef} />
</Col>
</Row>
<Habits />
<Habits ref={habitsRef} />
</div>
)
}
})
export default ImplicitDetail

View File

@@ -1,13 +1,27 @@
import { type FC } from 'react'
import { forwardRef, useImperativeHandle, useRef } from 'react'
import { Row, Col, Space } from 'antd';
import { useParams } from 'react-router-dom'
import WordCloud from '../components/WordCloud'
import EmotionTags from '../components/EmotionTags'
import Health from '../components/Health'
import Suggestions from '../components/Suggestions'
import { generateSuggestions } from '@/api/memory'
const StatementDetail: FC = () => {
const StatementDetail = forwardRef((_props, ref) => {
const { id } = useParams()
const suggestionsRef = useRef<{ handleRefresh: () => void; }>(null)
const handleRefresh = () => {
if (!id) return
generateSuggestions(id)
.then(() => {
suggestionsRef.current?.handleRefresh()
})
}
useImperativeHandle(ref, () => ({
handleRefresh
}));
return (
<Row gutter={[16, 16]}>
<Col span={12}>
@@ -18,10 +32,10 @@ const StatementDetail: FC = () => {
</Space>
</Col>
<Col span={12}>
<Suggestions />
<Suggestions ref={suggestionsRef} />
</Col>
</Row>
)
}
})
export default StatementDetail

View File

@@ -24,6 +24,8 @@ const Detail: FC = () => {
const navigate = useNavigate()
const [name, setName] = useState<string>('')
const forgetDetailRef = useRef<{ handleRefresh: () => void }>(null)
const statementDetailRef = useRef<{ handleRefresh: () => void }>(null)
const implicitDetailRef = useRef<{ handleRefresh: () => void }>(null)
useEffect(() => {
if (!id) return
@@ -45,7 +47,17 @@ const Detail: FC = () => {
navigate(`/user-memory/detail/${id}/${key}`, { replace: true })
}
const handleRefresh = () => {
forgetDetailRef.current?.handleRefresh()
switch(type) {
case 'FORGET_MEMORY':
forgetDetailRef.current?.handleRefresh()
break;
case 'EMOTIONAL_MEMORY':
statementDetailRef.current?.handleRefresh()
break
case 'IMPLICIT_MEMORY':
implicitDetailRef.current?.handleRefresh()
break
}
}
if (type === 'GRAPH') {
@@ -67,16 +79,16 @@ const Detail: FC = () => {
</div>
</Dropdown>
}
extra={type === 'FORGET_MEMORY' &&
extra={['FORGET_MEMORY', 'EMOTIONAL_MEMORY', 'IMPLICIT_MEMORY'].includes(type as string) &&
<Button type="primary" ghost className="rb:group rb:h-6! rb:px-2!" onClick={handleRefresh}>
<img src={refreshIcon} className="rb:w-4 rb:h-4" />
{t('common.refresh')}
</Button>}
/>
<div className="rb:h-[calc(100vh-64px)] rb:overflow-y-auto rb:py-3 rb:px-4">
{type === 'EMOTIONAL_MEMORY' && <StatementDetail />}
{type === 'EMOTIONAL_MEMORY' && <StatementDetail ref={statementDetailRef} />}
{type === 'FORGET_MEMORY' && <ForgetDetail ref={forgetDetailRef} />}
{type === 'IMPLICIT_MEMORY' && <ImplicitDetail />}
{type === 'IMPLICIT_MEMORY' && <ImplicitDetail ref={implicitDetailRef} />}
{type === 'SHORT_TERM_MEMORY' && <ShortTermDetail />}
{type === 'PERCEPTUAL_MEMORY' && <PerceptualDetail />}
{type === 'EPISODIC_MEMORY' && <EpisodicDetail />}

View File

@@ -16,6 +16,7 @@ import InitialValuePlugin from './plugin/InitialValuePlugin';
import CommandPlugin from './plugin/CommandPlugin';
import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin';
import LineNumberPlugin from './plugin/LineNumberPlugin';
import BlurPlugin from './plugin/BlurPlugin';
import { VariableNode } from './nodes/VariableNode'
interface LexicalEditorProps {
@@ -113,8 +114,10 @@ const Editor: FC<LexicalEditorProps> =({
display: flex;
align-items: flex-start;
}
.editor-content-with-numbers {
.editor-content-wrapper {
flex: 1;
}
.editor-content-with-numbers {
white-space: pre-wrap;
}
.editor-content-with-numbers p {
@@ -174,18 +177,20 @@ const Editor: FC<LexicalEditorProps> =({
<div className="line-numbers">
<div>1</div>
</div>
<ContentEditable
className="editor-content-with-numbers"
style={{
minHeight: minheight,
padding: '4px 0',
outline: 'none',
resize: 'none',
fontSize: fontSize,
lineHeight: lineHeight,
border: 'none',
}}
/>
<div className="editor-content-wrapper">
<ContentEditable
className="editor-content-with-numbers"
style={{
minHeight: minheight,
padding: '4px 0',
outline: 'none',
resize: 'none',
fontSize: fontSize,
lineHeight: lineHeight,
border: 'none',
}}
/>
</div>
</div>
) : (
<ContentEditable
@@ -207,8 +212,8 @@ const Editor: FC<LexicalEditorProps> =({
style={{
minHeight: placeHolderMinheight,
position: 'absolute',
top: variant === 'borderless' ? '0' : '6px',
left: enableJinja2 ? '59px' : (variant === 'borderless' ? '0' : '11px'),
top: enableJinja2 ? '4px' : variant === 'borderless' ? '0' : '6px',
left: enableJinja2 ? '16px' : (variant === 'borderless' ? '0' : '11px'),
color: '#A8A9AA',
fontSize: fontSize,
lineHeight: placeHolderMinheight,
@@ -227,6 +232,7 @@ const Editor: FC<LexicalEditorProps> =({
<AutocompletePlugin options={options} enableJinja2={enableJinja2} />
<CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} />
<InitialValuePlugin value={value} options={options} enableJinja2={enableJinja2} />
{enableJinja2 && <BlurPlugin />}
</div>
</LexicalComposer>
);

View File

@@ -36,7 +36,7 @@ const VariableComponent: React.FC<{ nodeKey: NodeKey; data: Suggestion }> = ({
return (
<span
onClick={handleClick}
className={clsx('rb:border rb:rounded-md rb:bg-white rb:text-[12px] rb:inline-flex rb:items-center rb:py-0.5 rb:px-1.5 rb:mx-0.5 rb:cursor-pointer', {
className={clsx('rb:border rb:rounded-md rb:bg-white rb:text-[10px] rb:inline-flex rb:items-center rb:py-0 rb:px-1.5 rb:mx-0.5 rb:cursor-pointer', {
'rb:border-[#155EEF]': isSelected,
'rb:border-[#DFE4ED]': !isSelected
})}

View File

@@ -1,6 +1,6 @@
import { useEffect, useState, type FC } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { $getSelection, $isRangeSelection } from 'lexical';
import { $getSelection, $isRangeSelection, $isTextNode } from 'lexical';
import { INSERT_VARIABLE_COMMAND } from '../commands';
import type { NodeProperties } from '../../../types'
@@ -96,7 +96,9 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }>
const textAfter = nodeText.substring(anchorOffset);
const newText = textBefore + `{{${suggestion.value}}}` + textAfter;
anchorNode.setTextContent(newText);
if ($isTextNode(anchorNode)) {
anchorNode.setTextContent(newText);
}
// 设置光标位置到插入文本之后
const newOffset = textBefore.length + `{{${suggestion.value}}}`.length;
@@ -129,6 +131,8 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }>
}
return (
<div
data-autocomplete-popup="true"
onMouseDown={(e) => e.preventDefault()}
style={{
position: 'fixed',
top: popupPosition.top,

View File

@@ -0,0 +1,33 @@
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { useEffect } from 'react';
import { $setSelection } from 'lexical';
export default function BlurPlugin() {
const [editor] = useLexicalComposerContext();
useEffect(() => {
return editor.registerRootListener((rootElement) => {
if (rootElement) {
const handleBlur = (e: FocusEvent) => {
// 检查是否点击了自动完成弹窗
const target = e.target as HTMLElement;
console.log('target', target)
if (target?.closest('[data-autocomplete-popup="true"]')) {
return;
}
editor.update(() => {
$setSelection(null);
});
};
rootElement.addEventListener('blur', handleBlur);
return () => {
rootElement.removeEventListener('blur', handleBlur);
};
}
});
}, [editor]);
return null;
}

Some files were not shown because too many files have changed in this diff Show More