Add/develop memory (#264)

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 新增长期记忆功能

* 新增长期记忆功能

* 新增长期记忆功能

* 知识库检索多余字段

* 长期
This commit is contained in:
lixinyue11
2026-02-02 11:50:23 +08:00
committed by GitHub
parent 4b8b6fe407
commit 4e837cb90c
9 changed files with 1176 additions and 331 deletions

View File

@@ -11,7 +11,8 @@ import os
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
@@ -145,38 +146,33 @@ class LangChainAgent:
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
messages.append(HumanMessage(content=user_content))
return messages
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_save(self,messages,end_user_end,aimessages):
# '''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
# end_user_end=f"Term_{end_user_end}"
# print(messages)
# print(aimessages)
# session_id = store.save_session(
# userid=end_user_end,
# messages=messages,
# apply_id=end_user_end,
# end_user_id=end_user_end,
# aimessages=aimessages
# )
# store.delete_duplicate_sessions()
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
# return session_id
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_redis_read(self,end_user_end):
# end_user_end = f"Term_{end_user_end}"
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
# messagss_list=[]
# retrieved_content=[]
# for messages in history:
# query = messages.get("Query")
# aimessages = messages.get("Answer")
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
# retrieved_content.append({query: aimessages})
# return messagss_list,retrieved_content
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db())
scope=6
try:
repo = LongTermMemoryRepository(db)
await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type=="chunk" or type=="aggregate":
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'写入短长期:')
else:
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
finally:
db.close()
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
@@ -224,14 +220,6 @@ class LangChainAgent:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
# 调用 Celery 任务,传递结构化消息列表
# 数据流:
# 1. structured_messages 传递给 write_message_task
# 2. write_message_task 调用 memory_agent_service.write_memory
# 3. write_memory 调用 write_tools.write传递 messages 参数
# 4. write_tools.write 调用 get_chunked_dialogs传递 messages 参数
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk设置 speaker 字段
# 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
@@ -288,30 +276,6 @@ class LangChainAgent:
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# db_for_memory = next(get_db())
# if memory_flag:
# if len(history_term_memory)>=4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# print(retrieved_content)
# # 为长期记忆操作获取新的数据库连接
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# except Exception as e:
# logger.error(f"Failed to write to LongTermMemory: {e}")
# raise
# finally:
# db_for_memory.close()
# # 长期记忆写入(
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
@@ -339,10 +303,11 @@ class LangChainAgent:
elapsed_time = time.time() - start_time
if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, content)
'''长期'''
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
response = {
"content": content,
"model": self.model_name,
@@ -410,25 +375,7 @@ class LangChainAgent:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# # TODO 乐力齐
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# if memory_flag:
# if len(history_term_memory) >= 4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# db_for_memory = next(get_db())
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# # 长期记忆写入
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
# except Exception as e:
# logger.error(f"Failed to write to long term memory: {e}")
# finally:
# db_for_memory.close()
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
@@ -483,9 +430,9 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
if memory_flag:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, full_content)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)

View File

