新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

This commit is contained in:
lixinyue
2026-01-21 19:37:03 +08:00
parent afcf12ebc9
commit 4a4931bee2
84 changed files with 1193 additions and 1190 deletions

0
api/app/__init__.py Normal file
View File

View File

@@ -125,7 +125,7 @@ async def write_server(
Write service endpoint - processes write operations synchronously Write service endpoint - processes write operations synchronously
Args: Args:
user_input: Write request containing message and group_id user_input: Write request containing message and end_user_id
Returns: Returns:
Response with write operation status Response with write operation status
@@ -160,14 +160,11 @@ async def write_server(
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j' storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try: try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory( result = await memory_agent_service.write_memory(
user_input.group_id, user_input.end_user_id,
messages_list, # 传递结构化消息列表 user_input.message,
config_id, config_id,
db, db,
storage_type, storage_type,
@@ -196,7 +193,7 @@ async def write_server_async(
Async write service endpoint - enqueues write processing to Celery Async write service endpoint - enqueues write processing to Celery
Args: Args:
user_input: Write request containing message and group_id user_input: Write request containing message and end_user_id
Returns: Returns:
Task ID for tracking async operation Task ID for tracking async operation
@@ -224,12 +221,9 @@ async def write_server_async(
if knowledge: user_rag_memory_id = str(knowledge.id) if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try: try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
task = celery_app.send_task( task = celery_app.send_task(
"app.core.memory.agent.write_message", "app.core.memory.agent.write_message",
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id] args=[user_input.end_user_id, user_input.message, config_id, storage_type, user_rag_memory_id]
) )
api_logger.info(f"Write task queued: {task.id}") api_logger.info(f"Write task queued: {task.id}")
@@ -255,7 +249,7 @@ async def read_server(
- "2": Direct answer based on context - "2": Direct answer based on context
Args: Args:
user_input: Read request with message, history, search_switch, and group_id user_input: Read request with message, history, search_switch, and end_user_id
Returns: Returns:
Response with query answer Response with query answer
@@ -279,12 +273,13 @@ async def read_server(
name="USER_RAG_MERORY", name="USER_RAG_MERORY",
workspace_id=workspace_id workspace_id=workspace_id
) )
if knowledge: user_rag_memory_id = str(knowledge.id) if knowledge:
user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try: try:
result = await memory_agent_service.read_memory( result = await memory_agent_service.read_memory(
user_input.group_id, user_input.end_user_id,
user_input.message, user_input.message,
user_input.history, user_input.history,
user_input.search_switch, user_input.search_switch,
@@ -297,7 +292,7 @@ async def read_server(
retrieve_info = result['answer'] retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id) history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
query = user_input.message query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案 # 调用 memory_agent_service 的方法生成最终答案
result['answer'] = await memory_agent_service.generate_summary_from_retrieve( result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
retrieve_info=retrieve_info, retrieve_info=retrieve_info,
@@ -403,7 +398,7 @@ async def read_server_async(
try: try:
task = celery_app.send_task( task = celery_app.send_task(
"app.core.memory.agent.read_message", "app.core.memory.agent.read_message",
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch, args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
config_id, storage_type, user_rag_memory_id] config_id, storage_type, user_rag_memory_id]
) )
api_logger.info(f"Read task queued: {task.id}") api_logger.info(f"Read task queued: {task.id}")
@@ -447,7 +442,7 @@ async def get_read_task_result(
return success( return success(
data={ data={
"result": task_result.get("result"), "result": task_result.get("result"),
"group_id": task_result.get("group_id"), "end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"), "elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id "task_id": task_id
}, },
@@ -524,7 +519,7 @@ async def get_write_task_result(
return success( return success(
data={ data={
"result": task_result.get("result"), "result": task_result.get("result"),
"group_id": task_result.get("group_id"), "end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"), "elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id "task_id": task_id
}, },
@@ -578,16 +573,16 @@ async def status_type(
Determine the type of user message (read or write) Determine the type of user message (read or write)
Args: Args:
user_input: Request containing user message and group_id user_input: Request containing user message and end_user_id
Returns: Returns:
Type classification result Type classification result
""" """
api_logger.info(f"Status type check requested for group {user_input.group_id}") api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
try: try:
# 获取标准化的消息列表 # 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input) messages_list = memory_agent_service.get_messages_list(user_input)
# 将消息列表转换为字符串用于分类 # 将消息列表转换为字符串用于分类
# 只取最后一条用户消息进行分类 # 只取最后一条用户消息进行分类
last_user_message = "" last_user_message = ""
@@ -595,13 +590,13 @@ async def status_type(
if msg.get('role') == 'user': if msg.get('role') == 'user':
last_user_message = msg.get('content', '') last_user_message = msg.get('content', '')
break break
if not last_user_message: if not last_user_message:
# 如果没有用户消息,使用所有消息的内容 # 如果没有用户消息,使用所有消息的内容
last_user_message = " ".join([msg.get('content', '') for msg in messages_list]) last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
result = await memory_agent_service.classify_message_type( result = await memory_agent_service.classify_message_type(
last_user_message, user_input.message,
user_input.config_id, user_input.config_id,
db db
) )
@@ -624,7 +619,7 @@ async def get_knowledge_type_stats_api(
会对缺失类型补 0返回字典形式。 会对缺失类型补 0返回字典形式。
可选按状态过滤。 可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤 - 知识库类型根据当前用户的 current_workspace_id 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤 - memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0 - 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
""" """
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
@@ -697,7 +692,7 @@ async def get_user_profile_api(
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
获取工作空间下Popular Memory Tags,包含: 获取用户详情,包含:
- name: 用户名字(直接使用 end_user_id - name: 用户名字(直接使用 end_user_id
- tags: 3个用户特征标签从语句和实体中LLM总结 - tags: 3个用户特征标签从语句和实体中LLM总结
- hot_tags: 4个热门记忆标签 - hot_tags: 4个热门记忆标签

View File

@@ -39,7 +39,7 @@ async def write_memory_api_service(
Stores memory content for the specified end user using the Memory API Service. Stores memory content for the specified end user using the Memory API Service.
""" """
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}") logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
memory_api_service = MemoryAPIService(db) memory_api_service = MemoryAPIService(db)
@@ -50,6 +50,7 @@ async def write_memory_api_service(
config_id=payload.config_id, config_id=payload.config_id,
storage_type=payload.storage_type, storage_type=payload.storage_type,
user_rag_memory_id=payload.user_rag_memory_id, user_rag_memory_id=payload.user_rag_memory_id,
tenant_id=api_key_auth.tenant_id,
) )
logger.info(f"Memory write successful for end_user: {payload.end_user_id}") logger.info(f"Memory write successful for end_user: {payload.end_user_id}")

View File

@@ -145,41 +145,38 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content)) messages.append(HumanMessage(content=user_content))
return messages return messages
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 async def term_memory_save(self,messages,end_user_end,aimessages):
# async def term_memory_save(self,messages,end_user_end,aimessages): '''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
# '''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j''' end_user_end=f"Term_{end_user_end}"
# end_user_end=f"Term_{end_user_end}" print(messages)
# print(messages) print(aimessages)
# print(aimessages) session_id = store.save_session(
# session_id = store.save_session( userid=end_user_end,
# userid=end_user_end, messages=messages,
# messages=messages, apply_id=end_user_end,
# apply_id=end_user_end, end_user_id=end_user_end,
# group_id=end_user_end, aimessages=aimessages
# aimessages=aimessages )
# ) store.delete_duplicate_sessions()
# store.delete_duplicate_sessions() # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}') return session_id
# return session_id async def term_memory_redis_read(self,end_user_end):
end_user_end = f"Term_{end_user_end}"
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# async def term_memory_redis_read(self,end_user_end): # logger.info(f'Redis_Agent:{end_user_end};{history}')
# end_user_end = f"Term_{end_user_end}" messagss_list=[]
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) retrieved_content=[]
# # logger.info(f'Redis_Agent:{end_user_end};{history}') for messages in history:
# messagss_list=[] query = messages.get("Query")
# retrieved_content=[] aimessages = messages.get("Answer")
# for messages in history: messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
# query = messages.get("Query") retrieved_content.append({query: aimessages})
# aimessages = messages.get("Answer") return messagss_list,retrieved_content
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
# retrieved_content.append({query: aimessages})
# return messagss_list,retrieved_content
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
""" """
写入记忆(支持结构化消息) 写入记忆(支持结构化消息)
Args: Args:
storage_type: 存储类型 (neo4j/rag) storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID end_user_id: 终端用户ID
@@ -188,7 +185,7 @@ class LangChainAgent:
user_rag_memory_id: RAG 记忆ID user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID actual_end_user_id: 实际用户ID
actual_config_id: 配置ID actual_config_id: 配置ID
逻辑说明: 逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表 - Neo4j 模式:使用结构化消息列表
@@ -204,20 +201,20 @@ class LangChainAgent:
else: else:
# Neo4j 模式:使用结构化消息列表 # Neo4j 模式:使用结构化消息列表
structured_messages = [] structured_messages = []
# 始终添加用户消息(如果不为空) # 始终添加用户消息(如果不为空)
if user_message: if user_message:
structured_messages.append({"role": "user", "content": user_message}) structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息 # 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message: if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message}) structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回 # 如果没有消息,直接返回
if not structured_messages: if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}") logger.warning(f"No messages to write for user {actual_end_user_id}")
return return
# 调用 Celery 任务,传递结构化消息列表 # 调用 Celery 任务,传递结构化消息列表
# 数据流: # 数据流:
# 1. structured_messages 传递给 write_message_task # 1. structured_messages 传递给 write_message_task

View File

@@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点""" """问题分解节点"""
# 从状态中获取数据 # 从状态中获取数据
content = state.get('data', '') content = state.get('data', '')
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式 # 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema() json_schema = ProblemExtensionResponse.model_json_schema()
@@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
start = time.time() start = time.time()
content = state.get('data', '') content = state.get('data', '')
data = state.get('spit_data', '')['context'] data = state.get('spit_data', '')['context']
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
@@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
databasets = {} databasets = {}
data = [] data = []
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式 # 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema() json_schema = ProblemExtensionResponse.model_json_schema()

View File

@@ -52,9 +52,9 @@ async def rag_config(state):
return kb_config return kb_config
async def rag_knowledge(state,question): async def rag_knowledge(state,question):
kb_config = await rag_config(state) kb_config = await rag_config(state)
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id=state.get("user_rag_memory_id",'')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)]) retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try: try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge) clean_content = '\n\n'.join(retrieval_knowledge)
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
problem_extension=state.get('problem_extension', '')['context'] problem_extension=state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '') storage_type=state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '') user_rag_memory_id=state.get('user_rag_memory_id', '')
group_id=state.get('group_id', '') end_user_id=state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
original=state.get('data', '') original=state.get('data', '')
problem_list=[] problem_list=[]
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
try: try:
# Prepare search parameters based on storage type # Prepare search parameters based on storage type
search_params = { search_params = {
"group_id": group_id, "end_user_id": end_user_id,
"question": question, "question": question,
"return_raw_results": True "return_raw_results": True
} }
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
async def retrieve(state: ReadState) -> ReadState: async def retrieve(state: ReadState) -> ReadState:
# 从state中获取group_id # 从state中获取end_user_id
import time import time
start=time.time() start=time.time()
problem_extension = state.get('problem_extension', '')['context'] problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
original = state.get('data', '') original = state.get('data', '')
problem_list = [] problem_list = []
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
temperature=0.2, temperature=0.2,
) )
time_retrieval_tool = create_time_retrieval_tool(group_id) time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { "group_id": group_id, "return_raw_results": True } search_params = { "end_user_id": end_user_id, "return_raw_results": True }
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent( agent = create_agent(
llm, llm,
tools=[time_retrieval_tool,hybrid_retrieval], tools=[time_retrieval_tool,hybrid_retrieval],
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}" system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
) )
# 创建异步任务处理单个问题 # 创建异步任务处理单个问题

View File

@@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
summary_service = SummaryNodeService() summary_service = SummaryNodeService()
async def summary_history(state: ReadState) -> ReadState: async def summary_history(state: ReadState) -> ReadState:
group_id = state.get("group_id", '') end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str: async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
@@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
async def summary_redis_save(state: ReadState,aimessages) -> ReadState: async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
data = state.get("data", '') data = state.get("data", '')
group_id = state.get("group_id", '') end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session( await SessionService(store).save_session(
user_id=group_id, user_id=end_user_id,
query=data, query=data,
apply_id=group_id, apply_id=end_user_id,
group_id=group_id, end_user_id=end_user_id,
ai_response=aimessages ai_response=aimessages
) )
await SessionService(store).cleanup_duplicates() await SessionService(store).cleanup_duplicates()
@@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '') data=state.get("data", '')
group_id=state.get("group_id", '') end_user_id=state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state) history = await summary_history( state)
search_params = { search_params = {
"group_id": group_id, "end_user_id": end_user_id,
"question": data, "question": data,
"return_raw_results": True, "return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance "include": ["summaries"] # Only search summary nodes for faster performance

View File

@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===") logger.info("=== Verify 节点开始执行 ===")
try: try:
content = state.get('data', '') content = state.get('data', '')
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}") logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}") logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {}) retrieve = state.get("retrieve", {})

View File

@@ -9,47 +9,36 @@ async def write_node(state: WriteState) -> WriteState:
Write data to the database/file system. Write data to the database/file system.
Args: Args:
state: WriteState containing messages, group_id, and memory_config content: Data content to write
end_user_id: End user identifier
memory_config: MemoryConfig object containing all configuration
Returns: Returns:
dict: Contains 'write_result' with status and data fields dict: Contains 'status', 'saved_to', and 'data' fields
""" """
messages = state.get('messages', []) content=state.get('data','')
group_id = state.get('group_id', '') end_user_id=state.get('end_user_id','')
memory_config = state.get('memory_config', '') memory_config=state.get('memory_config', '')
# Convert LangChain messages to structured format expected by write()
structured_messages = []
for msg in messages:
if hasattr(msg, 'type') and hasattr(msg, 'content'):
# Map LangChain message types to role names
role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type
structured_messages.append({
"role": role,
"content": msg.content # content is now guaranteed to be a string
})
try: try:
result = await write( result=await write(
messages=structured_messages, content=content,
user_id=group_id, end_user_id=end_user_id,
apply_id=group_id,
group_id=group_id,
memory_config=memory_config, memory_config=memory_config,
) )
logger.info(f"Write completed successfully! Config: {memory_config.config_name}") logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
write_result = { write_result= {
"status": "success", "status": "success",
"data": structured_messages, "data": content,
"config_id": memory_config.config_id, "config_id": memory_config.config_id,
"config_name": memory_config.config_name, "config_name": memory_config.config_name,
} }
return {"write_result": write_result} return {"write_result":write_result}
except Exception as e: except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True) logger.error(f"Data_write failed: {e}", exc_info=True)
write_result = { write_result= {
"status": "error", "status": "error",
"message": str(e), "message": str(e),
} }

View File

@@ -79,7 +79,7 @@ async def make_read_graph():
async def main(): async def main():
"""主函数 - 运行工作流""" """主函数 - 运行工作流"""
message = "昨天有什么好看的电影" message = "昨天有什么好看的电影"
group_id = '88a459f5_text09' # 组ID end_user_id = '88a459f5_text09' # 组ID
storage_type = 'neo4j' # 存储类型 storage_type = 'neo4j' # 存储类型
search_switch = '1' # 搜索开关 search_switch = '1' # 搜索开关
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
@@ -95,9 +95,9 @@ async def main():
start=time.time() start=time.time()
try: try:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} ,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
# 获取节点更新信息 # 获取节点更新信息
_intermediate_outputs = [] _intermediate_outputs = []

View File

@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
class TimeRetrievalInput(BaseModel): class TimeRetrievalInput(BaseModel):
"""时间检索工具的输入模式""" """时间检索工具的输入模式"""
context: str = Field(description="用户输入的查询内容") context: str = Field(description="用户输入的查询内容")
group_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果") end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(group_id: str): def create_time_retrieval_tool(end_user_id: str):
""" """
创建一个带有特定group_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements) 创建一个带有特定end_user_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
""" """
def clean_temporal_result_fields(data): def clean_temporal_result_fields(data):
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
return data return data
@tool @tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str: def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
""" """
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段 优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数: 显式接收参数:
- context: 查询上下文内容 - context: 查询上下文内容
- start_date: 开始时间可选格式YYYY-MM-DD - start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD - end_date: 结束时间可选格式YYYY-MM-DD
- group_id_param: 组ID可选用于覆盖默认组ID - end_user_id_param: 组ID可选用于覆盖默认组ID
- clean_output: 是否清理输出中的元数据字段 - clean_output: 是否清理输出中的元数据字段
-end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d") -end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
""" """
async def _async_search(): async def _async_search():
# 使用传入的参数或默认值 # 使用传入的参数或默认值
actual_group_id = group_id_param or group_id actual_end_user_id = end_user_id_param or end_user_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 基本时间搜索 # 基本时间搜索
results = await search_by_temporal( results = await search_by_temporal(
group_id=actual_group_id, end_user_id=actual_end_user_id,
start_date=actual_start_date, start_date=actual_start_date,
end_date=actual_end_date, end_date=actual_end_date,
limit=10 limit=10
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
# 关键词时间搜索 # 关键词时间搜索
results = await search_by_keyword_temporal( results = await search_by_keyword_temporal(
query_text=context, query_text=context,
group_id=group_id, end_user_id=end_user_id,
start_date=actual_start_date, start_date=actual_start_date,
end_date=actual_end_date, end_date=actual_end_date,
limit=15 limit=15
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
Args: Args:
memory_config: 内存配置对象 memory_config: 内存配置对象
**search_params: 搜索参数,包含group_id, limit, include等 **search_params: 搜索参数,包含end_user_id, limit, include等
""" """
def clean_result_fields(data): def clean_result_fields(data):
@@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: str, context: str,
search_type: str = "hybrid", search_type: str = "hybrid",
limit: int = 10, limit: int = 10,
group_id: str = None, end_user_id: str = None,
rerank_alpha: float = 0.6, rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False, use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False, use_llm_rerank: bool = False,
@@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: 查询内容 context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制 limit: 结果数量限制
group_id: 组ID用于过滤搜索结果 end_user_id: 组ID用于过滤搜索结果
rerank_alpha: 重排序权重参数 rerank_alpha: 重排序权重参数
use_forgetting_rerank: 是否使用遗忘重排序 use_forgetting_rerank: 是否使用遗忘重排序
use_llm_rerank: 是否使用LLM重排序 use_llm_rerank: 是否使用LLM重排序
@@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
final_params = { final_params = {
"query_text": context, "query_text": context,
"search_type": search_type, "search_type": search_type,
"group_id": group_id or search_params.get("group_id"), "end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10), "limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"output_path": None, # 不保存到文件 "output_path": None, # 不保存到文件
@@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: str, context: str,
search_type: str = "hybrid", search_type: str = "hybrid",
limit: int = 10, limit: int = 10,
group_id: str = None, end_user_id: str = None,
clean_output: bool = True clean_output: bool = True
) -> str: ) -> str:
""" """
@@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: 查询内容 context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制 limit: 结果数量限制
group_id: 组ID用于过滤搜索结果 end_user_id: 组ID用于过滤搜索结果
clean_output: 是否清理输出中的元数据字段 clean_output: 是否清理输出中的元数据字段
""" """
async def _async_search(): async def _async_search():
@@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"context": context, "context": context,
"search_type": search_type, "search_type": search_type,
"limit": limit, "limit": limit,
"group_id": group_id, "end_user_id": end_user_id,
"clean_output": clean_output "clean_output": clean_output
}) })

View File

@@ -14,6 +14,7 @@ from app.db import get_db
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
@@ -26,12 +27,18 @@ async def make_write_graph():
""" """
Create a write graph workflow for memory operations. Create a write graph workflow for memory operations.
The workflow directly processes messages from the initial state Args:
and saves them to Neo4j storage. user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
""" """
workflow = StateGraph(WriteState) workflow = StateGraph(WriteState)
workflow.add_node("content_input", content_input_write)
workflow.add_node("save_neo4j", write_node) workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j") workflow.add_edge(START, "content_input")
workflow.add_edge("content_input", "save_neo4j")
workflow.add_edge("save_neo4j", END) workflow.add_edge("save_neo4j", END)
graph = workflow.compile() graph = workflow.compile()
@@ -42,7 +49,7 @@ async def make_write_graph():
async def main(): async def main():
"""主函数 - 运行工作流""" """主函数 - 运行工作流"""
message = "今天周一" message = "今天周一"
group_id = 'new_2025test1103' # 组ID end_user_id = 'new_2025test1103' # 组ID
# 获取数据库会话 # 获取数据库会话
@@ -54,9 +61,9 @@ async def main():
) )
try: try:
async with make_write_graph() as graph: async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config} initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
# 获取节点更新信息 # 获取节点更新信息
async for update_event in graph.astream( async for update_event in graph.astream(

View File

@@ -24,7 +24,7 @@ class ParameterBuilder:
tool_call_id: str, tool_call_id: str,
search_switch: str, search_switch: str,
apply_id: str, apply_id: str,
group_id: str, end_user_id: str,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -44,7 +44,7 @@ class ParameterBuilder:
tool_call_id: Extracted tool call identifier tool_call_id: Extracted tool call identifier
search_switch: Search routing parameter search_switch: Search routing parameter
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
storage_type: Storage type for the workspace (optional) storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional) user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
@@ -55,7 +55,7 @@ class ParameterBuilder:
base_args = { base_args = {
"usermessages": tool_call_id, "usermessages": tool_call_id,
"apply_id": apply_id, "apply_id": apply_id,
"group_id": group_id "end_user_id": end_user_id
} }
# Always add storage_type and user_rag_memory_id (with defaults if None) # Always add storage_type and user_rag_memory_id (with defaults if None)

View File

@@ -91,7 +91,7 @@ class SearchService:
async def execute_hybrid_search( async def execute_hybrid_search(
self, self,
group_id: str, end_user_id: str,
question: str, question: str,
limit: int = 5, limit: int = 5,
search_type: str = "hybrid", search_type: str = "hybrid",
@@ -105,7 +105,7 @@ class SearchService:
Execute hybrid search and return clean content. Execute hybrid search and return clean content.
Args: Args:
group_id: Group identifier for filtering results end_user_id: Group identifier for filtering results
question: Search query text question: Search query text
limit: Maximum number of results to return (default: 5) limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid") search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
@@ -130,7 +130,7 @@ class SearchService:
answer = await run_hybrid_search( answer = await run_hybrid_search(
query_text=cleaned_query, query_text=cleaned_query,
search_type=search_type, search_type=search_type,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include, include=include,
output_path=output_path, output_path=output_path,
@@ -186,7 +186,7 @@ class SearchService:
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Search failed for query '{question}' in group '{group_id}': {e}", f"Search failed for query '{question}' in group '{end_user_id}': {e}",
exc_info=True exc_info=True
) )
# Return empty results on failure # Return empty results on failure

View File

@@ -59,7 +59,7 @@ class SessionService:
self, self,
user_id: str, user_id: str,
apply_id: str, apply_id: str,
group_id: str end_user_id: str
) -> List[dict]: ) -> List[dict]:
""" """
Retrieve conversation history from Redis. Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args: Args:
user_id: User identifier user_id: User identifier
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
Returns: Returns:
List of conversation history items with Query and Answer keys List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error Returns empty list if no history found or on error
""" """
try: try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id) history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure # Validate history structure
if not isinstance(history, list): if not isinstance(history, list):
logger.warning( logger.warning(
f"Invalid history format for user {user_id}, " f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
) )
return [] return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to retrieve history for user {user_id}, " f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}", f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True exc_info=True
) )
# Return empty list on error to allow execution to continue # Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str, user_id: str,
query: str, query: str,
apply_id: str, apply_id: str,
group_id: str, end_user_id: str,
ai_response: str ai_response: str
) -> Optional[str]: ) -> Optional[str]:
""" """
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier user_id: User identifier
query: User query/message query: User query/message
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
ai_response: AI response/answer ai_response: AI response/answer
Returns: Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id, userid=user_id,
messages=query, messages=query,
apply_id=apply_id, apply_id=apply_id,
group_id=group_id, end_user_id=end_user_id,
aimessages=ai_response aimessages=ai_response
) )
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching: Duplicates are identified by matching:
- sessionid - sessionid
- user_id (id field) - user_id (id field)
- group_id - end_user_id
- messages - messages
- aimessages - aimessages

View File

@@ -9,65 +9,56 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
async def get_chunked_dialogs( async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1", end_user_id: str = "group_1",
user_id: str = "user1", content: str = "这是用户的输入",
apply_id: str = "applyid",
messages: list = None,
ref_id: str = "wyl_20251027", ref_id: str = "wyl_20251027",
config_id: str = None config_id: str = None
) -> List[DialogData]: ) -> List[DialogData]:
"""Generate chunks from structured messages using the specified chunker strategy. """Generate chunks from all test data entries using the specified chunker strategy.
Args: Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker) chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier end_user_id: End user identifier
user_id: User identifier content: Dialog content
apply_id: Application identifier
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier ref_id: Reference identifier
config_id: Configuration ID for processing config_id: Configuration ID for processing
Returns: Returns:
List of DialogData objects with generated chunks List of DialogData objects with generated chunks for each test entry
""" """
from app.core.logging_config import get_agent_logger dialog_data_list = []
logger = get_agent_logger(__name__) messages = []
if not messages or not isinstance(messages, list) or len(messages) == 0: messages.append(ConversationMessage(role="用户", msg=content))
raise ValueError("messages parameter must be a non-empty list")
# Create DialogData
conversation_messages = [] conversation_context = ConversationContext(msgs=messages)
# Create DialogData with end_user_id
for idx, msg in enumerate(messages):
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
role = msg['role']
content = msg['content']
if role not in ['user', 'assistant']:
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip():
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering")
conversation_context = ConversationContext(msgs=conversation_messages)
dialog_data = DialogData( dialog_data = DialogData(
context=conversation_context, context=conversation_context,
ref_id=ref_id, ref_id=ref_id,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
config_id=config_id config_id=config_id
) )
# Create DialogueChunker and process the dialogue
chunker = DialogueChunker(chunker_strategy) chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data) extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks dialog_data.chunks = extracted_chunks
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
return [dialog_data] dialog_data_list.append(dialog_data)
# Convert to dict with datetime serialized
def serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
combined_output = [dd.model_dump() for dd in dialog_data_list]
print(dialog_data_list)
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
return dialog_data_list

View File

@@ -12,13 +12,11 @@ class WriteState(TypedDict):
Langgrapg Writing TypedDict Langgrapg Writing TypedDict
''' '''
messages: Annotated[list[AnyMessage], add_messages] messages: Annotated[list[AnyMessage], add_messages]
user_id:str end_user_id: str
apply_id:str
group_id:str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
memory_config: object memory_config: object
write_result: dict write_result: dict
data:str data: str
class ReadState(TypedDict): class ReadState(TypedDict):
""" """
@@ -28,7 +26,7 @@ class ReadState(TypedDict):
messages: 消息列表,支持自动追加 messages: 消息列表,支持自动追加
loop_count: 遍历次数 loop_count: 遍历次数
search_switch: 搜索类型开关 search_switch: 搜索类型开关
group_id: 组标识 end_user_id: 组标识
config_id: 配置ID用于过滤结果 config_id: 配置ID用于过滤结果
data: 从content_input_node传递的内容数据 data: 从content_input_node传递的内容数据
spit_data: 从Split_The_Problem传递的分解结果 spit_data: 从Split_The_Problem传递的分解结果
@@ -39,7 +37,7 @@ class ReadState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式 messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
loop_count: int loop_count: int
search_switch: str search_switch: str
group_id: str end_user_id: str
config_id: str config_id: str
data: str # 新增字段用于传递内容 data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果 spit_data: dict # 新增字段用于传递问题分解结果

View File

@@ -28,7 +28,7 @@ class RedisSessionStore:
return text return text
# 修改后的 save_session 方法 # 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, group_id): def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
""" """
写入一条会话数据,返回 session_id 写入一条会话数据,返回 session_id
优化版本确保写入时间不超过1秒 优化版本确保写入时间不超过1秒
@@ -46,7 +46,7 @@ class RedisSessionStore:
"id": self.uudi, "id": self.uudi,
"sessionid": userid, "sessionid": userid,
"apply_id": apply_id, "apply_id": apply_id,
"group_id": group_id, "end_user_id": end_user_id,
"messages": messages, "messages": messages,
"aimessages": aimessages, "aimessages": aimessages,
"starttime": starttime "starttime": starttime
@@ -67,7 +67,7 @@ class RedisSessionStore:
def save_sessions_batch(self, sessions_data): def save_sessions_batch(self, sessions_data):
""" """
批量写入多条会话数据,返回 session_id 列表 批量写入多条会话数据,返回 session_id 列表
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
优化版本:批量操作,大幅提升性能 优化版本:批量操作,大幅提升性能
""" """
try: try:
@@ -83,7 +83,7 @@ class RedisSessionStore:
"id": self.uudi, "id": self.uudi,
"sessionid": session.get('userid'), "sessionid": session.get('userid'),
"apply_id": session.get('apply_id'), "apply_id": session.get('apply_id'),
"group_id": session.get('group_id'), "end_user_id": session.get('end_user_id'),
"messages": session.get('messages'), "messages": session.get('messages'),
"aimessages": session.get('aimessages'), "aimessages": session.get('aimessages'),
"starttime": starttime "starttime": starttime
@@ -108,9 +108,9 @@ class RedisSessionStore:
data = self.r.hgetall(key) data = self.r.hgetall(key)
return data if data else None return data if data else None
def get_session_apply_group(self, sessionid, apply_id, group_id): def get_session_apply_group(self, sessionid, apply_id, end_user_id):
""" """
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
""" """
result_items = [] result_items = []
@@ -124,7 +124,7 @@ class RedisSessionStore:
# 检查三个条件是否都匹配 # 检查三个条件是否都匹配
if (data.get('sessionid') == sessionid and if (data.get('sessionid') == sessionid and
data.get('apply_id') == apply_id and data.get('apply_id') == apply_id and
data.get('group_id') == group_id): data.get('end_user_id') == end_user_id):
result_items.append(data) result_items.append(data)
return result_items return result_items
@@ -172,7 +172,7 @@ class RedisSessionStore:
def delete_duplicate_sessions(self): def delete_duplicate_sessions(self):
""" """
删除重复会话数据,条件: 删除重复会话数据,条件:
"sessionid""user_id""group_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除 "sessionid""user_id""end_user_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
优化版本:使用 pipeline 批量操作确保在1秒内完成 优化版本:使用 pipeline 批量操作确保在1秒内完成
""" """
import time import time
@@ -202,12 +202,12 @@ class RedisSessionStore:
# 获取五个字段的值 # 获取五个字段的值
sessionid = data.get('sessionid', '') sessionid = data.get('sessionid', '')
user_id = data.get('id', '') user_id = data.get('id', '')
group_id = data.get('group_id', '') end_user_id = data.get('end_user_id', '')
messages = data.get('messages', '') messages = data.get('messages', '')
aimessages = data.get('aimessages', '') aimessages = data.get('aimessages', '')
# 用五元组作为唯一标识 # 用五元组作为唯一标识
identifier = (sessionid, user_id, group_id, messages, aimessages) identifier = (sessionid, user_id, end_user_id, messages, aimessages)
if identifier in seen: if identifier in seen:
# 重复,标记为待删除 # 重复,标记为待删除
@@ -248,9 +248,9 @@ class RedisSessionStore:
result_items = [] result_items = []
return (result_items) return (result_items)
def find_user_apply_group(self, sessionid, apply_id, group_id): def find_user_apply_group(self, sessionid, apply_id, end_user_id):
""" """
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据返回最新的6条 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据返回最新的6条
""" """
import time import time
start_time = time.time() start_time = time.time()
@@ -276,7 +276,7 @@ class RedisSessionStore:
# 检查是否符合三个条件 # 检查是否符合三个条件
if (data.get('apply_id') == apply_id and if (data.get('apply_id') == apply_id and
data.get('group_id') == group_id): data.get('end_user_id') == end_user_id):
# 支持模糊匹配 sessionid 或者完全匹配 # 支持模糊匹配 sessionid 或者完全匹配
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append({ matched_items.append({

View File

@@ -59,7 +59,7 @@ class SessionService:
self, self,
user_id: str, user_id: str,
apply_id: str, apply_id: str,
group_id: str end_user_id: str
) -> List[dict]: ) -> List[dict]:
""" """
Retrieve conversation history from Redis. Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args: Args:
user_id: User identifier user_id: User identifier
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
Returns: Returns:
List of conversation history items with Query and Answer keys List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error Returns empty list if no history found or on error
""" """
try: try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id) history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure # Validate history structure
if not isinstance(history, list): if not isinstance(history, list):
logger.warning( logger.warning(
f"Invalid history format for user {user_id}, " f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
) )
return [] return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to retrieve history for user {user_id}, " f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}", f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True exc_info=True
) )
# Return empty list on error to allow execution to continue # Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str, user_id: str,
query: str, query: str,
apply_id: str, apply_id: str,
group_id: str, end_user_id: str,
ai_response: str ai_response: str
) -> Optional[str]: ) -> Optional[str]:
""" """
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier user_id: User identifier
query: User query/message query: User query/message
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
ai_response: AI response/answer ai_response: AI response/answer
Returns: Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id, userid=user_id,
messages=query, messages=query,
apply_id=apply_id, apply_id=apply_id,
group_id=group_id, end_user_id=end_user_id,
aimessages=ai_response aimessages=ai_response
) )
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching: Duplicates are identified by matching:
- sessionid - sessionid
- user_id (id field) - user_id (id field)
- group_id - end_user_id
- messages - messages
- aimessages - aimessages

View File

@@ -29,9 +29,7 @@ logger = get_agent_logger(__name__)
async def write( async def write(
user_id: str, end_user_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig, memory_config: MemoryConfig,
messages: list, messages: list,
ref_id: str = "wyl20251027", ref_id: str = "wyl20251027",
@@ -40,9 +38,7 @@ async def write(
Execute the complete knowledge extraction pipeline. Execute the complete knowledge extraction pipeline.
Args: Args:
user_id: User identifier end_user_id: End user identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...] messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027" ref_id: Reference ID, defaults to "wyl20251027"
@@ -58,7 +54,7 @@ async def write(
logger.info(f"LLM model: {memory_config.llm_model_name}") logger.info(f"LLM model: {memory_config.llm_model_name}")
logger.info(f"Embedding model: {memory_config.embedding_model_name}") logger.info(f"Embedding model: {memory_config.embedding_model_name}")
logger.info(f"Chunker strategy: {chunker_strategy}") logger.info(f"Chunker strategy: {chunker_strategy}")
logger.info(f"Group ID: {group_id}") logger.info(f"End User ID: {end_user_id}")
# Construct clients from memory_config using factory pattern with db session # Construct clients from memory_config using factory pattern with db session
with get_db_context() as db: with get_db_context() as db:
@@ -83,9 +79,7 @@ async def write(
step_start = time.time() step_start = time.time()
chunked_dialogs = await get_chunked_dialogs( chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=chunker_strategy, chunker_strategy=chunker_strategy,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
messages=messages, messages=messages,
ref_id=ref_id, ref_id=ref_id,
config_id=config_id, config_id=config_id,

View File

@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。""" """用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。") meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
""" """
使用LLM筛选标签列表仅保留具有代表性的核心名词。 使用LLM筛选标签列表仅保留具有代表性的核心名词。
Args: Args:
tags: 原始标签列表 tags: 原始标签列表
group_id: 用户组ID用于获取配置 end_user_id: 用户组ID用于获取配置
Returns: Returns:
筛选后的标签列表 筛选后的标签列表
@@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
get_end_user_connected_config, get_end_user_connected_config,
) )
connected_config = get_end_user_connected_config(group_id, db) connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
if not config_id: if not config_id:
raise ValueError( raise ValueError(
f"No memory_config_id found for group_id: {group_id}. " f"No memory_config_id found for end_user_id: {end_user_id}. "
"Please ensure the user has a valid memory configuration." "Please ensure the user has a valid memory configuration."
) )
@@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
async def get_raw_tags_from_db( async def get_raw_tags_from_db(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: str, end_user_id: str,
limit: int, limit: int,
by_user: bool = False by_user: bool = False
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
@@ -99,9 +99,9 @@ async def get_raw_tags_from_db(
Args: Args:
connector: Neo4j连接器实例 connector: Neo4j连接器实例
group_id: 如果by_user=False则为group_id如果by_user=True则为user_id end_user_id: 如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制 limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询 by_user: 是否按user_id查询默认Falseend_user_id查询
Returns: Returns:
List[Tuple[str, int]]: 标签名称和频率的元组列表 List[Tuple[str, int]]: 标签名称和频率的元组列表
@@ -119,7 +119,7 @@ async def get_raw_tags_from_db(
else: else:
query = ( query = (
"MATCH (e:ExtractedEntity) " "MATCH (e:ExtractedEntity) "
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude " "WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"RETURN e.name AS name, count(e) AS frequency " "RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC " "ORDER BY frequency DESC "
"LIMIT $limit" "LIMIT $limit"
@@ -128,44 +128,44 @@ async def get_raw_tags_from_db(
# 使用项目的Neo4jConnector执行查询 # 使用项目的Neo4jConnector执行查询
results = await connector.execute_query( results = await connector.execute_query(
query, query,
id=group_id, id=end_user_id,
limit=limit, limit=limit,
names_to_exclude=names_to_exclude names_to_exclude=names_to_exclude
) )
return [(record["name"], record["frequency"]) for record in results] return [(record["name"], record["frequency"]) for record in results]
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
""" """
获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。 获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。 查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
Args: Args:
group_id: 必需参数。如果by_user=False则为group_id如果by_user=True则为user_id end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制 limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询 by_user: 是否按user_id查询默认Falseend_user_id查询
Raises: Raises:
ValueError: 如果group_id未提供或为空 ValueError: 如果end_user_id未提供或为空
""" """
# 验证group_id必须提供且不为空 # 验证end_user_id必须提供且不为空
if not group_id or not group_id.strip(): if not end_user_id or not end_user_id.strip():
raise ValueError( raise ValueError(
"group_id is required. Please provide a valid group_id or user_id." "end_user_id is required. Please provide a valid end_user_id or user_id."
) )
# 使用项目的Neo4jConnector # 使用项目的Neo4jConnector
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
# 1. 从数据库获取原始排名靠前的标签 # 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user) raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
if not raw_tags_with_freq: if not raw_tags_with_freq:
return [] return []
raw_tag_names = [tag for tag, freq in raw_tags_with_freq] raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序 # 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
final_tags = [] final_tags = []

View File

@@ -75,8 +75,8 @@ class MemoryDataSource:
start_date = time_range.start_date if time_range else None start_date = time_range.start_date if time_range else None
end_date = time_range.end_date if time_range else None end_date = time_range.end_date if time_range else None
summary_dicts = await self.memory_summary_repo.find_by_group_id( summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
group_id=user_id, end_user_id=user_id,
limit=limit, limit=limit,
start_date=start_date, start_date=start_date,
end_date=end_date end_date=end_date

View File

@@ -41,7 +41,7 @@ DIALOGUE_EMBEDDING_SEARCH = """
WITH $embedding AS q WITH $embedding AS q
MATCH (d:Dialogue) MATCH (d:Dialogue)
WHERE d.dialog_embedding IS NOT NULL WHERE d.dialog_embedding IS NOT NULL
AND ($group_id IS NULL OR d.group_id = $group_id) AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
WITH d, q, d.dialog_embedding AS v WITH d, q, d.dialog_embedding AS v
WITH d, WITH d,
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot, reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
@@ -50,7 +50,7 @@ WITH d,
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
WHERE score > $threshold WHERE score > $threshold
RETURN d.id AS dialog_id, RETURN d.id AS dialog_id,
d.group_id AS group_id, d.end_user_id AS end_user_id,
d.content AS content, d.content AS content,
d.created_at AS created_at, d.created_at AS created_at,
d.expired_at AS expired_at, d.expired_at AS expired_at,

View File

@@ -36,7 +36,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def ingest_contexts_via_full_pipeline( async def ingest_contexts_via_full_pipeline(
contexts: List[str], contexts: List[str],
group_id: str, end_user_id: str,
chunker_strategy: str | None = None, chunker_strategy: str | None = None,
embedding_name: str | None = None, embedding_name: str | None = None,
save_chunk_output: bool = False, save_chunk_output: bool = False,
@@ -48,7 +48,7 @@ async def ingest_contexts_via_full_pipeline(
This function mirrors the steps in main(), but starts from raw text contexts. This function mirrors the steps in main(), but starts from raw text contexts.
Args: Args:
contexts: List of dialogue texts, each containing lines like "role: message". contexts: List of dialogue texts, each containing lines like "role: message".
group_id: Group ID to assign to generated DialogData and graph nodes. end_user_id: Group ID to assign to generated DialogData and graph nodes.
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY. chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID. embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging. save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
@@ -109,7 +109,7 @@ async def ingest_contexts_via_full_pipeline(
dialog = DialogData( dialog = DialogData(
context=context_model, context=context_model,
ref_id=f"pipeline_item_{idx}", ref_id=f"pipeline_item_{idx}",
group_id=group_id, end_user_id=end_user_id,
user_id="default_user", user_id="default_user",
apply_id="default_application", apply_id="default_application",
) )
@@ -318,16 +318,16 @@ async def handle_context_processing(args):
print("No contexts provided for processing.") print("No contexts provided for processing.")
return False return False
return await main_from_contexts(contexts, args.context_group_id) return await main_from_contexts(contexts, args.context_end_user_id)
async def main_from_contexts(contexts: List[str], group_id: str): async def main_from_contexts(contexts: List[str], end_user_id: str):
"""Run the pipeline from provided dialogue contexts instead of test data.""" """Run the pipeline from provided dialogue contexts instead of test data."""
print("=== Running pipeline from provided contexts ===") print("=== Running pipeline from provided contexts ===")
success = await ingest_contexts_via_full_pipeline( success = await ingest_contexts_via_full_pipeline(
contexts=contexts, contexts=contexts,
group_id=group_id, end_user_id=end_user_id,
chunker_strategy=SELECTED_CHUNKER_STRATEGY, chunker_strategy=SELECTED_CHUNKER_STRATEGY,
embedding_name=SELECTED_EMBEDDING_ID, embedding_name=SELECTED_EMBEDDING_ID,
save_chunk_output=True save_chunk_output=True

View File

@@ -47,7 +47,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.definitions import ( from app.core.memory.utils.definitions import (
PROJECT_ROOT, PROJECT_ROOT,
SELECTED_EMBEDDING_ID, SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID, SELECTED_end_user_id,
SELECTED_LLM_ID, SELECTED_LLM_ID,
) )
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -59,7 +59,7 @@ from app.services.memory_config_service import MemoryConfigService
async def run_locomo_benchmark( async def run_locomo_benchmark(
sample_size: int = 20, sample_size: int = 20,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
search_type: str = "hybrid", search_type: str = "hybrid",
search_limit: int = 12, search_limit: int = 12,
context_char_budget: int = 8000, context_char_budget: int = 8000,
@@ -85,7 +85,7 @@ async def run_locomo_benchmark(
Args: Args:
sample_size: Number of QA pairs to evaluate (from first conversation) sample_size: Number of QA pairs to evaluate (from first conversation)
group_id: Database group ID for retrieval (uses default if None) end_user_id: Database group ID for retrieval (uses default if None)
search_type: "keyword", "embedding", or "hybrid" search_type: "keyword", "embedding", or "hybrid"
search_limit: Max documents to retrieve per query search_limit: Max documents to retrieve per query
context_char_budget: Max characters for context context_char_budget: Max characters for context
@@ -96,8 +96,8 @@ async def run_locomo_benchmark(
Returns: Returns:
Dictionary with evaluation results including metrics, timing, and samples Dictionary with evaluation results including metrics, timing, and samples
""" """
# Use default group_id if not provided # Use default end_user_id if not provided
group_id = group_id or SELECTED_GROUP_ID end_user_id = end_user_id or SELECTED_end_user_id
# Determine data path # Determine data path
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
@@ -110,7 +110,7 @@ async def run_locomo_benchmark(
print(f"{'='*60}") print(f"{'='*60}")
print("📊 Configuration:") print("📊 Configuration:")
print(f" Sample size: {sample_size}") print(f" Sample size: {sample_size}")
print(f" Group ID: {group_id}") print(f" Group ID: {end_user_id}")
print(f" Search type: {search_type}") print(f" Search type: {search_type}")
print(f" Search limit: {search_limit}") print(f" Search limit: {search_limit}")
print(f" Context budget: {context_char_budget} chars") print(f" Context budget: {context_char_budget} chars")
@@ -134,7 +134,7 @@ async def run_locomo_benchmark(
# Step 2: Extract conversations and ingest if needed # Step 2: Extract conversations and ingest if needed
if skip_ingest: if skip_ingest:
print("⏭️ Skipping data ingestion (using existing data in Neo4j)") print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
print(f" Group ID: {group_id}\n") print(f" Group ID: {end_user_id}\n")
else: else:
print("💾 Checking database ingestion...") print("💾 Checking database ingestion...")
try: try:
@@ -142,10 +142,10 @@ async def run_locomo_benchmark(
print(f"📝 Extracted {len(conversations)} conversations") print(f"📝 Extracted {len(conversations)} conversations")
# Always ingest for now (ingestion check not implemented) # Always ingest for now (ingestion check not implemented)
print(f"🔄 Ingesting conversations into group '{group_id}'...") print(f"🔄 Ingesting conversations into group '{end_user_id}'...")
success = await ingest_conversations_if_needed( success = await ingest_conversations_if_needed(
conversations=conversations, conversations=conversations,
group_id=group_id, end_user_id=end_user_id,
reset=reset_group reset=reset_group
) )
@@ -224,7 +224,7 @@ async def run_locomo_benchmark(
try: try:
retrieved_info = await retrieve_relevant_information( retrieved_info = await retrieve_relevant_information(
question=question, question=question,
group_id=group_id, end_user_id=end_user_id,
search_type=search_type, search_type=search_type,
search_limit=search_limit, search_limit=search_limit,
connector=connector, connector=connector,
@@ -409,7 +409,7 @@ async def run_locomo_benchmark(
"sample_size": len(qa_items), "sample_size": len(qa_items),
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"params": { "params": {
"group_id": group_id, "end_user_id": end_user_id,
"search_type": search_type, "search_type": search_type,
"search_limit": search_limit, "search_limit": search_limit,
"context_char_budget": context_char_budget, "context_char_budget": context_char_budget,
@@ -467,7 +467,7 @@ def main():
help="Number of QA pairs to evaluate" help="Number of QA pairs to evaluate"
) )
parser.add_argument( parser.add_argument(
"--group_id", "--end_user_id",
type=str, type=str,
default=None, default=None,
help="Database group ID for retrieval (uses default if not specified)" help="Database group ID for retrieval (uses default if not specified)"
@@ -516,7 +516,7 @@ def main():
# Run benchmark # Run benchmark
result = asyncio.run(run_locomo_benchmark( result = asyncio.run(run_locomo_benchmark(
sample_size=args.sample_size, sample_size=args.sample_size,
group_id=args.group_id, end_user_id=args.end_user_id,
search_type=args.search_type, search_type=args.search_type,
search_limit=args.search_limit, search_limit=args.search_limit,
context_char_budget=args.context_char_budget, context_char_budget=args.context_char_budget,

View File

@@ -555,7 +555,7 @@ async def run_enhanced_evaluation():
search_results = await run_hybrid_search( search_results = await run_hybrid_search(
query_text=q, query_text=q,
search_type="hybrid", search_type="hybrid",
group_id="locomo_sk", end_user_id="locomo_sk",
limit=20, limit=20,
include=["statements", "chunks", "entities", "summaries"], include=["statements", "chunks", "entities", "summaries"],
alpha=0.6, # BM25权重 alpha=0.6, # BM25权重

View File

@@ -348,7 +348,7 @@ def select_and_format_information(
async def retrieve_relevant_information( async def retrieve_relevant_information(
question: str, question: str,
group_id: str, end_user_id: str,
search_type: str, search_type: str,
search_limit: int, search_limit: int,
connector: Any, connector: Any,
@@ -368,7 +368,7 @@ async def retrieve_relevant_information(
Args: Args:
question: Question to search for question: Question to search for
group_id: Database group ID (identifies which conversation memory to search) end_user_id: Database group ID (identifies which conversation memory to search)
search_type: "keyword", "embedding", or "hybrid" search_type: "keyword", "embedding", or "hybrid"
search_limit: Max memory pieces to retrieve search_limit: Max memory pieces to retrieve
connector: Neo4j connector instance connector: Neo4j connector instance
@@ -396,7 +396,7 @@ async def retrieve_relevant_information(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
) )
@@ -455,7 +455,7 @@ async def retrieve_relevant_information(
search_results = await search_graph( search_results = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit limit=search_limit
) )
@@ -491,7 +491,7 @@ async def retrieve_relevant_information(
search_results = await run_hybrid_search( search_results = await run_hybrid_search(
query_text=question, query_text=question,
search_type=search_type, search_type=search_type,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
output_path=None, output_path=None,
@@ -524,7 +524,7 @@ async def retrieve_relevant_information(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
) )
@@ -584,7 +584,7 @@ async def retrieve_relevant_information(
async def ingest_conversations_if_needed( async def ingest_conversations_if_needed(
conversations: List[str], conversations: List[str],
group_id: str, end_user_id: str,
reset: bool = False reset: bool = False
) -> bool: ) -> bool:
""" """
@@ -603,7 +603,7 @@ async def ingest_conversations_if_needed(
Args: Args:
conversations: List of raw conversation texts from LoCoMo dataset conversations: List of raw conversation texts from LoCoMo dataset
Example: ["User: I went to Paris. AI: When was that?", ...] Example: ["User: I went to Paris. AI: When was that?", ...]
group_id: Target group ID for database storage end_user_id: Target group ID for database storage
reset: Whether to clear existing data first (not implemented in wrapper) reset: Whether to clear existing data first (not implemented in wrapper)
Returns: Returns:
@@ -617,7 +617,7 @@ async def ingest_conversations_if_needed(
try: try:
success = await ingest_contexts_via_full_pipeline( success = await ingest_contexts_via_full_pipeline(
contexts=conversations, contexts=conversations,
group_id=group_id, end_user_id=end_user_id,
save_chunk_output=True save_chunk_output=True
) )
return success return success

View File

@@ -30,7 +30,7 @@ from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import ( from app.core.memory.utils.config.definitions import (
PROJECT_ROOT, PROJECT_ROOT,
SELECTED_EMBEDDING_ID, SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID, SELECTED_end_user_id,
SELECTED_LLM_ID, SELECTED_LLM_ID,
) )
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -249,7 +249,7 @@ def get_search_params_by_category(category: str):
async def run_locomo_eval( async def run_locomo_eval(
sample_size: int = 1, sample_size: int = 1,
group_id: str | None = None, end_user_id: str | None = None,
search_limit: int = 8, search_limit: int = 8,
context_char_budget: int = 4000, # 保持默认值不变 context_char_budget: int = 4000, # 保持默认值不变
llm_temperature: float = 0.0, llm_temperature: float = 0.0,
@@ -262,7 +262,7 @@ async def run_locomo_eval(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# 函数内部使用三路检索逻辑,但保持参数签名不变 # 函数内部使用三路检索逻辑,但保持参数签名不变
group_id = group_id or SELECTED_GROUP_ID end_user_id = end_user_id or SELECTED_end_user_id
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path): if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "locomo10.json") data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
@@ -340,7 +340,7 @@ async def run_locomo_eval(
# 关键修复:强制重新摄入纯净的对话数据 # 关键修复:强制重新摄入纯净的对话数据
print("🔄 强制重新摄入纯净的对话数据...") print("🔄 强制重新摄入纯净的对话数据...")
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True) await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True)
# 使用异步LLM客户端 # 使用异步LLM客户端
with get_db_context() as db: with get_db_context() as db:
@@ -405,7 +405,7 @@ async def run_locomo_eval(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=q, query_text=q,
group_id=group_id, end_user_id=end_user_id,
limit=adjusted_limit, limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型 include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
) )
@@ -456,7 +456,7 @@ async def run_locomo_eval(
search_results = await search_graph( search_results = await search_graph(
connector=connector, connector=connector,
q=q, q=q,
group_id=group_id, end_user_id=end_user_id,
limit=adjusted_limit limit=adjusted_limit
) )
dialogs = search_results.get("dialogues", []) dialogs = search_results.get("dialogues", [])
@@ -486,7 +486,7 @@ async def run_locomo_eval(
search_results = await run_hybrid_search( search_results = await run_hybrid_search(
query_text=q, query_text=q,
search_type=search_type, search_type=search_type,
group_id=group_id, end_user_id=end_user_id,
limit=adjusted_limit, limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
output_path=None, output_path=None,
@@ -524,7 +524,7 @@ async def run_locomo_eval(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=q, query_text=q,
group_id=group_id, end_user_id=end_user_id,
limit=adjusted_limit, limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
) )
@@ -597,7 +597,7 @@ async def run_locomo_eval(
"dialogues": [ "dialogues": [
{ {
"uuid": d.get("uuid", ""), "uuid": d.get("uuid", ""),
"group_id": d.get("group_id", ""), "end_user_id": d.get("end_user_id", ""),
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""), "content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
"score": d.get("score", 0.0) "score": d.get("score", 0.0)
} }
@@ -795,7 +795,7 @@ async def run_locomo_eval(
}, },
"samples": samples, "samples": samples,
"params": { "params": {
"group_id": group_id, "end_user_id": end_user_id,
"search_limit": search_limit, "search_limit": search_limit,
"context_char_budget": context_char_budget, "context_char_budget": context_char_budget,
"search_type": search_type, "search_type": search_type,
@@ -825,7 +825,7 @@ async def run_locomo_eval(
def main(): def main():
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search") parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate") parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval") parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval")
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query") parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context") parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature") parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
@@ -841,7 +841,7 @@ def main():
result = asyncio.run(run_locomo_eval( result = asyncio.run(run_locomo_eval(
sample_size=args.sample_size, sample_size=args.sample_size,
group_id=args.group_id, end_user_id=args.end_user_id,
search_limit=args.search_limit, search_limit=args.search_limit,
context_char_budget=args.context_char_budget, context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature, llm_temperature=args.llm_temperature,

View File

@@ -523,11 +523,11 @@ def generate_query_keywords_cn(question: str) -> List[str]:
# 通过别名匹配进行实体关键词检索多token合并 # 通过别名匹配进行实体关键词检索多token合并
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]: async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = [] results: List[Dict[str, Any]] = []
try: try:
for tok in tokens: for tok in tokens:
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit) rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
if rows: if rows:
results.extend(rows) results.extend(rows)
except Exception: except Exception:
@@ -547,15 +547,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
# 通过对话/陈述中的entity_ids反查实体名称 # 通过对话/陈述中的entity_ids反查实体名称
_FETCH_ENTITIES_BY_IDS = """ _FETCH_ENTITIES_BY_IDS = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id) WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
""" """
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]: async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
if not ids: if not ids:
return [] return []
try: try:
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id) rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
return rows or [] return rows or []
except Exception: except Exception:
return [] return []
@@ -565,18 +565,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
_TIME_ENTITY_SEARCH = """ _TIME_ENTITY_SEARCH = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
AND ($group_id IS NULL OR e.group_id = $group_id) AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
LIMIT $limit LIMIT $limit
""" """
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
"""专门搜索时间相关的实体""" """专门搜索时间相关的实体"""
try: try:
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
rows = await connector.execute_query(_TIME_ENTITY_SEARCH, rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
date_pattern=date_pattern, date_pattern=date_pattern,
group_id=group_id, end_user_id=end_user_id,
limit=limit) limit=limit)
return rows or [] return rows or []
except Exception: except Exception:
@@ -623,7 +623,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
async def run_longmemeval_test( async def run_longmemeval_test(
sample_size: int = 3, sample_size: int = 3,
group_id: str = "longmemeval_zh_bak_3", end_user_id: str = "longmemeval_zh_bak_3",
search_limit: int = 8, search_limit: int = 8,
context_char_budget: int = 4000, context_char_budget: int = 4000,
llm_temperature: float = 0.0, llm_temperature: float = 0.0,
@@ -677,13 +677,13 @@ async def run_longmemeval_test(
contexts.extend(selected) contexts.extend(selected)
print(f"📥 摄入 {len(contexts)} 个上下文到数据库") print(f"📥 摄入 {len(contexts)} 个上下文到数据库")
if reset_group_before_ingest and group_id: if reset_group_before_ingest and end_user_id:
try: try:
_tmp_conn = Neo4jConnector() _tmp_conn = Neo4jConnector()
await _tmp_conn.delete_group(group_id) await _tmp_conn.delete_group(end_user_id)
print(f"🧹 已清空组 {group_id} 的历史图数据") print(f"🧹 已清空组 {end_user_id} 的历史图数据")
except Exception as _e: except Exception as _e:
print(f"⚠️ 清空组数据失败(忽略继续): {group_id} - {_e}") print(f"⚠️ 清空组数据失败(忽略继续): {end_user_id} - {_e}")
finally: finally:
try: try:
await _tmp_conn.close() await _tmp_conn.close()
@@ -695,7 +695,7 @@ async def run_longmemeval_test(
else: else:
await _ingest_fn( await _ingest_fn(
contexts, contexts,
group_id, end_user_id,
save_chunk_output=save_chunk_output, save_chunk_output=save_chunk_output,
save_chunk_output_path=save_chunk_output_path, save_chunk_output_path=save_chunk_output_path,
) )
@@ -750,7 +750,7 @@ async def run_longmemeval_test(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
) )
@@ -795,7 +795,7 @@ async def run_longmemeval_test(
search_results = await search_graph( search_results = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
) )
chunks = search_results.get("chunks", []) chunks = search_results.get("chunks", [])
@@ -830,7 +830,7 @@ async def run_longmemeval_test(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
) )
@@ -848,7 +848,7 @@ async def run_longmemeval_test(
kw_res = await search_graph( kw_res = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
) )
if isinstance(kw_res, dict): if isinstance(kw_res, dict):
@@ -859,7 +859,7 @@ async def run_longmemeval_test(
# 时间推理问题的特殊处理 # 时间推理问题的特殊处理
if is_temporal: if is_temporal:
# 专门搜索时间实体 # 专门搜索时间实体
time_entities = await _search_time_entities(connector, group_id, search_limit//2) time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
if time_entities: if time_entities:
kw_entities.extend(time_entities) kw_entities.extend(time_entities)
# 添加时间相关关键词检索 # 添加时间相关关键词检索
@@ -869,7 +869,7 @@ async def run_longmemeval_test(
time_res = await search_graph( time_res = await search_graph(
connector=connector, connector=connector,
q=tk, q=tk,
group_id=group_id, end_user_id=end_user_id,
limit=2, limit=2,
) )
if isinstance(time_res, dict): if isinstance(time_res, dict):
@@ -880,7 +880,7 @@ async def run_longmemeval_test(
# 中文关键词拆分后做别名匹配 # 中文关键词拆分后做别名匹配
cn_tokens = _extract_cn_tokens(question) cn_tokens = _extract_cn_tokens(question)
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit) alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
if alias_entities: if alias_entities:
kw_entities.extend(alias_entities) kw_entities.extend(alias_entities)
@@ -894,7 +894,7 @@ async def run_longmemeval_test(
except Exception: except Exception:
pass pass
if ids: if ids:
id_entities = await _fetch_entities_by_ids(connector, ids, group_id) id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
if id_entities: if id_entities:
kw_entities.extend(id_entities) kw_entities.extend(id_entities)
@@ -908,7 +908,7 @@ async def run_longmemeval_test(
sub_res = await search_graph( sub_res = await search_graph(
connector=connector, connector=connector,
q=str(kw), q=str(kw),
group_id=group_id, end_user_id=end_user_id,
limit=max(3, search_limit // 2), limit=max(3, search_limit // 2),
) )
if isinstance(sub_res, dict): if isinstance(sub_res, dict):
@@ -927,7 +927,7 @@ async def run_longmemeval_test(
opt_res = await search_graph( opt_res = await search_graph(
connector=connector, connector=connector,
q=str(opt), q=str(opt),
group_id=group_id, end_user_id=end_user_id,
limit=max(3, search_limit // 2), limit=max(3, search_limit // 2),
) )
if isinstance(opt_res, dict): if isinstance(opt_res, dict):

View File

@@ -498,11 +498,11 @@ def smart_context_selection(contexts: List[str], question: str, max_chars: int =
# 通过别名匹配进行实体关键词检索多token合并 # 通过别名匹配进行实体关键词检索多token合并
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]: async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = [] results: List[Dict[str, Any]] = []
try: try:
for tok in tokens: for tok in tokens:
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit) rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
if rows: if rows:
results.extend(rows) results.extend(rows)
except Exception: except Exception:
@@ -522,15 +522,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
# 通过对话/陈述中的entity_ids反查实体名称 # 通过对话/陈述中的entity_ids反查实体名称
_FETCH_ENTITIES_BY_IDS = """ _FETCH_ENTITIES_BY_IDS = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id) WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
""" """
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]: async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
if not ids: if not ids:
return [] return []
try: try:
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id) rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
return rows or [] return rows or []
except Exception: except Exception:
return [] return []
@@ -540,18 +540,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
_TIME_ENTITY_SEARCH = """ _TIME_ENTITY_SEARCH = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
AND ($group_id IS NULL OR e.group_id = $group_id) AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
LIMIT $limit LIMIT $limit
""" """
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
"""专门搜索时间相关的实体""" """专门搜索时间相关的实体"""
try: try:
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
rows = await connector.execute_query(_TIME_ENTITY_SEARCH, rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
date_pattern=date_pattern, date_pattern=date_pattern,
group_id=group_id, end_user_id=end_user_id,
limit=limit) limit=limit)
return rows or [] return rows or []
except Exception: except Exception:
@@ -559,25 +559,25 @@ async def _search_time_entities(connector: Neo4jConnector, group_id: str | None,
# 技术术语专门检索 # 技术术语专门检索
async def _search_tech_terms(connector: Neo4jConnector, question: str, group_id: str | None, limit: int = 3) -> List[Dict[str, Any]]: async def _search_tech_terms(connector: Neo4jConnector, question: str, end_user_id: str | None, limit: int = 3) -> List[Dict[str, Any]]:
"""专门搜索技术术语相关的实体""" """专门搜索技术术语相关的实体"""
tech_entities = [] tech_entities = []
try: try:
# GPS相关 # GPS相关
if any(term in question for term in ["GPS", "导航", "定位系统"]): if any(term in question for term in ["GPS", "导航", "定位系统"]):
gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", group_id=group_id, limit=limit) gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", end_user_id=end_user_id, limit=limit)
if gps_rows: if gps_rows:
tech_entities.extend(gps_rows) tech_entities.extend(gps_rows)
# 活动相关 # 活动相关
if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]): if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]):
workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", group_id=group_id, limit=limit) workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", end_user_id=end_user_id, limit=limit)
if workshop_rows: if workshop_rows:
tech_entities.extend(workshop_rows) tech_entities.extend(workshop_rows)
# 时间顺序相关 # 时间顺序相关
if any(term in question for term in ["", "", "第一个"]): if any(term in question for term in ["", "", "第一个"]):
time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", group_id=group_id, limit=limit) time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", end_user_id=end_user_id, limit=limit)
if time_rows: if time_rows:
tech_entities.extend(time_rows) tech_entities.extend(time_rows)
@@ -627,7 +627,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
async def run_longmemeval_test( async def run_longmemeval_test(
sample_size: int = 3, sample_size: int = 3,
group_id: str = "longmemeval_zh_bak_2", end_user_id: str = "longmemeval_zh_bak_2",
search_limit: int = 8, search_limit: int = 8,
context_char_budget: int = 4000, context_char_budget: int = 4000,
llm_temperature: float = 0.0, llm_temperature: float = 0.0,
@@ -707,7 +707,7 @@ async def run_longmemeval_test(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["dialogues", "statements", "entities"], include=["dialogues", "statements", "entities"],
) )
@@ -746,7 +746,7 @@ async def run_longmemeval_test(
search_results = await search_graph( search_results = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
) )
dialogs = search_results.get("dialogues", []) dialogs = search_results.get("dialogues", [])
@@ -776,7 +776,7 @@ async def run_longmemeval_test(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["dialogues", "statements", "entities"], include=["dialogues", "statements", "entities"],
) )
@@ -792,7 +792,7 @@ async def run_longmemeval_test(
kw_res = await search_graph( kw_res = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
) )
if isinstance(kw_res, dict): if isinstance(kw_res, dict):
@@ -801,14 +801,14 @@ async def run_longmemeval_test(
kw_entities = kw_res.get("entities", []) or [] kw_entities = kw_res.get("entities", []) or []
# 技术术语专门检索 # 技术术语专门检索
tech_entities = await _search_tech_terms(connector, question, group_id, search_limit//2) tech_entities = await _search_tech_terms(connector, question, end_user_id, search_limit//2)
if tech_entities: if tech_entities:
kw_entities.extend(tech_entities) kw_entities.extend(tech_entities)
# 时间推理问题的特殊处理 # 时间推理问题的特殊处理
if is_temporal: if is_temporal:
# 专门搜索时间实体 # 专门搜索时间实体
time_entities = await _search_time_entities(connector, group_id, search_limit//2) time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
if time_entities: if time_entities:
kw_entities.extend(time_entities) kw_entities.extend(time_entities)
# 添加时间相关关键词检索 # 添加时间相关关键词检索
@@ -818,7 +818,7 @@ async def run_longmemeval_test(
time_res = await search_graph( time_res = await search_graph(
connector=connector, connector=connector,
q=tk, q=tk,
group_id=group_id, end_user_id=end_user_id,
limit=2, limit=2,
) )
if isinstance(time_res, dict): if isinstance(time_res, dict):
@@ -829,7 +829,7 @@ async def run_longmemeval_test(
# 中文关键词拆分后做别名匹配 # 中文关键词拆分后做别名匹配
cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取 cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit) alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
if alias_entities: if alias_entities:
kw_entities.extend(alias_entities) kw_entities.extend(alias_entities)
@@ -843,7 +843,7 @@ async def run_longmemeval_test(
except Exception: except Exception:
pass pass
if ids: if ids:
id_entities = await _fetch_entities_by_ids(connector, ids, group_id) id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
if id_entities: if id_entities:
kw_entities.extend(id_entities) kw_entities.extend(id_entities)
@@ -857,7 +857,7 @@ async def run_longmemeval_test(
sub_res = await search_graph( sub_res = await search_graph(
connector=connector, connector=connector,
q=str(kw), q=str(kw),
group_id=group_id, end_user_id=end_user_id,
limit=max(3, search_limit // 2), limit=max(3, search_limit // 2),
) )
if isinstance(sub_res, dict): if isinstance(sub_res, dict):
@@ -876,7 +876,7 @@ async def run_longmemeval_test(
opt_res = await search_graph( opt_res = await search_graph(
connector=connector, connector=connector,
q=str(opt), q=str(opt),
group_id=group_id, end_user_id=group_id,
limit=max(3, search_limit // 2), limit=max(3, search_limit // 2),
) )
if isinstance(opt_res, dict): if isinstance(opt_res, dict):

View File

@@ -27,7 +27,7 @@ from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import ( from app.core.memory.utils.config.definitions import (
PROJECT_ROOT, PROJECT_ROOT,
SELECTED_EMBEDDING_ID, SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID, SELECTED_end_user_id,
SELECTED_LLM_ID, SELECTED_LLM_ID,
) )
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -135,8 +135,8 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
return merged return merged
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]: async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
group_id = group_id or SELECTED_GROUP_ID end_user_id = end_user_id or SELECTED_end_user_id
# Load data # Load data
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
if not os.path.exists(data_path): if not os.path.exists(data_path):
@@ -147,7 +147,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入 # 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
# 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略 # 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
contexts: List[str] = [build_context_from_dialog(item) for item in items] contexts: List[str] = [build_context_from_dialog(item) for item in items]
await ingest_contexts_via_full_pipeline(contexts, group_id) await ingest_contexts_via_full_pipeline(contexts, end_user_id)
# LLM client (使用异步调用) # LLM client (使用异步调用)
with get_db_context() as db: with get_db_context() as db:
@@ -173,7 +173,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
results = await run_hybrid_search( results = await run_hybrid_search(
query_text=question, query_text=question,
search_type=search_type, search_type=search_type,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["dialogues", "statements", "entities"], include=["dialogues", "statements", "entities"],
output_path=None, output_path=None,
@@ -298,7 +298,7 @@ def main():
load_dotenv() load_dotenv()
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen") parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量") parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json") parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json")
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数") parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
@@ -309,7 +309,7 @@ def main():
result = asyncio.run( result = asyncio.run(
run_memsciqa_eval( run_memsciqa_eval(
sample_size=args.sample_size, sample_size=args.sample_size,
group_id=args.group_id, end_user_id=args.end_user_id,
search_limit=args.search_limit, search_limit=args.search_limit,
context_char_budget=args.context_char_budget, context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature, llm_temperature=args.llm_temperature,

View File

@@ -33,7 +33,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import ( from app.core.memory.utils.config.definitions import (
PROJECT_ROOT, PROJECT_ROOT,
SELECTED_EMBEDDING_ID, SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID, SELECTED_end_user_id,
SELECTED_LLM_ID, SELECTED_LLM_ID,
) )
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -198,7 +198,7 @@ def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
async def run_memsciqa_test( async def run_memsciqa_test(
sample_size: int = 3, sample_size: int = 3,
group_id: str | None = None, end_user_id: str | None = None,
search_limit: int = 8, search_limit: int = 8,
context_char_budget: int = 4000, context_char_budget: int = 4000,
llm_temperature: float = 0.0, llm_temperature: float = 0.0,
@@ -216,7 +216,7 @@ async def run_memsciqa_test(
""" """
# 默认使用指定的 memsci 组 ID # 默认使用指定的 memsci 组 ID
group_id = group_id or "group_memsci" end_user_id = end_user_id or "group_memsci"
# 数据路径解析(项目根与当前工作目录兜底) # 数据路径解析(项目根与当前工作目录兜底)
if not data_path: if not data_path:
@@ -282,7 +282,7 @@ async def run_memsciqa_test(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
) )
@@ -291,7 +291,7 @@ async def run_memsciqa_test(
results = await search_graph( results = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
) )
@@ -499,7 +499,7 @@ async def run_memsciqa_test(
}, },
"samples": samples, "samples": samples,
"params": { "params": {
"group_id": group_id, "end_user_id": end_user_id,
"search_limit": search_limit, "search_limit": search_limit,
"context_char_budget": context_char_budget, "context_char_budget": context_char_budget,
"llm_temperature": llm_temperature, "llm_temperature": llm_temperature,
@@ -542,7 +542,7 @@ def main():
result = asyncio.run( result = asyncio.run(
run_memsciqa_test( run_memsciqa_test(
sample_size=sample_size, sample_size=sample_size,
group_id=args.group_id, end_user_id=args.end_user_id,
search_limit=args.search_limit, search_limit=args.search_limit,
context_char_budget=args.context_char_budget, context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature, llm_temperature=args.llm_temperature,

View File

@@ -15,7 +15,7 @@ except Exception:
return None return None
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT from app.core.memory.utils.config.definitions import SELECTED_end_user_id, PROJECT_ROOT
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
@@ -26,7 +26,7 @@ async def run(
dataset: str, dataset: str,
sample_size: int, sample_size: int,
reset_group: bool, reset_group: bool,
group_id: str | None, end_user_id: str | None,
judge_model: str | None = None, judge_model: str | None = None,
search_limit: int | None = None, search_limit: int | None = None,
context_char_budget: int | None = None, context_char_budget: int | None = None,
@@ -37,17 +37,17 @@ async def run(
max_contexts_per_item: int | None = None, max_contexts_per_item: int | None = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认 # 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
group_id = group_id or SELECTED_GROUP_ID end_user_id = end_user_id or SELECTED_end_user_id
if reset_group: if reset_group:
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
await connector.delete_group(group_id) await connector.delete_group(end_user_id)
finally: finally:
await connector.close() await connector.close()
if dataset == "locomo": if dataset == "locomo":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None: if search_limit is not None:
kwargs["search_limit"] = search_limit kwargs["search_limit"] = search_limit
if context_char_budget is not None: if context_char_budget is not None:
@@ -61,7 +61,7 @@ async def run(
return await run_locomo_eval(**kwargs) return await run_locomo_eval(**kwargs)
if dataset == "memsciqa": if dataset == "memsciqa":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None: if search_limit is not None:
kwargs["search_limit"] = search_limit kwargs["search_limit"] = search_limit
if context_char_budget is not None: if context_char_budget is not None:
@@ -75,7 +75,7 @@ async def run(
return await run_memsciqa_eval(**kwargs) return await run_memsciqa_eval(**kwargs)
if dataset == "longmemeval": if dataset == "longmemeval":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None: if search_limit is not None:
kwargs["search_limit"] = search_limit kwargs["search_limit"] = search_limit
if context_char_budget is not None: if context_char_budget is not None:
@@ -99,8 +99,8 @@ def main():
parser = argparse.ArgumentParser(description="统一评估入口memsciqa / longmemeval / locomo") parser = argparse.ArgumentParser(description="统一评估入口memsciqa / longmemeval / locomo")
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True) parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通") parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据") parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json") parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json")
parser.add_argument("--judge-model", type=str, default=None, help="可选longmemeval 判别式评测模型名") parser.add_argument("--judge-model", type=str, default=None, help="可选longmemeval 判别式评测模型名")
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)") parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)") parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
@@ -117,7 +117,7 @@ def main():
args.dataset, args.dataset,
args.sample_size, args.sample_size,
args.reset_group, args.reset_group,
args.group_id, args.end_user_id,
args.judge_model, args.judge_model,
args.search_limit, args.search_limit,
args.context_char_budget, args.context_char_budget,

View File

@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
"""Parameters for temporal search queries in the knowledge graph. """Parameters for temporal search queries in the knowledge graph.
Attributes: Attributes:
group_id: Group ID to filter search results (default: 'test') end_user_id: Group ID to filter search results (default: 'test')
apply_id: Application ID to filter search results apply_id: Application ID to filter search results
user_id: User ID to filter search results user_id: User ID to filter search results
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD') start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD') invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
limit: Maximum number of results to return (default: 3) limit: Maximum number of results to return (default: 3)
""" """
group_id: Optional[str] = Field("test", description="The group ID to filter the search.") end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.") apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
user_id: Optional[str] = Field(None, description="The user ID to filter the search.") user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
start_date: Optional[str] = Field(None, description="The start date for the search.") start_date: Optional[str] = Field(None, description="The start date for the search.")

View File

@@ -103,9 +103,7 @@ class Edge(BaseModel):
id: Unique identifier for the edge id: Unique identifier for the edge
source: ID of the source node source: ID of the source node
target: ID of the target node target: ID of the target node
group_id: Group ID for multi-tenancy end_user_id: End user ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
run_id: Unique identifier for the pipeline run that created this edge run_id: Unique identifier for the pipeline run that created this edge
created_at: Timestamp when the edge was created (system perspective) created_at: Timestamp when the edge was created (system perspective)
expired_at: Optional timestamp when the edge expires (system perspective) expired_at: Optional timestamp when the edge expires (system perspective)
@@ -113,9 +111,7 @@ class Edge(BaseModel):
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.") id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
source: str = Field(..., description="The ID of the source node.") source: str = Field(..., description="The ID of the source node.")
target: str = Field(..., description="The ID of the target node.") target: str = Field(..., description="The ID of the target node.")
group_id: str = Field(..., description="The group ID of the edge.") end_user_id: str = Field(..., description="The end user ID of the edge.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
@@ -185,18 +181,14 @@ class Node(BaseModel):
Attributes: Attributes:
id: Unique identifier for the node id: Unique identifier for the node
name: Name of the node name: Name of the node
group_id: Group ID for multi-tenancy end_user_id: End user ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
run_id: Unique identifier for the pipeline run that created this node run_id: Unique identifier for the pipeline run that created this node
created_at: Timestamp when the node was created (system perspective) created_at: Timestamp when the node was created (system perspective)
expired_at: Optional timestamp when the node expires (system perspective) expired_at: Optional timestamp when the node expires (system perspective)
""" """
id: str = Field(..., description="The unique identifier for the node.") id: str = Field(..., description="The unique identifier for the node.")
name: str = Field(..., description="The name of the node.") name: str = Field(..., description="The name of the node.")
group_id: str = Field(..., description="The group ID of the node.") end_user_id: str = Field(..., description="The end user ID of the node.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the node from system perspective.") created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.") expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")

View File

@@ -55,7 +55,7 @@ class Statement(BaseModel):
Attributes: Attributes:
id: Unique identifier for the statement id: Unique identifier for the statement
chunk_id: ID of the parent chunk this statement belongs to chunk_id: ID of the parent chunk this statement belongs to
group_id: Optional group ID for multi-tenancy end_user_id: Optional group ID for multi-tenancy
statement: The actual statement text content statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses) speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
statement_embedding: Optional embedding vector for the statement statement_embedding: Optional embedding vector for the statement
@@ -73,7 +73,7 @@ class Statement(BaseModel):
""" """
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.") id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.") chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
statement: str = Field(..., description="The text content of the statement.") statement: str = Field(..., description="The text content of the statement.")
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses") speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.") statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
@@ -159,9 +159,7 @@ class DialogData(BaseModel):
context: Full conversation context context: Full conversation context
dialog_embedding: Optional embedding vector for the entire dialog dialog_embedding: Optional embedding vector for the entire dialog
ref_id: Reference ID linking to external dialog system ref_id: Reference ID linking to external dialog system
group_id: Group ID for multi-tenancy end_user_id: End user ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
created_at: Timestamp when the dialog was created created_at: Timestamp when the dialog was created
expired_at: Timestamp when the dialog expires (default: far future) expired_at: Timestamp when the dialog expires (default: far future)
metadata: Additional metadata as key-value pairs metadata: Additional metadata as key-value pairs
@@ -175,9 +173,7 @@ class DialogData(BaseModel):
context: ConversationContext = Field(..., description="The full conversation context as a single string.") context: ConversationContext = Field(..., description="The full conversation context as a single string.")
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.") dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.") ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
group_id: str = Field(default=..., description="Group ID of dialogue data") end_user_id: str = Field(default=..., description="End user ID of dialogue data")
user_id: str = Field(..., description="USER ID of dialogue data")
apply_id: str = Field(..., description="APPLY ID of dialogue data")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.") created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.") expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
@@ -256,5 +252,5 @@ class DialogData(BaseModel):
""" """
for chunk in self.chunks: for chunk in self.chunks:
for statement in chunk.statements: for statement in chunk.statements:
if statement.group_id is None: if statement.end_user_id is None:
statement.group_id = self.group_id statement.end_user_id = self.end_user_id

View File

@@ -6,6 +6,7 @@ import os
import time import time
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
if TYPE_CHECKING: if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
@@ -396,13 +397,13 @@ def rerank_with_activation(
return reranked return reranked
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None): def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
"""Log search query information using the logger. """Log search query information using the logger.
Args: Args:
query_text: The search query text query_text: The search query text
search_type: Type of search (keyword, embedding, hybrid) search_type: Type of search (keyword, embedding, hybrid)
group_id: Group identifier for filtering end_user_id: Group identifier for filtering
limit: Maximum number of results limit: Maximum number of results
include: List of result types to include include: List of result types to include
log_file: Deprecated parameter, kept for backward compatibility log_file: Deprecated parameter, kept for backward compatibility
@@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li
# Log using the standard logger # Log using the standard logger
logger.info( logger.info(
f"Search query: query='{cleaned_query}', type={search_type}, " f"Search query: query='{cleaned_query}', type={search_type}, "
f"group_id={group_id}, limit={limit}, include={include}" f"end_user_id={end_user_id}, limit={limit}, include={include}"
) )
@@ -672,7 +673,7 @@ def apply_reranker_placeholder(
async def run_hybrid_search( async def run_hybrid_search(
query_text: str, query_text: str,
search_type: str, search_type: str,
group_id: str | None, end_user_id: str | None,
limit: int, limit: int,
include: List[str], include: List[str],
output_path: str | None, output_path: str | None,
@@ -692,6 +693,9 @@ async def run_hybrid_search(
# Start overall timing # Start overall timing
search_start_time = time.time() search_start_time = time.time()
latency_metrics = {} latency_metrics = {}
print(100*'-')
print(memory_config)
print(100 * '-')
logger.info(f"using embedding_id:{memory_config.embedding_model_id}...") logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
# Clean and normalize the incoming query before use/logging # Clean and normalize the incoming query before use/logging
@@ -715,7 +719,7 @@ async def run_hybrid_search(
} }
# Log the search query # Log the search query
log_search_query(query_text, search_type, group_id, limit, include) log_search_query(query_text, search_type, end_user_id, limit, include)
connector = Neo4jConnector() connector = Neo4jConnector()
results = {} results = {}
@@ -732,7 +736,7 @@ async def run_hybrid_search(
search_graph( search_graph(
connector=connector, connector=connector,
q=query_text, q=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include include=include
) )
@@ -769,7 +773,7 @@ async def run_hybrid_search(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=query_text, query_text=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include, include=include,
) )
@@ -916,9 +920,7 @@ async def run_hybrid_search(
async def search_by_temporal( async def search_by_temporal(
group_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
@@ -929,7 +931,7 @@ async def search_by_temporal(
Temporal search across Statements. Temporal search across Statements.
- Matches statements created between start_date and end_date - Matches statements created between start_date and end_date
- Optionally filters by group_id - Optionally filters by end_user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
connector = Neo4jConnector() connector = Neo4jConnector()
@@ -939,9 +941,7 @@ async def search_by_temporal(
end_date = normalize_date_safe(end_date) end_date = normalize_date_safe(end_date)
params = TemporalSearchParams.model_validate({ params = TemporalSearchParams.model_validate({
"group_id": group_id, "end_user_id": end_user_id,
"apply_id": apply_id,
"user_id": user_id,
"start_date": start_date, "start_date": start_date,
"end_date": end_date, "end_date": end_date,
"valid_date": valid_date, "valid_date": valid_date,
@@ -950,9 +950,7 @@ async def search_by_temporal(
}) })
statements = await search_graph_by_temporal( statements = await search_graph_by_temporal(
connector=connector, connector=connector,
group_id=params.group_id, end_user_id=params.end_user_id,
apply_id=params.apply_id,
user_id=params.user_id,
start_date=params.start_date, start_date=params.start_date,
end_date=params.end_date, end_date=params.end_date,
valid_date=params.valid_date, valid_date=params.valid_date,
@@ -964,9 +962,7 @@ async def search_by_temporal(
async def search_by_keyword_temporal( async def search_by_keyword_temporal(
query_text: str, query_text: str,
group_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
@@ -987,9 +983,7 @@ async def search_by_keyword_temporal(
invalid_date = normalize_date_safe(invalid_date) invalid_date = normalize_date_safe(invalid_date)
params = TemporalSearchParams.model_validate({ params = TemporalSearchParams.model_validate({
"group_id": group_id, "end_user_id": end_user_id,
"apply_id": apply_id,
"user_id": user_id,
"start_date": start_date, "start_date": start_date,
"end_date": end_date, "end_date": end_date,
"valid_date": valid_date, "valid_date": valid_date,
@@ -999,9 +993,7 @@ async def search_by_keyword_temporal(
statements = await search_graph_by_keyword_temporal( statements = await search_graph_by_keyword_temporal(
connector=connector, connector=connector,
query_text=query_text, query_text=query_text,
group_id=params.group_id, end_user_id=params.end_user_id,
apply_id=params.apply_id,
user_id=params.user_id,
start_date=params.start_date, start_date=params.start_date,
end_date=params.end_date, end_date=params.end_date,
valid_date=params.valid_date, valid_date=params.valid_date,
@@ -1013,7 +1005,7 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id( async def search_chunk_by_chunk_id(
chunk_id: str, chunk_id: str,
group_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
limit: int = 1, limit: int = 1,
): ):
""" """
@@ -1023,8 +1015,68 @@ async def search_chunk_by_chunk_id(
chunks = await search_graph_by_chunk_id( chunks = await search_graph_by_chunk_id(
connector=connector, connector=connector,
chunk_id=chunk_id, chunk_id=chunk_id,
group_id=group_id, end_user_id=end_user_id,
limit=limit limit=limit
) )
return {"chunks": chunks} return {"chunks": chunks}
if __name__ == '__main__':
# 测试混合检索功能
from app.schemas.memory_config_schema import MemoryConfig
from app.db import get_db
from app.services.memory_config_service import MemoryConfigService
# 从数据库获取真实配置
db = next(get_db())
try:
config_service = MemoryConfigService(db)
# 使用 config_id=17 获取配置
memory_config = config_service.load_memory_config(config_id=17)
if not memory_config:
print("错误:找不到 config_id=17 的配置")
print("请先在数据库中创建配置,或修改 config_id")
exit(1)
print(f"✓ 成功加载配置: {memory_config.config_name}")
print(f" - Workspace: {memory_config.workspace_name}")
print(f" - LLM Model: {memory_config.llm_model_name}")
print(f" - Embedding Model: {memory_config.embedding_model_name}")
print(f" - Storage Type: {memory_config.storage_type}")
print()
# 修改这里的参数进行测试
test_end_user_id = "021886bc-fab9-4fd5-b607-497b262e0381" # 修改为你的 end_user_id
test_query = "小明擅长什么?" # 修改为你的查询
print(f"开始测试检索...")
print(f" - Query: {test_query}")
print(f" - End User ID: {test_end_user_id}")
print(f" - Search Type: hybrid")
print()
results = asyncio.run(run_hybrid_search(
query_text=test_query,
search_type="hybrid", # 可选: "keyword", "embedding", "hybrid"
end_user_id=test_end_user_id,
limit=10,
include=["statements", "entities", "chunks", "summaries"],
output_path=None,
memory_config=memory_config,
rerank_alpha=0.6,
use_forgetting_rerank=False,
use_llm_rerank=False
))
print("=" * 80)
print("检索结果:")
print("=" * 80)
print(results)
except Exception as e:
print(f"错误: {e}")
import traceback
traceback.print_exc()
finally:
db.close()

View File

@@ -555,8 +555,8 @@ class DataPreprocessor:
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
# 获取group_id如果不存在则生成默认值 # 获取end_user_id如果不存在则生成默认值
group_id = item.get('group_id', f'group_default_{i}') end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}') user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}') apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -574,7 +574,7 @@ class DataPreprocessor:
dialog_data = DialogData( dialog_data = DialogData(
context=context, context=context,
ref_id=dialog_id, ref_id=dialog_id,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id, user_id=user_id,
apply_id=apply_id, apply_id=apply_id,
metadata=metadata metadata=metadata
@@ -644,7 +644,7 @@ class DataPreprocessor:
context = ConversationContext(msgs=messages) context = ConversationContext(msgs=messages)
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
group_id = item.get('group_id', f'group_default_{i}') end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}') user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}') apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -657,7 +657,7 @@ class DataPreprocessor:
dialog_data = DialogData( dialog_data = DialogData(
context=context, context=context,
ref_id=dialog_id, ref_id=dialog_id,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id, user_id=user_id,
apply_id=apply_id, apply_id=apply_id,
metadata=metadata metadata=metadata

View File

@@ -199,7 +199,7 @@ def accurate_match(
entity_nodes: List[ExtractedEntityNode] entity_nodes: List[ExtractedEntityNode]
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]: ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
""" """
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。 精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
返回: (deduped_entities, id_redirect, exact_merge_map) 返回: (deduped_entities, id_redirect, exact_merge_map)
""" """
exact_merge_map: Dict[str, Dict] = {} exact_merge_map: Dict[str, Dict] = {}
@@ -210,8 +210,8 @@ def accurate_match(
for ent in entity_nodes: for ent in entity_nodes:
name_norm = (getattr(ent, "name", "") or "").strip() name_norm = (getattr(ent, "name", "") or "").strip()
type_norm = (getattr(ent, "entity_type", "") or "").strip() type_norm = (getattr(ent, "entity_type", "") or "").strip()
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}" key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}"
# 为避免跨业务组误并,明确以 group_id 为范围边界 # 为避免跨业务组误并,明确以 end_user_id 为范围边界
if key not in canonical_map: if key not in canonical_map:
canonical_map[key] = ent canonical_map[key] = ent
id_redirect[ent.id] = ent.id id_redirect[ent.id] = ent.id
@@ -223,11 +223,11 @@ def accurate_match(
id_redirect[ent.id] = canonical.id id_redirect[ent.id] = canonical.id
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用) # 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
try: try:
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}" k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
if k not in exact_merge_map: if k not in exact_merge_map:
exact_merge_map[k] = { exact_merge_map[k] = {
"canonical_id": canonical.id, "canonical_id": canonical.id,
"group_id": canonical.group_id, "end_user_id": canonical.end_user_id,
"name": canonical.name, "name": canonical.name,
"entity_type": canonical.entity_type, "entity_type": canonical.entity_type,
"merged_ids": set(), "merged_ids": set(),
@@ -596,7 +596,7 @@ def fuzzy_match(
b = deduped_entities[j] b = deduped_entities[j]
# 跳过不同业务组的实体 # 跳过不同业务组的实体
if getattr(a, "group_id", None) != getattr(b, "group_id", None): if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
j += 1 j += 1
continue continue
@@ -671,7 +671,7 @@ def fuzzy_match(
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]" merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]" merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
fuzzy_merge_records.append( fuzzy_merge_records.append(
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | " f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | "
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}" f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
) )
except Exception: except Exception:
@@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
# 记录 LLM 融合日志 # 记录 LLM 融合日志
try: try:
llm_records.append( llm_records.append(
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
) )
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason # 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
except Exception: except Exception:
@@ -847,7 +847,7 @@ async def LLM_disamb_decision(
id_redirect[k] = a.id id_redirect[k] = a.id
try: try:
disamb_records.append( disamb_records.append(
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
) )
except Exception: except Exception:
pass pass

View File

@@ -174,7 +174,7 @@ async def _judge_pair(
pass pass
# 3. 构建LLM判断的“上下文信息”规则层计算的所有特征 判断上下文特征有助于实体消歧首先判断的类型关系 # 3. 构建LLM判断的“上下文信息”规则层计算的所有特征 判断上下文特征有助于实体消歧首先判断的类型关系
ctx = { ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), "same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim, "name_text_sim": name_text_sim,
@@ -235,7 +235,7 @@ async def _judge_pair_disamb(
except Exception: except Exception:
pass pass
ctx = { ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), "same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim, "name_text_sim": name_text_sim,
"name_embed_sim": name_embed_sim, "name_embed_sim": name_embed_sim,
@@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
a = entity_nodes[i] a = entity_nodes[i]
for j in range(i + 1, len(entity_nodes)): for j in range(i + 1, len(entity_nodes)):
b = entity_nodes[j] b = entity_nodes[j]
# 规则1必须属于同一组group_id相同不同组的实体不重复 # 规则1必须属于同一组end_user_id相同不同组的实体不重复
if getattr(a, "group_id", None) != getattr(b, "group_id", None): if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue continue
# 规则2类型必须兼容调用_simple_type_ok判断 # 规则2类型必须兼容调用_simple_type_ok判断
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)): if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
@@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
- max_rounds: upper bound for iterative passes (default 3) - max_rounds: upper bound for iterative passes (default 3)
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90) - auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83) - co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition - shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition
Returns: Returns:
- global_redirect: dict losing_id -> canonical_id accumulated across rounds - global_redirect: dict losing_id -> canonical_id accumulated across rounds
@@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]: def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
""" """
group_id 分块,避免跨组实体在同一块,减少无效候选对 end_user_id 分块,避免跨组实体在同一块,减少无效候选对
Args: Args:
nodes: 实体节点列表 nodes: 实体节点列表
@@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
""" """
groups: Dict[str, List[ExtractedEntityNode]] = {} groups: Dict[str, List[ExtractedEntityNode]] = {}
for e in nodes: for e in nodes:
gid = getattr(e, "group_id", None) gid = getattr(e, "end_user_id", None)
groups.setdefault(str(gid), []).append(e) groups.setdefault(str(gid), []).append(e)
blocks: List[List[ExtractedEntityNode]] = [] blocks: List[List[ExtractedEntityNode]] = []
for gid, arr in groups.items(): for gid, arr in groups.items():
@@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
# Collapse nodes to canonical reps before each round to avoid redundant comparisons # Collapse nodes to canonical reps before each round to avoid redundant comparisons
# 步骤1折叠实体合并已确定的重复实体减少后续计算量 # 步骤1折叠实体合并已确定的重复实体减少后续计算量
current_nodes = _collapse_nodes(current_nodes) current_nodes = _collapse_nodes(current_nodes)
# 步骤2分块group_id分块避免跨组处理 # 步骤2分块end_user_id分块避免跨组处理
blocks = _partition_blocks(current_nodes) blocks = _partition_blocks(current_nodes)
if not blocks: # 无块可处理(实体已全部折叠),退出循环 if not blocks: # 无块可处理(实体已全部折叠),退出循环
break break
@@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative(
a = entity_nodes[i] a = entity_nodes[i]
b = entity_nodes[j] b = entity_nodes[j]
# 必须同组 # 必须同组
if getattr(a, "group_id", None) != getattr(b, "group_id", None): if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue continue
ta = getattr(a, "entity_type", None) ta = getattr(a, "entity_type", None)
tb = getattr(b, "entity_type", None) tb = getattr(b, "entity_type", None)

View File

@@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
return ExtractedEntityNode( return ExtractedEntityNode(
id=row.get("id"), id=row.get("id"),
name=row.get("name") or "", name=row.get("name") or "",
group_id=row.get("group_id") or "", end_user_id=row.get("end_user_id") or "",
user_id=row.get("user_id") or "", user_id=row.get("user_id") or "",
apply_id=row.get("apply_id") or "", apply_id=row.get("apply_id") or "",
created_at=_parse_dt(row.get("created_at")), created_at=_parse_dt(row.get("created_at")),
@@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重 async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重 end_user_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体 entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系 statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系 entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
@@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]: ) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
""" """
第二层去重消歧: 第二层去重消歧:
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体 - 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合 - 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
- 返回融合后的实体与重定向后的边(边已指向规范 ID优先 DB ID - 返回融合后的实体与重定向后的边(边已指向规范 ID优先 DB ID
""" """
@@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
] ]
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体并将结果赋值给candidates_map等待异步操作完成 candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体并将结果赋值给candidates_map等待异步操作完成
connector=connector, group_id=group_id, connector=connector, end_user_id=end_user_id,
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引) entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略若精确匹配无结果用包含关系召回候选与src\database\cypher_queries.py的307产生联动 use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略若精确匹配无结果用包含关系召回候选与src\database\cypher_queries.py的307产生联动
) )

View File

@@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return(
if pipeline_config is None: if pipeline_config is None:
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return") raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
# 先探测 group_id决定报告写入策略 # 先探测 end_user_id决定报告写入策略
group_id: Optional[str] = None end_user_id: Optional[str] = None
for dd in dialog_data_list: for dd in dialog_data_list:
group_id = getattr(dd, "group_id", None) end_user_id = getattr(dd, "end_user_id", None)
if group_id: if end_user_id:
break break
# 第一层去重消歧 # 第一层去重消歧
@@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return(
# 第二层去重消歧:与 Neo4j 中同组实体联合融合 # 第二层去重消歧:与 Neo4j 中同组实体联合融合
try: try:
if group_id: if end_user_id:
if connector: if connector:
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j( fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
connector=connector, connector=connector,
group_id=group_id, end_user_id=end_user_id,
entity_nodes=dedup_entity_nodes, entity_nodes=dedup_entity_nodes,
statement_entity_edges=dedup_statement_entity_edges, statement_entity_edges=dedup_statement_entity_edges,
entity_entity_edges=dedup_entity_entity_edges, entity_entity_edges=dedup_entity_entity_edges,
@@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return(
else: else:
print("Skip second-layer dedup: missing connector") print("Skip second-layer dedup: missing connector")
else: else:
print("Skip second-layer dedup: missing group_id") print("Skip second-layer dedup: missing end_user_id")
except Exception as e: except Exception as e:
print(f"Second-layer dedup failed: {e}") print(f"Second-layer dedup failed: {e}")

View File

@@ -287,7 +287,7 @@ class ExtractionOrchestrator:
for d_idx, dialog in enumerate(dialog_data_list): for d_idx, dialog in enumerate(dialog_data_list):
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
for c_idx, chunk in enumerate(dialog.chunks): for c_idx, chunk in enumerate(dialog.chunks):
all_chunks.append((chunk, dialog.group_id, dialogue_content)) all_chunks.append((chunk, dialog.end_user_id, dialogue_content))
chunk_metadata.append((d_idx, c_idx)) chunk_metadata.append((d_idx, c_idx))
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
@@ -299,9 +299,9 @@ class ExtractionOrchestrator:
# 全局并行处理所有分块 # 全局并行处理所有分块
async def extract_for_chunk(chunk_data, chunk_index): async def extract_for_chunk(chunk_data, chunk_index):
nonlocal completed_chunks nonlocal completed_chunks
chunk, group_id, dialogue_content = chunk_data chunk, end_user_id, dialogue_content = chunk_data
try: try:
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content) statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
# 流式输出:每提取完一个分块的陈述句,立即发送进度 # 流式输出:每提取完一个分块的陈述句,立即发送进度
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送 # 注意:只在试运行模式下发送陈述句详情,正式模式不发送
@@ -992,9 +992,7 @@ class ExtractionOrchestrator:
id=dialog_data.id, id=dialog_data.id,
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段 name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
ref_id=dialog_data.ref_id, ref_id=dialog_data.ref_id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=dialog_data.context.content if dialog_data.context else "", content=dialog_data.context.content if dialog_data.context else "",
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None, dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
@@ -1012,9 +1010,7 @@ class ExtractionOrchestrator:
id=chunk.id, id=chunk.id,
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段 name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
dialog_id=dialog_data.id, dialog_id=dialog_data.id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=chunk.content, content=chunk.content,
chunk_embedding=chunk.chunk_embedding, chunk_embedding=chunk.chunk_embedding,
@@ -1035,9 +1031,7 @@ class ExtractionOrchestrator:
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement, statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
@@ -1060,9 +1054,7 @@ class ExtractionOrchestrator:
statement_chunk_edge = StatementChunkEdge( statement_chunk_edge = StatementChunkEdge(
source=statement.id, source=statement.id,
target=chunk.id, target=chunk.id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
) )
@@ -1095,9 +1087,7 @@ class ExtractionOrchestrator:
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None), name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at, expired_at=dialog_data.expired_at,
@@ -1112,9 +1102,7 @@ class ExtractionOrchestrator:
source=statement.id, source=statement.id,
target=entity.id, target=entity.id,
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
) )
@@ -1134,9 +1122,7 @@ class ExtractionOrchestrator:
relation_type=triplet.predicate, relation_type=triplet.predicate,
statement=statement.statement, statement=statement.statement,
source_statement_id=statement.id, source_statement_id=statement.id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at, expired_at=dialog_data.expired_at,
@@ -1763,14 +1749,14 @@ class ExtractionOrchestrator:
async def get_chunked_dialogs( async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1", end_user_id: str = "group_1",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""从测试数据生成分块对话 """从测试数据生成分块对话
Args: Args:
chunker_strategy: 分块策略(默认: RecursiveChunker chunker_strategy: 分块策略(默认: RecursiveChunker
group_id: 组ID end_user_id: 组ID
indices: 要处理的数据索引列表(可选) indices: 要处理的数据索引列表(可选)
Returns: Returns:
@@ -1834,7 +1820,7 @@ async def get_chunked_dialogs(
dialog_data = DialogData( dialog_data = DialogData(
context=conversation_context, context=conversation_context,
ref_id=data['id'], ref_id=data['id'],
group_id=group_id, end_user_id=end_user_id,
metadata=dialog_metadata, metadata=dialog_metadata,
) )
@@ -1936,7 +1922,7 @@ async def get_chunked_dialogs_from_preprocessed(
async def get_chunked_dialogs_with_preprocessing( async def get_chunked_dialogs_with_preprocessing(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
group_id: str = "default", end_user_id: str = "default",
user_id: str = "default", user_id: str = "default",
apply_id: str = "default", apply_id: str = "default",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
@@ -1948,7 +1934,7 @@ async def get_chunked_dialogs_with_preprocessing(
Args: Args:
chunker_strategy: 分块策略 chunker_strategy: 分块策略
group_id: 组ID end_user_id: 组ID
user_id: 用户ID user_id: 用户ID
apply_id: 应用ID apply_id: 应用ID
indices: 要处理的数据索引列表 indices: 要处理的数据索引列表
@@ -1976,11 +1962,9 @@ async def get_chunked_dialogs_with_preprocessing(
indices=indices, indices=indices,
) )
# 设置 group_id, user_id, apply_id # 设置 end_user_id
for dd in preprocessed_data: for dd in preprocessed_data:
dd.group_id = group_id dd.end_user_id = end_user_id
dd.user_id = user_id
dd.apply_id = apply_id
# 步骤2: 语义剪枝 # 步骤2: 语义剪枝
try: try:

View File

@@ -193,9 +193,9 @@ async def _process_chunk_summary(
node = MemorySummaryNode( node = MemorySummaryNode(
id=uuid4().hex, id=uuid4().hex,
name=title if title else f"MemorySummaryChunk_{chunk.id}", name=title if title else f"MemorySummaryChunk_{chunk.id}",
group_id=dialog.group_id, end_user_id=dialog.end_user_id,
user_id=dialog.user_id, user_id=dialog.end_user_id,
apply_id=dialog.apply_id, apply_id=dialog.end_user_id,
run_id=dialog.run_id, # 使用 dialog 的 run_id run_id=dialog.run_id, # 使用 dialog 的 run_id
created_at=datetime.now(), created_at=datetime.now(),
expired_at=datetime(9999, 12, 31), expired_at=datetime(9999, 12, 31),

View File

@@ -82,12 +82,12 @@ class StatementExtractor:
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty") logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
return None return None
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
"""Process a single chunk and return extracted statements """Process a single chunk and return extracted statements
Args: Args:
chunk: Chunk object to process chunk: Chunk object to process
group_id: Group ID to assign to all statements in this chunk end_user_id: Group ID to assign to all statements in this chunk
dialogue_content: Full dialogue content to provide as context dialogue_content: Full dialogue content to provide as context
Returns: Returns:
@@ -158,7 +158,7 @@ class StatementExtractor:
temporal_info=temporal_type, temporal_info=temporal_type,
relevence_info=relevence_info, relevence_info=relevence_info,
chunk_id=chunk.id, chunk_id=chunk.id,
group_id=group_id, end_user_id=end_user_id,
speaker=chunk_speaker, speaker=chunk_speaker,
) )
@@ -184,10 +184,10 @@ class StatementExtractor:
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction") logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data # Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
results = await asyncio.gather( results = await asyncio.gather(
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process], *[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process],
return_exceptions=True return_exceptions=True
) )
@@ -225,7 +225,7 @@ class StatementExtractor:
for i, statement in enumerate(statements, 1): for i, statement in enumerate(statements, 1):
f.write(f"Statement {i}:\n") f.write(f"Statement {i}:\n")
f.write(f"Id: {statement.id}\n") f.write(f"Id: {statement.id}\n")
f.write(f"Group Id: {statement.group_id}\n") f.write(f"Group Id: {statement.end_user_id}\n")
f.write(f"Content: {statement.statement}\n") f.write(f"Content: {statement.statement}\n")
f.write(f"Type: {statement.stmt_type.value}\n") f.write(f"Type: {statement.stmt_type.value}\n")
f.write(f"Temporal Info: {statement.temporal_info.value}\n") f.write(f"Temporal Info: {statement.temporal_info.value}\n")
@@ -298,7 +298,7 @@ class StatementExtractor:
dialog_sections.append({ dialog_sections.append({
"dialog_id": dialog.ref_id, "dialog_id": dialog.ref_id,
"group_id": dialog.group_id, "end_user_id": dialog.end_user_id,
"content": dialog.content if getattr(dialog, "content", None) else "", "content": dialog.content if getattr(dialog, "content", None) else "",
"strong": strong_relations, "strong": strong_relations,
"weak": weak_relations, "weak": weak_relations,
@@ -312,7 +312,7 @@ class StatementExtractor:
for idx, section in enumerate(dialog_sections, 1): for idx, section in enumerate(dialog_sections, 1):
f.write(f"Dialog {idx}:\n") f.write(f"Dialog {idx}:\n")
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n") f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
f.write(f"Group ID: {section.get('group_id', '')}\n") f.write(f"Group ID: {section.get('end_user_id', '')}\n")
f.write("Content:\n") f.write("Content:\n")
f.write(f"{section.get('content', '')}\n") f.write(f"{section.get('content', '')}\n")
f.write("-" * 40 + "\n\n") f.write("-" * 40 + "\n\n")

View File

@@ -132,7 +132,7 @@ class TemporalExtractor:
prompt_logger.info("") prompt_logger.info("")
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===") prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
prompt_logger.info( prompt_logger.info(
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}" f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}"
) )
except Exception: except Exception:
pass pass

View File

@@ -116,7 +116,7 @@ class TripletExtractor:
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...") logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
try: try:
prompt_logger.info( prompt_logger.info(
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}" f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}"
) )
except Exception: except Exception:
pass pass

View File

@@ -75,7 +75,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None current_time: Optional[datetime] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -91,7 +91,7 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签Statement, ExtractedEntity, MemorySummary node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选用于过滤 end_user_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间) current_time: 当前时间(可选,默认使用系统时间)
Returns: Returns:
@@ -123,7 +123,7 @@ class AccessHistoryManager:
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
try: try:
# 步骤1读取当前节点状态 # 步骤1读取当前节点状态
node_data = await self._fetch_node(node_id, node_label, group_id) node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data: if not node_data:
raise ValueError( raise ValueError(
@@ -142,7 +142,7 @@ class AccessHistoryManager:
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
update_data=update_data, update_data=update_data,
group_id=group_id end_user_id=end_user_id
) )
logger.info( logger.info(
@@ -172,7 +172,7 @@ class AccessHistoryManager:
self, self,
node_ids: List[str], node_ids: List[str],
node_label: str, node_label: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None current_time: Optional[datetime] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
@@ -184,7 +184,7 @@ class AccessHistoryManager:
Args: Args:
node_ids: 节点ID列表 node_ids: 节点ID列表
node_label: 节点标签(所有节点必须是同一类型) node_label: 节点标签(所有节点必须是同一类型)
group_id: 组ID可选 end_user_id: 组ID可选
current_time: 当前时间(可选) current_time: 当前时间(可选)
Returns: Returns:
@@ -202,7 +202,7 @@ class AccessHistoryManager:
task = self.record_access( task = self.record_access(
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
group_id=group_id, end_user_id=end_user_id,
current_time=current_time current_time=current_time
) )
tasks.append(task) tasks.append(task)
@@ -235,7 +235,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Tuple[ConsistencyCheckResult, Optional[str]]: ) -> Tuple[ConsistencyCheckResult, Optional[str]]:
""" """
检查节点数据的一致性 检查节点数据的一致性
@@ -249,14 +249,14 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
Tuple[ConsistencyCheckResult, Optional[str]]: Tuple[ConsistencyCheckResult, Optional[str]]:
- 一致性检查结果枚举 - 一致性检查结果枚举
- 错误描述(如果不一致) - 错误描述(如果不一致)
""" """
node_data = await self._fetch_node(node_id, node_label, group_id) node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data: if not node_data:
return ConsistencyCheckResult.CONSISTENT, None return ConsistencyCheckResult.CONSISTENT, None
@@ -305,7 +305,7 @@ class AccessHistoryManager:
async def check_batch_consistency( async def check_batch_consistency(
self, self,
node_label: str, node_label: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1000 limit: int = 1000
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -313,7 +313,7 @@ class AccessHistoryManager:
Args: Args:
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
limit: 检查的最大节点数 limit: 检查的最大节点数
Returns: Returns:
@@ -329,16 +329,16 @@ class AccessHistoryManager:
MATCH (n:{node_label}) MATCH (n:{node_label})
WHERE n.access_history IS NOT NULL WHERE n.access_history IS NOT NULL
""" """
if group_id: if end_user_id:
query += " AND n.group_id = $group_id" query += " AND n.end_user_id = $end_user_id"
query += """ query += """
RETURN n.id as id RETURN n.id as id
LIMIT $limit LIMIT $limit
""" """
params = {"limit": limit} params = {"limit": limit}
if group_id: if end_user_id:
params["group_id"] = group_id params["end_user_id"] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)
node_ids = [r['id'] for r in results] node_ids = [r['id'] for r in results]
@@ -351,7 +351,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency( result, message = await self.check_consistency(
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
group_id=group_id end_user_id=end_user_id
) )
if result == ConsistencyCheckResult.CONSISTENT: if result == ConsistencyCheckResult.CONSISTENT:
@@ -387,7 +387,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> bool: ) -> bool:
""" """
自动修复节点的数据不一致问题 自动修复节点的数据不一致问题
@@ -401,7 +401,7 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
bool: 修复成功返回True否则返回False bool: 修复成功返回True否则返回False
@@ -411,7 +411,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency( result, message = await self.check_consistency(
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
group_id=group_id end_user_id=end_user_id
) )
if result == ConsistencyCheckResult.CONSISTENT: if result == ConsistencyCheckResult.CONSISTENT:
@@ -419,7 +419,7 @@ class AccessHistoryManager:
return True return True
# 获取节点数据 # 获取节点数据
node_data = await self._fetch_node(node_id, node_label, group_id) node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data: if not node_data:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
return False return False
@@ -457,8 +457,8 @@ class AccessHistoryManager:
query = f""" query = f"""
MATCH (n:{node_label} {{id: $node_id}}) MATCH (n:{node_label} {{id: $node_id}})
""" """
if group_id: if end_user_id:
query += " WHERE n.group_id = $group_id" query += " WHERE n.end_user_id = $end_user_id"
query += """ query += """
SET n += $repair_data SET n += $repair_data
RETURN n RETURN n
@@ -468,8 +468,8 @@ class AccessHistoryManager:
'node_id': node_id, 'node_id': node_id,
'repair_data': repair_data 'repair_data': repair_data
} }
if group_id: if end_user_id:
params['group_id'] = group_id params['end_user_id'] = end_user_id
await self.connector.execute_query(query, **params) await self.connector.execute_query(query, **params)
@@ -491,7 +491,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
获取节点数据 获取节点数据
@@ -499,7 +499,7 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
Optional[Dict[str, Any]]: 节点数据如果不存在返回None Optional[Dict[str, Any]]: 节点数据如果不存在返回None
@@ -507,8 +507,8 @@ class AccessHistoryManager:
query = f""" query = f"""
MATCH (n:{node_label} {{id: $node_id}}) MATCH (n:{node_label} {{id: $node_id}})
""" """
if group_id: if end_user_id:
query += " WHERE n.group_id = $group_id" query += " WHERE n.end_user_id = $end_user_id"
query += """ query += """
RETURN n.id as id, RETURN n.id as id,
n.importance_score as importance_score, n.importance_score as importance_score,
@@ -519,8 +519,8 @@ class AccessHistoryManager:
""" """
params = {'node_id': node_id} params = {'node_id': node_id}
if group_id: if end_user_id:
params['group_id'] = group_id params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)
@@ -585,7 +585,7 @@ class AccessHistoryManager:
node_id: str, node_id: str,
node_label: str, node_label: str,
update_data: Dict[str, Any], update_data: Dict[str, Any],
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
原子性更新节点(使用乐观锁) 原子性更新节点(使用乐观锁)
@@ -597,7 +597,7 @@ class AccessHistoryManager:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
update_data: 更新数据 update_data: 更新数据
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
Dict[str, Any]: 更新后的节点数据 Dict[str, Any]: 更新后的节点数据
@@ -606,13 +606,13 @@ class AccessHistoryManager:
RuntimeError: 如果更新失败或发生版本冲突 RuntimeError: 如果更新失败或发生版本冲突
""" """
# 定义事务函数 # 定义事务函数
async def update_transaction(tx, node_id, node_label, update_data, group_id): async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
# 步骤1读取当前节点并获取版本号 # 步骤1读取当前节点并获取版本号
read_query = f""" read_query = f"""
MATCH (n:{node_label} {{id: $node_id}}) MATCH (n:{node_label} {{id: $node_id}})
""" """
if group_id: if end_user_id:
read_query += " WHERE n.group_id = $group_id" read_query += " WHERE n.end_user_id = $end_user_id"
read_query += """ read_query += """
RETURN n.id as id, RETURN n.id as id,
n.version as version, n.version as version,
@@ -624,8 +624,8 @@ class AccessHistoryManager:
""" """
read_params = {'node_id': node_id} read_params = {'node_id': node_id}
if group_id: if end_user_id:
read_params['group_id'] = group_id read_params['end_user_id'] = end_user_id
read_result = await tx.run(read_query, **read_params) read_result = await tx.run(read_query, **read_params)
current_node = await read_result.single() current_node = await read_result.single()
@@ -656,8 +656,8 @@ class AccessHistoryManager:
# 构建 WHERE 子句 # 构建 WHERE 子句
where_conditions = [] where_conditions = []
if group_id: if end_user_id:
where_conditions.append("n.group_id = $group_id") where_conditions.append("n.end_user_id = $end_user_id")
# 添加版本检查 # 添加版本检查
if current_version > 0: if current_version > 0:
@@ -695,8 +695,8 @@ class AccessHistoryManager:
'last_access_time': update_data['last_access_time'], 'last_access_time': update_data['last_access_time'],
'access_count': update_data['access_count'] 'access_count': update_data['access_count']
} }
if group_id: if end_user_id:
update_params['group_id'] = group_id update_params['end_user_id'] = end_user_id
update_result = await tx.run(update_query, **update_params) update_result = await tx.run(update_query, **update_params)
updated_node = await update_result.single() updated_node = await update_result.single()
@@ -720,7 +720,7 @@ class AccessHistoryManager:
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
update_data=update_data, update_data=update_data,
group_id=group_id end_user_id=end_user_id
) )
return result return result
except Exception as e: except Exception as e:

View File

@@ -66,7 +66,7 @@ class ForgettingScheduler:
async def run_forgetting_cycle( async def run_forgetting_cycle(
self, self,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
max_merge_batch_size: int = 100, max_merge_batch_size: int = 100,
min_days_since_access: int = 30, min_days_since_access: int = 30,
config_id: Optional[int] = None, config_id: Optional[int] = None,
@@ -77,7 +77,7 @@ class ForgettingScheduler:
Args: Args:
group_id: 组 ID可选用于过滤特定组的节点 end_user_id: 组 ID可选用于过滤特定组的节点
max_merge_batch_size: 单次最大融合节点对数(默认 100 max_merge_batch_size: 单次最大融合节点对数(默认 100
min_days_since_access: 最小未访问天数(默认 30 天) min_days_since_access: 最小未访问天数(默认 30 天)
config_id: 配置ID可选用于获取 llm_id config_id: 配置ID可选用于获取 llm_id
@@ -107,19 +107,19 @@ class ForgettingScheduler:
start_time_iso = start_time.isoformat() start_time_iso = start_time.isoformat()
logger.info( logger.info(
f"开始遗忘周期: group_id={group_id}, " f"开始遗忘周期: end_user_id={end_user_id}, "
f"max_batch={max_merge_batch_size}, " f"max_batch={max_merge_batch_size}, "
f"min_days={min_days_since_access}" f"min_days={min_days_since_access}"
) )
try: try:
# 步骤1统计遗忘前的节点数量 # 步骤1统计遗忘前的节点数量
nodes_before = await self._count_knowledge_nodes(group_id) nodes_before = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘前节点总数: {nodes_before}") logger.info(f"遗忘前节点总数: {nodes_before}")
# 步骤2识别可遗忘的节点对 # 步骤2识别可遗忘的节点对
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes( forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
group_id=group_id, end_user_id=end_user_id,
min_days_since_access=min_days_since_access min_days_since_access=min_days_since_access
) )
@@ -213,7 +213,7 @@ class ForgettingScheduler:
'statement_text': pair['statement_text'], 'statement_text': pair['statement_text'],
'statement_activation': pair['statement_activation'], 'statement_activation': pair['statement_activation'],
'statement_importance': pair['statement_importance'], 'statement_importance': pair['statement_importance'],
'group_id': group_id 'end_user_id': end_user_id
} }
entity_node = { entity_node = {
@@ -222,7 +222,7 @@ class ForgettingScheduler:
'entity_type': pair['entity_type'], 'entity_type': pair['entity_type'],
'entity_activation': pair['entity_activation'], 'entity_activation': pair['entity_activation'],
'entity_importance': pair['entity_importance'], 'entity_importance': pair['entity_importance'],
'group_id': group_id 'end_user_id': end_user_id
} }
# 融合节点 # 融合节点
@@ -262,7 +262,7 @@ class ForgettingScheduler:
continue continue
# 步骤6统计遗忘后的节点数量 # 步骤6统计遗忘后的节点数量
nodes_after = await self._count_knowledge_nodes(group_id) nodes_after = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘后节点总数: {nodes_after}") logger.info(f"遗忘后节点总数: {nodes_after}")
# 步骤7生成遗忘报告 # 步骤7生成遗忘报告
@@ -315,7 +315,7 @@ class ForgettingScheduler:
async def _count_knowledge_nodes( async def _count_knowledge_nodes(
self, self,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> int: ) -> int:
""" """
统计知识层节点总数 统计知识层节点总数
@@ -323,7 +323,7 @@ class ForgettingScheduler:
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。 统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
Args: Args:
group_id: 组 ID可选用于过滤特定组的节点 end_user_id: 组 ID可选用于过滤特定组的节点
Returns: Returns:
int: 知识层节点总数 int: 知识层节点总数
@@ -333,16 +333,16 @@ class ForgettingScheduler:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
""" """
if group_id: if end_user_id:
query += " AND n.group_id = $group_id" query += " AND n.end_user_id = $end_user_id"
query += """ query += """
RETURN count(n) as total RETURN count(n) as total
""" """
params = {} params = {}
if group_id: if end_user_id:
params['group_id'] = group_id end_user_id['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)

View File

@@ -90,7 +90,7 @@ class ForgettingStrategy:
async def find_forgettable_nodes( async def find_forgettable_nodes(
self, self,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
min_days_since_access: int = 30 min_days_since_access: int = 30
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
@@ -102,7 +102,7 @@ class ForgettingStrategy:
3. Statement 和 Entity 之间存在关系边 3. Statement 和 Entity 之间存在关系边
Args: Args:
group_id: 组 ID可选用于过滤特定组的节点 end_user_id: 组 ID可选用于过滤特定组的节点
min_days_since_access: 最小未访问天数(默认 30 天) min_days_since_access: 最小未访问天数(默认 30 天)
Returns: Returns:
@@ -136,8 +136,8 @@ class ForgettingStrategy:
AND (e.entity_type IS NULL OR e.entity_type <> 'Person') AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
""" """
if group_id: if end_user_id:
query += " AND s.group_id = $group_id AND e.group_id = $group_id" query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id"
query += """ query += """
RETURN s.id as statement_id, RETURN s.id as statement_id,
@@ -159,8 +159,8 @@ class ForgettingStrategy:
'threshold': self.forgetting_threshold, 'threshold': self.forgetting_threshold,
'cutoff_time': cutoff_time_iso 'cutoff_time': cutoff_time_iso
} }
if group_id: if end_user_id:
params['group_id'] = group_id params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)
@@ -247,8 +247,8 @@ class ForgettingStrategy:
entity_activation = entity_node['entity_activation'] entity_activation = entity_node['entity_activation']
entity_importance = entity_node['entity_importance'] entity_importance = entity_node['entity_importance']
# 获取 group_id从 statement 或 entity 节点) # 获取 end_user_id从 statement 或 entity 节点)
group_id = statement_node.get('group_id') or entity_node.get('group_id') end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id')
# 生成摘要内容 # 生成摘要内容
summary_text = await self._generate_summary( summary_text = await self._generate_summary(
@@ -325,7 +325,7 @@ class ForgettingStrategy:
last_access_time: $current_time, last_access_time: $current_time,
access_count: 1, access_count: 1,
version: 1, version: 1,
group_id: $group_id, end_user_id: $end_user_id,
created_at: datetime($current_time), created_at: datetime($current_time),
merged_at: datetime($current_time) merged_at: datetime($current_time)
}) })
@@ -423,7 +423,7 @@ class ForgettingStrategy:
'inherited_activation': inherited_activation, 'inherited_activation': inherited_activation,
'inherited_importance': inherited_importance, 'inherited_importance': inherited_importance,
'current_time': current_time_iso, 'current_time': current_time_iso,
'group_id': group_id 'end_user_id': end_user_id
} }
try: try:

View File

@@ -37,7 +37,7 @@ __all__ = [
async def run_hybrid_search( async def run_hybrid_search(
query_text: str, query_text: str,
search_type: str = "hybrid", search_type: str = "hybrid",
group_id: str | None = None, end_user_id: str | None = None,
apply_id: str | None = None, apply_id: str | None = None,
user_id: str | None = None, user_id: str | None = None,
limit: int = 50, limit: int = 50,
@@ -54,7 +54,7 @@ async def run_hybrid_search(
Args: Args:
query_text: 查询文本 query_text: 查询文本
search_type: 搜索类型("hybrid", "keyword", "semantic" search_type: 搜索类型("hybrid", "keyword", "semantic"
group_id: 组ID过滤 end_user_id: 组ID过滤
apply_id: 应用ID过滤 apply_id: 应用ID过滤
user_id: 用户ID过滤 user_id: 用户ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
@@ -104,7 +104,7 @@ async def run_hybrid_search(
# 执行搜索 # 执行搜索
result = await strategy.search( result = await strategy.search(
query_text=query_text, query_text=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include, include=include,
alpha=alpha, alpha=alpha,

View File

@@ -77,7 +77,7 @@
# async def search( # async def search(
# self, # self,
# query_text: str, # query_text: str,
# group_id: Optional[str] = None, # end_user_id: Optional[str] = None,
# limit: int = 50, # limit: int = 50,
# include: Optional[List[str]] = None, # include: Optional[List[str]] = None,
# **kwargs # **kwargs
@@ -86,7 +86,7 @@
# Args: # Args:
# query_text: 查询文本 # query_text: 查询文本
# group_id: 可选的组ID过滤 # end_user_id: 可选的组ID过滤
# limit: 每个类别的最大结果数 # limit: 每个类别的最大结果数
# include: 要包含的搜索类别列表 # include: 要包含的搜索类别列表
# **kwargs: 其他搜索参数如alpha, use_forgetting_curve # **kwargs: 其他搜索参数如alpha, use_forgetting_curve
@@ -94,7 +94,7 @@
# Returns: # Returns:
# SearchResult: 搜索结果对象 # SearchResult: 搜索结果对象
# """ # """
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}") # logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# # 从kwargs中获取参数 # # 从kwargs中获取参数
# alpha = kwargs.get("alpha", self.alpha) # alpha = kwargs.get("alpha", self.alpha)
@@ -107,14 +107,14 @@
# # 并行执行关键词搜索和语义搜索 # # 并行执行关键词搜索和语义搜索
# keyword_result = await self.keyword_strategy.search( # keyword_result = await self.keyword_strategy.search(
# query_text=query_text, # query_text=query_text,
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# include=include_list # include=include_list
# ) # )
# semantic_result = await self.semantic_strategy.search( # semantic_result = await self.semantic_strategy.search(
# query_text=query_text, # query_text=query_text,
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# include=include_list # include=include_list
# ) # )
@@ -139,7 +139,7 @@
# metadata = self._create_metadata( # metadata = self._create_metadata(
# query_text=query_text, # query_text=query_text,
# search_type="hybrid", # search_type="hybrid",
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# include=include_list, # include=include_list,
# alpha=alpha, # alpha=alpha,
@@ -165,7 +165,7 @@
# metadata=self._create_metadata( # metadata=self._create_metadata(
# query_text=query_text, # query_text=query_text,
# search_type="hybrid", # search_type="hybrid",
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# error=str(e) # error=str(e)
# ) # )

View File

@@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy):
async def search( async def search(
self, self,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
**kwargs **kwargs
@@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy):
Args: Args:
query_text: 查询文本 query_text: 查询文本
group_id: 可选的组ID过滤 end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
include: 要包含的搜索类别列表 include: 要包含的搜索类别列表
**kwargs: 其他搜索参数 **kwargs: 其他搜索参数
@@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy):
Returns: Returns:
SearchResult: 搜索结果对象 SearchResult: 搜索结果对象
""" """
logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}") logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别 # 获取有效的搜索类别
include_list = self._get_include_list(include) include_list = self._get_include_list(include)
@@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy):
results_dict = await search_graph( results_dict = await search_graph(
connector=self.connector, connector=self.connector,
q=query_text, q=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata = self._create_metadata( metadata = self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="keyword", search_type="keyword",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata=self._create_metadata( metadata=self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="keyword", search_type="keyword",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
error=str(e) error=str(e)
) )

View File

@@ -58,7 +58,7 @@ class SearchStrategy(ABC):
async def search( async def search(
self, self,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
**kwargs **kwargs
@@ -67,7 +67,7 @@ class SearchStrategy(ABC):
Args: Args:
query_text: 查询文本 query_text: 查询文本
group_id: 可选的组ID过滤 end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
include: 要包含的搜索类别列表statements, chunks, entities, summaries include: 要包含的搜索类别列表statements, chunks, entities, summaries
**kwargs: 其他搜索参数 **kwargs: 其他搜索参数
@@ -81,7 +81,7 @@ class SearchStrategy(ABC):
self, self,
query_text: str, query_text: str,
search_type: str, search_type: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
**kwargs **kwargs
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -90,7 +90,7 @@ class SearchStrategy(ABC):
Args: Args:
query_text: 查询文本 query_text: 查询文本
search_type: 搜索类型 search_type: 搜索类型
group_id: 组ID end_user_id: 组ID
limit: 结果限制 limit: 结果限制
**kwargs: 其他元数据 **kwargs: 其他元数据
@@ -100,7 +100,7 @@ class SearchStrategy(ABC):
metadata = { metadata = {
"query": query_text, "query": query_text,
"search_type": search_type, "search_type": search_type,
"group_id": group_id, "end_user_id": end_user_id,
"limit": limit, "limit": limit,
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat()
} }

View File

@@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy):
async def search( async def search(
self, self,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
**kwargs **kwargs
@@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy):
Args: Args:
query_text: 查询文本 query_text: 查询文本
group_id: 可选的组ID过滤 end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
include: 要包含的搜索类别列表 include: 要包含的搜索类别列表
**kwargs: 其他搜索参数 **kwargs: 其他搜索参数
@@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy):
Returns: Returns:
SearchResult: 搜索结果对象 SearchResult: 搜索结果对象
""" """
logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}") logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别 # 获取有效的搜索类别
include_list = self._get_include_list(include) include_list = self._get_include_list(include)
@@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy):
connector=self.connector, connector=self.connector,
embedder_client=self.embedder_client, embedder_client=self.embedder_client,
query_text=query_text, query_text=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata = self._create_metadata( metadata = self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="semantic", search_type="semantic",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata=self._create_metadata( metadata=self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="semantic", search_type="semantic",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
error=str(e) error=str(e)
) )

View File

@@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]:
target_keys = [ target_keys = [
"id", "id",
"statement", "statement",
"group_id", "end_user_id",
"chunk_id", "chunk_id",
"created_at", "created_at",
"expired_at", "expired_at",
@@ -75,7 +75,7 @@ async def get_data(result):
""" """
EXCLUDE_FIELDS = { EXCLUDE_FIELDS = {
"user_id", "user_id",
"group_id", "end_user_id",
"entity_type", "entity_type",
"connect_strength", "connect_strength",
"relationship_type", "relationship_type",

View File

@@ -62,7 +62,7 @@ class ConfigAuditLogger:
self, self,
config_id: str, config_id: str,
user_id: Optional[str] = None, user_id: Optional[str] = None,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
success: bool = True, success: bool = True,
details: Optional[Dict[str, Any]] = None details: Optional[Dict[str, Any]] = None
): ):
@@ -72,14 +72,14 @@ class ConfigAuditLogger:
Args: Args:
config_id: 配置 ID config_id: 配置 ID
user_id: 用户 ID可选 user_id: 用户 ID可选
group_id: 组 ID可选 end_user_id: 组 ID可选
success: 是否成功 success: 是否成功
details: 详细信息(可选) details: 详细信息(可选)
""" """
result = "SUCCESS" if success else "FAILED" result = "SUCCESS" if success else "FAILED"
msg = ( msg = (
f"CONFIG_LOAD config_id={config_id} " f"CONFIG_LOAD config_id={config_id} "
f"user={user_id or 'N/A'} group={group_id or 'N/A'} " f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} "
f"result={result}" f"result={result}"
) )
if details: if details:
@@ -121,7 +121,7 @@ class ConfigAuditLogger:
self, self,
operation: str, operation: str,
config_id: str, config_id: str,
group_id: str, end_user_id: str,
success: bool = True, success: bool = True,
duration: Optional[float] = None, duration: Optional[float] = None,
error: Optional[str] = None, error: Optional[str] = None,
@@ -133,7 +133,7 @@ class ConfigAuditLogger:
Args: Args:
operation: 操作类型WRITE, READ 等) operation: 操作类型WRITE, READ 等)
config_id: 配置 ID config_id: 配置 ID
group_id: 组 ID end_user_id: 组 ID
success: 是否成功 success: 是否成功
duration: 操作耗时(秒) duration: 操作耗时(秒)
error: 错误信息(可选) error: 错误信息(可选)
@@ -142,7 +142,7 @@ class ConfigAuditLogger:
result = "SUCCESS" if success else "FAILED" result = "SUCCESS" if success else "FAILED"
msg = ( msg = (
f"{operation.upper()} config_id={config_id} " f"{operation.upper()} config_id={config_id} "
f"group={group_id} result={result}" f"group={end_user_id} result={result}"
) )
if duration is not None: if duration is not None:
msg += f" duration={duration:.2f}s" msg += f" duration={duration:.2f}s"

View File

@@ -4,7 +4,7 @@ from enum import StrEnum, auto
class Field(StrEnum): class Field(StrEnum):
CONTENT_KEY = "page_content" CONTENT_KEY = "page_content"
METADATA_KEY = "metadata" METADATA_KEY = "metadata"
GROUP_KEY = "group_id" GROUP_KEY = "end_user_id"
VECTOR = auto() VECTOR = auto()
# Sparse Vector aims to support full text search # Sparse Vector aims to support full text search
SPARSE_VECTOR = auto() SPARSE_VECTOR = auto()

View File

@@ -32,7 +32,7 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect
"id": stable_edge_id, "id": stable_edge_id,
"source": chunk.id, "source": chunk.id,
"target": stmt.id, "target": stmt.id,
"group_id": getattr(stmt, 'group_id', None), "end_user_id": getattr(stmt, 'end_user_id', None),
"user_id":getattr(stmt, 'user_id', None), "user_id":getattr(stmt, 'user_id', None),
"apply_id": getattr(stmt, 'apply_id', None), "apply_id": getattr(stmt, 'apply_id', None),
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None), "run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
@@ -83,7 +83,7 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
edges.append({ edges.append({
"summary_id": s.id, "summary_id": s.id,
"chunk_id": chunk_id, "chunk_id": chunk_id,
"group_id": s.group_id, "end_user_id": s.end_user_id,
"run_id": s.run_id, "run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None, "created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None, "expired_at": s.expired_at.isoformat() if s.expired_at else None,

View File

@@ -6,10 +6,10 @@ from app.core.memory.models.graph_models import DialogueNode, StatementNode, Chu
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def delete_all_nodes(group_id: str, connector: Neo4jConnector): async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database.""" """Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n") result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
print(f"All group_id: {group_id} node and edge deleted successfully") print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
return result return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
@@ -32,9 +32,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
for dialogue in dialogues: for dialogue in dialogues:
flattened_dialogues.append({ flattened_dialogues.append({
"id": dialogue.id, "id": dialogue.id,
"group_id": dialogue.group_id, "end_user_id": dialogue.end_user_id,
"user_id": dialogue.user_id,
"apply_id": dialogue.apply_id,
"run_id": dialogue.run_id, "run_id": dialogue.run_id,
"ref_id": dialogue.ref_id, "ref_id": dialogue.ref_id,
"name": dialogue.name, "name": dialogue.name,
@@ -79,9 +77,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
flattened_statement = { flattened_statement = {
"id": statement.id, "id": statement.id,
"name": statement.name, "name": statement.name,
"group_id": statement.group_id, "end_user_id": statement.end_user_id,
"user_id": statement.user_id,
"apply_id": statement.apply_id,
"run_id": statement.run_id, "run_id": statement.run_id,
"chunk_id": statement.chunk_id, "chunk_id": statement.chunk_id,
# "created_at": statement.created_at.isoformat(), # "created_at": statement.created_at.isoformat(),
@@ -154,9 +150,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
flattened_chunk = { flattened_chunk = {
"id": chunk.id, "id": chunk.id,
"name": chunk.name, "name": chunk.name,
"group_id": chunk.group_id, "end_user_id": chunk.end_user_id,
"user_id": chunk.user_id,
"apply_id": chunk.apply_id,
"run_id": chunk.run_id, "run_id": chunk.run_id,
"created_at": chunk.created_at.isoformat() if chunk.created_at else None, "created_at": chunk.created_at.isoformat() if chunk.created_at else None,
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None, "expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
@@ -206,9 +200,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
flattened.append({ flattened.append({
"id": s.id, "id": s.id,
"name": s.name, "name": s.name,
"group_id": s.group_id, "end_user_id": s.end_user_id,
"user_id": s.user_id,
"apply_id": s.apply_id,
"run_id": s.run_id, "run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None, "created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None, "expired_at": s.expired_at.isoformat() if s.expired_at else None,

View File

@@ -152,7 +152,7 @@ class BaseNeo4jRepository(BaseRepository[T]):
Example: Example:
>>> results = await repository.find( >>> results = await repository.find(
... {"group_id": "group_123", "user_id": "user_456"}, ... {"end_user_id": "group_123", "user_id": "user_456"},
... limit=50 ... limit=50
... ) ... )
""" """

View File

@@ -3,9 +3,7 @@ DIALOGUE_NODE_SAVE = """
UNWIND $dialogues AS dialogue UNWIND $dialogues AS dialogue
MERGE (n:Dialogue {id: dialogue.id}) MERGE (n:Dialogue {id: dialogue.id})
SET n.uuid = coalesce(n.uuid, dialogue.id), SET n.uuid = coalesce(n.uuid, dialogue.id),
n.group_id = dialogue.group_id, n.end_user_id = dialogue.end_user_id,
n.user_id = dialogue.user_id,
n.apply_id = dialogue.apply_id,
n.run_id = dialogue.run_id, n.run_id = dialogue.run_id,
n.ref_id = dialogue.ref_id, n.ref_id = dialogue.ref_id,
n.created_at = dialogue.created_at, n.created_at = dialogue.created_at,
@@ -22,9 +20,7 @@ SET s += {
id: statement.id, id: statement.id,
run_id: statement.run_id, run_id: statement.run_id,
chunk_id: statement.chunk_id, chunk_id: statement.chunk_id,
group_id: statement.group_id, end_user_id: statement.end_user_id,
user_id: statement.user_id,
apply_id: statement.apply_id,
stmt_type: statement.stmt_type, stmt_type: statement.stmt_type,
statement: statement.statement, statement: statement.statement,
emotion_intensity: statement.emotion_intensity, emotion_intensity: statement.emotion_intensity,
@@ -54,9 +50,7 @@ MERGE (c:Chunk {id: chunk.id})
SET c += { SET c += {
id: chunk.id, id: chunk.id,
name: chunk.name, name: chunk.name,
group_id: chunk.group_id, end_user_id: chunk.end_user_id,
user_id: chunk.user_id,
apply_id: chunk.apply_id,
run_id: chunk.run_id, run_id: chunk.run_id,
created_at: chunk.created_at, created_at: chunk.created_at,
expired_at: chunk.expired_at, expired_at: chunk.expired_at,
@@ -76,9 +70,7 @@ EXTRACTED_ENTITY_NODE_SAVE = """
UNWIND $entities AS entity UNWIND $entities AS entity
MERGE (e:ExtractedEntity {id: entity.id}) MERGE (e:ExtractedEntity {id: entity.id})
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END, SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END, e.end_user_id = CASE WHEN entity.end_user_id IS NOT NULL AND entity.end_user_id <> '' THEN entity.end_user_id ELSE e.end_user_id END,
e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END,
e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END,
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END, e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
e.created_at = CASE e.created_at = CASE
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at) WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
@@ -134,9 +126,9 @@ RETURN e.id AS uuid
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships # Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
ENTITY_RELATIONSHIP_SAVE = """ ENTITY_RELATIONSHIP_SAVE = """
UNWIND $relationships AS rel UNWIND $relationships AS rel
// Match entities by stable id within group, do not constrain by run_id // Match entities by stable id within end_user_id, do not constrain by run_id
MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id}) MATCH (subject:ExtractedEntity {id: rel.source_id, end_user_id: rel.end_user_id})
MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id}) MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id})
// Avoid duplicate edges across runs for the same endpoints // Avoid duplicate edges across runs for the same endpoints
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object) MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
SET r.predicate = rel.predicate, SET r.predicate = rel.predicate,
@@ -148,7 +140,7 @@ SET r.predicate = rel.predicate,
r.created_at = rel.created_at, r.created_at = rel.created_at,
r.expired_at = rel.expired_at, r.expired_at = rel.expired_at,
r.run_id = rel.run_id, r.run_id = rel.run_id,
r.group_id = rel.group_id r.end_user_id = rel.end_user_id
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
@@ -160,7 +152,7 @@ UNWIND $weak_entities AS entity
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id}) MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
SET e += { SET e += {
name: entity.name, name: entity.name,
group_id: entity.group_id, end_user_id: entity.end_user_id,
run_id: entity.run_id, run_id: entity.run_id,
description: entity.description, description: entity.description,
chunk_id: entity.chunk_id, chunk_id: entity.chunk_id,
@@ -175,11 +167,11 @@ RETURN e.id AS id
SAVE_STRONG_TRIPLE_ENTITIES = """ SAVE_STRONG_TRIPLE_ENTITIES = """
UNWIND $items AS item UNWIND $items AS item
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id}) MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id} SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag // Independent strong flag
SET s.is_strong = true SET s.is_strong = true
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id}) MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id} SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag // Independent strong flag
SET o.is_strong = true SET o.is_strong = true
""" """
@@ -194,7 +186,7 @@ DIALOGUE_STATEMENT_EDGE_SAVE = """
// 仅按端点去重,关系属性可更新 // 仅按端点去重,关系属性可更新
MERGE (dialogue)-[e:MENTIONS]->(statement) MERGE (dialogue)-[e:MENTIONS]->(statement)
SET e.uuid = edge.id, SET e.uuid = edge.id,
e.group_id = edge.group_id, e.end_user_id = edge.end_user_id,
e.created_at = edge.created_at, e.created_at = edge.created_at,
e.expired_at = edge.expired_at e.expired_at = edge.expired_at
RETURN e.uuid AS uuid RETURN e.uuid AS uuid
@@ -208,7 +200,7 @@ CHUNK_STATEMENT_EDGE_SAVE = """
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id}) MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id}) MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement) MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
SET e.group_id = edge.group_id, SET e.end_user_id = edge.end_user_id,
e.run_id = edge.run_id, e.run_id = edge.run_id,
e.created_at = edge.created_at, e.created_at = edge.created_at,
e.expired_at = edge.expired_at e.expired_at = edge.expired_at
@@ -218,13 +210,12 @@ CHUNK_STATEMENT_EDGE_SAVE = """
STATEMENT_ENTITY_EDGE_SAVE = """ STATEMENT_ENTITY_EDGE_SAVE = """
UNWIND $relationships AS rel UNWIND $relationships AS rel
// Statement nodes are per-run; keep run_id constraint on statements // Statement nodes are per-run; keep run_id constraint on statements
// Statement nodes are per-run; keep run_id constraint on statements
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id}) MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
// Entities are shared across runs within a group; do not constrain by run_id // Entities are shared across runs within end_user_id; do not constrain by run_id
MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id}) MATCH (entity:ExtractedEntity {id: rel.target, end_user_id: rel.end_user_id})
// Avoid duplicate edges across runs for same endpoints // Avoid duplicate edges across runs for same endpoints
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity) MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
SET r.group_id = rel.group_id, SET r.end_user_id = rel.end_user_id,
r.run_id = rel.run_id, r.run_id = rel.run_id,
r.created_at = rel.created_at, r.created_at = rel.created_at,
r.expired_at = rel.expired_at, r.expired_at = rel.expired_at,
@@ -236,10 +227,10 @@ ENTITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding) CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
YIELD node AS e, score YIELD node AS e, score
WHERE e.name_embedding IS NOT NULL WHERE e.name_embedding IS NOT NULL
AND ($group_id IS NULL OR e.group_id = $group_id) AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, RETURN e.id AS id,
e.name AS name, e.name AS name,
e.group_id AS group_id, e.end_user_id AS end_user_id,
e.entity_type AS entity_type, e.entity_type AS entity_type,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value, COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score, COALESCE(e.importance_score, 0.5) AS importance_score,
@@ -254,10 +245,10 @@ STATEMENT_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding) CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
YIELD node AS s, score YIELD node AS s, score
WHERE s.statement_embedding IS NOT NULL WHERE s.statement_embedding IS NOT NULL
AND ($group_id IS NULL OR s.group_id = $group_id) AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
RETURN s.id AS id, RETURN s.id AS id,
s.statement AS statement, s.statement AS statement,
s.group_id AS group_id, s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id, s.chunk_id AS chunk_id,
s.created_at AS created_at, s.created_at AS created_at,
s.expired_at AS expired_at, s.expired_at AS expired_at,
@@ -277,9 +268,9 @@ CHUNK_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding) CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score YIELD node AS c, score
WHERE c.chunk_embedding IS NOT NULL WHERE c.chunk_embedding IS NOT NULL
AND ($group_id IS NULL OR c.group_id = $group_id) AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.id AS chunk_id, RETURN c.id AS chunk_id,
c.group_id AS group_id, c.end_user_id AS end_user_id,
c.content AS content, c.content AS content,
c.dialog_id AS dialog_id, c.dialog_id AS dialog_id,
COALESCE(c.activation_value, 0.5) AS activation_value, COALESCE(c.activation_value, 0.5) AS activation_value,
@@ -292,12 +283,12 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_KEYWORD = """ SEARCH_STATEMENTS_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id) WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id, RETURN s.id AS id,
s.statement AS statement, s.statement AS statement,
s.group_id AS group_id, s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id, s.chunk_id AS chunk_id,
s.created_at AS created_at, s.created_at AS created_at,
s.expired_at AS expired_at, s.expired_at AS expired_at,
@@ -316,15 +307,13 @@ LIMIT $limit
# 查询实体名称包含指定字符串的实体 # 查询实体名称包含指定字符串的实体
SEARCH_ENTITIES_BY_NAME = """ SEARCH_ENTITIES_BY_NAME = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
WHERE ($group_id IS NULL OR e.group_id = $group_id) WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
RETURN e.id AS id, RETURN e.id AS id,
e.name AS name, e.name AS name,
e.group_id AS group_id, e.end_user_id AS end_user_id,
e.entity_type AS entity_type, e.entity_type AS entity_type,
e.apply_id AS apply_id,
e.user_id AS user_id,
e.created_at AS created_at, e.created_at AS created_at,
e.expired_at AS expired_at, e.expired_at AS expired_at,
e.entity_idx AS entity_idx, e.entity_idx AS entity_idx,
@@ -347,11 +336,11 @@ LIMIT $limit
SEARCH_CHUNKS_BY_CONTENT = """ SEARCH_CHUNKS_BY_CONTENT = """
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
WHERE ($group_id IS NULL OR c.group_id = $group_id) WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement) OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN c.id AS chunk_id, RETURN c.id AS chunk_id,
c.group_id AS group_id, c.end_user_id AS end_user_id,
c.content AS content, c.content AS content,
c.dialog_id AS dialog_id, c.dialog_id AS dialog_id,
c.sequence_number AS sequence_number, c.sequence_number AS sequence_number,
@@ -413,10 +402,10 @@ LIMIT $limit
SEARCH_DIALOGUE_BY_DIALOG_ID = """ SEARCH_DIALOGUE_BY_DIALOG_ID = """
MATCH (d:Dialogue) MATCH (d:Dialogue)
WHERE ($group_id IS NULL OR d.group_id = $group_id) WHERE ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
AND d.id = $dialog_id AND d.id = $dialog_id
RETURN d.id AS dialog_id, RETURN d.id AS dialog_id,
d.group_id AS group_id, d.end_user_id AS end_user_id,
d.content AS content, d.content AS content,
d.created_at AS created_at, d.created_at AS created_at,
d.expired_at AS expired_at d.expired_at AS expired_at
@@ -426,10 +415,10 @@ LIMIT $limit
SEARCH_CHUNK_BY_CHUNK_ID = """ SEARCH_CHUNK_BY_CHUNK_ID = """
MATCH (c:Chunk) MATCH (c:Chunk)
WHERE ($group_id IS NULL OR c.group_id = $group_id) WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
AND c.id = $chunk_id AND c.id = $chunk_id
RETURN c.id AS chunk_id, RETURN c.id AS chunk_id,
c.group_id AS group_id, c.end_user_id AS end_user_id,
c.content AS content, c.content AS content,
c.dialog_id AS dialog_id, c.dialog_id AS dialog_id,
c.created_at AS created_at, c.created_at AS created_at,
@@ -441,18 +430,14 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_TEMPORAL = """ SEARCH_STATEMENTS_BY_TEMPORAL = """
MATCH (s:Statement) MATCH (s:Statement)
WHERE ($group_id IS NULL OR s.group_id = $group_id) WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date)) AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date))) AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date))) OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date))))) AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
RETURN s.id AS id, RETURN s.id AS id,
s.statement AS statement, s.statement AS statement,
s.group_id AS group_id, s.end_user_id AS end_user_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.chunk_id AS chunk_id, s.chunk_id AS chunk_id,
s.created_at AS created_at, s.created_at AS created_at,
s.valid_at AS valid_at, s.valid_at AS valid_at,
@@ -468,9 +453,7 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """ SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id) WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date))) AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date)))) AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date))) OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
@@ -479,9 +462,7 @@ OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id, RETURN s.id AS id,
s.statement AS statement, s.statement AS statement,
s.group_id AS group_id, s.end_user_id AS end_user_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.chunk_id AS chunk_id, s.chunk_id AS chunk_id,
s.created_at AS created_at, s.created_at AS created_at,
s.valid_at AS valid_at, s.valid_at AS valid_at,
@@ -499,15 +480,11 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_CREATED_AT = """ SEARCH_STATEMENTS_BY_CREATED_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at)) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -519,15 +496,11 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_VALID_AT = """ SEARCH_STATEMENTS_BY_VALID_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at)) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -539,15 +512,11 @@ LIMIT $limit
SEARCH_STATEMENTS_G_CREATED_AT = """ SEARCH_STATEMENTS_G_CREATED_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at)) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -559,15 +528,11 @@ LIMIT $limit
SEARCH_STATEMENTS_L_CREATED_AT = """ SEARCH_STATEMENTS_L_CREATED_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at)) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -579,15 +544,11 @@ LIMIT $limit
SEARCH_STATEMENTS_G_VALID_AT = """ SEARCH_STATEMENTS_G_VALID_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at)) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -599,15 +560,11 @@ LIMIT $limit
SEARCH_STATEMENTS_L_VALID_AT = """ SEARCH_STATEMENTS_L_VALID_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at)) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -665,18 +622,18 @@ LIMIT $limit
# 根据id修改句子的invalid_at的值 # 根据id修改句子的invalid_at的值
UPDATE_STATEMENT_INVALID_AT = """ UPDATE_STATEMENT_INVALID_AT = """
MATCH (n:Statement {group_id: $group_id, id: $id}) MATCH (n:Statement {end_user_id: $end_user_id, id: $id})
SET n.invalid_at = $new_invalid_at SET n.invalid_at = $new_invalid_at
""" """
# MemorySummary keyword search using fulltext index # MemorySummary keyword search using fulltext index
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """ SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
WHERE ($group_id IS NULL OR m.group_id = $group_id) WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement) OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
RETURN m.id AS id, RETURN m.id AS id,
m.name AS name, m.name AS name,
m.group_id AS group_id, m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id, m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids, m.chunk_ids AS chunk_ids,
m.content AS content, m.content AS content,
@@ -695,10 +652,10 @@ MEMORY_SUMMARY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding) CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
YIELD node AS m, score YIELD node AS m, score
WHERE m.summary_embedding IS NOT NULL WHERE m.summary_embedding IS NOT NULL
AND ($group_id IS NULL OR m.group_id = $group_id) AND ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
RETURN m.id AS id, RETURN m.id AS id,
m.name AS name, m.name AS name,
m.group_id AS group_id, m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id, m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids, m.chunk_ids AS chunk_ids,
m.content AS content, m.content AS content,
@@ -718,9 +675,7 @@ MERGE (m:MemorySummary {id: summary.id})
SET m += { SET m += {
id: summary.id, id: summary.id,
name: summary.name, name: summary.name,
group_id: summary.group_id, end_user_id: summary.end_user_id,
user_id: summary.user_id,
apply_id: summary.apply_id,
run_id: summary.run_id, run_id: summary.run_id,
created_at: summary.created_at, created_at: summary.created_at,
expired_at: summary.expired_at, expired_at: summary.expired_at,
@@ -814,7 +769,7 @@ RETURN count(losing) as deleted
neo4j_statement_part = ''' neo4j_statement_part = '''
MATCH (n:Statement) MATCH (n:Statement)
WHERE n.group_id = "{}" WHERE n.end_user_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D') AND datetime(n.created_at) >= datetime() - duration('P3D')
RETURN RETURN
n.statement as statement_name, n.statement as statement_name,
@@ -824,7 +779,7 @@ RETURN
''' '''
neo4j_statement_all = ''' neo4j_statement_all = '''
MATCH (n:Statement) MATCH (n:Statement)
WHERE n.group_id = "{}" WHERE n.end_user_id = "{}"
RETURN RETURN
n.statement as statement_name, n.statement as statement_name,
n.id as statement_id n.id as statement_id
@@ -832,7 +787,7 @@ RETURN
''' '''
neo4j_query_part = """ neo4j_query_part = """
MATCH (n)-[r]-(m:ExtractedEntity) MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}" WHERE n.end_user_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D') AND datetime(n.created_at) >= datetime() - duration('P3D')
WITH DISTINCT m WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
@@ -853,7 +808,7 @@ neo4j_query_part = """
""" """
neo4j_query_all = """ neo4j_query_all = """
MATCH (n)-[r]-(m:ExtractedEntity) MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}" WHERE n.end_user_id = "{}"
WITH DISTINCT m WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
RETURN RETURN
@@ -1027,14 +982,14 @@ RETURN DISTINCT
Memory_Space_User=""" Memory_Space_User="""
MATCH (n)-[r]->(m) MATCH (n)-[r]->(m)
WHERE n.group_id = $group_id AND m.name="用户" WHERE n.end_user_id = $end_user_id AND m.name="用户"
return DISTINCT elementId(m) as id return DISTINCT elementId(m) as id
""" """
Memory_Space_Entity=""" Memory_Space_Entity="""
MATCH (n)-[]-(m) MATCH (n)-[]-(m)
WHERE elementId(m) = $id AND m.entity_type = "Person" WHERE elementId(m) = $id AND m.entity_type = "Person"
RETURN RETURN
DISTINCT m.name as name,m.group_id as group_id DISTINCT m.name as name,m.end_user_id as end_user_id
""" """
Memory_Space_Associative=""" Memory_Space_Associative="""
MATCH (u)-[]-(x)-[]-(h) MATCH (u)-[]-(x)-[]-(h)

View File

@@ -19,7 +19,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""对话仓储 """对话仓储
管理对话节点的创建、查询、更新和删除操作。 管理对话节点的创建、查询、更新和删除操作。
提供按group_id、user_id、ref_id等条件查询对话的方法。 提供按end_user_id、user_id、ref_id等条件查询对话的方法。
Attributes: Attributes:
connector: Neo4j连接器实例 connector: Neo4j连接器实例
@@ -54,17 +54,17 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
return DialogueNode(**n) return DialogueNode(**n)
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]: async def find_by_end_user_id(self, end_user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据group_id查询对话 """根据end_user_id查询对话
Args: Args:
group_id: 组ID end_user_id: 组ID
limit: 返回结果的最大数量 limit: 返回结果的最大数量
Returns: Returns:
List[DialogueNode]: 对话列表 List[DialogueNode]: 对话列表
""" """
return await self.find({"group_id": group_id}, limit=limit) return await self.find({"end_user_id": end_user_id}, limit=limit)
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]: async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据user_id查询对话 """根据user_id查询对话
@@ -94,14 +94,14 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
async def find_by_group_and_user( async def find_by_group_and_user(
self, self,
group_id: str, end_user_id: str,
user_id: str, user_id: str,
limit: int = 100 limit: int = 100
) -> List[DialogueNode]: ) -> List[DialogueNode]:
"""根据group_id和user_id查询对话 """根据end_user_id和user_id查询对话
Args: Args:
group_id: 组ID end_user_id: 组ID
user_id: 用户ID user_id: 用户ID
limit: 返回结果的最大数量 limit: 返回结果的最大数量
@@ -109,20 +109,20 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
List[DialogueNode]: 对话列表 List[DialogueNode]: 对话列表
""" """
return await self.find( return await self.find(
{"group_id": group_id, "user_id": user_id}, {"end_user_id": end_user_id, "user_id": user_id},
limit=limit limit=limit
) )
async def find_recent_dialogs( async def find_recent_dialogs(
self, self,
group_id: str, end_user_id: str,
days: int = 7, days: int = 7,
limit: int = 100 limit: int = 100
) -> List[DialogueNode]: ) -> List[DialogueNode]:
"""查询最近的对话 """查询最近的对话
Args: Args:
group_id: 组ID end_user_id: 组ID
days: 查询最近多少天的对话 days: 查询最近多少天的对话
limit: 返回结果的最大数量 limit: 返回结果的最大数量
@@ -131,7 +131,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
""" """
query = f""" query = f"""
MATCH (n:{self.node_label}) MATCH (n:{self.node_label})
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
AND n.created_at >= datetime() - duration({{days: $days}}) AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n RETURN n
ORDER BY n.created_at DESC ORDER BY n.created_at DESC
@@ -139,7 +139,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
""" """
results = await self.connector.execute_query( results = await self.connector.execute_query(
query, query,
group_id=group_id, end_user_id=end_user_id,
days=days, days=days,
limit=limit limit=limit
) )
@@ -164,16 +164,16 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
async def find_by_config_and_group( async def find_by_config_and_group(
self, self,
config_id: str, config_id: str,
group_id: str, end_user_id: str,
limit: int = 100 limit: int = 100
) -> List[DialogueNode]: ) -> List[DialogueNode]:
"""根据config_id和group_id查询对话 """根据config_id和end_user_id查询对话
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。 支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
Args: Args:
config_id: 配置ID config_id: 配置ID
group_id: 组ID end_user_id: 组ID
limit: 返回结果的最大数量 limit: 返回结果的最大数量
Returns: Returns:

View File

@@ -40,7 +40,7 @@ class EmotionRepository:
async def get_emotion_tags( async def get_emotion_tags(
self, self,
group_id: str, end_user_id: str,
emotion_type: Optional[str] = None, emotion_type: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
@@ -51,7 +51,7 @@ class EmotionRepository:
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。 查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
Args: Args:
group_id: 用户组ID宿主ID end_user_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤joy/sadness/anger/fear/surprise/neutral emotion_type: 可选的情绪类型过滤joy/sadness/anger/fear/surprise/neutral
start_date: 可选的开始日期ISO格式字符串 start_date: 可选的开始日期ISO格式字符串
end_date: 可选的结束日期ISO格式字符串 end_date: 可选的结束日期ISO格式字符串
@@ -65,8 +65,8 @@ class EmotionRepository:
- avg_intensity: 平均强度 - avg_intensity: 平均强度
""" """
# 构建查询条件 # 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"] where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_type IS NOT NULL"]
params = {"group_id": group_id, "limit": limit} params = {"end_user_id": end_user_id, "limit": limit}
if emotion_type: if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type") where_clauses.append("s.emotion_type = $emotion_type")
@@ -119,7 +119,7 @@ class EmotionRepository:
async def get_emotion_wordcloud( async def get_emotion_wordcloud(
self, self,
group_id: str, end_user_id: str,
emotion_type: Optional[str] = None, emotion_type: Optional[str] = None,
limit: int = 50 limit: int = 50
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
@@ -128,7 +128,7 @@ class EmotionRepository:
查询情绪关键词及其频率,用于生成词云可视化。 查询情绪关键词及其频率,用于生成词云可视化。
Args: Args:
group_id: 用户组ID宿主ID end_user_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤 emotion_type: 可选的情绪类型过滤
limit: 返回关键词的最大数量 limit: 返回关键词的最大数量
@@ -140,8 +140,8 @@ class EmotionRepository:
- avg_intensity: 平均强度 - avg_intensity: 平均强度
""" """
# 构建查询条件 # 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"] where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_keywords IS NOT NULL"]
params = {"group_id": group_id, "limit": limit} params = {"end_user_id": end_user_id, "limit": limit}
if emotion_type: if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type") where_clauses.append("s.emotion_type = $emotion_type")
@@ -186,7 +186,7 @@ class EmotionRepository:
async def get_emotions_in_range( async def get_emotions_in_range(
self, self,
group_id: str, end_user_id: str,
time_range: str = "30d" time_range: str = "30d"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""获取时间范围内的情绪数据 """获取时间范围内的情绪数据
@@ -194,7 +194,7 @@ class EmotionRepository:
查询指定时间范围内的所有情绪数据,用于健康指数计算。 查询指定时间范围内的所有情绪数据,用于健康指数计算。
Args: Args:
group_id: 用户组ID宿主ID end_user_id: 用户组ID宿主ID
time_range: 时间范围7d/30d/90d time_range: 时间范围7d/30d/90d
Returns: Returns:
@@ -214,7 +214,7 @@ class EmotionRepository:
# 优化的 Cypher 查询:使用字符串比较避免时区问题 # 优化的 Cypher 查询:使用字符串比较避免时区问题
query = """ query = """
MATCH (s:Statement) MATCH (s:Statement)
WHERE s.group_id = $group_id WHERE s.end_user_id = $end_user_id
AND s.emotion_type IS NOT NULL AND s.emotion_type IS NOT NULL
AND s.created_at >= $start_date AND s.created_at >= $start_date
RETURN s.id as statement_id, RETURN s.id as statement_id,

View File

@@ -44,9 +44,7 @@ async def save_entities_and_relationships(
'created_at': edge.created_at.isoformat(), 'created_at': edge.created_at.isoformat(),
'expired_at': edge.expired_at.isoformat(), 'expired_at': edge.expired_at.isoformat(),
'run_id': edge.run_id, 'run_id': edge.run_id,
'group_id': edge.group_id, 'end_user_id': edge.end_user_id,
'user_id': edge.user_id,
'apply_id': edge.apply_id,
} }
all_relationships.append(relationship) all_relationships.append(relationship)
@@ -101,9 +99,7 @@ async def save_statement_chunk_edges(
"id": edge.id, "id": edge.id,
"source": edge.source, "source": edge.source,
"target": edge.target, "target": edge.target,
"group_id": edge.group_id, "end_user_id": edge.end_user_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"run_id": edge.run_id, "run_id": edge.run_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None, "created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None, "expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
@@ -132,9 +128,7 @@ async def save_statement_entity_edges(
edge_data = { edge_data = {
"source": edge.source, "source": edge.source,
"target": edge.target, "target": edge.target,
"group_id": edge.group_id, "end_user_id": edge.end_user_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"run_id": edge.run_id, "run_id": edge.run_id,
"connect_strength": edge.connect_strength, "connect_strength": edge.connect_strength,
"created_at": edge.created_at.isoformat() if edge.created_at else None, "created_at": edge.created_at.isoformat() if edge.created_at else None,

View File

@@ -33,7 +33,7 @@ async def _update_activation_values_batch(
connector: Neo4jConnector, connector: Neo4jConnector,
nodes: List[Dict[str, Any]], nodes: List[Dict[str, Any]],
node_label: str, node_label: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
max_retries: int = 3 max_retries: int = 3
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
@@ -46,7 +46,7 @@ async def _update_activation_values_batch(
connector: Neo4j连接器 connector: Neo4j连接器
nodes: 节点列表,每个节点必须包含 'id' 字段 nodes: 节点列表,每个节点必须包含 'id' 字段
node_label: 节点标签Statement, ExtractedEntity, MemorySummary node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选 end_user_id: 组ID可选
max_retries: 最大重试次数 max_retries: 最大重试次数
Returns: Returns:
@@ -97,7 +97,7 @@ async def _update_activation_values_batch(
updated_nodes = await access_manager.record_batch_access( updated_nodes = await access_manager.record_batch_access(
node_ids=unique_node_ids, node_ids=unique_node_ids,
node_label=node_label, node_label=node_label,
group_id=group_id end_user_id=end_user_id
) )
logger.info( logger.info(
@@ -118,7 +118,7 @@ async def _update_activation_values_batch(
async def _update_search_results_activation( async def _update_search_results_activation(
connector: Neo4jConnector, connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]], results: Dict[str, List[Dict[str, Any]]],
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
更新搜索结果中所有知识节点的激活值 更新搜索结果中所有知识节点的激活值
@@ -129,7 +129,7 @@ async def _update_search_results_activation(
Args: Args:
connector: Neo4j连接器 connector: Neo4j连接器
results: 搜索结果字典,包含不同类型节点的列表 results: 搜索结果字典,包含不同类型节点的列表
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果 Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
@@ -152,7 +152,7 @@ async def _update_search_results_activation(
connector=connector, connector=connector,
nodes=results[key], nodes=results[key],
node_label=label, node_label=label,
group_id=group_id end_user_id=end_user_id
) )
) )
update_keys.append(key) update_keys.append(key)
@@ -218,7 +218,7 @@ async def _update_search_results_activation(
async def search_graph( async def search_graph(
connector: Neo4jConnector, connector: Neo4jConnector,
q: str, q: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: List[str] = None, include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -236,7 +236,7 @@ async def search_graph(
Args: Args:
connector: Neo4j connector connector: Neo4j connector
q: Query text q: Query text
group_id: Optional group filter end_user_id: Optional group filter
limit: Max results per category limit: Max results per category
include: List of categories to search (default: all) include: List of categories to search (default: all)
@@ -254,7 +254,7 @@ async def search_graph(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD, SEARCH_STATEMENTS_BY_KEYWORD,
q=q, q=q,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("statements") task_keys.append("statements")
@@ -263,7 +263,7 @@ async def search_graph(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
q=q, q=q,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("entities") task_keys.append("entities")
@@ -272,7 +272,7 @@ async def search_graph(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT, SEARCH_CHUNKS_BY_CONTENT,
q=q, q=q,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("chunks") task_keys.append("chunks")
@@ -281,7 +281,7 @@ async def search_graph(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
q=q, q=q,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("summaries") task_keys.append("summaries")
@@ -305,19 +305,12 @@ async def search_graph(
results[key] = _deduplicate_results(results[key]) results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary # 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
# Skip activation updates if only searching summaries (optimization) results = await _update_search_results_activation(
needs_activation_update = any( connector=connector,
key in include and key in results and results[key] results=results,
for key in ['statements', 'entities', 'chunks'] end_user_id=end_user_id
) )
if needs_activation_update:
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results return results
@@ -325,7 +318,7 @@ async def search_graph_by_embedding(
connector: Neo4jConnector, connector: Neo4jConnector,
embedder_client, embedder_client,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"], include: List[str] = ["statements", "chunks", "entities","summaries"],
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -337,7 +330,7 @@ async def search_graph_by_embedding(
- Computes query embedding with the provided embedder_client - Computes query embedding with the provided embedder_client
- Ranks by cosine similarity in Cypher - Ranks by cosine similarity in Cypher
- Filters by group_id if provided - Filters by end_user_id if provided
- Returns up to 'limit' per included type - Returns up to 'limit' per included type
""" """
import time import time
@@ -346,7 +339,7 @@ async def search_graph_by_embedding(
embed_start = time.time() embed_start = time.time()
embeddings = await embedder_client.response([query_text]) embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start embed_time = time.time() - embed_start
logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s") print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]: if not embeddings or not embeddings[0]:
return {"statements": [], "chunks": [], "entities": [], "summaries": []} return {"statements": [], "chunks": [], "entities": [], "summaries": []}
@@ -361,7 +354,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH, STATEMENT_EMBEDDING_SEARCH,
embedding=embedding, embedding=embedding,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("statements") task_keys.append("statements")
@@ -371,7 +364,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH, CHUNK_EMBEDDING_SEARCH,
embedding=embedding, embedding=embedding,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("chunks") task_keys.append("chunks")
@@ -381,7 +374,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH,
embedding=embedding, embedding=embedding,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("entities") task_keys.append("entities")
@@ -391,7 +384,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH, MEMORY_SUMMARY_EMBEDDING_SEARCH,
embedding=embedding, embedding=embedding,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("summaries") task_keys.append("summaries")
@@ -400,7 +393,7 @@ async def search_graph_by_embedding(
query_start = time.time() query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start query_time = time.time() - query_start
logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary # Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = { results: Dict[str, List[Dict[str, Any]]] = {
@@ -424,28 +417,19 @@ async def search_graph_by_embedding(
results[key] = _deduplicate_results(results[key]) results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary # 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
# Skip activation updates if only searching summaries (optimization) update_start = time.time()
needs_activation_update = any( results = await _update_search_results_activation(
key in include and key in results and results[key] connector=connector,
for key in ['statements', 'entities', 'chunks'] results=results,
end_user_id=end_user_id
) )
update_time = time.time() - update_start
if needs_activation_update: print(f"[PERF] Activation value updates took: {update_time:.4f}s")
update_start = time.time()
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
update_time = time.time() - update_start
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
else:
logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: str, end_user_id: str,
entities: List[Dict[str, Any]], entities: List[Dict[str, Any]],
use_contains_fallback: bool = True, use_contains_fallback: bool = True,
batch_size: int = 500, batch_size: int = 500,
@@ -453,7 +437,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries 为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选; - 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (end_user_id, name) 检索候选;
- 保留并发控制与返回结构incoming_id -> [db_entity_props...] - 保留并发控制与返回结构incoming_id -> [db_entity_props...]
- 若提供 `entity_type`,在本地对返回结果做类型过滤; - 若提供 `entity_type`,在本地对返回结果做类型过滤;
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。 - `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
@@ -477,7 +461,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
rows = await connector.execute_query( rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
q=name, q=name,
group_id=group_id, end_user_id=end_user_id,
limit=100, limit=100,
) )
except Exception: except Exception:
@@ -501,7 +485,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
rows = await connector.execute_query( rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
q=name.lower(), q=name.lower(),
group_id=group_id, end_user_id=end_user_id,
limit=100, limit=100,
) )
for r in rows: for r in rows:
@@ -532,9 +516,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal( async def search_graph_by_keyword_temporal(
connector: Neo4jConnector, connector: Neo4jConnector,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
@@ -547,32 +529,30 @@ async def search_graph_by_keyword_temporal(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements containing query_text created between start_date and end_date - Matches statements containing query_text created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
if not query_text: if not query_text:
logger.warning(f"query_text cannot be empty") print(f"query_text不能为空")
return {"statements": []} return {"statements": []}
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
q=query_text, q=query_text,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
valid_date=valid_date, valid_date=valid_date,
invalid_date=invalid_date, invalid_date=invalid_date,
limit=limit, limit=limit,
) )
logger.debug(f"Temporal keyword search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
@@ -580,9 +560,7 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal( async def search_graph_by_temporal(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
@@ -595,14 +573,12 @@ async def search_graph_by_temporal(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created between start_date and end_date - Matches statements created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_TEMPORAL, SEARCH_STATEMENTS_BY_TEMPORAL,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
valid_date=valid_date, valid_date=valid_date,
@@ -610,16 +586,16 @@ async def search_graph_by_temporal(
limit=limit, limit=limit,
) )
logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}") print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
logger.debug(f"Temporal search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
@@ -628,23 +604,23 @@ async def search_graph_by_temporal(
async def search_graph_by_dialog_id( async def search_graph_by_dialog_id(
connector: Neo4jConnector, connector: Neo4jConnector,
dialog_id: str, dialog_id: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Dialogues. Temporal search across Dialogues.
- Matches dialogues with dialog_id - Matches dialogues with dialog_id
- Optionally filters by group_id - Optionally filters by end_user_id
- Returns up to 'limit' dialogues - Returns up to 'limit' dialogues
""" """
if not dialog_id: if not dialog_id:
logger.warning(f"dialog_id cannot be empty") print(f"dialog_id不能为空")
return {"dialogues": []} return {"dialogues": []}
dialogues = await connector.execute_query( dialogues = await connector.execute_query(
SEARCH_DIALOGUE_BY_DIALOG_ID, SEARCH_DIALOGUE_BY_DIALOG_ID,
group_id=group_id, end_user_id=end_user_id,
dialog_id=dialog_id, dialog_id=dialog_id,
limit=limit, limit=limit,
) )
@@ -654,15 +630,15 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id( async def search_graph_by_chunk_id(
connector: Neo4jConnector, connector: Neo4jConnector,
chunk_id : str, chunk_id : str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id: if not chunk_id:
logger.warning(f"chunk_id cannot be empty") print(f"chunk_id不能为空")
return {"chunks": []} return {"chunks": []}
chunks = await connector.execute_query( chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
group_id=group_id, end_user_id=end_user_id,
chunk_id=chunk_id, chunk_id=chunk_id,
limit=limit, limit=limit,
) )
@@ -671,9 +647,9 @@ async def search_graph_by_chunk_id(
async def search_graph_by_created_at( async def search_graph_by_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None, created_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -683,37 +659,37 @@ async def search_graph_by_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at - Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_CREATED_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
async def search_graph_by_valid_at( async def search_graph_by_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None, valid_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -723,37 +699,37 @@ async def search_graph_by_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at - Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_VALID_AT, SEARCH_STATEMENTS_BY_VALID_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
async def search_graph_g_created_at( async def search_graph_g_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None, created_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -763,37 +739,37 @@ async def search_graph_g_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at - Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_G_CREATED_AT, SEARCH_STATEMENTS_G_CREATED_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
async def search_graph_g_valid_at( async def search_graph_g_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None, valid_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -803,37 +779,37 @@ async def search_graph_g_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at - Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT, SEARCH_STATEMENTS_G_VALID_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
async def search_graph_l_created_at( async def search_graph_l_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None, created_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -843,37 +819,37 @@ async def search_graph_l_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at - Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_L_CREATED_AT, SEARCH_STATEMENTS_L_CREATED_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
async def search_graph_l_valid_at( async def search_graph_l_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None, valid_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -883,28 +859,28 @@ async def search_graph_l_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at - Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_L_VALID_AT, SEARCH_STATEMENTS_L_VALID_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results

View File

@@ -18,7 +18,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""Memory Summary Repository """Memory Summary Repository
Manages CRUD operations for MemorySummary nodes. Manages CRUD operations for MemorySummary nodes.
Provides methods to query summaries by group_id, user_id, and time ranges. Provides methods to query summaries by end_user_id, user_id, and time ranges.
Attributes: Attributes:
connector: Neo4j connector instance connector: Neo4j connector instance
@@ -51,17 +51,17 @@ class MemorySummaryRepository(BaseNeo4jRepository):
return dict(n) return dict(n)
async def find_by_group_id( async def find_by_end_user_id(
self, self,
group_id: str, end_user_id: str,
limit: int = 1000, limit: int = 1000,
start_date: Optional[datetime] = None, start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Query memory summaries by group_id """Query memory summaries by end_user_id
Args: Args:
group_id: Group ID to filter by end_user_id: Group ID to filter by
limit: Maximum number of results to return limit: Maximum number of results to return
start_date: Optional start date filter start_date: Optional start date filter
end_date: Optional end date filter end_date: Optional end date filter
@@ -71,10 +71,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
""" """
query = f""" query = f"""
MATCH (n:{self.node_label}) MATCH (n:{self.node_label})
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
""" """
params = {"group_id": group_id, "limit": limit} params = {"end_user_id": end_user_id, "limit": limit}
# Add date range filters if provided # Add date range filters if provided
if start_date: if start_date:
@@ -139,16 +139,16 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_by_group_and_user( async def find_by_group_and_user(
self, self,
group_id: str, end_user_id: str,
user_id: str, user_id: str,
limit: int = 1000, limit: int = 1000,
start_date: Optional[datetime] = None, start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Query memory summaries by both group_id and user_id """Query memory summaries by both end_user_id and user_id
Args: Args:
group_id: Group ID to filter by end_user_id: Group ID to filter by
user_id: User ID to filter by user_id: User ID to filter by
limit: Maximum number of results to return limit: Maximum number of results to return
start_date: Optional start date filter start_date: Optional start date filter
@@ -159,10 +159,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
""" """
query = f""" query = f"""
MATCH (n:{self.node_label}) MATCH (n:{self.node_label})
WHERE n.group_id = $group_id AND n.user_id = $user_id WHERE n.end_user_id = $end_user_id AND n.user_id = $user_id
""" """
params = {"group_id": group_id, "user_id": user_id, "limit": limit} params = {"end_user_id": end_user_id, "user_id": user_id, "limit": limit}
# Add date range filters if provided # Add date range filters if provided
if start_date: if start_date:
@@ -184,14 +184,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_recent_summaries( async def find_recent_summaries(
self, self,
group_id: str, end_user_id: str,
days: int = 7, days: int = 7,
limit: int = 1000 limit: int = 1000
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Query recent memory summaries """Query recent memory summaries
Args: Args:
group_id: Group ID to filter by end_user_id: Group ID to filter by
days: Number of recent days to query days: Number of recent days to query
limit: Maximum number of results to return limit: Maximum number of results to return
@@ -200,7 +200,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
""" """
query = f""" query = f"""
MATCH (n:{self.node_label}) MATCH (n:{self.node_label})
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
AND n.created_at >= datetime() - duration({{days: $days}}) AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n RETURN n
ORDER BY n.created_at DESC ORDER BY n.created_at DESC

View File

@@ -141,14 +141,14 @@ class Neo4jConnector:
async with self.driver.session(database="neo4j") as session: async with self.driver.session(database="neo4j") as session:
return await session.execute_read(transaction_func, **kwargs) return await session.execute_read(transaction_func, **kwargs)
async def delete_group(self, group_id: str): async def delete_group(self, end_user_id: str):
"""删除指定组的所有数据 """删除指定组的所有数据
删除所有属于指定group_id的节点和边。 删除所有属于指定end_user_id的节点和边。
这是一个危险操作,会永久删除数据。 这是一个危险操作,会永久删除数据。
Args: Args:
group_id: 要删除的组ID end_user_id: 要删除的组ID
Example: Example:
>>> connector = Neo4jConnector() >>> connector = Neo4jConnector()
@@ -157,14 +157,14 @@ class Neo4jConnector:
""" """
# 删除节点DETACH DELETE会同时删除相关的边 # 删除节点DETACH DELETE会同时删除相关的边
await self.driver.execute_query( await self.driver.execute_query(
"MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n", "MATCH (n) WHERE n.end_user_id = $end_user_id DETACH DELETE n",
database="neo4j", database="neo4j",
group_id=group_id end_user_id=end_user_id
) )
# 删除独立的边(如果有的话) # 删除独立的边(如果有的话)
await self.driver.execute_query( await self.driver.execute_query(
"MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r", "MATCH ()-[r]->() WHERE r.end_user_id = $end_user_id DELETE r",
database="neo4j", database="neo4j",
group_id=group_id end_user_id=end_user_id
) )
print(f"Group {group_id} deleted.") print(f"Group {end_user_id} deleted.")

View File

@@ -20,7 +20,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
"""陈述句仓储 """陈述句仓储
管理陈述句节点的创建、查询、更新和删除操作。 管理陈述句节点的创建、查询、更新和删除操作。
提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。 提供按chunk_id、end_user_id、向量相似度等条件查询陈述句的方法。
Attributes: Attributes:
connector: Neo4j连接器实例 connector: Neo4j连接器实例

View File

@@ -7,11 +7,11 @@ class UserInput(BaseModel):
message: str message: str
history: list[dict] history: list[dict]
search_switch: str search_switch: str
group_id: str end_user_id: str
config_id: Optional[str] = None config_id: Optional[str] = None
class Write_UserInput(BaseModel): class Write_UserInput(BaseModel):
messages: list[dict] messages: list[dict]
group_id: str end_user_id: str
config_id: Optional[str] = None config_id: Optional[str] = None

View File

@@ -92,7 +92,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
try: try:
memory_content = asyncio.run( memory_content = asyncio.run(
MemoryAgentService().read_memory( MemoryAgentService().read_memory(
group_id=end_user_id, end_user_id=end_user_id,
message=question, message=question,
history=[], history=[],
search_switch="2", search_switch="2",

View File

@@ -75,7 +75,7 @@ class EmotionAnalyticsService:
# 调用仓储层查询 # 调用仓储层查询
tags = await self.emotion_repo.get_emotion_tags( tags = await self.emotion_repo.get_emotion_tags(
group_id=end_user_id, end_user_id=end_user_id,
emotion_type=emotion_type, emotion_type=emotion_type,
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
@@ -157,7 +157,7 @@ class EmotionAnalyticsService:
# 调用仓储层查询 # 调用仓储层查询
keywords = await self.emotion_repo.get_emotion_wordcloud( keywords = await self.emotion_repo.get_emotion_wordcloud(
group_id=end_user_id, end_user_id=end_user_id,
emotion_type=emotion_type, emotion_type=emotion_type,
limit=limit limit=limit
) )
@@ -339,7 +339,7 @@ class EmotionAnalyticsService:
# 获取时间范围内的情绪数据 # 获取时间范围内的情绪数据
emotions = await self.emotion_repo.get_emotions_in_range( emotions = await self.emotion_repo.get_emotions_in_range(
group_id=end_user_id, end_user_id=end_user_id,
time_range=time_range time_range=time_range
) )
@@ -519,7 +519,7 @@ class EmotionAnalyticsService:
# 3. 获取情绪数据用于模式分析 # 3. 获取情绪数据用于模式分析
emotions = await self.emotion_repo.get_emotions_in_range( emotions = await self.emotion_repo.get_emotions_in_range(
group_id=end_user_id, end_user_id=end_user_id,
time_range="30d" time_range="30d"
) )
@@ -598,13 +598,13 @@ class EmotionAnalyticsService:
# 查询用户的实体和标签 # 查询用户的实体和标签
query = """ query = """
MATCH (e:Entity) MATCH (e:Entity)
WHERE e.group_id = $group_id WHERE e.end_user_id = $end_user_id
RETURN e.name as name, e.type as type RETURN e.name as name, e.type as type
ORDER BY e.created_at DESC ORDER BY e.created_at DESC
LIMIT 20 LIMIT 20
""" """
entities = await connector.execute_query(query, group_id=end_user_id) entities = await connector.execute_query(query, end_user_id=end_user_id)
# 提取兴趣标签 # 提取兴趣标签
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5] interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]

View File

@@ -27,6 +27,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_config_schema import ConfigurationError
@@ -54,25 +55,25 @@ _neo4j_connector = Neo4jConnector()
class MemoryAgentService: class MemoryAgentService:
"""Service for memory agent operations""" """Service for memory agent operations"""
def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context): def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
duration = time.time() - start_time duration = time.time() - start_time
if str(messages) == 'success': if str(messages) == 'success':
logger.info(f"Write operation successful for group {group_id} with config_id {config_id}") logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作 # 记录成功的操作
if audit_logger: if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True, audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
duration=duration, details={"message_length": len(message)}) duration=duration, details={"message_length": len(message)})
return context return context
else: else:
logger.warning(f"Write operation failed for group {group_id}") logger.warning(f"Write operation failed for group {end_user_id}")
# 记录失败的操作 # 记录失败的操作
if audit_logger: if audit_logger:
audit_logger.log_operation( audit_logger.log_operation(
operation="WRITE", operation="WRITE",
config_id=config_id, config_id=config_id,
group_id=group_id, end_user_id=end_user_id,
success=False, success=False,
duration=duration, duration=duration,
error=f"写入失败: {messages[:100]}" error=f"写入失败: {messages[:100]}"
@@ -265,13 +266,13 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources") logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic # LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: async def write_memory(self, end_user_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
""" """
Process write operation with config_id Process write operation with config_id
Args: Args:
group_id: Group identifier (also used as end_user_id) end_user_id: Group identifier (also used as end_user_id)
messages: Structured message list [{"role": "user", "content": "..."}, ...] message: Message to write
config_id: Configuration ID from database config_id: Configuration ID from database
db: SQLAlchemy database session db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag) storage_type: Storage type (neo4j or rag)
@@ -286,15 +287,15 @@ class MemoryAgentService:
# Resolve config_id if None using end_user's connected config # Resolve config_id if None using end_user's connected config
if config_id is None: if config_id is None:
try: try:
connected_config = get_end_user_connected_config(group_id, db) connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
if config_id is None: if config_id is None:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e: except Exception as e:
if "No memory configuration found" in str(e): if "No memory configuration found" in str(e):
raise raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}") logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
import time import time
start_time = time.time() start_time = time.time()
@@ -314,7 +315,7 @@ class MemoryAgentService:
# Log failed operation # Log failed operation
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
@@ -322,11 +323,11 @@ class MemoryAgentService:
if storage_type == "rag": if storage_type == "rag":
# For RAG storage, convert messages to single string # For RAG storage, convert messages to single string
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
result = await write_rag(group_id, message_text, user_rag_memory_id) result = await write_rag(end_user_id, message_text, user_rag_memory_id)
return result return result
else: else:
async with make_write_graph() as graph: async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": end_user_id}}
# Convert structured messages to LangChain messages # Convert structured messages to LangChain messages
langchain_messages = [] langchain_messages = []
for msg in messages: for msg in messages:
@@ -339,7 +340,7 @@ class MemoryAgentService:
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = { initial_state = {
"messages": langchain_messages, "messages": langchain_messages,
"group_id": group_id, "end_user_id": end_user_id,
"memory_config": memory_config "memory_config": memory_config
} }
@@ -356,14 +357,14 @@ class MemoryAgentService:
contents = massages.get('write_result') contents = massages.get('write_result')
# Convert messages back to string for logging # Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents) return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
except Exception as e: except Exception as e:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}" error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
@@ -371,15 +372,14 @@ class MemoryAgentService:
async def read_memory( async def read_memory(
self, self,
group_id: str, end_user_id: str,
message: str, message: str,
history: List[Dict], history: List[Dict],
search_switch: str, search_switch: str,
config_id: Optional[str], config_id: Optional[str],
db: Session, db: Session,
storage_type: str, storage_type: str,
user_rag_memory_id: str user_rag_memory_id: str) -> Dict:
) -> Dict:
""" """
Process read operation with config_id Process read operation with config_id
@@ -389,7 +389,7 @@ class MemoryAgentService:
- "2": Direct answer based on context - "2": Direct answer based on context
Args: Args:
group_id: Group identifier (also used as end_user_id) end_user_id: Group identifier (also used as end_user_id)
message: User message message: User message
history: Conversation history history: Conversation history
search_switch: Search mode switch search_switch: Search mode switch
@@ -407,22 +407,22 @@ class MemoryAgentService:
import time import time
start_time = time.time() start_time = time.time()
logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}") ori_message= message
# Resolve config_id if None using end_user's connected config # Resolve config_id if None using end_user's connected config
if config_id is None: if config_id is None:
try: try:
connected_config = get_end_user_connected_config(group_id, db) connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
if config_id is None: if config_id is None:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e: except Exception as e:
if "No memory configuration found" in str(e): if "No memory configuration found" in str(e):
raise # Re-raise our specific error raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}") logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}") logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
# 导入审计日志记录器 # 导入审计日志记录器
try: try:
@@ -431,15 +431,13 @@ class MemoryAgentService:
audit_logger = None audit_logger = None
config_load_start = time.time()
try: try:
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config( memory_config = config_service.load_memory_config(
config_id=config_id, config_id=config_id,
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
config_load_time = time.time() - config_load_start logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
except ConfigurationError as e: except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg) logger.error(error_msg)
@@ -450,7 +448,7 @@ class MemoryAgentService:
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
config_id=config_id, config_id=config_id,
group_id=group_id, end_user_id=end_user_id,
success=False, success=False,
duration=duration, duration=duration,
error=error_msg error=error_msg
@@ -460,16 +458,16 @@ class MemoryAgentService:
# Step 2: Prepare history # Step 2: Prepare history
history.append({"role": "user", "content": message}) history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") logger.debug(f"Group ID:{end_user_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow # Step 3: Initialize MCP client and execute read workflow
graph_exec_start = time.time() graph_exec_start = time.time()
try: try:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"group_id": group_id "end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config} "memory_config": memory_config}
# 获取节点更新信息 # 获取节点更新信息
@@ -565,13 +563,13 @@ class MemoryAgentService:
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
# 使用 upsert 方法 # 使用 upsert 方法
repo.upsert( repo.upsert(
end_user_id=group_id, end_user_id=end_user_id,
messages=message, messages=message,
aimessages=summary, aimessages=summary,
retrieved_content=retrieved_content, retrieved_content=retrieved_content,
search_switch=str(search_switch) search_switch=str(search_switch)
) )
logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}") logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
else: else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
@@ -580,14 +578,12 @@ class MemoryAgentService:
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation # Log successful operation
total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
config_id=config_id, config_id=config_id,
group_id=group_id, end_user_id=end_user_id,
success=True, success=True,
duration=duration duration=duration
) )
@@ -599,14 +595,13 @@ class MemoryAgentService:
except Exception as e: except Exception as e:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}" error_msg = f"Read operation failed: {str(e)}"
total_time = time.time() - start_time logger.error(error_msg)
logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}")
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
config_id=config_id, config_id=config_id,
group_id=group_id, end_user_id=end_user_id,
success=False, success=False,
duration=duration, duration=duration,
error=error_msg error=error_msg
@@ -755,7 +750,7 @@ class MemoryAgentService:
""" """
统计知识库类型分布,包含: 统计知识库类型分布,包含:
1. PostgreSQL 中的知识库类型General, Web, Third-party, Folder根据 workspace_id 过滤) 1. PostgreSQL 中的知识库类型General, Web, Third-party, Folder根据 workspace_id 过滤)
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤) 2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤)
3. total: 所有类型的总和 3. total: 所有类型的总和
参数: 参数:
@@ -841,11 +836,11 @@ class MemoryAgentService:
for end_user in end_users: for end_user in end_users:
end_user_id_str = str(end_user.id) end_user_id_str = str(end_user.id)
memory_query = """ memory_query = """
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN count(n) AS Count MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count
""" """
neo4j_result = await _neo4j_connector.execute_query( neo4j_result = await _neo4j_connector.execute_query(
memory_query, memory_query,
group_id=end_user_id_str, end_user_id=end_user_id_str,
) )
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0 chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
total_chunks += chunk_count total_chunks += chunk_count
@@ -885,7 +880,7 @@ class MemoryAgentService:
获取指定用户的热门记忆标签 获取指定用户的热门记忆标签
参数: 参数:
- end_user_id: 用户ID可选对应Neo4j中的group_id字段 - end_user_id: 用户ID可选对应Neo4j中的end_user_id字段
- limit: 返回标签数量限制 - limit: 返回标签数量限制
返回格式: 返回格式:
@@ -895,7 +890,7 @@ class MemoryAgentService:
] ]
""" """
try: try:
# by_user=False 表示按 group_id 查询在Neo4j中group_id就是用户维度 # by_user=False 表示按 end_user_id 查询在Neo4j中end_user_id就是用户维度
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False) tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
payload=[] payload=[]
for tag, freq in tags: for tag, freq in tags:
@@ -970,21 +965,21 @@ class MemoryAgentService:
# 查询该用户的语句 # 查询该用户的语句
query = ( query = (
"MATCH (s:Statement) " "MATCH (s:Statement) "
"WHERE ($group_id IS NULL OR s.group_id = $group_id) AND s.statement IS NOT NULL " "WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND s.statement IS NOT NULL "
"RETURN s.statement AS statement " "RETURN s.statement AS statement "
"ORDER BY s.created_at DESC LIMIT 100" "ORDER BY s.created_at DESC LIMIT 100"
) )
rows = await connector.execute_query(query, group_id=end_user_id) rows = await connector.execute_query(query, end_user_id=end_user_id)
statements = [r.get("statement", "") for r in rows if r.get("statement")] statements = [r.get("statement", "") for r in rows if r.get("statement")]
# 查询该用户的热门实体 # 查询该用户的热门实体
entity_query = ( entity_query = (
"MATCH (e:ExtractedEntity) " "MATCH (e:ExtractedEntity) "
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL " "WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
"RETURN e.name AS name, count(e) AS frequency " "RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC LIMIT 20" "ORDER BY frequency DESC LIMIT 20"
) )
entity_rows = await connector.execute_query(entity_query, group_id=end_user_id) entity_rows = await connector.execute_query(entity_query, end_user_id=end_user_id)
entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows] entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows]
await connector.close() await connector.close()
@@ -1037,14 +1032,14 @@ class MemoryAgentService:
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria'] names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
hot_tag_query = ( hot_tag_query = (
"MATCH (e:ExtractedEntity) " "MATCH (e:ExtractedEntity) "
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' " "WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' "
"AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude " "AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"RETURN e.name AS name, count(e) AS frequency " "RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC LIMIT 4" "ORDER BY frequency DESC LIMIT 4"
) )
hot_tag_rows = await connector.execute_query( hot_tag_rows = await connector.execute_query(
hot_tag_query, hot_tag_query,
group_id=end_user_id, end_user_id=end_user_id,
names_to_exclude=names_to_exclude names_to_exclude=names_to_exclude
) )
await connector.close() await connector.close()
@@ -1190,6 +1185,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
"memory_config_id": memory_config_id "memory_config_id": memory_config_id
} }
print(188*'*')
print(result)
print(188 * '*')
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}") logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result return result
@@ -1230,10 +1229,10 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 1. 批量查询所有 end_user 及其 app_id # 1. 批量查询所有 end_user 及其 app_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all() end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 创建 end_user_id -> app_id 的映射 # 创建 end_user_id -> app_id 的映射
user_to_app = {str(eu.id): eu.app_id for eu in end_users} user_to_app = {str(eu.id): eu.app_id for eu in end_users}
# 记录未找到的用户 # 记录未找到的用户
found_user_ids = set(user_to_app.keys()) found_user_ids = set(user_to_app.keys())
missing_user_ids = set(end_user_ids) - found_user_ids missing_user_ids = set(end_user_ids) - found_user_ids
@@ -1275,13 +1274,13 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 批量查询 memory_config_name # 批量查询 memory_config_name
config_id_to_name = {} config_id_to_name = {}
if memory_config_ids: if memory_config_ids:
memory_configs = db.query(DataConfig).filter(DataConfig.config_id.in_(memory_config_ids)).all() memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all()
config_id_to_name = {str(mc.config_id): mc.config_name for mc in memory_configs} config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs}
# 4. 构建最终结果 # 4. 构建最终结果
for end_user_id, app_id in user_to_app.items(): for end_user_id, app_id in user_to_app.items():
release = app_to_release.get(app_id) release = app_to_release.get(app_id)
if not release: if not release:
logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})") logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})")
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None} result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
@@ -1293,7 +1292,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 获取配置名称 # 获取配置名称
memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None
result[end_user_id] = { result[end_user_id] = {
"memory_config_id": memory_config_id, "memory_config_id": memory_config_id,

View File

@@ -25,7 +25,7 @@ class MemoryAPIService:
This service provides a thin layer that: This service provides a thin layer that:
1. Validates end_user exists and belongs to the authorized workspace 1. Validates end_user exists and belongs to the authorized workspace
2. Maps end_user_id to group_id for memory operations 2. Maps end_user_id to end_user_id for memory operations
3. Delegates to MemoryAgentService for actual memory read/write operations 3. Delegates to MemoryAgentService for actual memory read/write operations
""" """
@@ -68,7 +68,7 @@ class MemoryAPIService:
) )
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first() end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first()
if not end_user: if not end_user:
logger.warning(f"End user not found: {end_user_id}") logger.warning(f"End user not found: {end_user_id}")
raise ResourceNotFoundException( raise ResourceNotFoundException(
@@ -115,7 +115,7 @@ class MemoryAPIService:
Args: Args:
workspace_id: Workspace ID for resource validation workspace_id: Workspace ID for resource validation
end_user_id: End user identifier (used as group_id) end_user_id: End user identifier (used as end_user_id)
message: Message content to store message: Message content to store
config_id: Optional memory configuration ID config_id: Optional memory configuration ID
storage_type: Storage backend (neo4j or rag) storage_type: Storage backend (neo4j or rag)
@@ -133,13 +133,12 @@ class MemoryAPIService:
# Validate end_user exists and belongs to workspace # Validate end_user exists and belongs to workspace
self.validate_end_user(end_user_id, workspace_id) self.validate_end_user(end_user_id, workspace_id)
# Use end_user_id as group_id for memory operations # Use end_user_id as end_user_id for memory operations
group_id = end_user_id
try: try:
# Delegate to MemoryAgentService # Delegate to MemoryAgentService
result = await MemoryAgentService().write_memory( result = await MemoryAgentService().write_memory(
group_id=group_id, end_user_id=end_user_id,
message=message, message=message,
config_id=config_id, config_id=config_id,
db=self.db, db=self.db,
@@ -186,7 +185,7 @@ class MemoryAPIService:
Args: Args:
workspace_id: Workspace ID for resource validation workspace_id: Workspace ID for resource validation
end_user_id: End user identifier (used as group_id) end_user_id: End user identifier (used as end_user_id)
message: Query message message: Query message
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search) search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
config_id: Optional memory configuration ID config_id: Optional memory configuration ID
@@ -205,13 +204,13 @@ class MemoryAPIService:
# Validate end_user exists and belongs to workspace # Validate end_user exists and belongs to workspace
self.validate_end_user(end_user_id, workspace_id) self.validate_end_user(end_user_id, workspace_id)
# Use end_user_id as group_id for memory operations # Use end_user_id as end_user_id for memory operations
group_id = end_user_id
try: try:
# Delegate to MemoryAgentService # Delegate to MemoryAgentService
result = await MemoryAgentService().read_memory( result = await MemoryAgentService().read_memory(
group_id=group_id, end_user_id=end_user_id,
message=message, message=message,
history=[], history=[],
search_switch=search_switch, search_switch=search_switch,

View File

@@ -326,7 +326,7 @@ class MemoryBaseService:
Args: Args:
summary_id: Summary节点的ID summary_id: Summary节点的ID
end_user_id: 终端用户ID (group_id) end_user_id: 终端用户ID (end_user_id)
Returns: Returns:
最大emotion_intensity对应的emotion_type如果没有则返回None 最大emotion_intensity对应的emotion_type如果没有则返回None
@@ -334,7 +334,7 @@ class MemoryBaseService:
try: try:
query = """ query = """
MATCH (s:MemorySummary) MATCH (s:MemorySummary)
WHERE elementId(s) = $summary_id AND s.group_id = $group_id WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
WHERE stmt.emotion_type IS NOT NULL WHERE stmt.emotion_type IS NOT NULL
AND stmt.emotion_intensity IS NOT NULL AND stmt.emotion_intensity IS NOT NULL
@@ -347,7 +347,7 @@ class MemoryBaseService:
result = await self.neo4j_connector.execute_query( result = await self.neo4j_connector.execute_query(
query, query,
summary_id=summary_id, summary_id=summary_id,
group_id=end_user_id end_user_id=end_user_id
) )
if result and len(result) > 0: if result and len(result) > 0:
@@ -381,10 +381,10 @@ class MemoryBaseService:
if end_user_id: if end_user_id:
query = """ query = """
MATCH (n:MemorySummary) MATCH (n:MemorySummary)
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
RETURN count(n) as count RETURN count(n) as count
""" """
result = await self.neo4j_connector.execute_query(query, group_id=end_user_id) result = await self.neo4j_connector.execute_query(query, end_user_id=end_user_id)
else: else:
query = """ query = """
MATCH (n:MemorySummary) MATCH (n:MemorySummary)
@@ -423,12 +423,12 @@ class MemoryBaseService:
if end_user_id: if end_user_id:
semantic_query = """ semantic_query = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE e.group_id = $group_id AND e.is_explicit_memory = true WHERE e.end_user_id = $end_user_id AND e.is_explicit_memory = true
RETURN count(e) as count RETURN count(e) as count
""" """
semantic_result = await self.neo4j_connector.execute_query( semantic_result = await self.neo4j_connector.execute_query(
semantic_query, semantic_query,
group_id=end_user_id end_user_id=end_user_id
) )
else: else:
semantic_query = """ semantic_query = """
@@ -519,7 +519,7 @@ class MemoryBaseService:
""" """
if end_user_id: if end_user_id:
query += " AND n.group_id = $group_id" query += " AND n.end_user_id = $end_user_id"
query += """ query += """
RETURN sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes RETURN sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
@@ -528,7 +528,7 @@ class MemoryBaseService:
# 设置查询参数 # 设置查询参数
params = {'threshold': forgetting_threshold} params = {'threshold': forgetting_threshold}
if end_user_id: if end_user_id:
params['group_id'] = end_user_id params['end_user_id'] = end_user_id
# 执行查询 # 执行查询
result = await self.neo4j_connector.execute_query(query, **params) result = await self.neo4j_connector.execute_query(query, **params)

View File

@@ -717,8 +717,8 @@ class MemoryInteraction:
ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id) ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id)
if ori_data!=[]: if ori_data!=[]:
# name = ori_data[0]['name'] # name = ori_data[0]['name']
group_id = [i['group_id'] for i in ori_data][0] end_user_id = [i['end_user_id'] for i in ori_data][0]
Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id) Space_User = await self.connector.execute_query(Memory_Space_User, end_user_id=end_user_id)
if not Space_User: if not Space_User:
return [] return []
user_id=Space_User[0]['id'] user_id=Space_User[0]['id']

View File

@@ -34,7 +34,7 @@ class MemoryEpisodicService(MemoryBaseService):
Args: Args:
summary_id: Summary节点的ID summary_id: Summary节点的ID
end_user_id: 终端用户ID (group_id) end_user_id: 终端用户ID (end_user_id)
Returns: Returns:
(标题, 类型)元组,如果不存在则返回默认值 (标题, 类型)元组,如果不存在则返回默认值
@@ -43,14 +43,14 @@ class MemoryEpisodicService(MemoryBaseService):
# 查询Summary节点的name(作为title)和memory_type(作为type) # 查询Summary节点的name(作为title)和memory_type(作为type)
query = """ query = """
MATCH (s:MemorySummary) MATCH (s:MemorySummary)
WHERE elementId(s) = $summary_id AND s.group_id = $group_id WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
RETURN s.name AS title, s.memory_type AS type RETURN s.name AS title, s.memory_type AS type
""" """
result = await self.neo4j_connector.execute_query( result = await self.neo4j_connector.execute_query(
query, query,
summary_id=summary_id, summary_id=summary_id,
group_id=end_user_id end_user_id=end_user_id
) )
if not result or len(result) == 0: if not result or len(result) == 0:
@@ -77,7 +77,7 @@ class MemoryEpisodicService(MemoryBaseService):
Args: Args:
summary_id: Summary节点的ID summary_id: Summary节点的ID
end_user_id: 终端用户ID (group_id) end_user_id: 终端用户ID (end_user_id)
Returns: Returns:
前3个实体的name属性列表 前3个实体的name属性列表
@@ -87,7 +87,7 @@ class MemoryEpisodicService(MemoryBaseService):
# 按activation_value降序排序,返回前3个 # 按activation_value降序排序,返回前3个
query = """ query = """
MATCH (s:MemorySummary) MATCH (s:MemorySummary)
WHERE elementId(s) = $summary_id AND s.group_id = $group_id WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
MATCH (stmt)-[:REFERENCES_ENTITY]->(entity:ExtractedEntity) MATCH (stmt)-[:REFERENCES_ENTITY]->(entity:ExtractedEntity)
WHERE entity.activation_value IS NOT NULL WHERE entity.activation_value IS NOT NULL
@@ -99,7 +99,7 @@ class MemoryEpisodicService(MemoryBaseService):
result = await self.neo4j_connector.execute_query( result = await self.neo4j_connector.execute_query(
query, query,
summary_id=summary_id, summary_id=summary_id,
group_id=end_user_id end_user_id=end_user_id
) )
# 提取实体名称 # 提取实体名称
@@ -123,7 +123,7 @@ class MemoryEpisodicService(MemoryBaseService):
Args: Args:
summary_id: Summary节点的ID summary_id: Summary节点的ID
end_user_id: 终端用户ID (group_id) end_user_id: 终端用户ID (end_user_id)
Returns: Returns:
所有Statement节点的statement属性内容列表 所有Statement节点的statement属性内容列表
@@ -132,7 +132,7 @@ class MemoryEpisodicService(MemoryBaseService):
# 查询Summary节点指向的所有Statement节点 # 查询Summary节点指向的所有Statement节点
query = """ query = """
MATCH (s:MemorySummary) MATCH (s:MemorySummary)
WHERE elementId(s) = $summary_id AND s.group_id = $group_id WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
WHERE stmt.statement IS NOT NULL AND stmt.statement <> '' WHERE stmt.statement IS NOT NULL AND stmt.statement <> ''
RETURN stmt.statement AS statement RETURN stmt.statement AS statement
@@ -141,7 +141,7 @@ class MemoryEpisodicService(MemoryBaseService):
result = await self.neo4j_connector.execute_query( result = await self.neo4j_connector.execute_query(
query, query,
summary_id=summary_id, summary_id=summary_id,
group_id=end_user_id end_user_id=end_user_id
) )
# 提取statement内容 # 提取statement内容
@@ -214,12 +214,12 @@ class MemoryEpisodicService(MemoryBaseService):
# 1. 先查询所有情景记忆的总数(不受筛选条件限制) # 1. 先查询所有情景记忆的总数(不受筛选条件限制)
total_all_query = """ total_all_query = """
MATCH (s:MemorySummary) MATCH (s:MemorySummary)
WHERE s.group_id = $group_id WHERE s.end_user_id = $end_user_id
RETURN count(s) AS total_all RETURN count(s) AS total_all
""" """
total_all_result = await self.neo4j_connector.execute_query( total_all_result = await self.neo4j_connector.execute_query(
total_all_query, total_all_query,
group_id=end_user_id end_user_id=end_user_id
) )
total_all = total_all_result[0]["total_all"] if total_all_result else 0 total_all = total_all_result[0]["total_all"] if total_all_result else 0
@@ -229,7 +229,7 @@ class MemoryEpisodicService(MemoryBaseService):
# 3. 构建Cypher查询 # 3. 构建Cypher查询
query = """ query = """
MATCH (s:MemorySummary) MATCH (s:MemorySummary)
WHERE s.group_id = $group_id WHERE s.end_user_id = $end_user_id
""" """
# 添加时间范围过滤 # 添加时间范围过滤
@@ -248,7 +248,7 @@ class MemoryEpisodicService(MemoryBaseService):
ORDER BY s.created_at DESC ORDER BY s.created_at DESC
""" """
params = {"group_id": end_user_id} params = {"end_user_id": end_user_id}
if time_filter: if time_filter:
params["time_filter"] = time_filter params["time_filter"] = time_filter
if title_keyword: if title_keyword:
@@ -333,14 +333,14 @@ class MemoryEpisodicService(MemoryBaseService):
# 1. 查询指定的MemorySummary节点 # 1. 查询指定的MemorySummary节点
query = """ query = """
MATCH (s:MemorySummary) MATCH (s:MemorySummary)
WHERE elementId(s) = $summary_id AND s.group_id = $group_id WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
RETURN elementId(s) AS id, s.created_at AS created_at RETURN elementId(s) AS id, s.created_at AS created_at
""" """
result = await self.neo4j_connector.execute_query( result = await self.neo4j_connector.execute_query(
query, query,
summary_id=summary_id, summary_id=summary_id,
group_id=end_user_id end_user_id=end_user_id
) )
# 2. 如果节点不存在,返回错误 # 2. 如果节点不存在,返回错误

View File

@@ -60,7 +60,7 @@ class MemoryExplicitService(MemoryBaseService):
# ========== 1. 查询情景记忆MemorySummary节点 ========== # ========== 1. 查询情景记忆MemorySummary节点 ==========
episodic_query = """ episodic_query = """
MATCH (s:MemorySummary) MATCH (s:MemorySummary)
WHERE s.group_id = $group_id WHERE s.end_user_id = $end_user_id
RETURN elementId(s) AS id, RETURN elementId(s) AS id,
s.name AS title, s.name AS title,
s.content AS content, s.content AS content,
@@ -70,7 +70,7 @@ class MemoryExplicitService(MemoryBaseService):
episodic_result = await self.neo4j_connector.execute_query( episodic_result = await self.neo4j_connector.execute_query(
episodic_query, episodic_query,
group_id=end_user_id end_user_id=end_user_id
) )
# 处理情景记忆数据 # 处理情景记忆数据
@@ -96,7 +96,7 @@ class MemoryExplicitService(MemoryBaseService):
# ========== 2. 查询语义记忆ExtractedEntity节点 ========== # ========== 2. 查询语义记忆ExtractedEntity节点 ==========
semantic_query = """ semantic_query = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE e.group_id = $group_id WHERE e.end_user_id = $end_user_id
AND e.is_explicit_memory = true AND e.is_explicit_memory = true
RETURN elementId(e) AS id, RETURN elementId(e) AS id,
e.name AS name, e.name AS name,
@@ -107,7 +107,7 @@ class MemoryExplicitService(MemoryBaseService):
semantic_result = await self.neo4j_connector.execute_query( semantic_result = await self.neo4j_connector.execute_query(
semantic_query, semantic_query,
group_id=end_user_id end_user_id=end_user_id
) )
# 处理语义记忆数据 # 处理语义记忆数据
@@ -189,7 +189,7 @@ class MemoryExplicitService(MemoryBaseService):
# ========== 1. 先尝试查询情景记忆 ========== # ========== 1. 先尝试查询情景记忆 ==========
episodic_query = """ episodic_query = """
MATCH (s:MemorySummary) MATCH (s:MemorySummary)
WHERE elementId(s) = $memory_id AND s.group_id = $group_id WHERE elementId(s) = $memory_id AND s.end_user_id = $end_user_id
RETURN s.name AS title, RETURN s.name AS title,
s.content AS content, s.content AS content,
s.created_at AS created_at s.created_at AS created_at
@@ -198,7 +198,7 @@ class MemoryExplicitService(MemoryBaseService):
episodic_result = await self.neo4j_connector.execute_query( episodic_result = await self.neo4j_connector.execute_query(
episodic_query, episodic_query,
memory_id=memory_id, memory_id=memory_id,
group_id=end_user_id end_user_id=end_user_id
) )
if episodic_result and len(episodic_result) > 0: if episodic_result and len(episodic_result) > 0:
@@ -229,7 +229,7 @@ class MemoryExplicitService(MemoryBaseService):
semantic_query = """ semantic_query = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE elementId(e) = $memory_id WHERE elementId(e) = $memory_id
AND e.group_id = $group_id AND e.end_user_id = $end_user_id
AND e.is_explicit_memory = true AND e.is_explicit_memory = true
RETURN e.name AS name, RETURN e.name AS name,
e.description AS core_definition, e.description AS core_definition,
@@ -240,7 +240,7 @@ class MemoryExplicitService(MemoryBaseService):
semantic_result = await self.neo4j_connector.execute_query( semantic_result = await self.neo4j_connector.execute_query(
semantic_query, semantic_query,
memory_id=memory_id, memory_id=memory_id,
group_id=end_user_id end_user_id=end_user_id
) )
if semantic_result and len(semantic_result) > 0: if semantic_result and len(semantic_result) > 0:

View File

@@ -132,7 +132,7 @@ class MemoryForgetService:
async def _get_knowledge_stats( async def _get_knowledge_stats(
self, self,
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
forgetting_threshold: float = 0.3 forgetting_threshold: float = 0.3
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -140,7 +140,7 @@ class MemoryForgetService:
Args: Args:
connector: Neo4j 连接器 connector: Neo4j 连接器
group_id: 组ID可选 end_user_id: 组ID可选
forgetting_threshold: 遗忘阈值 forgetting_threshold: 遗忘阈值
Returns: Returns:
@@ -152,8 +152,8 @@ class MemoryForgetService:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
""" """
if group_id: if end_user_id:
query += " AND n.group_id = $group_id" query += " AND n.end_user_id = $end_user_id"
query += """ query += """
WITH n, WITH n,
@@ -172,8 +172,8 @@ class MemoryForgetService:
""" """
params = {'threshold': forgetting_threshold} params = {'threshold': forgetting_threshold}
if group_id: if end_user_id:
params['group_id'] = group_id params['end_user_id'] = end_user_id
results = await connector.execute_query(query, **params) results = await connector.execute_query(query, **params)
@@ -200,7 +200,7 @@ class MemoryForgetService:
async def _get_pending_forgetting_nodes( async def _get_pending_forgetting_nodes(
self, self,
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: str, end_user_id: str,
forgetting_threshold: float, forgetting_threshold: float,
min_days_since_access: int, min_days_since_access: int,
limit: int = 20 limit: int = 20
@@ -212,7 +212,7 @@ class MemoryForgetService:
Args: Args:
connector: Neo4j 连接器 connector: Neo4j 连接器
group_id: 组ID end_user_id: 组ID
forgetting_threshold: 遗忘阈值 forgetting_threshold: 遗忘阈值
min_days_since_access: 最小未访问天数 min_days_since_access: 最小未访问天数
limit: 返回节点数量限制 limit: 返回节点数量限制
@@ -229,7 +229,7 @@ class MemoryForgetService:
query = """ query = """
MATCH (n) MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
AND n.group_id = $group_id AND n.end_user_id = $end_user_id
AND n.activation_value IS NOT NULL AND n.activation_value IS NOT NULL
AND n.activation_value < $threshold AND n.activation_value < $threshold
AND n.last_access_time IS NOT NULL AND n.last_access_time IS NOT NULL
@@ -250,7 +250,7 @@ class MemoryForgetService:
""" """
params = { params = {
'group_id': group_id, 'end_user_id': end_user_id,
'threshold': forgetting_threshold, 'threshold': forgetting_threshold,
'min_access_time_str': min_access_time_str, 'min_access_time_str': min_access_time_str,
'limit': limit 'limit': limit
@@ -291,7 +291,7 @@ class MemoryForgetService:
async def trigger_forgetting_cycle( async def trigger_forgetting_cycle(
self, self,
db: Session, db: Session,
group_id: str, end_user_id: str,
max_merge_batch_size: Optional[int] = None, max_merge_batch_size: Optional[int] = None,
min_days_since_access: Optional[int] = None, min_days_since_access: Optional[int] = None,
config_id: Optional[int] = None config_id: Optional[int] = None
@@ -303,10 +303,10 @@ class MemoryForgetService:
Args: Args:
db: 数据库会话 db: 数据库会话
group_id: 组ID即终端用户ID必填 end_user_id: 组ID即终端用户ID必填
max_merge_batch_size: 最大融合批次大小(可选) max_merge_batch_size: 最大融合批次大小(可选)
min_days_since_access: 最小未访问天数(可选) min_days_since_access: 最小未访问天数(可选)
config_id: 配置ID必填由控制器层通过 group_id 获取) config_id: 配置ID必填由控制器层通过 end_user_id 获取)
Returns: Returns:
dict: 遗忘报告 dict: 遗忘报告
@@ -319,7 +319,7 @@ class MemoryForgetService:
# 运行遗忘周期LLM 客户端将在需要时由 forgetting_strategy 内部获取) # 运行遗忘周期LLM 客户端将在需要时由 forgetting_strategy 内部获取)
report = await forgetting_scheduler.run_forgetting_cycle( report = await forgetting_scheduler.run_forgetting_cycle(
group_id=group_id, end_user_id=end_user_id,
max_merge_batch_size=max_merge_batch_size, max_merge_batch_size=max_merge_batch_size,
min_days_since_access=min_days_since_access, min_days_since_access=min_days_since_access,
config_id=config_id, config_id=config_id,
@@ -338,7 +338,7 @@ class MemoryForgetService:
stats_query = """ stats_query = """
MATCH (n) MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
AND n.group_id = $group_id AND n.end_user_id = $end_user_id
RETURN RETURN
count(n) as total_nodes, count(n) as total_nodes,
avg(n.activation_value) as average_activation, avg(n.activation_value) as average_activation,
@@ -347,7 +347,7 @@ class MemoryForgetService:
stats_results = await connector.execute_query( stats_results = await connector.execute_query(
stats_query, stats_query,
group_id=group_id, end_user_id=end_user_id,
threshold=config['forgetting_threshold'] threshold=config['forgetting_threshold']
) )
@@ -364,7 +364,7 @@ class MemoryForgetService:
# 保存历史记录到数据库 # 保存历史记录到数据库
self.history_repository.create( self.history_repository.create(
db=db, db=db,
end_user_id=group_id, end_user_id=end_user_id,
execution_time=execution_time, execution_time=execution_time,
merged_count=report['merged_count'], merged_count=report['merged_count'],
failed_count=report['failed_count'], failed_count=report['failed_count'],
@@ -376,7 +376,7 @@ class MemoryForgetService:
) )
api_logger.info( api_logger.info(
f"已保存遗忘周期历史记录: end_user_id={group_id}, " f"已保存遗忘周期历史记录: end_user_id={end_user_id}, "
f"merged_count={report['merged_count']}" f"merged_count={report['merged_count']}"
) )
@@ -465,7 +465,7 @@ class MemoryForgetService:
async def get_forgetting_stats( async def get_forgetting_stats(
self, self,
db: Session, db: Session,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
config_id: Optional[int] = None config_id: Optional[int] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -475,7 +475,7 @@ class MemoryForgetService:
Args: Args:
db: 数据库会话 db: 数据库会话
group_id: 组ID可选 end_user_id: 组ID可选
config_id: 配置ID可选用于获取遗忘阈值 config_id: 配置ID可选用于获取遗忘阈值
Returns: Returns:
@@ -493,8 +493,8 @@ class MemoryForgetService:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
""" """
if group_id: if end_user_id:
activation_query += " AND n.group_id = $group_id" activation_query += " AND n.end_user_id = $end_user_id"
activation_query += """ activation_query += """
RETURN RETURN
@@ -506,8 +506,8 @@ class MemoryForgetService:
""" """
params = {'threshold': forgetting_threshold} params = {'threshold': forgetting_threshold}
if group_id: if end_user_id:
params['group_id'] = group_id params['end_user_id'] = end_user_id
activation_results = await connector.execute_query(activation_query, **params) activation_results = await connector.execute_query(activation_query, **params)
@@ -539,8 +539,8 @@ class MemoryForgetService:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
""" """
if group_id: if end_user_id:
distribution_query += " AND n.group_id = $group_id" distribution_query += " AND n.end_user_id = $end_user_id"
distribution_query += """ distribution_query += """
WITH n, WITH n,
@@ -558,8 +558,8 @@ class MemoryForgetService:
""" """
dist_params = {} dist_params = {}
if group_id: if end_user_id:
dist_params['group_id'] = group_id dist_params['end_user_id'] = end_user_id
distribution_results = await connector.execute_query(distribution_query, **dist_params) distribution_results = await connector.execute_query(distribution_query, **dist_params)
@@ -582,11 +582,11 @@ class MemoryForgetService:
# 获取最近7个日期的历史趋势数据每天取最后一次执行 # 获取最近7个日期的历史趋势数据每天取最后一次执行
recent_trends = [] recent_trends = []
try: try:
if group_id: if end_user_id:
# 查询所有历史记录 # 查询所有历史记录
history_records = self.history_repository.get_recent_by_end_user( history_records = self.history_repository.get_recent_by_end_user(
db=db, db=db,
end_user_id=group_id end_user_id=end_user_id
) )
# 按日期分组(一天可能有多次执行,取最后一次) # 按日期分组(一天可能有多次执行,取最后一次)
@@ -632,7 +632,7 @@ class MemoryForgetService:
# 获取待遗忘节点列表前20个满足遗忘条件的节点 # 获取待遗忘节点列表前20个满足遗忘条件的节点
pending_nodes = [] pending_nodes = []
try: try:
if group_id: if end_user_id:
# 验证 min_days_since_access 配置值 # 验证 min_days_since_access 配置值
min_days = config.get('min_days_since_access') min_days = config.get('min_days_since_access')
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0: if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
@@ -643,7 +643,7 @@ class MemoryForgetService:
pending_nodes = await self._get_pending_forgetting_nodes( pending_nodes = await self._get_pending_forgetting_nodes(
connector=connector, connector=connector,
group_id=group_id, end_user_id=end_user_id,
forgetting_threshold=forgetting_threshold, forgetting_threshold=forgetting_threshold,
min_days_since_access=int(min_days), min_days_since_access=int(min_days),
limit=20 limit=20

View File

@@ -450,12 +450,12 @@ async def create_document_chunk(
return success(data=chunk, msg="文档块创建成功") return success(data=chunk, msg="文档块创建成功")
async def write_rag(group_id, message, user_rag_memory_id): async def write_rag(end_user_id, message, user_rag_memory_id):
""" """
将消息写入 RAG 知识库 将消息写入 RAG 知识库
Args: Args:
group_id: 组ID用作文件标题 end_user_id: 组ID用作文件标题
message: 消息内容 message: 消息内容
user_rag_memory_id: 知识库ID必须是有效的UUID user_rag_memory_id: 知识库ID必须是有效的UUID
@@ -487,10 +487,10 @@ async def write_rag(group_id, message, user_rag_memory_id):
db = next(db_gen) db = next(db_gen)
try: try:
create_data = CustomTextFileCreate(title=group_id, content=message) create_data = CustomTextFileCreate(title=end_user_id, content=message)
current_user = SimpleUser(user_rag_memory_id) current_user = SimpleUser(user_rag_memory_id)
# 检查文档是否已存在 # 检查文档是否已存在
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{group_id}.txt") document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
print('======',document) print('======',document)
api_logger.info(f"查找文档结果: document_id={document}") api_logger.info(f"查找文档结果: document_id={document}")
if document is not None: if document is not None:
@@ -508,7 +508,7 @@ async def write_rag(group_id, message, user_rag_memory_id):
return result return result
else: else:
# 文档不存在,创建新文档 # 文档不存在,创建新文档
api_logger.info(f"文档不存在,创建新文档: group_id={group_id}") api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
result = await memory_konwledges_up( result = await memory_konwledges_up(
kb_id=user_rag_memory_id, kb_id=user_rag_memory_id,
parent_id=user_rag_memory_id, parent_id=user_rag_memory_id,
@@ -520,13 +520,13 @@ async def write_rag(group_id, message, user_rag_memory_id):
new_document_id = find_document_id_by_kb_and_filename( new_document_id = find_document_id_by_kb_and_filename(
db=db, db=db,
kb_id=user_rag_memory_id, kb_id=user_rag_memory_id,
file_name=f"{group_id}.txt" file_name=f"{end_user_id}.txt"
) )
if new_document_id: if new_document_id:
await parse_document_by_id(new_document_id, db=db, current_user=current_user) await parse_document_by_id(new_document_id, db=db, current_user=current_user)
else: else:
api_logger.error(f"创建文档后无法找到文档ID: group_id={group_id}") api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
return result return result
finally: finally:
# 确保数据库会话被关闭 # 确保数据库会话被关闭

View File

@@ -183,7 +183,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"config_name": config.config_name, "config_name": config.config_name,
"config_desc": config.config_desc, "config_desc": config.config_desc,
"workspace_id": str(config.workspace_id) if config.workspace_id else None, "workspace_id": str(config.workspace_id) if config.workspace_id else None,
"group_id": config.group_id, "end_user_id": config.end_user_id,
"user_id": config.user_id, "user_id": config.user_id,
"apply_id": config.apply_id, "apply_id": config.apply_id,
"llm_id": config.llm_id, "llm_id": config.llm_id,
@@ -391,7 +391,7 @@ _neo4j_connector = Neo4jConnector()
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query( result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_DIALOGUE, DataConfigRepository.SEARCH_FOR_DIALOGUE,
group_id=end_user_id, end_user_id=end_user_id,
) )
data = {"search_for": "dialogue", "num": result[0]["num"]} data = {"search_for": "dialogue", "num": result[0]["num"]}
return data return data
@@ -400,7 +400,7 @@ async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query( result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_CHUNK, DataConfigRepository.SEARCH_FOR_CHUNK,
group_id=end_user_id, end_user_id=end_user_id,
) )
data = {"search_for": "chunk", "num": result[0]["num"]} data = {"search_for": "chunk", "num": result[0]["num"]}
return data return data
@@ -409,7 +409,7 @@ async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query( result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_STATEMENT, DataConfigRepository.SEARCH_FOR_STATEMENT,
group_id=end_user_id, end_user_id=end_user_id,
) )
data = {"search_for": "statement", "num": result[0]["num"]} data = {"search_for": "statement", "num": result[0]["num"]}
return data return data
@@ -418,7 +418,7 @@ async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query( result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ENTITY, DataConfigRepository.SEARCH_FOR_ENTITY,
group_id=end_user_id, end_user_id=end_user_id,
) )
data = {"search_for": "entity", "num": result[0]["num"]} data = {"search_for": "entity", "num": result[0]["num"]}
return data return data
@@ -427,7 +427,7 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query( result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ALL, DataConfigRepository.SEARCH_FOR_ALL,
group_id=end_user_id, end_user_id=end_user_id,
) )
# 检查结果是否为空或长度不足 # 检查结果是否为空或长度不足
@@ -462,7 +462,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
""" """
result = await _neo4j_connector.execute_query( result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ALL, DataConfigRepository.SEARCH_FOR_ALL,
group_id=end_user_id, end_user_id=end_user_id,
) )
# 检查结果是否为空或长度不足 # 检查结果是否为空或长度不足
@@ -493,7 +493,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]: async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
result = await _neo4j_connector.execute_query( result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_DETIALS, DataConfigRepository.SEARCH_FOR_DETIALS,
group_id=end_user_id, end_user_id=end_user_id,
) )
return result return result
@@ -501,11 +501,32 @@ async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, An
async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]: async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
result = await _neo4j_connector.execute_query( result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_EDGES, DataConfigRepository.SEARCH_FOR_EDGES,
group_id=end_user_id, end_user_id=end_user_id,
) )
return result return result
async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]:
"""搜索所有实体之间的关系网络group 维度)。"""
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH,
end_user_id=end_user_id,
)
# 对source_node 和 target_node 的 fact_summary进行截取只截取前三条的内容需要提取前三条“来源”
for item in result:
source_fact = item["sourceNode"]["fact_summary"]
target_fact = item["targetNode"]["fact_summary"]
# 截取前三条“来源”
item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else []
item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else []
# 与现有返回风格保持一致,携带搜索类型、数量与详情
data = {
"search_for": "entity_graph",
"num": len(result),
"detials": result,
}
return data
async def analytics_hot_memory_tags( async def analytics_hot_memory_tags(
db: Session, db: Session,

View File

@@ -91,7 +91,7 @@ async def run_pilot_extraction(
dialog = DialogData( dialog = DialogData(
context=context, context=context,
ref_id="pilot_dialog_1", ref_id="pilot_dialog_1",
group_id=str(memory_config.workspace_id), end_user_id=str(memory_config.workspace_id),
user_id=str(memory_config.tenant_id), user_id=str(memory_config.tenant_id),
apply_id=str(memory_config.config_id), apply_id=str(memory_config.config_id),
metadata={"source": "pilot_run", "input_type": "frontend_text"}, metadata={"source": "pilot_run", "input_type": "frontend_text"},

View File

@@ -155,10 +155,10 @@ class MemoryInsightHelper:
""" """
query = """ query = """
MATCH (d:Dialogue) MATCH (d:Dialogue)
WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> '' WHERE d.end_user_id = $end_user_id AND d.created_at IS NOT NULL AND d.created_at <> ''
RETURN d.created_at AS creation_time RETURN d.created_at AS creation_time
""" """
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id) records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id)
if not records: if not records:
return [] return []
@@ -211,17 +211,17 @@ class MemoryInsightHelper:
async def get_social_connections(self) -> dict | None: async def get_social_connections(self) -> dict | None:
"""Find the user with whom the most memories are shared.""" """Find the user with whom the most memories are shared."""
query = """ query = """
MATCH (c1:Chunk {group_id: $group_id}) MATCH (c1:Chunk {end_user_id: $end_user_id})
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement) OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk) OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL WHERE c1.end_user_id <> c2.end_user_id AND s IS NOT NULL AND c2 IS NOT NULL
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements WITH c2.end_user_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
WHERE common_statements > 0 WHERE common_statements > 0
RETURN other_user_id, common_statements RETURN other_user_id, common_statements
ORDER BY common_statements DESC ORDER BY common_statements DESC
LIMIT 1 LIMIT 1
""" """
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id) records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id)
if not records or not records[0].get("other_user_id"): if not records or not records[0].get("other_user_id"):
return None return None
@@ -230,7 +230,7 @@ class MemoryInsightHelper:
time_range_query = """ time_range_query = """
MATCH (c:Chunk) MATCH (c:Chunk)
WHERE c.group_id IN [$user_id, $other_user_id] WHERE c.end_user_id IN [$user_id, $other_user_id]
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
""" """
time_records = await self.neo4j_connector.execute_query( time_records = await self.neo4j_connector.execute_query(
@@ -294,11 +294,11 @@ class UserSummaryHelper:
"""Fetch recent statements authored by the user/group for context.""" """Fetch recent statements authored by the user/group for context."""
query = ( query = (
"MATCH (s:Statement) " "MATCH (s:Statement) "
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL " "WHERE s.end_user_id = $end_user_id AND s.statement IS NOT NULL "
"RETURN s.statement AS statement, s.created_at AS created_at " "RETURN s.statement AS statement, s.created_at AS created_at "
"ORDER BY created_at DESC LIMIT $limit" "ORDER BY created_at DESC LIMIT $limit"
) )
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit) rows = await self.connector.execute_query(query, end_user_id=self.user_id, limit=limit)
records = [] records = []
for r in rows: for r in rows:
try: try:
@@ -1152,7 +1152,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
import re import re
# 创建 UserSummaryHelper 实例 # 创建 UserSummaryHelper 实例
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123")) user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123"))
try: try:
# 1) 收集上下文数据 # 1) 收集上下文数据
@@ -1273,10 +1273,10 @@ async def analytics_node_statistics(
if end_user_id: if end_user_id:
query = f""" query = f"""
MATCH (n:{node_type}) MATCH (n:{node_type})
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
RETURN count(n) as count RETURN count(n) as count
""" """
result = await _neo4j_connector.execute_query(query, group_id=end_user_id) result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id)
else: else:
query = f""" query = f"""
MATCH (n:{node_type}) MATCH (n:{node_type})
@@ -1387,10 +1387,10 @@ async def analytics_memory_types(
# 查询 Statement 节点数量 # 查询 Statement 节点数量
query = """ query = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
RETURN count(n) as count RETURN count(n) as count
""" """
result = await _neo4j_connector.execute_query(query, group_id=end_user_id) result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id)
statement_count = result[0]["count"] if result and len(result) > 0 else 0 statement_count = result[0]["count"] if result and len(result) > 0 else 0
# 取三分之一作为隐性记忆数量 # 取三分之一作为隐性记忆数量
implicit_count = round(statement_count / 3) implicit_count = round(statement_count / 3)
@@ -1504,7 +1504,7 @@ async def analytics_graph_data(
包含节点、边和统计信息的字典 包含节点、边和统计信息的字典
""" """
try: try:
# 1. 获取 group_id # 1. 获取 end_user_id
user_uuid = uuid.UUID(end_user_id) user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db) repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid) end_user = repo.get_by_id(user_uuid)
@@ -1528,7 +1528,7 @@ async def analytics_graph_data(
# 基于中心节点的扩展查询 # 基于中心节点的扩展查询
node_query = f""" node_query = f"""
MATCH path = (center)-[*1..{depth}]-(connected) MATCH path = (center)-[*1..{depth}]-(connected)
WHERE center.group_id = $group_id WHERE center.end_user_id = $end_user_id
AND elementId(center) = $center_node_id AND elementId(center) = $center_node_id
WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes
UNWIND all_nodes as n UNWIND all_nodes as n
@@ -1539,7 +1539,7 @@ async def analytics_graph_data(
LIMIT $limit LIMIT $limit
""" """
node_params = { node_params = {
"group_id": end_user_id, "end_user_id": end_user_id,
"center_node_id": center_node_id, "center_node_id": center_node_id,
"limit": limit "limit": limit
} }
@@ -1547,7 +1547,7 @@ async def analytics_graph_data(
# 按节点类型过滤查询 # 按节点类型过滤查询
node_query = """ node_query = """
MATCH (n) MATCH (n)
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
AND labels(n)[0] IN $node_types AND labels(n)[0] IN $node_types
RETURN RETURN
elementId(n) as id, elementId(n) as id,
@@ -1556,7 +1556,7 @@ async def analytics_graph_data(
LIMIT $limit LIMIT $limit
""" """
node_params = { node_params = {
"group_id": end_user_id, "end_user_id": end_user_id,
"node_types": node_types, "node_types": node_types,
"limit": limit "limit": limit
} }
@@ -1564,7 +1564,7 @@ async def analytics_graph_data(
# 查询所有节点 # 查询所有节点
node_query = """ node_query = """
MATCH (n) MATCH (n)
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
elementId(n) as id, elementId(n) as id,
labels(n)[0] as label, labels(n)[0] as label,
@@ -1572,7 +1572,7 @@ async def analytics_graph_data(
LIMIT $limit LIMIT $limit
""" """
node_params = { node_params = {
"group_id": end_user_id, "end_user_id": end_user_id,
"limit": limit "limit": limit
} }

View File

@@ -382,12 +382,12 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
@celery_app.task(name="app.core.memory.agent.read_message", bind=True) @celery_app.task(name="app.core.memory.agent.read_message", bind=True)
def read_message_task(self, group_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
"""Celery task to process a read message via MemoryAgentService. """Celery task to process a read message via MemoryAgentService.
Args: Args:
group_id: Group ID for the memory agent (also used as end_user_id) end_user_id: Group ID for the memory agent (also used as end_user_id)
message: User message to process message: User message to process
history: Conversation history history: Conversation history
search_switch: Search switch parameter search_switch: Search switch parameter
@@ -408,7 +408,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
from app.services.memory_agent_service import get_end_user_connected_config from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db()) db = next(get_db())
try: try:
connected_config = get_end_user_connected_config(group_id, db) connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id") actual_config_id = connected_config.get("memory_config_id")
finally: finally:
db.close() db.close()
@@ -420,24 +420,42 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
db = next(get_db()) db = next(get_db())
try: try:
service = MemoryAgentService() service = MemoryAgentService()
return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id) return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
finally: finally:
db.close() db.close()
try: try:
result = asyncio.run(_run()) # 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
return { return {
"status": "SUCCESS", "status": "SUCCESS",
"result": result, "result": result,
"group_id": group_id, "end_user_id": end_user_id,
"config_id": config_id, "config_id": config_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"task_id": self.request.id "task_id": self.request.id
} }
except BaseException as e: except BaseException as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'): if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages) detailed_error = "; ".join(error_messages)
@@ -446,7 +464,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
return { return {
"status": "FAILURE", "status": "FAILURE",
"error": detailed_error, "error": detailed_error,
"group_id": group_id, "end_user_id": end_user_id,
"config_id": config_id, "config_id": config_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"task_id": self.request.id "task_id": self.request.id
@@ -454,19 +472,13 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
@celery_app.task(name="app.core.memory.agent.write_message", bind=True) @celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]: def write_message_task(self, end_user_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService. """Celery task to process a write message via MemoryAgentService.
支持两种消息格式:
1. 字符串格式向后兼容message="user: xxx\nassistant: yyy"
2. 结构化消息列表推荐message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}]
Args: Args:
group_id: Group ID for the memory agent (also used as end_user_id) end_user_id: Group ID for the memory agent (also used as end_user_id)
message: Message to write (str or list[dict]) message: Message to write
config_id: Optional configuration ID config_id: Optional configuration ID
storage_type: Storage type (neo4j/rag)
user_rag_memory_id: RAG memory ID
Returns: Returns:
Dict containing the result and metadata Dict containing the result and metadata
@@ -477,7 +489,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
from app.core.logging_config import get_logger from app.core.logging_config import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
logger.info(f"[CELERY WRITE] Starting write task - group_id={group_id}, config_id={config_id}, storage_type={storage_type}") logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
start_time = time.time() start_time = time.time()
# Resolve config_id if None # Resolve config_id if None
@@ -487,7 +499,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
from app.services.memory_agent_service import get_end_user_connected_config from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db()) db = next(get_db())
try: try:
connected_config = get_end_user_connected_config(group_id, db) connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id") actual_config_id = connected_config.get("memory_config_id")
finally: finally:
db.close() db.close()
@@ -500,7 +512,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
try: try:
logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory") logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory")
service = MemoryAgentService() service = MemoryAgentService()
result = await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id) result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
logger.info(f"[CELERY WRITE] Write completed successfully: {result}") logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
return result return result
except Exception as e: except Exception as e:
@@ -510,7 +522,24 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
db.close() db.close()
try: try:
result = asyncio.run(_run()) # 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
@@ -518,13 +547,14 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
return { return {
"status": "SUCCESS", "status": "SUCCESS",
"result": result, "result": result,
"group_id": group_id, "end_user_id": end_user_id,
"config_id": config_id, "config_id": config_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"task_id": self.request.id "task_id": self.request.id
} }
except BaseException as e: except BaseException as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'): if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages) detailed_error = "; ".join(error_messages)
@@ -536,7 +566,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
return { return {
"status": "FAILURE", "status": "FAILURE",
"error": detailed_error, "error": detailed_error,
"group_id": group_id, "end_user_id": end_user_id,
"config_id": config_id, "config_id": config_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"task_id": self.request.id "task_id": self.request.id
@@ -564,53 +594,53 @@ def reflection_timer_task() -> None:
""" """
reflection_engine() reflection_engine()
# unused task
# @celery_app.task(name="app.core.memory.agent.health.check_read_service") @celery_app.task(name="app.core.memory.agent.health.check_read_service")
# def check_read_service_task() -> Dict[str, str]: def check_read_service_task() -> Dict[str, str]:
# """Call read_service and write latest status to Redis. """Call read_service and write latest status to Redis.
# Returns status data dict that gets written to Redis. Returns status data dict that gets written to Redis.
# """ """
# client = redis.Redis( client = redis.Redis(
# host=settings.REDIS_HOST, host=settings.REDIS_HOST,
# port=settings.REDIS_PORT, port=settings.REDIS_PORT,
# db=settings.REDIS_DB, db=settings.REDIS_DB,
# password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
# ) )
# try: try:
# api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
# payload = { payload = {
# "user_id": "健康检查", "user_id": "健康检查",
# "apply_id": "健康检查", "apply_id": "健康检查",
# "group_id": "健康检查", "end_user_id": "健康检查",
# "message": "你好", "message": "你好",
# "history": [], "history": [],
# "search_switch": "2", "search_switch": "2",
# } }
# resp = requests.post(api_url, json=payload, timeout=15) resp = requests.post(api_url, json=payload, timeout=15)
# ok = resp.status_code == 200 ok = resp.status_code == 200
# status = "Success" if ok else "Fail" status = "Success" if ok else "Fail"
# msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
# error = "" if ok else resp.text error = "" if ok else resp.text
# code = 0 if ok else 500 code = 0 if ok else 500
# except Exception as e: except Exception as e:
# status = "Fail" status = "Fail"
# msg = "接口请求失败" msg = "接口请求失败"
# error = str(e) error = str(e)
# code = 500 code = 500
# data = { data = {
# "status": status, "status": status,
# "msg": msg, "msg": msg,
# "error": error, "error": error,
# "code": str(code), "code": str(code),
# "time": str(int(time.time())), "time": str(int(time.time())),
# } }
# client.hset("memsci:health:read_service", mapping=data) client.hset("memsci:health:read_service", mapping=data)
# client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
# return data return data
@celery_app.task(name="app.controllers.memory_storage_controller.search_all") @celery_app.task(name="app.controllers.memory_storage_controller.search_all")
@@ -875,7 +905,24 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
} }
try: try:
result = asyncio.run(_run()) # 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id result["task_id"] = self.request.id
@@ -1002,7 +1049,24 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
} }
try: try:
result = asyncio.run(_run()) # 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id result["task_id"] = self.request.id
@@ -1048,7 +1112,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
# 运行遗忘周期 # 运行遗忘周期
report = await forget_service.trigger_forgetting( report = await forget_service.trigger_forgetting(
db=db, db=db,
group_id=None, # 处理所有组 end_user_id=None, # 处理所有组
config_id=config_id config_id=config_id
) )
@@ -1078,4 +1142,11 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
"duration_seconds": duration "duration_seconds": duration
} }
return asyncio.run(_run()) # 运行异步函数
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
return result
finally:
loop.close()