diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index a34c781f..f75e3432 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,7 +11,8 @@ import os import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence - +from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse +from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage from app.db import get_db from app.core.logging_config import get_business_logger from app.core.memory.agent.utils.redis_tool import store @@ -145,38 +146,33 @@ class LangChainAgent: user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}" messages.append(HumanMessage(content=user_content)) - return messages -# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # async def term_memory_save(self,messages,end_user_end,aimessages): - # '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' - # end_user_end=f"Term_{end_user_end}" - # print(messages) - # print(aimessages) - # session_id = store.save_session( - # userid=end_user_end, - # messages=messages, - # apply_id=end_user_end, - # end_user_id=end_user_end, - # aimessages=aimessages - # ) - # store.delete_duplicate_sessions() - # # logger.info(f'Redis_Agent:{end_user_end};{session_id}') - # return session_id -# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # async def term_memory_redis_read(self,end_user_end): - # end_user_end = f"Term_{end_user_end}" - # history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) - # # logger.info(f'Redis_Agent:{end_user_end};{history}') - # messagss_list=[] - # retrieved_content=[] - # for messages in history: - # query = messages.get("Query") - # aimessages = messages.get("Answer") - # messagss_list.append(f'用户:{query}。AI回复:{aimessages}') - # retrieved_content.append({query: aimessages}) - # return messagss_list,retrieved_content + async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): + db = next(get_db()) + scope=6 + + try: + repo = LongTermMemoryRepository(db) + await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) + + from app.core.memory.agent.utils.redis_tool import write_store + result = write_store.get_session_by_userid(end_user_id) + if type=="chunk" or type=="aggregate": + data = await format_parsing(result, "dict") + chunk_data = data[:scope] + if len(chunk_data)==scope: + repo.upsert(end_user_id, chunk_data) + logger.info(f'写入短长期:') + else: + long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + long_messages = await messages_parse(long_time_data) + repo.upsert(end_user_id, long_messages) + logger.info(f'写入短长期:') + finally: + db.close() + async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): """ 写入记忆(支持结构化消息) @@ -224,14 +220,6 @@ class LangChainAgent: logger.warning(f"No messages to write for user {actual_end_user_id}") return - # 调用 Celery 任务,传递结构化消息列表 - # 数据流: - # 1. structured_messages 传递给 write_message_task - # 2. write_message_task 调用 memory_agent_service.write_memory - # 3. write_memory 调用 write_tools.write,传递 messages 参数 - # 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数 - # 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段 - # 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段 logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") write_id = write_message_task.delay( actual_end_user_id, # end_user_id: 用户ID @@ -288,30 +276,6 @@ class LangChainAgent: actual_end_user_id = end_user_id if end_user_id is not None else "unknown" logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') -# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码 -# history_term_memory_result = await self.term_memory_redis_read(end_user_id) -# history_term_memory = history_term_memory_result[0] -# db_for_memory = next(get_db()) -# if memory_flag: -# if len(history_term_memory)>=4 and storage_type != "rag": -# history_term_memory = ';'.join(history_term_memory) -# retrieved_content = history_term_memory_result[1] -# print(retrieved_content) -# # 为长期记忆操作获取新的数据库连接 -# try: -# repo = LongTermMemoryRepository(db_for_memory) -# repo.upsert(end_user_id, retrieved_content) -# logger.info( -# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') -# except Exception as e: -# logger.error(f"Failed to write to LongTermMemory: {e}") -# raise -# finally: -# db_for_memory.close() - -# # 长期记忆写入( -# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id) -# # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表 messages = self._prepare_messages(message, history, context) @@ -339,10 +303,11 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: + long_term_messages=await agent_chat_messages(message_chat,content) # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # await self.term_memory_save(message_chat, end_user_id, content) + '''长期''' + await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") response = { "content": content, "model": self.model_name, @@ -410,25 +375,7 @@ class LangChainAgent: db.close() except Exception as e: logger.warning(f"Failed to get db session: {e}") -# # TODO 乐力齐 -# history_term_memory_result = await self.term_memory_redis_read(end_user_id) -# history_term_memory = history_term_memory_result[0] -# if memory_flag: -# if len(history_term_memory) >= 4 and storage_type != "rag": -# history_term_memory = ';'.join(history_term_memory) -# retrieved_content = history_term_memory_result[1] -# db_for_memory = next(get_db()) -# try: -# repo = LongTermMemoryRepository(db_for_memory) -# repo.upsert(end_user_id, retrieved_content) -# logger.info( -# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') -# # 长期记忆写入 -# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id) -# except Exception as e: -# logger.error(f"Failed to write to long term memory: {e}") -# finally: -# db_for_memory.close() + # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: @@ -483,9 +430,9 @@ class LangChainAgent: logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") if memory_flag: # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + long_term_messages = await agent_chat_messages(message_chat, full_content) await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) - # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # await self.term_memory_save(message_chat, end_user_id, full_content) + await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py new file mode 100644 index 00000000..d6fbbb38 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -0,0 +1,165 @@ +import os + +from app.core.logging_config import get_agent_logger +from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing +from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph + +from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel +from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ +from app.core.memory.agent.utils.redis_tool import write_store +from app.core.memory.agent.utils.redis_tool import count_store +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context +logger = get_agent_logger(__name__) +template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') + + +async def write_messages(end_user_id,langchain_messages,memory_config): + ''' + 写入数据到neo4j: + Args: + end_user_id: 终端用户ID + memory_config: 内存配置对象 + langchain_messages:原始数据LIST + ''' + try: + + async with make_write_graph() as graph: + config = {"configurable": {"thread_id": end_user_id}} + # 初始状态 - 包含所有必要字段 + initial_state = { + "messages": langchain_messages, + "end_user_id": end_user_id, + "memory_config": memory_config + } + + # 获取节点更新信息 + async for update_event in graph.astream( + initial_state, + stream_mode="updates", + config=config + ): + for node_name, node_data in update_event.items(): + if 'save_neo4j' == node_name: + massages = node_data + massagesstatus = massages.get('write_result')['status'] + contents = massages.get('write_result') + print(contents) + except Exception as e: + import traceback + traceback.print_exc() +'''根据窗口''' +async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): + ''' + 根据窗口获取redis数据,写入neo4j: + Args: + end_user_id: 终端用户ID + memory_config: 内存配置对象 + langchain_messages:原始数据LIST + scope:窗口大小 + ''' + scope=scope + is_end_user_id = count_store.get_sessions_count(end_user_id) + if is_end_user_id is not False: + is_end_user_id = count_store.get_sessions_count(end_user_id)[0] + redis_messages = count_store.get_sessions_count(end_user_id)[1] + if is_end_user_id and int(is_end_user_id) != int(scope): + print(is_end_user_id) + is_end_user_id += 1 + langchain_messages += redis_messages + count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) + elif int(is_end_user_id) == int(scope): + print('写入长期记忆,并且设置为0') + print(is_end_user_id) + formatted_messages = await chat_data_format(redis_messages) + print(100*'-') + print(formatted_messages) + print(100*'-') + await write_messages(end_user_id, formatted_messages, memory_config) + count_store.update_sessions_count(end_user_id, 0, '') + else: + count_store.save_sessions_count(end_user_id, 1, langchain_messages) + + +"""根据时间""" +async def memory_long_term_storage(end_user_id,memory_config,time): + ''' + 根据时间获取redis数据,写入neo4j: + Args: + end_user_id: 终端用户ID + memory_config: 内存配置对象 + ''' + long_time_data = write_store.find_user_recent_sessions(end_user_id, time) + format_messages = await chat_data_format(long_time_data) + if format_messages!=[]: + await write_messages(end_user_id, format_messages, memory_config) +'''聚合判断''' +async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: + """ + 聚合判断函数:判断输入句子和历史消息是否描述同一事件 + + Args: + end_user_id: 终端用户ID + ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] + memory_config: 内存配置对象 + """ + + try: + # 1. 获取历史会话数据(使用新方法) + result = write_store.get_all_sessions_by_end_user_id(end_user_id) + history = await format_parsing(result) + if not result: + history = [] + else: + history = await format_parsing(result) + json_schema = WriteAggregateModel.model_json_schema() + template_service = TemplateService(template_root) + system_prompt = await template_service.render_template( + template_name='write_aggregate_judgment.jinja2', + operation_name='aggregate_judgment', + history=history, + sentence=ori_messages, + json_schema=json_schema + ) + with get_db_context() as db_session: + factory = MemoryClientFactory(db_session) + llm_client = factory.get_llm_client(memory_config.llm_model_id) + messages = [ + { + "role": "user", + "content": system_prompt + } + ] + structured = await llm_client.response_structured( + messages=messages, + response_model=WriteAggregateModel + ) + output_value = structured.output + if isinstance(output_value, list): + output_value = [ + {"role": msg.role, "content": msg.content} + for msg in output_value + ] + + result_dict = { + "is_same_event": structured.is_same_event, + "output": output_value + } + if not structured.is_same_event: + logger.info(result_dict) + await write_messages(end_user_id, output_value, memory_config) + return result_dict + + except Exception as e: + print(f"[aggregate_judgment] 发生错误: {e}") + import traceback + traceback.print_exc() + + return { + "is_same_event": False, + "output": ori_messages, + "messages": ori_messages, + "history": history if 'history' in locals() else [], + "error": str(e) + } \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py new file mode 100644 index 00000000..a1fb8226 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -0,0 +1,100 @@ +import json + +from langchain_core.messages import HumanMessage, AIMessage + + +async def format_parsing(messages: list,type:str='string'): + """ + 格式化解析消息列表 + + Args: + messages: 消息列表 + type: 返回类型 ('string' 或 'dict') + + Returns: + 格式化后的消息列表 + """ + result = [] + user=[] + ai=[] + + for message in messages: + hstory_messages = message['messages'] + for history_messag in hstory_messages.strip().splitlines(): + history_messag = json.loads(history_messag) + for content in history_messag: + role = content['role'] + content = content['content'] + if type == "string": + if role == 'human': + content = '用户:' + content + else: + content = 'AI:' + content + result.append(content) + if type == "dict": + if role == 'human': + user.append( content) + else: + ai.append(content) + if type == "dict": + for key,values in zip(user,ai): + result.append({key:values}) + return result + +async def messages_parse(messages: list | dict): + user=[] + ai=[] + database=[] + for message in messages: + Query = message['Query'] + Query = json.loads(Query) + for data in Query: + role = data['role'] + if role == "human": + user.append(data['content']) + if role == "ai": + ai.append(data['content']) + for key, values in zip(user, ai): + database.append({key, values}) + return database +async def chat_data_format(messages: list | dict): + """ + 将消息格式化为 LangChain 消息格式 + + Args: + messages: 消息列表或字典 + + Returns: + LangChain 消息列表 + """ + langchain_messages = [] + if isinstance(messages, list): + for msg in messages: + if 'role' in msg.keys(): + if msg['role'] == 'user': + langchain_messages.append(HumanMessage(content=msg['content'])) + elif msg['role'] == 'assistant': + langchain_messages.append(AIMessage(content=msg['content'])) + if "Query" in msg.keys(): + langchain_messages.append(HumanMessage(content=msg['Query'])) + langchain_messages.append(AIMessage(content=msg['Answer'])) + if isinstance(messages, dict): + if messages['type'] == 'human': + langchain_messages.append(HumanMessage(content=messages['content'])) + elif messages['type'] == 'ai': + langchain_messages.append(AIMessage(content=messages['content'])) + return langchain_messages + +async def agent_chat_messages(user_content,ai_content): + messages = [ + { + "role": "user", + "content": f"{user_content}" + }, + { + "role": "assistant", + "content": f"{ai_content}" + } + + ] + return messages diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 8b5de444..5101fa29 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,20 +1,17 @@ import asyncio +import json import sys import warnings from contextlib import asynccontextmanager - - -from langchain_core.messages import HumanMessage from langgraph.constants import END, START from langgraph.graph import StateGraph - +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node -from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) @@ -34,14 +31,6 @@ async def make_write_graph(): end_user_id: Group identifier memory_config: MemoryConfig object containing all configuration """ - # workflow = StateGraph(WriteState) - # workflow.add_node("content_input", content_input_write) - # workflow.add_node("save_neo4j", write_node) - # workflow.add_edge(START, "content_input") - # workflow.add_edge("content_input", "save_neo4j") - # workflow.add_edge("save_neo4j", END) - # - # graph = workflow.compile() workflow = StateGraph(WriteState) workflow.add_node("save_neo4j", write_node) workflow.add_edge(START, "save_neo4j") @@ -51,43 +40,56 @@ async def make_write_graph(): yield graph - -async def main(): - """主函数 - 运行工作流""" - message = "今天周一" - end_user_id = 'new_2025test1103' # 组ID - - +async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): + from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment + from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format + from app.core.memory.agent.utils.redis_tool import write_store + write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) # 获取数据库会话 db_session = next(get_db()) config_service = MemoryConfigService(db_session) memory_config = config_service.load_memory_config( - config_id=17, # 改为整数 + config_id="08ed205c-0f05-49c3-8e0c-a580d28f5fd4", # 改为整数 service_name="MemoryAgentService" ) - try: - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config} + if long_term_type=='chunk': + '''方案一:对话窗口6轮对话''' + await window_dialogue(end_user_id,langchain_messages,memory_config,scope) + if long_term_type=='time': + """时间""" + await memory_long_term_storage(end_user_id, memory_config,5) + if long_term_type=='aggregate': - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j'==node_name: - massages=node_data - massages=massages.get('write_result')['status'] - print(massages) # | 更新数据: {node_data} + """方案三:聚合判断""" + await aggregate_judgment(end_user_id, langchain_messages, memory_config) - except Exception as e: - import traceback - traceback.print_exc() - - -if __name__ == "__main__": - import asyncio - asyncio.run(main()) \ No newline at end of file +# +# async def main(): +# """主函数 - 运行工作流""" +# langchain_messages = [ +# { +# "role": "user", +# "content": "今天周五好开心啊" +# }, +# { +# "role": "assistant", +# "content": "你也这么觉得,我也是耶" +# } +# +# ] +# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID +# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" +# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# from app.core.memory.agent.utils.redis_tool import write_store +# result=write_store.get_session_by_userid(end_user_id) +# data=await format_parsing(result,"dict") +# chunk_data=data[:6] +# +# long_time_data = write_store.find_user_recent_sessions(end_user_id, 240) +# long_=await messages_parse(long_time_data) +# print(long_) +# +# +# if __name__ == "__main__": +# import asyncio +# asyncio.run(main()) \ No newline at end of file diff --git a/api/app/core/memory/agent/models/write_aggregate_model.py b/api/app/core/memory/agent/models/write_aggregate_model.py new file mode 100644 index 00000000..fd423314 --- /dev/null +++ b/api/app/core/memory/agent/models/write_aggregate_model.py @@ -0,0 +1,28 @@ +"""Pydantic models for write aggregate judgment operations.""" + +from typing import List, Union +from pydantic import BaseModel, Field + + +class MessageItem(BaseModel): + """Individual message item in conversation.""" + + role: str = Field(..., description="角色:user 或 assistant") + content: str = Field(..., description="消息内容") + + +class WriteAggregateResponse(BaseModel): + """Response model for aggregate judgment containing judgment result and output.""" + + is_same_event: bool = Field( + ..., + description="是否是同一事件。True表示是同一事件,False表示不同事件" + ) + output: Union[List[MessageItem], bool] = Field( + ..., + description="如果is_same_event为True,返回False;如果is_same_event为False,返回消息列表" + ) + + +# 为了保持向后兼容,保留旧的类名作为别名 +WriteAggregateModel = WriteAggregateResponse diff --git a/api/app/core/memory/agent/utils/prompt/write_aggregate_judgment.jinja2 b/api/app/core/memory/agent/utils/prompt/write_aggregate_judgment.jinja2 new file mode 100644 index 00000000..fb0247aa --- /dev/null +++ b/api/app/core/memory/agent/utils/prompt/write_aggregate_judgment.jinja2 @@ -0,0 +1,57 @@ +输入句子:{{sentence}} +历史消息:{{history}} + +# 你的角色 +你是一个擅长事件聚合与语义判断的专家。 + +# 你的任务 +结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。 + +以下情况视为"同一事件"(需要返回 is_same_event=True, output=False): +- 描述的是同一个具体事件或事实 +- 存在明显的因果关系、前后发展关系 +- 是对同一事件的补充、解释、追问或延展 +- 逻辑上属于同一语境下的连续讨论 + +以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表): +- 话题不同,事件主体不同 +- 时间、地点、对象明显不同 +- 只是语义相似,但并非同一具体事件 +- 无直接事件、因果或逻辑关联 + +# 输出规则(非常重要) +你必须按照以下JSON格式输出: + +**如果是同一事件:** +```json +{ + "is_same_event": true, + "output": false +} +``` + +**如果不是同一事件:** +```json +{ + "is_same_event": false, + "output": [ + { + "role": "user", + "content": "输入句子的内容" + }, + { + "role": "assistant", + "content": "对应的回复内容" + } + ] +} +``` + +# JSON Schema +{{json_schema}} + +# 注意事项 +- 必须严格按照上述格式输出 +- output 字段:如果是同一事件返回 false,如果不是同一事件返回完整的消息列表 +- 消息列表必须包含 role 和 content 字段 +- 不要输出任何解释、分析或多余内容 diff --git a/api/app/core/memory/agent/utils/redis_base.py b/api/app/core/memory/agent/utils/redis_base.py new file mode 100644 index 00000000..59bac109 --- /dev/null +++ b/api/app/core/memory/agent/utils/redis_base.py @@ -0,0 +1,186 @@ +import json +from typing import Any, List, Dict, Optional +from datetime import datetime, timedelta + + +def serialize_messages(messages: Any) -> str: + """ + 将消息序列化为 JSON 字符串,支持 LangChain 消息对象 + + Args: + messages: 可以是 list、dict、string 或 LangChain 消息对象列表 + + Returns: + str: JSON 字符串 + """ + if isinstance(messages, str): + return messages + + if isinstance(messages, (list, tuple)): + # 检查是否是 LangChain 消息对象列表 + serialized_list = [] + for msg in messages: + if hasattr(msg, 'type') and hasattr(msg, 'content'): + # LangChain 消息对象 + serialized_list.append({ + 'type': msg.type, + 'content': msg.content, + 'role': getattr(msg, 'role', msg.type) + }) + elif isinstance(msg, dict): + serialized_list.append(msg) + else: + serialized_list.append(str(msg)) + return json.dumps(serialized_list, ensure_ascii=False) + + if isinstance(messages, dict): + return json.dumps(messages, ensure_ascii=False) + + # 其他类型转为字符串 + return str(messages) + + +def deserialize_messages(messages_str: str) -> Any: + """ + 将 JSON 字符串反序列化为原始格式 + + Args: + messages_str: JSON 字符串 + + Returns: + 反序列化后的对象(list、dict 或 string) + """ + if not messages_str: + return [] + + try: + return json.loads(messages_str) + except (json.JSONDecodeError, TypeError): + return messages_str + + +def fix_encoding(text: str) -> str: + """ + 修复错误编码的文本 + + Args: + text: 需要修复的文本 + + Returns: + str: 修复后的文本 + """ + if not text or not isinstance(text, str): + return text + try: + # 尝试修复 Latin-1 误编码为 UTF-8 的情况 + return text.encode('latin-1').decode('utf-8') + except (UnicodeDecodeError, UnicodeEncodeError): + # 如果修复失败,返回原文本 + return text + + +def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]: + """ + 格式化会话数据为统一的输出格式 + + Args: + data: 原始会话数据 + include_time: 是否包含时间字段 + + Returns: + Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."} + """ + result = { + "Query": fix_encoding(data.get('messages', '')), + "Answer": fix_encoding(data.get('aimessages', '')) + } + + if include_time: + result["starttime"] = data.get('starttime', '') + + return result + + +def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]: + """ + 根据时间范围过滤数据 + + Args: + items: 包含 starttime 字段的数据列表 + minutes: 时间范围(分钟) + + Returns: + List[Dict]: 过滤后的数据列表 + """ + time_threshold = datetime.now() - timedelta(minutes=minutes) + time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S") + + filtered_items = [] + for item in items: + starttime = item.get('starttime', '') + if starttime and starttime >= time_threshold_str: + filtered_items.append(item) + + return filtered_items + + +def sort_and_limit_results(items: List[Dict], limit: int = 6, + remove_time: bool = True) -> List[Dict]: + """ + 对结果进行排序、限制数量并移除时间字段 + + Args: + items: 数据列表 + limit: 最大返回数量 + remove_time: 是否移除 starttime 字段 + + Returns: + List[Dict]: 处理后的数据列表 + """ + # 按时间降序排序(最新的在前) + items.sort(key=lambda x: x.get('starttime', ''), reverse=True) + + # 限制数量 + result_items = items[:limit] + + # 移除 starttime 字段 + if remove_time: + for item in result_items: + item.pop('starttime', None) + + # 如果结果少于1条,返回空列表 + if len(result_items) < 1: + return [] + + return result_items + + +def generate_session_key(session_id: str, key_type: str = "session") -> str: + """ + 生成 Redis key + + Args: + session_id: 会话ID + key_type: key 类型 ("session", "read", "write", "count") + + Returns: + str: Redis key + """ + if key_type == "count": + return f"session:count:{session_id}" + elif key_type == "write": + return f"session:write:{session_id}" + elif key_type == "session" or key_type == "read": + return f"session:{session_id}" + else: + return f"session:{session_id}" + + +def get_current_timestamp() -> str: + """ + 获取当前时间戳字符串 + + Returns: + str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS" + """ + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") \ No newline at end of file diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index 505545b3..b61319e5 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -1,11 +1,36 @@ import redis import uuid -from datetime import datetime from app.core.config import settings +from typing import List, Dict, Any, Optional, Union + +from app.core.memory.agent.utils.redis_base import ( + serialize_messages, + deserialize_messages, + fix_encoding, + format_session_data, + filter_by_time_range, + sort_and_limit_results, + generate_session_key, + get_current_timestamp +) -class RedisSessionStore: + + +class RedisWriteStore: + """Redis Write 类型存储类,用于管理 save_session_write 相关的数据""" + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): + """ + 初始化 Redis 连接 + + Args: + host: Redis 主机地址 + port: Redis 端口 + db: Redis 数据库编号 + password: Redis 密码 + session_id: 会话ID + """ self.r = redis.Redis( host=host, port=port, @@ -16,32 +41,400 @@ class RedisSessionStore: ) self.uudi = session_id - def _fix_encoding(self, text): - """修复错误编码的文本""" - if not text or not isinstance(text, str): - return text - try: - # 尝试修复 Latin-1 误编码为 UTF-8 的情况 - return text.encode('latin-1').decode('utf-8') - except (UnicodeDecodeError, UnicodeEncodeError): - # 如果修复失败,返回原文本 - return text - - # 修改后的 save_session 方法 - def save_session(self, userid, messages, aimessages, apply_id, end_user_id): + def save_session_write(self, userid: str, messages: str) -> str: """ 写入一条会话数据,返回 session_id - 优化版本:确保写入时间不超过1秒 + + Args: + userid: 用户ID + messages: 用户消息 + + Returns: + str: 新生成的 session_id """ try: - session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID - starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - key = f"session:{session_id}" # 使用新生成的 session_id 作为 key + messages = serialize_messages(messages) + session_id = str(uuid.uuid4()) + key = generate_session_key(session_id, key_type="write") - # 使用 pipeline 批量写入,减少网络往返 pipe = self.r.pipeline() + pipe.hset(key, mapping={ + "id": self.uudi, + "sessionid": userid, + "messages": messages, + "starttime": get_current_timestamp() + }) + result = pipe.execute() - # 直接写入数据,decode_responses=True 已经处理了编码 + print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") + return session_id + except Exception as e: + print(f"[save_session_write] 保存会话失败: {e}") + raise e + + def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]: + """ + 通过 save_session_write 的 userid 获取 sessionid 和 messages + + Args: + userid: 用户ID (对应 sessionid 字段) + + Returns: + List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False + """ + try: + # 只查询 write 类型的 key + keys = self.r.keys('session:write:*') + if not keys: + return False + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合 userid 的数据 + results = [] + for key, data in zip(keys, all_data): + if not data: + continue + + # 从 write 类型读取,匹配 sessionid 字段 + if data.get('sessionid') == userid: + # 从 key 中提取 session_id: session:write:{session_id} + session_id = key.split(':')[-1] + results.append({ + "sessionid": session_id, + "messages": fix_encoding(data.get('messages', '')) + }) + + if not results: + return False + + print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") + return results + except Exception as e: + print(f"[get_session_by_userid] 查询失败: {e}") + return False + + def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]: + """ + 通过 end_user_id 获取所有 write 类型的会话数据 + + Args: + end_user_id: 终端用户ID (对应 sessionid 字段) + + Returns: + List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False + + 返回格式: + [ + { + "session_id": "uuid", + "id": "...", + "sessionid": "end_user_id", + "messages": "...", + "starttime": "timestamp" + }, + ... + ] + """ + try: + # 只查询 write 类型的 key + keys = self.r.keys('session:write:*') + if not keys: + print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") + return False + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合 end_user_id 的数据 + results = [] + for key, data in zip(keys, all_data): + if not data: + continue + + # 从 write 类型读取,匹配 sessionid 字段 + if data.get('sessionid') == end_user_id: + # 从 key 中提取 session_id: session:write:{session_id} + session_id = key.split(':')[-1] + + # 构建完整的会话信息 + session_info = { + "session_id": session_id, + "id": data.get('id', ''), + "sessionid": data.get('sessionid', ''), + "messages": fix_encoding(data.get('messages', '')), + "starttime": data.get('starttime', '') + } + results.append(session_info) + + if not results: + print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") + return False + + # 按时间排序(最新的在前) + results.sort(key=lambda x: x.get('starttime', ''), reverse=True) + + print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") + return results + except Exception as e: + print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}") + import traceback + traceback.print_exc() + return False + + def find_user_recent_sessions(self, userid: str, + minutes: int = 5) -> List[Dict[str, str]]: + """ + 根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据 + + Args: + userid: 用户ID (对应 sessionid 字段) + minutes: 查询最近几分钟的数据,默认5分钟 + + Returns: + List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...] + """ + import time + start_time = time.time() + + # 只查询 write 类型的 key + keys = self.r.keys('session:write:*') + if not keys: + print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + return [] + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合 userid 的数据 + matched_items = [] + for data in all_data: + if not data: + continue + + # 从 write 类型读取,匹配 sessionid 字段 + if data.get('sessionid') == userid and data.get('starttime'): + # write 类型没有 aimessages,所以 Answer 为空 + matched_items.append({ + "Query": fix_encoding(data.get('messages', '')), + "Answer": "", + "starttime": data.get('starttime', '') + }) + + # 根据时间范围过滤 + filtered_items = filter_by_time_range(matched_items, minutes) + # 排序并移除时间字段 + result_items = sort_and_limit_results(filtered_items, limit=None) + print(result_items) + + elapsed_time = time.time() - start_time + print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " + f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") + + return result_items + + def delete_all_write_sessions(self) -> int: + """ + 删除所有 write 类型的会话 + + Returns: + int: 删除的数量 + """ + keys = self.r.keys('session:write:*') + if keys: + return self.r.delete(*keys) + return 0 + + +class RedisCountStore: + """Redis Count 类型存储类,用于管理访问次数统计相关的数据""" + + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): + """ + 初始化 Redis 连接 + + Args: + host: Redis 主机地址 + port: Redis 端口 + db: Redis 数据库编号 + password: Redis 密码 + session_id: 会话ID + """ + self.r = redis.Redis( + host=host, + port=port, + db=db, + password=password, + decode_responses=True, + encoding='utf-8' + ) + self.uudi = session_id + + def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str: + """ + 保存用户访问次数统计 + + Args: + end_user_id: 终端用户ID + count: 访问次数 + messages: 消息内容 + + Returns: + str: 新生成的 session_id + """ + session_id = str(uuid.uuid4()) + key = generate_session_key(session_id, key_type="count") + + pipe = self.r.pipeline() + pipe.hset(key, mapping={ + "id": self.uudi, + "end_user_id": end_user_id, + "count": int(count), + "messages": serialize_messages(messages), + "starttime": get_current_timestamp() + }) + pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 + result = pipe.execute() + + print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") + return session_id + + def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]: + """ + 通过 end_user_id 查询访问次数统计 + + Args: + end_user_id: 终端用户ID + + Returns: + list 或 False: 如果找到返回 [count, messages],否则返回 False + """ + try: + search_pattern = 'session:count:*' + + for key in self.r.keys(search_pattern): + data = self.r.hgetall(key) + + if not data: + continue + + if data.get('end_user_id') == end_user_id: + count = data.get('count') + messages_str = data.get('messages') + + if count is not None: + messages = deserialize_messages(messages_str) + return [int(count), messages] + + return False + except Exception as e: + print(f"[get_sessions_count] 查询失败: {e}") + return False + + def update_sessions_count(self, end_user_id: str, new_count: int, + messages: Any) -> bool: + """ + 通过 end_user_id 修改访问次数统计 + + Args: + end_user_id: 终端用户ID + new_count: 新的 count 值 + messages: 消息内容 + + Returns: + bool: 更新成功返回 True,未找到记录返回 False + """ + try: + messages_str = serialize_messages(messages) + search_pattern = 'session:count:*' + + for key in self.r.keys(search_pattern): + data = self.r.hgetall(key) + + if not data: + continue + + if data.get('end_user_id') == end_user_id: + self.r.hset(key, 'count', int(new_count)) + self.r.hset(key, 'messages', messages_str) + print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") + return True + + print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + return False + except Exception as e: + print(f"[update_sessions_count] 更新失败: {e}") + return False + + def delete_all_count_sessions(self) -> int: + """ + 删除所有 count 类型的会话 + + Returns: + int: 删除的数量 + """ + keys = self.r.keys('session:count:*') + if keys: + return self.r.delete(*keys) + return 0 + + +class RedisSessionStore: + """Redis 会话存储类,用于管理会话数据""" + + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): + """ + 初始化 Redis 连接 + + Args: + host: Redis 主机地址 + port: Redis 端口 + db: Redis 数据库编号 + password: Redis 密码 + session_id: 会话ID + """ + self.r = redis.Redis( + host=host, + port=port, + db=db, + password=password, + decode_responses=True, + encoding='utf-8' + ) + self.uudi = session_id + + # ==================== 写入操作 ==================== + + def save_session(self, userid: str, messages: str, aimessages: str, + apply_id: str, end_user_id: str) -> str: + """ + 写入一条会话数据,返回 session_id + + Args: + userid: 用户ID + messages: 用户消息 + aimessages: AI回复消息 + apply_id: 应用ID + end_user_id: 终端用户ID + + Returns: + str: 新生成的 session_id + """ + try: + session_id = str(uuid.uuid4()) + key = generate_session_key(session_id, key_type="read") + + pipe = self.r.pipeline() pipe.hset(key, mapping={ "id": self.uudi, "sessionid": userid, @@ -49,177 +442,195 @@ class RedisSessionStore: "end_user_id": end_user_id, "messages": messages, "aimessages": aimessages, - "starttime": starttime + "starttime": get_current_timestamp() }) - - # 可选:设置过期时间(例如30天),避免数据无限增长 - # pipe.expire(key, 30 * 24 * 60 * 60) - - # 执行批量操作 result = pipe.execute() - print(f"保存结果: {result[0]}, session_id: {session_id}") - return session_id # 返回新生成的 session_id + print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") + return session_id except Exception as e: - print(f"保存会话失败: {e}") + print(f"[save_session] 保存会话失败: {e}") raise e - def save_sessions_batch(self, sessions_data): - """ - 批量写入多条会话数据,返回 session_id 列表 - sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id - 优化版本:批量操作,大幅提升性能 - """ - try: - session_ids = [] - pipe = self.r.pipeline() - - for session in sessions_data: - session_id = str(uuid.uuid4()) - starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - key = f"session:{session_id}" - - pipe.hset(key, mapping={ - "id": self.uudi, - "sessionid": session.get('userid'), - "apply_id": session.get('apply_id'), - "end_user_id": session.get('end_user_id'), - "messages": session.get('messages'), - "aimessages": session.get('aimessages'), - "starttime": starttime - }) - - session_ids.append(session_id) - - # 一次性执行所有写入操作 - results = pipe.execute() - print(f"批量保存完成: {len(session_ids)} 条记录") - return session_ids - except Exception as e: - print(f"批量保存会话失败: {e}") - raise e - - # ---------------- 读取 ---------------- - def get_session(self, session_id): + # ==================== 读取操作 ==================== + + def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """ 读取一条会话数据 + + Args: + session_id: 会话ID + + Returns: + Dict 或 None: 会话数据 """ - key = f"session:{session_id}" + key = generate_session_key(session_id) data = self.r.hgetall(key) return data if data else None - def get_session_apply_group(self, sessionid, apply_id, end_user_id): + def get_all_sessions(self) -> Dict[str, Dict[str, Any]]: """ - 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据 - """ - result_items = [] - - # 遍历所有会话数据 - for key in self.r.keys('session:*'): - data = self.r.hgetall(key) - - if not data: - continue - - # 检查三个条件是否都匹配 - if (data.get('sessionid') == sessionid and - data.get('apply_id') == apply_id and - data.get('end_user_id') == end_user_id): - result_items.append(data) - - return result_items - - def get_all_sessions(self): - """ - 获取所有会话数据 + 获取所有会话数据(不包括 count 和 write 类型) + + Returns: + Dict: 所有会话数据,key 为 session_id """ sessions = {} for key in self.r.keys('session:*'): - sid = key.split(':')[1] - sessions[sid] = self.get_session(sid) + # 排除 count 和 write 类型的 key + if ':count:' not in key and ':write:' not in key: + sid = key.split(':')[1] + sessions[sid] = self.get_session(sid) return sessions - # ---------------- 更新 ---------------- - def update_session(self, session_id, field, value): + def find_user_apply_group(self, sessionid: str, apply_id: str, + end_user_id: str) -> List[Dict[str, str]]: + """ + 根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条 + + Args: + sessionid: 会话ID(支持模糊匹配) + apply_id: 应用ID + end_user_id: 终端用户ID + + Returns: + List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...] + """ + import time + start_time = time.time() + + keys = self.r.keys('session:*') + if not keys: + print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + return [] + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + # 排除 count 和 write 类型 + if ':count:' not in key and ':write:' not in key: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合条件的数据 + matched_items = [] + for data in all_data: + if not data: + continue + + if (data.get('apply_id') == apply_id and + data.get('end_user_id') == end_user_id): + # 支持模糊匹配或完全匹配 sessionid + if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: + matched_items.append(format_session_data(data, include_time=True)) + + # 排序、限制数量并移除时间字段 + result_items = sort_and_limit_results(matched_items, limit=6) + + elapsed_time = time.time() - start_time + print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") + + return result_items + + # ==================== 更新操作 ==================== + + def update_session(self, session_id: str, field: str, value: Any) -> bool: """ 更新单个字段 - 优化版本:使用 pipeline 减少网络往返 + + Args: + session_id: 会话ID + field: 字段名 + value: 字段值 + + Returns: + bool: 是否更新成功 """ - key = f"session:{session_id}" + key = generate_session_key(session_id) pipe = self.r.pipeline() pipe.exists(key) pipe.hset(key, field, value) results = pipe.execute() - return bool(results[0]) # 返回 key 是否存在 + return bool(results[0]) - # ---------------- 删除 ---------------- - def delete_session(self, session_id): + # ==================== 删除操作 ==================== + + def delete_session(self, session_id: str) -> int: """ 删除单条会话 + + Args: + session_id: 会话ID + + Returns: + int: 删除的数量 """ - key = f"session:{session_id}" + key = generate_session_key(session_id) return self.r.delete(key) - def delete_all_sessions(self): + def delete_all_sessions(self) -> int: """ - 删除所有会话 + 删除所有会话(不包括 count 和 write 类型) + + Returns: + int: 删除的数量 """ keys = self.r.keys('session:*') - if keys: - return self.r.delete(*keys) + # 过滤掉 count 和 write 类型 + keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k] + if keys_to_delete: + return self.r.delete(*keys_to_delete) return 0 - def delete_duplicate_sessions(self): + def delete_duplicate_sessions(self) -> int: """ - 删除重复会话数据,条件: - "sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除 - 优化版本:使用 pipeline 批量操作,确保在1秒内完成 + 删除重复会话数据(不包括 count 和 write 类型) + 条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个 + + Returns: + int: 删除的数量 """ import time start_time = time.time() - # 第一步:使用 pipeline 批量获取所有 key keys = self.r.keys('session:*') - if not keys: print("[delete_duplicate_sessions] 没有会话数据") return 0 - # 第二步:使用 pipeline 批量获取所有数据 + # 批量获取所有数据 pipe = self.r.pipeline() for key in keys: - pipe.hgetall(key) + # 排除 count 和 write 类型 + if ':count:' not in key and ':write:' not in key: + pipe.hgetall(key) all_data = pipe.execute() - # 第三步:在内存中识别重复数据 - seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key) - keys_to_delete = [] # 需要删除的 key 列表 + # 识别重复数据 + seen = {} + keys_to_delete = [] - for key, data in zip(keys, all_data, strict=False): + for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False): if not data: continue - # 获取五个字段的值 - sessionid = data.get('sessionid', '') - user_id = data.get('id', '') - end_user_id = data.get('end_user_id', '') - messages = data.get('messages', '') - aimessages = data.get('aimessages', '') - # 用五元组作为唯一标识 - identifier = (sessionid, user_id, end_user_id, messages, aimessages) + identifier = ( + data.get('sessionid', ''), + data.get('id', ''), + data.get('end_user_id', ''), + data.get('messages', ''), + data.get('aimessages', '') + ) if identifier in seen: - # 重复,标记为待删除 keys_to_delete.append(key) else: - # 第一次出现,记录 seen[identifier] = key - # 第四步:使用 pipeline 批量删除重复的 key + # 批量删除重复的 key deleted_count = 0 if keys_to_delete: - # 分批删除,避免单次操作过大 batch_size = 1000 for i in range(0, len(keys_to_delete), batch_size): batch = keys_to_delete[i:i + batch_size] @@ -233,79 +644,28 @@ class RedisSessionStore: print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒") return deleted_count - def find_user_session(self, sessionid): - user_id = sessionid - - result_items = [] - for key, values in store.get_all_sessions().items(): - history = {} - if user_id == str(values['sessionid']): - history["Query"] = values['messages'] - history["Answer"] = values['aimessages'] - result_items.append(history) - - if len(result_items) <= 1: - result_items = [] - return (result_items) - - def find_user_apply_group(self, sessionid, apply_id, end_user_id): - """ - 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条 - """ - import time - start_time = time.time() - # 使用 pipeline 批量获取数据,提高性能 - keys = self.r.keys('session:*') - - if not keys: - print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") - return [] - - # 使用 pipeline 批量获取所有 hash 数据 - pipe = self.r.pipeline() - for key in keys: - pipe.hgetall(key) - all_data = pipe.execute() - - # 解析并筛选符合条件的数据 - matched_items = [] - for data in all_data: - if not data: - continue - - # 检查是否符合三个条件 - - if (data.get('apply_id') == apply_id and - data.get('end_user_id') == end_user_id): - # 支持模糊匹配 sessionid 或者完全匹配 - if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: - matched_items.append({ - "Query": self._fix_encoding(data.get('messages')), - "Answer": self._fix_encoding(data.get('aimessages')), - "starttime": data.get('starttime', '') - }) - # 按时间降序排序(最新的在前) - matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True) - # 只保留最新的6条 - result_items = matched_items[:6] - # # 移除 starttime 字段 - for item in result_items: - item.pop('starttime', None) - - # 如果结果少于等于1条,返回空列表 - if len(result_items) <= 1: - result_items = [] - - elapsed_time = time.time() - start_time - print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") - - return result_items - +# 全局实例 store = RedisSessionStore( host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None, session_id=str(uuid.uuid4()) -) \ No newline at end of file +) + +write_store = RedisWriteStore( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None, + session_id=str(uuid.uuid4()) +) + +count_store = RedisCountStore( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None, + session_id=str(uuid.uuid4()) +) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 09410091..ddaed685 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -12,8 +12,8 @@ class KnowledgeBaseConfig(BaseModel): kb_id: str = Field(..., description="知识库ID") top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量") similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值") - strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense") - weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)") + # strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense") + # weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)") vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重") retrieve_type: str = Field(default="hybrid", description="检索方式participle| semantic|hybrid")