@@ -0,0 +1,165 @@
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_messages(end_user_id,langchain_messages,memory_config):
'''
写入数据到neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
langchain_messages原始数据LIST
'''
try:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config
}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
print(contents)
except Exception as e:
import traceback
traceback.print_exc()
'''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
'''
根据窗口获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
langchain_messages原始数据LIST
scope窗口大小
'''
scope=scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope):
print(is_end_user_id)
is_end_user_id += 1
langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
print('写入长期记忆并且设置为0')
print(is_end_user_id)
formatted_messages = await chat_data_format(redis_messages)
print(100*'-')
print(formatted_messages)
print(100*'-')
await write_messages(end_user_id, formatted_messages, memory_config)
count_store.update_sessions_count(end_user_id, 0, '')
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""根据时间"""
async def memory_long_term_storage(end_user_id,memory_config,time):
'''
根据时间获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
'''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
format_messages = await chat_data_format(long_time_data)
if format_messages!=[]:
await write_messages(end_user_id, format_messages, memory_config)
'''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
聚合判断函数:判断输入句子和历史消息是否描述同一事件
Args:
end_user_id: 终端用户ID
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: 内存配置对象
"""
try:
# 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
if not result:
history = []
else:
history = await format_parsing(result)
json_schema = WriteAggregateModel.model_json_schema()
template_service = TemplateService(template_root)
system_prompt = await template_service.render_template(
template_name='write_aggregate_judgment.jinja2',
operation_name='aggregate_judgment',
history=history,
sentence=ori_messages,
json_schema=json_schema
)
with get_db_context() as db_session:
factory = MemoryClientFactory(db_session)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
messages = [
{
"role": "user",
"content": system_prompt
}
]
structured = await llm_client.response_structured(
messages=messages,
response_model=WriteAggregateModel
)
output_value = structured.output
if isinstance(output_value, list):
output_value = [
{"role": msg.role, "content": msg.content}
for msg in output_value
]
result_dict = {
"is_same_event": structured.is_same_event,
"output": output_value
}
if not structured.is_same_event:
logger.info(result_dict)
await write_messages(end_user_id, output_value, memory_config)
return result_dict
except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
return {
"is_same_event": False,
"output": ori_messages,
"messages": ori_messages,
"history": history if 'history' in locals() else [],
"error": str(e)
}

View File

@@ -0,0 +1,100 @@
import json
from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'):
"""
格式化解析消息列表
Args:
messages: 消息列表
type: 返回类型 ('string''dict')
Returns:
格式化后的消息列表
"""
result = []
user=[]
ai=[]
for message in messages:
hstory_messages = message['messages']
for history_messag in hstory_messages.strip().splitlines():
history_messag = json.loads(history_messag)
for content in history_messag:
role = content['role']
content = content['content']
if type == "string":
if role == 'human':
content = '用户:' + content
else:
content = 'AI:' + content
result.append(content)
if type == "dict":
if role == 'human':
user.append( content)
else:
ai.append(content)
if type == "dict":
for key,values in zip(user,ai):
result.append({key:values})
return result
async def messages_parse(messages: list | dict):
user=[]
ai=[]
database=[]
for message in messages:
Query = message['Query']
Query = json.loads(Query)
for data in Query:
role = data['role']
if role == "human":
user.append(data['content'])
if role == "ai":
ai.append(data['content'])
for key, values in zip(user, ai):
database.append({key, values})
return database
async def chat_data_format(messages: list | dict):
"""
将消息格式化为 LangChain 消息格式
Args:
messages: 消息列表或字典
Returns:
LangChain 消息列表
"""
langchain_messages = []
if isinstance(messages, list):
for msg in messages:
if 'role' in msg.keys():
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
if "Query" in msg.keys():
langchain_messages.append(HumanMessage(content=msg['Query']))
langchain_messages.append(AIMessage(content=msg['Answer']))
if isinstance(messages, dict):
if messages['type'] == 'human':
langchain_messages.append(HumanMessage(content=messages['content']))
elif messages['type'] == 'ai':
langchain_messages.append(AIMessage(content=messages['content']))
return langchain_messages
async def agent_chat_messages(user_content,ai_content):
messages = [
{
"role": "user",
"content": f"{user_content}"
},
{
"role": "assistant",
"content": f"{ai_content}"
}
]
return messages

View File

@@ -1,20 +1,17 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langchain_core.messages import HumanMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
@@ -34,14 +31,6 @@ async def make_write_graph():
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
# workflow = StateGraph(WriteState)
# workflow.add_node("content_input", content_input_write)
# workflow.add_node("save_neo4j", write_node)
# workflow.add_edge(START, "content_input")
# workflow.add_edge("content_input", "save_neo4j")
# workflow.add_edge("save_neo4j", END)
#
# graph = workflow.compile()
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
@@ -51,43 +40,56 @@ async def make_write_graph():
yield graph
async def main():
"""主函数 - 运行工作流"""
message = "今天周一"
end_user_id = 'new_2025test1103' # 组ID
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=17, # 改为整数
config_id="08ed205c-0f05-49c3-8e0c-a580d28f5fd4", # 改为整数
service_name="MemoryAgentService"
)
try:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j'==node_name:
massages=node_data
massages=massages.get('write_result')['status']
print(massages) # | 更新数据: {node_data}
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
except Exception as e:
import traceback
traceback.print_exc()
if __name__ == "__main__":
import asyncio
asyncio.run(main())
#
# async def main():
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五好开心啊"
# },
# {
# "role": "assistant",
# "content": "你也这么觉得,我也是耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
# from app.core.memory.agent.utils.redis_tool import write_store
# result=write_store.get_session_by_userid(end_user_id)
# data=await format_parsing(result,"dict")
# chunk_data=data[:6]
#
# long_time_data = write_store.find_user_recent_sessions(end_user_id, 240)
# long_=await messages_parse(long_time_data)
# print(long_)
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())

View File

@@ -0,0 +1,28 @@
"""Pydantic models for write aggregate judgment operations."""
from typing import List, Union
from pydantic import BaseModel, Field
class MessageItem(BaseModel):
"""Individual message item in conversation."""
role: str = Field(..., description="角色user 或 assistant")
content: str = Field(..., description="消息内容")
class WriteAggregateResponse(BaseModel):
"""Response model for aggregate judgment containing judgment result and output."""
is_same_event: bool = Field(
...,
description="是否是同一事件。True表示是同一事件False表示不同事件"
)
output: Union[List[MessageItem], bool] = Field(
...,
description="如果is_same_event为True返回False如果is_same_event为False返回消息列表"
)
# 为了保持向后兼容,保留旧的类名作为别名
WriteAggregateModel = WriteAggregateResponse

View File

@@ -0,0 +1,57 @@
输入句子:{{sentence}}
历史消息:{{history}}
# 你的角色
你是一个擅长事件聚合与语义判断的专家。
# 你的任务
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False
- 描述的是同一个具体事件或事实
- 存在明显的因果关系、前后发展关系
- 是对同一事件的补充、解释、追问或延展
- 逻辑上属于同一语境下的连续讨论
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
- 话题不同,事件主体不同
- 时间、地点、对象明显不同
- 只是语义相似,但并非同一具体事件
- 无直接事件、因果或逻辑关联
# 输出规则(非常重要)
你必须按照以下JSON格式输出
**如果是同一事件:**
```json
{
"is_same_event": true,
"output": false
}
```
**如果不是同一事件:**
```json
{
"is_same_event": false,
"output": [
{
"role": "user",
"content": "输入句子的内容"
},
{
"role": "assistant",
"content": "对应的回复内容"
}
]
}
```
# JSON Schema
{{json_schema}}
# 注意事项
- 必须严格按照上述格式输出
- output 字段:如果是同一事件返回 false如果不是同一事件返回完整的消息列表
- 消息列表必须包含 role 和 content 字段
- 不要输出任何解释、分析或多余内容

View File

@@ -0,0 +1,186 @@
import json
from typing import Any, List, Dict, Optional
from datetime import datetime, timedelta
def serialize_messages(messages: Any) -> str:
"""
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
Args:
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
Returns:
str: JSON 字符串
"""
if isinstance(messages, str):
return messages
if isinstance(messages, (list, tuple)):
# 检查是否是 LangChain 消息对象列表
serialized_list = []
for msg in messages:
if hasattr(msg, 'type') and hasattr(msg, 'content'):
# LangChain 消息对象
serialized_list.append({
'type': msg.type,
'content': msg.content,
'role': getattr(msg, 'role', msg.type)
})
elif isinstance(msg, dict):
serialized_list.append(msg)
else:
serialized_list.append(str(msg))
return json.dumps(serialized_list, ensure_ascii=False)
if isinstance(messages, dict):
return json.dumps(messages, ensure_ascii=False)
# 其他类型转为字符串
return str(messages)
def deserialize_messages(messages_str: str) -> Any:
"""
将 JSON 字符串反序列化为原始格式
Args:
messages_str: JSON 字符串
Returns:
反序列化后的对象list、dict 或 string
"""
if not messages_str:
return []
try:
return json.loads(messages_str)
except (json.JSONDecodeError, TypeError):
return messages_str
def fix_encoding(text: str) -> str:
"""
修复错误编码的文本
Args:
text: 需要修复的文本
Returns:
str: 修复后的文本
"""
if not text or not isinstance(text, str):
return text
try:
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
return text.encode('latin-1').decode('utf-8')
except (UnicodeDecodeError, UnicodeEncodeError):
# 如果修复失败,返回原文本
return text
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
"""
格式化会话数据为统一的输出格式
Args:
data: 原始会话数据
include_time: 是否包含时间字段
Returns:
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
"""
result = {
"Query": fix_encoding(data.get('messages', '')),
"Answer": fix_encoding(data.get('aimessages', ''))
}
if include_time:
result["starttime"] = data.get('starttime', '')
return result
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
"""
根据时间范围过滤数据
Args:
items: 包含 starttime 字段的数据列表
minutes: 时间范围(分钟)
Returns:
List[Dict]: 过滤后的数据列表
"""
time_threshold = datetime.now() - timedelta(minutes=minutes)
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
filtered_items = []
for item in items:
starttime = item.get('starttime', '')
if starttime and starttime >= time_threshold_str:
filtered_items.append(item)
return filtered_items
def sort_and_limit_results(items: List[Dict], limit: int = 6,
remove_time: bool = True) -> List[Dict]:
"""
对结果进行排序、限制数量并移除时间字段
Args:
items: 数据列表
limit: 最大返回数量
remove_time: 是否移除 starttime 字段
Returns:
List[Dict]: 处理后的数据列表
"""
# 按时间降序排序(最新的在前)
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
# 限制数量
result_items = items[:limit]
# 移除 starttime 字段
if remove_time:
for item in result_items:
item.pop('starttime', None)
# 如果结果少于1条返回空列表
if len(result_items) < 1:
return []
return result_items
def generate_session_key(session_id: str, key_type: str = "session") -> str:
"""
生成 Redis key
Args:
session_id: 会话ID
key_type: key 类型 ("session", "read", "write", "count")
Returns:
str: Redis key
"""
if key_type == "count":
return f"session:count:{session_id}"
elif key_type == "write":
return f"session:write:{session_id}"
elif key_type == "session" or key_type == "read":
return f"session:{session_id}"
else:
return f"session:{session_id}"
def get_current_timestamp() -> str:
"""
获取当前时间戳字符串
Returns:
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
"""
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -1,11 +1,36 @@
import redis
import uuid
from datetime import datetime
from app.core.config import settings
from typing import List, Dict, Any, Optional, Union
from app.core.memory.agent.utils.redis_base import (
serialize_messages,
deserialize_messages,
fix_encoding,
format_session_data,
filter_by_time_range,
sort_and_limit_results,
generate_session_key,
get_current_timestamp
)
class RedisSessionStore:
class RedisWriteStore:
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
@@ -16,32 +41,400 @@ class RedisSessionStore:
)
self.uudi = session_id
def _fix_encoding(self, text):
"""修复错误编码的文本"""
if not text or not isinstance(text, str):
return text
try:
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
return text.encode('latin-1').decode('utf-8')
except (UnicodeDecodeError, UnicodeEncodeError):
# 如果修复失败,返回原文本
return text
# 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
def save_session_write(self, userid: str, messages: str) -> str:
"""
写入一条会话数据,返回 session_id
优化版本确保写入时间不超过1秒
Args:
userid: 用户ID
messages: 用户消息
Returns:
str: 新生成的 session_id
"""
try:
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
messages = serialize_messages(messages)
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="write")
# 使用 pipeline 批量写入,减少网络往返
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": userid,
"messages": messages,
"starttime": get_current_timestamp()
})
result = pipe.execute()
# 直接写入数据decode_responses=True 已经处理了编码
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"[save_session_write] 保存会话失败: {e}")
raise e
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
"""
通过 save_session_write 的 userid 获取 sessionid 和 messages
Args:
userid: 用户ID (对应 sessionid 字段)
Returns:
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
"""
try:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
return False
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 userid 的数据
results = []
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid:
# 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1]
results.append({
"sessionid": session_id,
"messages": fix_encoding(data.get('messages', ''))
})
if not results:
return False
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_session_by_userid] 查询失败: {e}")
return False
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
"""
通过 end_user_id 获取所有 write 类型的会话数据
Args:
end_user_id: 终端用户ID (对应 sessionid 字段)
Returns:
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
返回格式:
[
{
"session_id": "uuid",
"id": "...",
"sessionid": "end_user_id",
"messages": "...",
"starttime": "timestamp"
},
...
]
"""
try:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
return False
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 end_user_id 的数据
results = []
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == end_user_id:
# 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1]
# 构建完整的会话信息
session_info = {
"session_id": session_id,
"id": data.get('id', ''),
"sessionid": data.get('sessionid', ''),
"messages": fix_encoding(data.get('messages', '')),
"starttime": data.get('starttime', '')
}
results.append(session_info)
if not results:
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
return False
# 按时间排序(最新的在前)
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
import traceback
traceback.print_exc()
return False
def find_user_recent_sessions(self, userid: str,
minutes: int = 5) -> List[Dict[str, str]]:
"""
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
Args:
userid: 用户ID (对应 sessionid 字段)
minutes: 查询最近几分钟的数据默认5分钟
Returns:
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
"""
import time
start_time = time.time()
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 userid 的数据
matched_items = []
for data in all_data:
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid and data.get('starttime'):
# write 类型没有 aimessages所以 Answer 为空
matched_items.append({
"Query": fix_encoding(data.get('messages', '')),
"Answer": "",
"starttime": data.get('starttime', '')
})
# 根据时间范围过滤
filtered_items = filter_by_time_range(matched_items, minutes)
# 排序并移除时间字段
result_items = sort_and_limit_results(filtered_items, limit=None)
print(result_items)
elapsed_time = time.time() - start_time
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
def delete_all_write_sessions(self) -> int:
"""
删除所有 write 类型的会话
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:write:*')
if keys:
return self.r.delete(*keys)
return 0
class RedisCountStore:
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
db=db,
password=password,
decode_responses=True,
encoding='utf-8'
)
self.uudi = session_id
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
"""
保存用户访问次数统计
Args:
end_user_id: 终端用户ID
count: 访问次数
messages: 消息内容
Returns:
str: 新生成的 session_id
"""
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="count")
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"end_user_id": end_user_id,
"count": int(count),
"messages": serialize_messages(messages),
"starttime": get_current_timestamp()
})
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
return session_id
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
"""
通过 end_user_id 查询访问次数统计
Args:
end_user_id: 终端用户ID
Returns:
list 或 False: 如果找到返回 [count, messages],否则返回 False
"""
try:
search_pattern = 'session:count:*'
for key in self.r.keys(search_pattern):
data = self.r.hgetall(key)
if not data:
continue
if data.get('end_user_id') == end_user_id:
count = data.get('count')
messages_str = data.get('messages')
if count is not None:
messages = deserialize_messages(messages_str)
return [int(count), messages]
return False
except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}")
return False
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool:
"""
通过 end_user_id 修改访问次数统计
Args:
end_user_id: 终端用户ID
new_count: 新的 count 值
messages: 消息内容
Returns:
bool: 更新成功返回 True未找到记录返回 False
"""
try:
messages_str = serialize_messages(messages)
search_pattern = 'session:count:*'
for key in self.r.keys(search_pattern):
data = self.r.hgetall(key)
if not data:
continue
if data.get('end_user_id') == end_user_id:
self.r.hset(key, 'count', int(new_count))
self.r.hset(key, 'messages', messages_str)
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}")
return False
def delete_all_count_sessions(self) -> int:
"""
删除所有 count 类型的会话
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:count:*')
if keys:
return self.r.delete(*keys)
return 0
class RedisSessionStore:
"""Redis 会话存储类,用于管理会话数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
db=db,
password=password,
decode_responses=True,
encoding='utf-8'
)
self.uudi = session_id
# ==================== 写入操作 ====================
def save_session(self, userid: str, messages: str, aimessages: str,
apply_id: str, end_user_id: str) -> str:
"""
写入一条会话数据,返回 session_id
Args:
userid: 用户ID
messages: 用户消息
aimessages: AI回复消息
apply_id: 应用ID
end_user_id: 终端用户ID
Returns:
str: 新生成的 session_id
"""
try:
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="read")
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": userid,
@@ -49,177 +442,195 @@ class RedisSessionStore:
"end_user_id": end_user_id,
"messages": messages,
"aimessages": aimessages,
"starttime": starttime
"starttime": get_current_timestamp()
})
# 可选设置过期时间例如30天避免数据无限增长
# pipe.expire(key, 30 * 24 * 60 * 60)
# 执行批量操作
result = pipe.execute()
print(f"保存结果: {result[0]}, session_id: {session_id}")
return session_id # 返回新生成的 session_id
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"保存会话失败: {e}")
print(f"[save_session] 保存会话失败: {e}")
raise e
def save_sessions_batch(self, sessions_data):
"""
批量写入多条会话数据,返回 session_id 列表
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
优化版本:批量操作,大幅提升性能
"""
try:
session_ids = []
pipe = self.r.pipeline()
for session in sessions_data:
session_id = str(uuid.uuid4())
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
key = f"session:{session_id}"
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": session.get('userid'),
"apply_id": session.get('apply_id'),
"end_user_id": session.get('end_user_id'),
"messages": session.get('messages'),
"aimessages": session.get('aimessages'),
"starttime": starttime
})
session_ids.append(session_id)
# 一次性执行所有写入操作
results = pipe.execute()
print(f"批量保存完成: {len(session_ids)} 条记录")
return session_ids
except Exception as e:
print(f"批量保存会话失败: {e}")
raise e
# ---------------- 读取 ----------------
def get_session(self, session_id):
# ==================== 读取操作 ====================
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
读取一条会话数据
Args:
session_id: 会话ID
Returns:
Dict 或 None: 会话数据
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
data = self.r.hgetall(key)
return data if data else None
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
"""
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
"""
result_items = []
# 遍历所有会话数据
for key in self.r.keys('session:*'):
data = self.r.hgetall(key)
if not data:
continue
# 检查三个条件是否都匹配
if (data.get('sessionid') == sessionid and
data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
result_items.append(data)
return result_items
def get_all_sessions(self):
"""
获取所有会话数据
获取所有会话数据(不包括 count 和 write 类型)
Returns:
Dict: 所有会话数据key 为 session_id
"""
sessions = {}
for key in self.r.keys('session:*'):
sid = key.split(':')[1]
sessions[sid] = self.get_session(sid)
# 排除 count 和 write 类型的 key
if ':count:' not in key and ':write:' not in key:
sid = key.split(':')[1]
sessions[sid] = self.get_session(sid)
return sessions
# ---------------- 更新 ----------------
def update_session(self, session_id, field, value):
def find_user_apply_group(self, sessionid: str, apply_id: str,
end_user_id: str) -> List[Dict[str, str]]:
"""
根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条
Args:
sessionid: 会话ID支持模糊匹配
apply_id: 应用ID
end_user_id: 终端用户ID
Returns:
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
"""
import time
start_time = time.time()
keys = self.r.keys('session:*')
if not keys:
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
# 排除 count 和 write 类型
if ':count:' not in key and ':write:' not in key:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合条件的数据
matched_items = []
for data in all_data:
if not data:
continue
if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
# 支持模糊匹配或完全匹配 sessionid
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append(format_session_data(data, include_time=True))
# 排序、限制数量并移除时间字段
result_items = sort_and_limit_results(matched_items, limit=6)
elapsed_time = time.time() - start_time
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
# ==================== 更新操作 ====================
def update_session(self, session_id: str, field: str, value: Any) -> bool:
"""
更新单个字段
优化版本:使用 pipeline 减少网络往返
Args:
session_id: 会话ID
field: 字段名
value: 字段值
Returns:
bool: 是否更新成功
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
pipe = self.r.pipeline()
pipe.exists(key)
pipe.hset(key, field, value)
results = pipe.execute()
return bool(results[0]) # 返回 key 是否存在
return bool(results[0])
# ---------------- 删除 ----------------
def delete_session(self, session_id):
# ==================== 删除操作 ====================
def delete_session(self, session_id: str) -> int:
"""
删除单条会话
Args:
session_id: 会话ID
Returns:
int: 删除的数量
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
return self.r.delete(key)
def delete_all_sessions(self):
def delete_all_sessions(self) -> int:
"""
删除所有会话
删除所有会话(不包括 count 和 write 类型)
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:*')
if keys:
return self.r.delete(*keys)
# 过滤掉 count 和 write 类型
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
if keys_to_delete:
return self.r.delete(*keys_to_delete)
return 0
def delete_duplicate_sessions(self):
def delete_duplicate_sessions(self) -> int:
"""
删除重复会话数据,条件:
"sessionid""user_id""end_user_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
优化版本:使用 pipeline 批量操作确保在1秒内完成
删除重复会话数据(不包括 count 和 write 类型)
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
Returns:
int: 删除的数量
"""
import time
start_time = time.time()
# 第一步:使用 pipeline 批量获取所有 key
keys = self.r.keys('session:*')
if not keys:
print("[delete_duplicate_sessions] 没有会话数据")
return 0
# 第二步:使用 pipeline 批量获取所有数据
# 批量获取所有数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
# 排除 count 和 write 类型
if ':count:' not in key and ':write:' not in key:
pipe.hgetall(key)
all_data = pipe.execute()
# 第三步:在内存中识别重复数据
seen = {} # 用字典记录identifier -> key保留第一个出现的 key
keys_to_delete = [] # 需要删除的 key 列表
# 识别重复数据
seen = {}
keys_to_delete = []
for key, data in zip(keys, all_data, strict=False):
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
if not data:
continue
# 获取五个字段的值
sessionid = data.get('sessionid', '')
user_id = data.get('id', '')
end_user_id = data.get('end_user_id', '')
messages = data.get('messages', '')
aimessages = data.get('aimessages', '')
# 用五元组作为唯一标识
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
identifier = (
data.get('sessionid', ''),
data.get('id', ''),
data.get('end_user_id', ''),
data.get('messages', ''),
data.get('aimessages', '')
)
if identifier in seen:
# 重复,标记为待删除
keys_to_delete.append(key)
else:
# 第一次出现,记录
seen[identifier] = key
# 第四步:使用 pipeline 批量删除重复的 key
# 批量删除重复的 key
deleted_count = 0
if keys_to_delete:
# 分批删除,避免单次操作过大
batch_size = 1000
for i in range(0, len(keys_to_delete), batch_size):
batch = keys_to_delete[i:i + batch_size]
@@ -233,79 +644,28 @@ class RedisSessionStore:
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
return deleted_count
def find_user_session(self, sessionid):
user_id = sessionid
result_items = []
for key, values in store.get_all_sessions().items():
history = {}
if user_id == str(values['sessionid']):
history["Query"] = values['messages']
history["Answer"] = values['aimessages']
result_items.append(history)
if len(result_items) <= 1:
result_items = []
return (result_items)
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
"""
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据返回最新的6条
"""
import time
start_time = time.time()
# 使用 pipeline 批量获取数据,提高性能
keys = self.r.keys('session:*')
if not keys:
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 使用 pipeline 批量获取所有 hash 数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 解析并筛选符合条件的数据
matched_items = []
for data in all_data:
if not data:
continue
# 检查是否符合三个条件
if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
# 支持模糊匹配 sessionid 或者完全匹配
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append({
"Query": self._fix_encoding(data.get('messages')),
"Answer": self._fix_encoding(data.get('aimessages')),
"starttime": data.get('starttime', '')
})
# 按时间降序排序(最新的在前)
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
# 只保留最新的6条
result_items = matched_items[:6]
# # 移除 starttime 字段
for item in result_items:
item.pop('starttime', None)
# 如果结果少于等于1条返回空列表
if len(result_items) <= 1:
result_items = []
elapsed_time = time.time() - start_time
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
# 全局实例
store = RedisSessionStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)
)
write_store = RedisWriteStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)
count_store = RedisCountStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)

View File

@@ -12,8 +12,8 @@ class KnowledgeBaseConfig(BaseModel):
kb_id: str = Field(..., description="知识库ID")
top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量")
similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值")
strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
# strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
# weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重")
retrieve_type: str = Field(default="hybrid", description="检索方式participle semantichybrid")