Merge branch 'release/v0.2.3' into develop

# Conflicts:
#	api/app/core/agent/langchain_agent.py
#	api/app/core/memory/agent/langgraph_graph/write_graph.py
#	api/app/repositories/neo4j/graph_saver.py
#	api/app/services/draft_run_service.py
This commit is contained in:
Mark
2026-02-06 14:48:50 +08:00
45 changed files with 973 additions and 850 deletions

View File

@@ -196,6 +196,11 @@ def update_config(
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try:
svc = DataConfigService(db)

View File

@@ -52,6 +52,7 @@ from app.services.ontology_service import OntologyService
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.validation.owl_validator import OWLValidator
from app.services.model_service import ModelConfigService
from app.repositories.ontology_scene_repository import OntologySceneRepository
api_logger = get_api_logger()
@@ -116,27 +117,35 @@ def _get_ontology_service(
detail=f"找不到指定的LLM模型: {llm_id}"
)
# 验证模型配置了API密钥
if not model_config.api_keys:
logger.error(f"Model {llm_id} has no API key configuration")
# 通过 Repository 获取可用的 API Key负载均衡逻辑由 Repository 处理)
from app.repositories.model_repository import ModelApiKeyRepository
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
if not api_keys:
logger.error(f"Model {llm_id} has no active API key")
raise HTTPException(
status_code=400,
detail="指定的LLM模型没有配置API密钥"
detail="指定的LLM模型没有可用的API密钥"
)
api_key_config = api_keys[0]
api_key_config = model_config.api_keys[0]
is_composite = getattr(model_config, 'is_composite', False)
logger.info(
f"Using specified model - user: {current_user.id}, "
f"model_id: {llm_id}, model_name: {api_key_config.model_name}"
f"model_id: {llm_id}, model_name: {api_key_config.model_name}, "
f"is_composite: {is_composite}, api_key_id: {api_key_config.id}"
)
# 创建模型配置对象
from app.core.models.base import RedBearModelConfig
# 对于组合模型,使用 API Key 的 provider否则使用 model_config 的 provider
actual_provider = api_key_config.provider if is_composite else (
getattr(model_config, 'provider', None) or "openai"
)
llm_model_config = RedBearModelConfig(
model_name=api_key_config.model_name,
provider=model_config.provider if hasattr(model_config, 'provider') else "openai",
provider=actual_provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
max_retries=3,
@@ -648,6 +657,46 @@ async def delete_scene(
return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e))
@router.get("/scenes/simple", response_model=ApiResponse)
async def get_scenes_simple(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取场景简单列表(轻量级,用于下拉选择)
仅返回 scene_id 和 scene_name不加载关联数据响应速度快。
适用于前端下拉选择场景的场景。
Args:
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含场景简单列表
Examples:
GET /scenes/simple
返回: {"data": [{"scene_id": "xxx", "scene_name": "场景1"}, ...]}
"""
api_logger.info(f"Simple scene list requested by user {current_user.id}")
try:
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
repo = OntologySceneRepository(db)
scenes = repo.get_simple_list(workspace_id)
api_logger.info(f"Simple scene list retrieved: {len(scenes)} scenes")
return success(data=scenes, msg="查询成功")
except Exception as e:
api_logger.error(f"Failed to get simple scene list: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
@router.get("/scenes", response_model=ApiResponse)
async def get_scenes(
workspace_id: Optional[str] = None,

View File

@@ -7,30 +7,21 @@ LangChain Agent 封装
- 支持流式输出
- 使用 RedBearLLM 支持多提供商
"""
import os
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from app.utils.config_utils import resolve_config_id
logger = get_business_logger()
@@ -289,105 +280,6 @@ class LangChainAgent:
return content_parts
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db())
#TODO: 魔法数字
scope=6
try:
repo = LongTermMemoryRepository(db)
await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
# Handle case where no session exists in Redis (returns False)
if not result or result is False:
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
return
if type=="chunk" or type=="aggregate":
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'写入短长期:')
else:
# TODO: This branch handles type="time" strategy, currently unused.
# Will be activated when time-based long-term storage is implemented.
# TODO: 魔法数字 - extract 5 to a constant
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
return
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
finally:
db.close()
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
db = next(get_db())
try:
actual_config_id=resolve_config_id(actual_config_id, db)
if storage_type == "rag":
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if user_message:
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j"
user_rag_memory_id # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def chat(
self,
message: str,
@@ -520,14 +412,7 @@ class LangChainAgent:
elapsed_time = time.time() - start_time
if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
response = {
"content": content,
"model": self.model_name,
@@ -710,15 +595,7 @@ class LangChainAgent:
yield total_tokens
break
if memory_flag:
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise

View File

@@ -1,8 +1,9 @@
import json
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
@@ -10,46 +11,108 @@ from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.db import get_db_context, get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
actual_config_id, long_term_messages=[]):
"""
写入记忆(支持结构化消息)
async def write_messages(end_user_id,langchain_messages,memory_config):
'''
写入数据到neo4j
Args:
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
memory_config: 内存配置对象
langchain_messages原始数据LIST
'''
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
db = next(get_db())
try:
actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if isinstance(user_message, str) and user_message.strip() != "":
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if isinstance(ai_message, str) and ai_message.strip() != "":
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果提供了 long_term_messages使用它替代 structured_messages
if long_term_messages and isinstance(long_term_messages, list):
structured_messages = long_term_messages
elif long_term_messages and isinstance(long_term_messages, str):
# 如果是 JSON 字符串,先解析
try:
structured_messages = json.loads(long_term_messages)
except json.JSONDecodeError:
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: JSON 字符串格式的消息列表
str(actual_config_id), # config_id: 配置ID字符串
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------')
else:
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config
}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
# TODO删除
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
print(contents)
except Exception as e:
import traceback
traceback.print_exc()
'''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
'''
@@ -61,25 +124,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
scope窗口大小
'''
scope=scope
redis_messages = []
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope):
print(is_end_user_id)
is_end_user_id += 1
langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
print('写入长期记忆并且设置为0')
print(is_end_user_id)
formatted_messages = await chat_data_format(redis_messages)
print(100*'-')
print(formatted_messages)
print(100*'-')
await write_messages(end_user_id, formatted_messages, memory_config)
count_store.update_sessions_count(end_user_id, 0, '')
logger.info('写入长期记忆NEO4J')
formatted_messages = (redis_messages)
# 获取 config_id如果 memory_config 是对象,提取 config_id否则直接使用
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
else:
config_id = memory_config
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
config_id, formatted_messages)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
@@ -93,12 +157,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
memory_config: 内存配置对象
'''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
return
format_messages = await chat_data_format(long_time_data)
format_messages = (long_time_data)
messages=[]
memory_config=memory_config.config_id
for i in format_messages:
message=json.loads(i['Query'])
messages+= message
if format_messages!=[]:
await write_messages(end_user_id, format_messages, memory_config)
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
'''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
@@ -109,13 +176,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: 内存配置对象
"""
try:
# 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
# Handle case where no session exists in Redis (returns False or empty)
if not result or result is False:
history = await format_parsing(result)
if not result:
history = []
else:
history = await format_parsing(result)
@@ -154,7 +220,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
}
if not structured.is_same_event:
logger.info(result_dict)
await write_messages(end_user_id, output_value, memory_config)
await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value)
return result_dict
except Exception as e:

View File

@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
清理后的数据
"""
# 需要过滤的字段列表
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
}
if isinstance(data, dict):

View File

@@ -1,8 +1,6 @@
import json
from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'):
"""
格式化解析消息列表
@@ -26,13 +24,13 @@ async def format_parsing(messages: list,type:str='string'):
role = content['role']
content = content['content']
if type == "string":
if role == 'human':
if role == 'human' or role=="user":
content = '用户:' + content
else:
content = 'AI:' + content
result.append(content)
if type == "dict":
if role == 'human':
if type == "dict" :
if role == 'human' or role=="user":
user.append( content)
else:
ai.append(content)
@@ -57,33 +55,7 @@ async def messages_parse(messages: list | dict):
for key, values in zip(user, ai):
database.append({key, values})
return database
async def chat_data_format(messages: list | dict):
"""
将消息格式化为 LangChain 消息格式
Args:
messages: 消息列表或字典
Returns:
LangChain 消息列表
"""
langchain_messages = []
if isinstance(messages, list):
for msg in messages:
if 'role' in msg.keys():
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
if "Query" in msg.keys():
langchain_messages.append(HumanMessage(content=msg['Query']))
langchain_messages.append(AIMessage(content=msg['Answer']))
if isinstance(messages, dict):
if messages['type'] == 'human':
langchain_messages.append(HumanMessage(content=messages['content']))
elif messages['type'] == 'ai':
langchain_messages.append(AIMessage(content=messages['content']))
return langchain_messages
async def agent_chat_messages(user_content,ai_content):
messages = [

View File

@@ -1,13 +1,18 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
@@ -37,76 +42,61 @@ async def make_write_graph():
yield graph
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
"""Dispatch long-term memory storage to Celery background tasks.
Args:
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
langchain_messages: List of messages to store
memory_config: Memory configuration ID (string)
end_user_id: End user identifier
scope: Window size for 'chunk' strategy (default: 6)
"""
from app.tasks import (
long_term_storage_window_task,
# TODO: Uncomment when implemented
# long_term_storage_time_task,
# long_term_storage_aggregate_task,
)
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# Convert config to string if needed
config_id = str(memory_config) if memory_config else ''
if long_term_type == 'chunk':
# Strategy 1: Window-based batching (6 rounds of dialogue)
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
long_term_storage_window_task.delay(
end_user_id=end_user_id,
langchain_messages=langchain_messages,
config_id=config_id,
scope=scope
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, (langchain_messages))
# 获取数据库会话
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
# TODO: Uncomment when time-based strategy is fully implemented
# elif long_term_type == 'time':
# # Strategy 2: Time-based retrieval
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
# long_term_storage_time_task.delay(
# end_user_id=end_user_id,
# config_id=config_id,
# time_window=5
# )
# TODO: Uncomment when aggregate strategy is fully implemented
# elif long_term_type == 'aggregate':
# # Strategy 3: Aggregate judgment (deduplication)
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
# long_term_storage_aggregate_task.delay(
# end_user_id=end_user_id,
# langchain_messages=langchain_messages,
# config_id=config_id
# )
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
# async def main():
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五好开心啊"
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "你也这么觉得,我也是耶"
# "content": "耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":

View File

@@ -294,6 +294,7 @@ class RedisCountStore:
"""
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="count")
index_key = f'session:count:index:{end_user_id}' # 索引键
pipe = self.r.pipeline()
pipe.hset(key, mapping={
@@ -304,6 +305,10 @@ class RedisCountStore:
"starttime": get_current_timestamp()
})
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
# 创建索引end_user_id -> session_id 映射
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
@@ -320,31 +325,47 @@ class RedisCountStore:
list 或 False: 如果找到返回 [count, messages],否则返回 False
"""
try:
search_pattern = 'session:count:*'
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
for key in self.r.keys(search_pattern):
data = self.r.hgetall(key)
if not data:
continue
if data.get('end_user_id') == end_user_id:
count = data.get('count')
messages_str = data.get('messages')
if count is not None:
messages = deserialize_messages(messages_str)
return [int(count), messages]
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
self.r.delete(index_key)
return False
except Exception as type_error:
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
return False
# 直接获取数据
key = generate_session_key(session_id, key_type="count")
data = self.r.hgetall(key)
if not data:
# 索引存在但数据不存在,清理索引
self.r.delete(index_key)
return False
count = data.get('count')
messages_str = data.get('messages')
if count is not None:
messages = deserialize_messages(messages_str)
return [int(count), messages]
return False
except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}")
return False
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool:
"""
通过 end_user_id 修改访问次数统计
通过 end_user_id 修改访问次数统计(优化版:使用索引)
Args:
end_user_id: 终端用户ID
@@ -355,23 +376,39 @@ class RedisCountStore:
bool: 更新成功返回 True未找到记录返回 False
"""
try:
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
# 索引键类型错误,删除并返回 False
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
self.r.delete(index_key)
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as type_error:
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
# 直接更新数据
key = generate_session_key(session_id, key_type="count")
messages_str = serialize_messages(messages)
search_pattern = 'session:count:*'
for key in self.r.keys(search_pattern):
data = self.r.hgetall(key)
if not data:
continue
if data.get('end_user_id') == end_user_id:
self.r.hset(key, 'count', int(new_count))
self.r.hset(key, 'messages', messages_str)
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
pipe = self.r.pipeline()
pipe.hset(key, 'count', int(new_count))
pipe.hset(key, 'messages', messages_str)
result = pipe.execute()
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}")
return False

View File

@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
This module provides the main write function for executing the knowledge extraction
pipeline. Only MemoryConfig is needed - clients are constructed internally.
"""
import asyncio
import time
from datetime import datetime
@@ -124,23 +125,48 @@ async def write(
except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True)
# 添加死锁重试机制
max_retries = 3
retry_delay = 1 # 秒
for attempt in range(max_retries):
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
break
else:
logger.warning("Failed to save some data to Neo4j")
if attempt < max_retries - 1:
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
except Exception as e:
error_msg = str(e)
# 检查是否是死锁错误
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
if attempt < max_retries - 1:
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
else:
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
raise
else:
# 非死锁错误,直接抛出
raise
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
finally:
await neo4j_connector.close()
except Exception as e:
logger.error(f"Error closing Neo4j connector: {e}")
log_time("Neo4j Database Save", time.time() - step_start, log_file)

View File

@@ -413,7 +413,8 @@ class ExtractedEntityNode(Node):
description="Entity aliases - alternative names for this entity"
)
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")

View File

@@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
if len(desc_b) > len(desc_a):
canonical.description = desc_b
# 合并事实摘要:统一保留一个“实体: name”行来源行去重保序
fact_a = getattr(canonical, "fact_summary", "") or ""
fact_b = getattr(ent, "fact_summary", "") or ""
def _extract_sources(txt: str) -> List[str]:
sources: List[str] = []
if not txt:
return sources
for line in str(txt).splitlines():
ln = line.strip()
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_a = getattr(canonical, "fact_summary", "") or ""
# fact_b = getattr(ent, "fact_summary", "") or ""
# def _extract_sources(txt: str) -> List[str]:
# sources: List[str] = []
# if not txt:
# return sources
# for line in str(txt).splitlines():
# ln = line.strip()
# 支持“来源:”或“来源:”前缀
m = re.match(r"^来源[:]\s*(.+)$", ln)
if m:
content = m.group(1).strip()
if content:
sources.append(content)
# m = re.match(r"^来源[:]\s*(.+)$", ln)
# if m:
# content = m.group(1).strip()
# if content:
# sources.append(content)
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
if not sources and txt.strip():
sources.append(txt.strip())
return sources
# if not sources and txt.strip():
# sources.append(txt.strip())
# return sources
try:
src_a = _extract_sources(fact_a)
src_b = _extract_sources(fact_b)
seen = set()
merged_sources: List[str] = []
for s in src_a + src_b:
if s and s not in seen:
seen.add(s)
merged_sources.append(s)
if merged_sources:
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
elif fact_b and not fact_a:
canonical.fact_summary = fact_b
# src_a = _extract_sources(fact_a)
# src_b = _extract_sources(fact_b)
# seen = set()
# merged_sources: List[str] = []
# for s in src_a + src_b:
# if s and s not in seen:
# seen.add(s)
# merged_sources.append(s)
# if merged_sources:
# name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
# canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
# elif fact_b and not fact_a:
# canonical.fact_summary = fact_b
pass
except Exception:
# 兜底:若解析失败,保留较长文本
if len(fact_b) > len(fact_a):
canonical.fact_summary = fact_b
# if len(fact_b) > len(fact_a):
# canonical.fact_summary = fact_b
pass
except Exception:
pass

View File

@@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
desc_a = (getattr(a, "description", "") or "")
desc_b = (getattr(b, "description", "") or "")
fact_a = (getattr(a, "fact_summary", "") or "")
fact_b = (getattr(b, "fact_summary", "") or "")
score_a = len(desc_a) + len(fact_a)
score_b = len(desc_b) + len(fact_b)
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_a = (getattr(a, "fact_summary", "") or "")
# fact_b = (getattr(b, "fact_summary", "") or "")
# score_a = len(desc_a) + len(fact_a)
# score_b = len(desc_b) + len(fact_b)
score_a = len(desc_a)
score_b = len(desc_b)
if score_a != score_b:
return 0 if score_a >= score_b else 1
return 0
@@ -189,7 +192,8 @@ async def _judge_pair(
"entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None),
}
entity_b = {
@@ -197,7 +201,8 @@ async def _judge_pair(
"entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None),
}
# 5. 渲染LLM提示词用工具函数填充模板包含实体信息、上下文、输出格式
@@ -248,7 +253,8 @@ async def _judge_pair_disamb(
"entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None),
}
entity_b = {
@@ -256,7 +262,8 @@ async def _judge_pair_disamb(
"entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None),
}
prompt = render_entity_dedup_prompt(

View File

@@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
description=row.get("description") or "",
aliases=row.get("aliases") or [],
name_embedding=row.get("name_embedding") or [],
fact_summary=row.get("fact_summary") or "",
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=row.get("fact_summary") or "",
connect_strength=row.get("connect_strength") or "",
)

View File

@@ -1088,7 +1088,8 @@ class ExtractionOrchestrator:
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
example=getattr(entity, 'example', ''), # 新增:传递示例字段
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None),

View File

@@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
key=lambda eid: (
_strength_rank(eid),
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
0 # 临时占位
),
reverse=True
)

View File

@@ -9,7 +9,8 @@
- 类型: "{{ entity_a.entity_type | default('') }}"
- 描述: "{{ entity_a.description | default('') }}"
- 别名: {{ entity_a.aliases | default([]) }}
- 摘要: "{{ entity_a.fact_summary | default('') }}"
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #}
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
实体B:
@@ -17,7 +18,8 @@
- 类型: "{{ entity_b.entity_type | default('') }}"
- 描述: "{{ entity_b.description | default('') }}"
- 别名: {{ entity_b.aliases | default([]) }}
- 摘要: "{{ entity_b.fact_summary | default('') }}"
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #}
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
上下文:

View File

@@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
import logging
logger = logging.getLogger(__name__)
def knowledge_retrieval(
query: str,
@@ -62,7 +64,15 @@ def knowledge_retrieval(
merge_strategy = config.get("merge_strategy", "weight")
reranker_id = config.get("reranker_id")
reranker_top_k = config.get("reranker_top_k", 1024)
use_graph = config.get("use_graph", "false").lower() == "true"
# use_graph = config.get("use_graph", "false").lower() == "true"
use_graph_value = config.get("use_graph", False)
if isinstance(use_graph_value, bool):
use_graph = use_graph_value
elif isinstance(use_graph_value, str):
use_graph = use_graph_value.lower() in ("true", "1", "yes")
else:
use_graph = False
file_names_filter = []
if user_ids:
@@ -159,13 +169,29 @@ def knowledge_retrieval(
# Use the specified reranker for re-ranking
if reranker_id:
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
# use graph
try:
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
except Exception as rerank_error:
# If reranker fails, log warning and continue with original results
logger.warning(
"Reranker failed, falling back to original results",
extra={
"reranker_id": reranker_id,
"query": query,
"doc_count": len(all_results),
"error": str(rerank_error),
},
)
if use_graph:
from app.core.rag.common.settings import kg_retriever
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
all_results.insert(0, doc)
try:
from app.core.rag.common.settings import kg_retriever
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
all_results.insert(0, doc)
except Exception as graph_error:
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
return all_results
except Exception as e:

View File

@@ -25,6 +25,18 @@ class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
if self.response_metadata:
usage = self.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
def _output_types(self) -> dict[str, VariableType]:
outputs = {}
@@ -180,6 +192,7 @@ class ParameterExtractorNode(BaseNode):
])
model_resp = await llm.ainvoke(messages)
self.response_metadata = model_resp.response_metadata
result = json_repair.repair_json(model_resp.content, return_objects=True)
logger.info(f"node: {self.node_id} get params:{result}")

View File

@@ -25,6 +25,18 @@ class QuestionClassifierNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {}
self.response_metadata = {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
if self.response_metadata:
usage = self.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
def _output_types(self) -> dict[str, VariableType]:
return {
@@ -120,6 +132,7 @@ class QuestionClassifierNode(BaseNode):
response = await llm.ainvoke(messages)
result = response.content.strip()
self.response_metadata = response.response_metadata
if result in category_names:
category = result

View File

@@ -86,7 +86,8 @@ class MemoryConfigRepository:
n.description AS description,
n.entity_type AS entity_type,
n.name AS name,
COALESCE(n.fact_summary, '') AS fact_summary,
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
// COALESCE(n.fact_summary, '') AS fact_summary,
n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
@@ -279,6 +280,9 @@ class MemoryConfigRepository:
if update.config_desc is not None:
db_config.config_desc = update.config_desc
has_update = True
if update.scene_id is not None:
db_config.scene_id = update.scene_id
has_update = True
if not has_update:
raise ValueError("No fields to update")
@@ -650,28 +654,32 @@ class MemoryConfigRepository:
raise
@staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
"""获取所有配置参数
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]:
"""获取所有配置参数,包含关联的场景名称
Args:
db: 数据库会话
workspace_id: 工作空间ID用于过滤查询结果
Returns:
List[MemoryConfig]: 配置列表
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
"""
from app.models.ontology_scene import OntologyScene
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
try:
query = db.query(MemoryConfig)
query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin(
OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id
)
if workspace_id:
query = query.filter(MemoryConfig.workspace_id == workspace_id)
configs = query.order_by(desc(MemoryConfig.updated_at)).all()
results = query.order_by(desc(MemoryConfig.updated_at)).all()
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
return configs
db_logger.debug(f"配置列表查询成功: 数量={len(results)}")
return results
except Exception as e:
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")

View File

@@ -79,7 +79,8 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
try:
edges: List[dict] = []
for s in summaries:
for chunk_id in getattr(s, "chunk_ids", []) or []:
chunk_ids = getattr(s, "chunk_ids", []) or []
for chunk_id in chunk_ids:
edges.append({
"summary_id": s.id,
"chunk_id": chunk_id,
@@ -91,12 +92,11 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
if not edges:
return []
result = await connector.execute_query(
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
edges=edges
)
created = [record.get("uuid") for record in result] if result else []
return created
except Exception:
except Exception as e:
return None

View File

@@ -217,8 +217,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
summaries=flattened
)
created_ids = [record.get("uuid") for record in result]
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
return created_ids
except Exception:
except Exception as e:
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
return None

View File

@@ -101,10 +101,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
e.name_embedding = CASE
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
ELSE e.name_embedding END,
e.fact_summary = CASE
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
THEN entity.fact_summary ELSE e.fact_summary END,
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
// e.fact_summary = CASE
// WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
// AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
// THEN entity.fact_summary ELSE e.fact_summary END,
e.connect_strength = CASE
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
ELSE CASE
@@ -321,7 +322,8 @@ RETURN e.id AS id,
e.description AS description,
e.aliases AS aliases,
e.name_embedding AS name_embedding,
COALESCE(e.fact_summary, '') AS fact_summary,
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
// COALESCE(e.fact_summary, '') AS fact_summary,
e.connect_strength AS connect_strength,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT c.id) AS chunk_ids,
@@ -1002,3 +1004,58 @@ RETURN DISTINCT
x.statement as statement,x.created_at as created_at
"""
Graph_Node_query = """
MATCH (n:MemorySummary)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
0 AS priority
LIMIT $limit
UNION ALL
MATCH (n:Dialogue)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
1 AS priority
LIMIT 1
UNION ALL
MATCH (n:Statement)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
1 AS priority
LIMIT $limit
UNION ALL
MATCH (n:ExtractedEntity)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
2 AS priority
LIMIT $limit
UNION ALL
MATCH (n:Chunk)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
3 AS priority
LIMIT $limit
"""

View File

@@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import (
ExtractedEntityNode,
EntityEntityEdge,
)
import logging
logger = logging.getLogger(__name__)
async def save_entities_and_relationships(
entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge],
@@ -41,8 +42,8 @@ async def save_entities_and_relationships(
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat(),
'expired_at': edge.expired_at.isoformat(),
'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id,
'end_user_id': edge.end_user_id,
}
@@ -147,14 +148,14 @@ async def save_statement_entity_edges(
async def save_dialog_and_statements_to_neo4j(
dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
@@ -171,40 +172,126 @@ async def save_dialog_and_statements_to_neo4j(
Returns:
bool: True if successful, False otherwise
"""
try:
# Save all dialogue nodes in batch
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector)
if dialogue_uuids:
# 定义事务函数,将所有写操作放在一个事务中
async def _save_all_in_transaction(tx):
"""在单个事务中执行所有保存操作,避免死锁"""
results = {}
# 1. Save all dialogue nodes in batch
if dialogue_nodes:
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE
dialogue_data = [node.model_dump() for node in dialogue_nodes]
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
dialogue_uuids = [record["uuid"] async for record in result]
results['dialogues'] = dialogue_uuids
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
else:
print("Failed to save dialogues to Neo4j")
return False
# Save all chunk nodes in batch
await save_chunk_nodes(chunk_nodes, connector)
# 2. Save all chunk nodes in batch
if chunk_nodes:
from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE
chunk_data = [node.model_dump() for node in chunk_nodes]
result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data)
chunk_uuids = [record["uuid"] async for record in result]
results['chunks'] = chunk_uuids
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
# Save all statement nodes in batch
# 3. Save all statement nodes in batch
if statement_nodes:
statement_uuids = await add_statement_nodes(statement_nodes, connector)
if statement_uuids:
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
else:
print("Failed to save statement nodes to Neo4j")
return False
else:
print("No statement nodes to save")
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
statement_data = [node.model_dump() for node in statement_nodes]
result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data)
statement_uuids = [record["uuid"] async for record in result]
results['statements'] = statement_uuids
logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
# Save entities and relationships
await save_entities_and_relationships(entity_nodes, entity_edges, connector)
print("Successfully saved entities and relationships to Neo4j")
# 4. Save entities
if entity_nodes:
from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
entity_data = [entity.model_dump() for entity in entity_nodes]
result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data)
entity_uuids = [record["uuid"] async for record in result]
results['entities'] = entity_uuids
logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
# Save new edges
await save_statement_chunk_edges(statement_chunk_edges, connector)
await save_statement_entity_edges(statement_entity_edges, connector)
# 5. Create entity relationships
if entity_edges:
from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE
relationship_data = []
for edge in entity_edges:
relationship_data.append({
'source_id': edge.source,
'target_id': edge.target,
'predicate': edge.relation_type,
'statement_id': edge.source_statement_id,
'value': edge.relation_value,
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id,
'end_user_id': edge.end_user_id,
})
result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data)
rel_uuids = [record["uuid"] async for record in result]
results['entity_relationships'] = rel_uuids
logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j")
# 6. Save statement-chunk edges
if statement_chunk_edges:
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE
sc_edge_data = []
for edge in statement_chunk_edges:
sc_edge_data.append({
"id": edge.id,
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
})
result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data)
sc_uuids = [record["uuid"] async for record in result]
results['statement_chunk_edges'] = sc_uuids
logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j")
# 7. Save statement-entity edges
if statement_entity_edges:
from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE
se_edge_data = []
for edge in statement_entity_edges:
se_edge_data.append({
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
"connect_strength": getattr(edge, "connect_strength", "strong"),
})
result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data)
se_uuids = [record["uuid"] async for record in result]
results['statement_entity_edges'] = se_uuids
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
return results
try:
# 使用显式写事务执行所有操作,避免死锁
results = await connector.execute_write_transaction(_save_all_in_transaction)
summary = {
key: len(value)
for key, value in results.items()
if isinstance(value, (list, tuple, set))
}
logger.info("Transaction completed. Summary: %s", summary)
logger.debug("Full transaction results: %r", results)
return True
except Exception as e:
logger.error(f"Neo4j integration error: {e}", exc_info=True)
print(f"Neo4j integration error: {e}")
print("Continuing without database storage...")
return False
return False

View File

@@ -392,3 +392,48 @@ class OntologySceneRepository:
exc_info=True
)
raise
def get_simple_list(self, workspace_id: UUID) -> List[dict]:
"""获取场景简单列表仅包含scene_id和scene_name用于下拉选择
这是一个轻量级查询不加载关联的classes响应速度快。
Args:
workspace_id: 工作空间ID
Returns:
List[dict]: 场景简单列表每项包含scene_id和scene_name
Examples:
>>> repo = OntologySceneRepository(db)
>>> scenes = repo.get_simple_list(workspace_id)
>>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...]
"""
try:
logger.debug(f"Getting simple scene list for workspace: {workspace_id}")
# 只查询需要的字段,不加载关联数据
results = self.db.query(
OntologyScene.scene_id,
OntologyScene.scene_name
).filter(
OntologyScene.workspace_id == workspace_id
).order_by(
OntologyScene.updated_at.desc()
).all()
scenes = [
{"scene_id": str(r.scene_id), "scene_name": r.scene_name}
for r in results
]
logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}")
return scenes
except Exception as e:
logger.error(
f"Failed to get simple scene list: {str(e)}",
exc_info=True
)
raise

View File

@@ -1,3 +1,4 @@
from abc import ABC
from typing import Optional
from pydantic import BaseModel
@@ -14,4 +15,15 @@ class UserInput(BaseModel):
class Write_UserInput(BaseModel):
messages: list[dict]
end_user_id: str
config_id: Optional[str] = None
config_id: Optional[str] = None
class AgentMemory_Long_Term(ABC):
"""长期记忆配置常量"""
STORAGE_NEO4J = "neo4j"
STORAGE_RAG = "rag"
STRATEGY_AGGREGATE = "aggregate"
STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6

View File

@@ -248,8 +248,9 @@ class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id: Union[uuid.UUID, int, str] = None
config_name: str = Field("配置名称", description="配置名称(字符串)")
config_desc: str = Field("配置描述", description="配置描述(字符串)")
config_name: Optional[str] = Field(None, description="配置名称(字符串)")
config_desc: Optional[str] = Field(None, description="配置描述(字符串)")
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID")
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型

View File

@@ -114,6 +114,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
result = task_service.get_task_memory_read_result(task.id)
status = result.get("status")
logger.info(f"读取任务状态:{status}")
if memory_content:
memory_content = memory_content['answer']
finally:
db.close()
@@ -127,11 +129,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"content_length": len(str(memory_content))
}
)
# 检查是否有有效内容
if not memory_content or str(memory_content).strip() == "" or "answer" in str(memory_content) and str(memory_content).count("''") > 0:
return "未找到相关的历史记忆。请直接回答用户的问题,不要再次调用此工具。"
return f"检索到以下历史记忆:\n\n{memory_content}"
except Exception as e:
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})

View File

@@ -183,11 +183,11 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
configs = MemoryConfigRepository.get_all(self.db, workspace_id)
results = MemoryConfigRepository.get_all(self.db, workspace_id)
# 将 ORM 对象转换为字典列表
data_list = []
for config in configs:
for config, scene_name in results:
# 安全地转换 user_id 为 int
config_id_old = None
if config.config_id_old:
@@ -209,7 +209,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"end_user_id": config.end_user_id,
"config_id_old": config_id_old,
"apply_id": config.apply_id,
"scene_id": config.scene_id,
"scene_id": str(config.scene_id) if config.scene_id else None,
"scene_name": scene_name, # 新增:场景名称
"llm_id": config.llm_id,
"embedding_id": config.embedding_id,
"rerank_id": config.rerank_id,
@@ -637,10 +638,9 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
if m < 1:
latest_relative = "刚刚"
elif m < 60:
latest_relative = f"{m}分钟"
latest_relative = "一会"
else:
h = int(m // 60)
latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前"
latest_relative = "较早前"
except Exception:
pass

View File

@@ -15,6 +15,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.conversation_repository import ConversationRepository
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.cypher_queries import Graph_Node_query
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService
@@ -1521,7 +1522,6 @@ async def analytics_graph_data(
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
return {
@@ -1575,21 +1575,11 @@ async def analytics_graph_data(
}
else:
# 查询所有节点
node_query = """
MATCH (n)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) as id,
labels(n)[0] as label,
properties(n) as properties
LIMIT $limit
"""
node_query=Graph_Node_query
node_params = {
"end_user_id": end_user_id,
"limit": limit
}
# 执行节点查询
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
@@ -1600,9 +1590,9 @@ async def analytics_graph_data(
for record in node_results:
node_id = record["id"]
node_label = record["label"]
node_labels = record.get("labels", [])
node_label = node_labels[0] if node_labels else "Unknown"
node_props = record["properties"]
# 根据节点类型提取需要的属性字段
filtered_props = await _extract_node_properties(node_label, node_props,node_id)

View File

@@ -5,42 +5,68 @@ Shared utilities for configuration handling to avoid circular imports.
"""
from uuid import UUID
from sqlalchemy.orm import Session
import uuid as uuid_module
def resolve_config_id(config_id: UUID | int|str, db: Session) -> UUID:
def resolve_config_id(config_id: UUID | int | str, db: Session) -> UUID:
"""
解析 config_id如果是整数则通过 config_id_old 查找对应的 UUID
解析 config_id支持 UUID、UUID字符串、整数等多种格式
Args:
config_id: 配置IDUUID 或整数)
config_id: 配置IDUUID、UUID字符串 整数)
db: 数据库会话
Returns:
UUID: 解析后的配置ID
Raises:
ValueError: 当找不到对应的配置时
ValueError: 当找不到对应的配置时或格式无效时
"""
from app.models.memory_config_model import MemoryConfig
if isinstance(config_id, UUID):
# 1. 如果已经是 UUID 类型,直接返回
if isinstance(config_id, UUID):
return config_id
if isinstance(config_id, str) and len(config_id)<=6:
memory_config = db.query(MemoryConfig).filter(
MemoryConfig.config_id_old == int(config_id)
).first()
print(memory_config)
if not memory_config:
raise ValueError(f"STR 未找到 config_id_old={config_id} 对应的配置")
return memory_config.config_id
# 2. 如果是字符串类型
if isinstance(config_id, str):
config_id_stripped = config_id.strip()
# 2.1 尝试解析为 UUID标准 UUID 字符串长度为 36
try:
return uuid_module.UUID(config_id_stripped)
except ValueError:
pass
# 2.2 尝试解析为整数(用于查询 config_id_old
try:
old_id = int(config_id_stripped)
if old_id > 0:
memory_config = db.query(MemoryConfig).filter(
MemoryConfig.config_id_old == old_id
).first()
if not memory_config:
raise ValueError(f"未找到 config_id_old={old_id} 对应的配置")
return memory_config.config_id
except ValueError:
pass
# 2.3 无法解析的字符串格式
raise ValueError(f"无效的 config_id 格式: '{config_id}'(必须是 UUID 或正整数)")
# 3. 如果是整数类型,通过 config_id_old 查找
if isinstance(config_id, int):
if config_id <= 0:
raise ValueError(f"config_id 必须是正整数: {config_id}")
memory_config = db.query(MemoryConfig).filter(
MemoryConfig.config_id_old == config_id
).first()
if not memory_config:
raise ValueError(f"INT 未找到 config_id_old={config_id} 对应的配置")
raise ValueError(f"未找到 config_id_old={config_id} 对应的配置")
return memory_config.config_id
return config_id
# 4. 不支持的类型
raise ValueError(f"不支持的 config_id 类型: {type(config_id).__name__}")

View File

@@ -13,6 +13,14 @@
"@antv/layout": "^1.2.14-beta.8",
"@antv/x6": "^3.0.1",
"@antv/x6-react-shape": "^3.0.1",
"@codemirror/lang-cpp": "^6.0.3",
"@codemirror/lang-java": "^6.0.2",
"@codemirror/lang-javascript": "^6.2.4",
"@codemirror/lang-python": "^6.2.1",
"@codemirror/lang-rust": "^6.0.2",
"@codemirror/state": "^6.5.4",
"@codemirror/theme-one-dark": "^6.1.3",
"@codemirror/view": "^6.39.12",
"@dnd-kit/core": "^6.3.1",
"@dnd-kit/modifiers": "^9.0.0",
"@dnd-kit/sortable": "^10.0.0",
@@ -25,6 +33,7 @@
"antd": "^5.27.4",
"axios": "^1.12.2",
"clsx": "^2.1.1",
"codemirror": "^6.0.2",
"copy-to-clipboard": "^3.3.3",
"crypto-js": "^4.2.0",
"dayjs": "^1.11.18",
@@ -55,6 +64,7 @@
"@tailwindcss/postcss": "^4.1.14",
"@tailwindcss/typography": "^0.5.19",
"@tailwindcss/vite": "^4.1.14",
"@types/codemirror": "^5.60.17",
"@types/crypto-js": "^4.2.2",
"@types/js-yaml": "^4.0.9",
"@types/node": "^24.6.0",

View File

@@ -8,6 +8,7 @@ import { request } from '@/utils/request'
import type { Query, OntologyModalData, OntologyClassModalData, OntologyClassExtractModalData, OntologyExportModalData } from '@/views/Ontology/types'
// Scene list
export const getOntologyScenesSimpleUrl = '/memory/ontology/scenes/simple'
export const getOntologyScenesUrl = '/memory/ontology/scenes'
export const getOntologyScenesList = (data: Query) => {
return request.get(getOntologyScenesUrl, data)

View File

@@ -0,0 +1,150 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-04 17:20:52
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-04 17:20:52
*/
import { useEffect, useRef, useMemo } from 'react';
import { EditorView, basicSetup } from 'codemirror';
import { EditorState } from '@codemirror/state';
import { python } from '@codemirror/lang-python';
import { javascript } from '@codemirror/lang-javascript';
import { java } from '@codemirror/lang-java';
import { cpp } from '@codemirror/lang-cpp';
import { rust } from '@codemirror/lang-rust';
import { oneDark } from '@codemirror/theme-one-dark';
/**
* Props for the CodeMirrorEditor component
* @property {string} value - The initial code content to display in the editor
* @property {string} language - Programming language for syntax highlighting (python, python3, javascript, typescript, java, cpp, c, rust)
* @property {function} onChange - Callback function triggered when editor content changes, receives the new code value
* @property {string} theme - Editor theme, either 'light' or 'dark'
* @property {boolean} readOnly - Whether the editor is read-only
* @property {string} height - Custom height for the editor
* @property {string} size - Predefined size preset: 'default' (120px min-height, 14px font) or 'small' (60px min-height, 12px font)
*/
interface CodeMirrorEditorProps {
value?: string;
language?: 'python' | 'python3' | 'javascript' | 'typescript' | 'java' | 'cpp' | 'c' | 'rust';
onChange?: (value: string) => void;
theme?: 'light' | 'dark';
readOnly?: boolean;
height?: string;
size?: 'default' | 'small';
}
/**
* Map of language identifiers to their corresponding CodeMirror language extensions
* Supports multiple programming languages with syntax highlighting
*/
const languageExtensions: Record<string, any> = {
python: python(),
python3: python(),
javascript: javascript(),
typescript: javascript({ typescript: true }),
java: java(),
cpp: cpp(),
c: cpp(),
rust: rust(),
};
/**
* CodeMirrorEditor - A React wrapper component for CodeMirror 6 editor
* Provides a code editor with syntax highlighting, theme support, and customizable sizing
* Used in workflow code execution nodes for editing Python and JavaScript code
*/
const CodeMirrorEditor = ({
value = '',
language = 'javascript',
onChange,
theme = 'light',
readOnly = false,
size,
}: CodeMirrorEditorProps) => {
// Reference to the DOM element that will contain the editor
const editorRef = useRef<HTMLDivElement>(null);
// Reference to the CodeMirror EditorView instance
const viewRef = useRef<EditorView | null>(null);
/**
* Initialize CodeMirror editor when component mounts or when language/theme/readOnly changes
* Sets up extensions for syntax highlighting, change listeners, and theme
*/
useEffect(() => {
if (!editorRef.current) return;
// Get the appropriate language extension, fallback to JavaScript if not found
const langExtension = languageExtensions[language] || languageExtensions.javascript;
// Configure editor extensions
const extensions = [
basicSetup, // Basic editor features (line numbers, bracket matching, etc.)
langExtension, // Language-specific syntax highlighting
// Listen for document changes and trigger onChange callback
EditorView.updateListener.of((update) => {
if (update.docChanged && onChange) {
onChange(update.state.doc.toString());
}
}),
EditorState.readOnly.of(readOnly), // Set read-only mode
];
// Apply dark theme if specified
if (theme === 'dark') {
extensions.push(oneDark);
}
// Create editor state with initial value and extensions
const state = EditorState.create({
doc: value,
extensions,
});
// Create and mount the editor view
viewRef.current = new EditorView({
state,
parent: editorRef.current,
});
// Cleanup: destroy editor instance when component unmounts or dependencies change
return () => {
viewRef.current?.destroy();
};
}, [language, theme, readOnly]);
/**
* Update editor content when the value prop changes externally
* Only updates if the new value differs from current editor content
*/
useEffect(() => {
if (viewRef.current && value !== viewRef.current.state.doc.toString()) {
viewRef.current.dispatch({
changes: {
from: 0,
to: viewRef.current.state.doc.length,
insert: value,
},
});
}
}, [value]);
// Calculate minimum height based on size prop: small (60px) or default (120px)
const minHeight = useMemo(() => {
return `${size === 'small' ? 60 : 120}px`
}, [size])
// Calculate font size based on size prop: small (12px) or default (14px)
const fontSize = useMemo(() => {
return `${size === 'small' ? 12 : 14}px`
}, [size])
// Calculate line height based on size prop: small (16px) or default (20px)
const lineHeight = useMemo(() => {
return `${size === 'small' ? 16 : 20}px`
}, [size])
return <div ref={editorRef} style={{ minHeight, fontSize, lineHeight }} />;
};
export default CodeMirrorEditor;

View File

@@ -81,7 +81,7 @@ const components = {
audio: ({ src, ...props }: any) => <AudioBlock node={{ children: [{ properties: { src: src || '' } }] }} {...props} />,
a: ({ href, children, ...props }: any) => <Link href={href || '#'} {...props}>{children}</Link>,
button: ({ children }: any) => <RbButton node={{ children }}>{[children]}</RbButton>,
table: ({ children, ...props }: any) => <table className="rb:border rb:border-[#D9D9D9] rb:mb-2" {...props}>{children}</table>,
table: ({ children, ...props }: any) => <div className="rb:overflow-x-auto rb:max-w-full"><table className="rb:border rb:border-[#D9D9D9] rb:mb-2" {...props}>{children}</table></div>,
tr: ({ children, ...props }: any) => <tr className="rb:border rb:border-[#D9D9D9]" {...props}>{children}</tr>,
th: ({ children, ...props }: any) => <th className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left rb:font-bold" {...props}>{children}</th>,
td: ({ children, ...props }: any) => <td className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left" {...props}>{children}</td>,

View File

@@ -180,4 +180,9 @@ body {
.x6-node foreignObject > body {
min-height: 100%;
max-height: 100%;
}
.ͼ2 .cm-gutters {
background-color: #FFFFFF;
border: none;
}

View File

@@ -140,7 +140,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
title={t('application.knowledgeBaseAssociation')}
extra={
<Space>
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleKnowledgeConfig}>{t('workflow.config.knowledge-retrieval.recallConfig')}</Button>
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleKnowledgeConfig}>{t('application.globalConfig')}</Button>
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleAddKnowledge}>+</Button>
</Space>
}

View File

@@ -16,7 +16,7 @@ import { useTranslation } from 'react-i18next';
import type { MemoryFormData, Memory, MemoryFormRef } from '../types';
import RbModal from '@/components/RbModal'
import { createMemoryConfig, updateMemoryConfig } from '@/api/memory'
import { getOntologyScenesUrl } from '@/api/ontology'
import { getOntologyScenesSimpleUrl } from '@/api/ontology'
import CustomSelect from '@/components/CustomSelect';
const FormItem = Form.Item;
@@ -129,8 +129,7 @@ const MemoryForm = forwardRef<MemoryFormRef, MemoryFormProps>(({
>
<CustomSelect
placeholder={t('common.pleaseSelect')}
url={getOntologyScenesUrl}
params={{ pagesize: 100, page: 1 }}
url={getOntologyScenesSimpleUrl}
hasAll={false}
valueKey='scene_id'
labelKey="scene_name"

View File

@@ -112,7 +112,7 @@ const MemoryManagement: React.FC = () => {
title={item.config_name}
>
<Tooltip title={item.config_desc}>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1">{item.config_desc}</div>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-[17px]">{item.config_desc}</div>
</Tooltip>
<RbAlert className="rb:mt-3 ">
<div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}>

View File

@@ -103,9 +103,9 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
{model.api_keys && model.api_keys.length > 0 && (
<div className="rb:mb-4">
{model.api_keys.map((key) => (
<div key={key.id} className="rb:flex rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2">
<div>
<div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium">{key.api_key}</div>
<div key={key.id} className="rb:flex rb:gap-3 rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2">
<div className="rb:flex-1">
<div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium rb:break-all">{key.api_key}</div>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:mt-1">{key.api_base}</div>
</div>
<Button type="primary" danger ghost onClick={() => handleDelete(key.id)}>{t('common.remove')}</Button>

View File

@@ -15,8 +15,6 @@ import CharacterCountPlugin from './plugin/CharacterCountPlugin'
import InitialValuePlugin from './plugin/InitialValuePlugin';
import CommandPlugin from './plugin/CommandPlugin';
import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin';
import Python3HighlightPlugin from './plugin/Python3HighlightPlugin';
import JavaScriptHighlightPlugin from './plugin/JavaScriptHighlightPlugin';
import LineNumberPlugin from './plugin/LineNumberPlugin';
import BlurPlugin from './plugin/BlurPlugin';
import { VariableNode } from './nodes/VariableNode'
@@ -32,7 +30,7 @@ export interface LexicalEditorProps {
lineHeight?: number;
size?: 'default' | 'small';
type?: 'input' | 'textarea',
language?: 'string' | 'jinja2' | 'python3' | 'javascript'
language?: 'string' | 'jinja2'
}
const theme = {
@@ -67,7 +65,7 @@ const Editor: FC<LexicalEditorProps> =({
const [enableLineNumbers, setEnableLineNumbers] = useState(false)
useEffect(() => {
const needsLineNumbers = language === 'jinja2' || language === 'python3' || language === 'javascript';
const needsLineNumbers = language === 'jinja2';
setEnableJinja2(language === 'jinja2');
setEnableLineNumbers(needsLineNumbers);
@@ -237,13 +235,11 @@ const Editor: FC<LexicalEditorProps> =({
<HistoryPlugin />
<CommandPlugin />
{language === 'jinja2' && <Jinja2HighlightPlugin />}
{language === 'python3' && <Python3HighlightPlugin />}
{language === 'javascript' && <JavaScriptHighlightPlugin />}
{enableLineNumbers && <LineNumberPlugin />}
<AutocompletePlugin options={options} enableJinja2={enableJinja2} />
<CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} />
<InitialValuePlugin key={language} value={value} options={options} enableLineNumbers={enableLineNumbers} />
{enableLineNumbers && <BlurPlugin />}
<InitialValuePlugin value={value} options={options} enableLineNumbers={enableLineNumbers} />
{enableJinja2 && <BlurPlugin />}
</div>
</LexicalComposer>
);

View File

@@ -1,182 +0,0 @@
import { useEffect, useRef } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
const JS_KEYWORDS = new Set([
'async', 'await', 'break', 'case', 'catch', 'class', 'const', 'continue', 'debugger', 'default',
'delete', 'do', 'else', 'export', 'extends', 'finally', 'for', 'function', 'if', 'import',
'in', 'instanceof', 'let', 'new', 'return', 'super', 'switch', 'this', 'throw', 'try',
'typeof', 'var', 'void', 'while', 'with', 'yield', 'true', 'false', 'null', 'undefined'
]);
const JavaScriptHighlightPlugin = () => {
const [editor] = useLexicalComposerContext();
const isPastingRef = useRef(false);
useEffect(() => {
return editor.registerCommand(
PASTE_COMMAND,
() => {
isPastingRef.current = true;
setTimeout(() => {
isPastingRef.current = false;
}, 100);
return false;
},
COMMAND_PRIORITY_LOW
);
}, [editor]);
useEffect(() => {
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
if (isPastingRef.current) return;
const text = textNode.getTextContent();
if (textNode.hasFormat('code')) return;
if (!needsHighlight(text)) return;
if (textNode.getStyle()) return;
const parent = textNode.getParent();
if (!parent) return;
const selection = $getSelection();
let selectionOffset = null;
if ($isRangeSelection(selection)) {
const anchor = selection.anchor;
if (anchor.getNode() === textNode) {
selectionOffset = anchor.offset;
}
}
const tokens = tokenizeJavaScript(text);
if (tokens.length <= 1) return;
const newNodes = tokens.map(token => {
const newNode = $createTextNode(token.text);
newNode.toggleFormat('code');
switch (token.type) {
case 'keyword':
newNode.setStyle('color: #d73a49; font-weight: 600;');
break;
case 'string':
newNode.setStyle('color: #032f62;');
break;
case 'comment':
newNode.setStyle('color: #6a737d; font-style: italic;');
break;
case 'number':
newNode.setStyle('color: #005cc5; font-weight: 500;');
break;
case 'function':
newNode.setStyle('color: #6f42c1; font-weight: 500;');
break;
}
return newNode;
});
if (newNodes.length > 1) {
textNode.replace(newNodes[0]);
for (let i = 1; i < newNodes.length; i++) {
newNodes[i - 1].insertAfter(newNodes[i]);
}
if (selectionOffset !== null && $isRangeSelection(selection)) {
let currentOffset = 0;
for (const node of newNodes) {
const nodeLength = node.getTextContent().length;
if (currentOffset + nodeLength >= selectionOffset) {
node.select(selectionOffset - currentOffset, selectionOffset - currentOffset);
break;
}
currentOffset += nodeLength;
}
}
}
});
}, [editor]);
return null;
};
function needsHighlight(text: string): boolean {
return /[a-zA-Z0-9_/"'`]/.test(text);
}
function tokenizeJavaScript(text: string): Array<{text: string, type: string}> {
const tokens: Array<{text: string, type: string}> = [];
let i = 0;
while (i < text.length) {
// Single-line comments
if (text.slice(i, i + 2) === '//') {
let start = i;
while (i < text.length && text[i] !== '\n') i++;
tokens.push({ text: text.slice(start, i), type: 'comment' });
continue;
}
// Multi-line comments
if (text.slice(i, i + 2) === '/*') {
let start = i;
i += 2;
while (i < text.length && text.slice(i, i + 2) !== '*/') i++;
if (i < text.length) i += 2;
tokens.push({ text: text.slice(start, i), type: 'comment' });
continue;
}
// Strings
if (text[i] === '"' || text[i] === "'" || text[i] === '`') {
const quote = text[i];
let start = i++;
while (i < text.length) {
if (text[i] === quote && text[i - 1] !== '\\') {
i++;
break;
}
i++;
}
tokens.push({ text: text.slice(start, i), type: 'string' });
continue;
}
// Numbers
if (/\d/.test(text[i])) {
let start = i;
while (i < text.length && /[\d.]/.test(text[i])) i++;
tokens.push({ text: text.slice(start, i), type: 'number' });
continue;
}
// Keywords and identifiers
if (/[a-zA-Z_$]/.test(text[i])) {
let start = i;
while (i < text.length && /[a-zA-Z0-9_$]/.test(text[i])) i++;
const word = text.slice(start, i);
if (JS_KEYWORDS.has(word)) {
tokens.push({ text: word, type: 'keyword' });
} else if (i < text.length && text[i] === '(') {
tokens.push({ text: word, type: 'function' });
} else {
tokens.push({ text: word, type: 'text' });
}
continue;
}
// Other characters
let start = i;
while (i < text.length && !/[a-zA-Z0-9_$/"'`]/.test(text[i])) i++;
if (start < i) {
tokens.push({ text: text.slice(start, i), type: 'text' });
}
}
return tokens;
}
export default JavaScriptHighlightPlugin;

View File

@@ -1,177 +0,0 @@
import { useEffect, useRef } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
const PYTHON_KEYWORDS = new Set([
'False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue',
'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import',
'in', 'is', 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while',
'with', 'yield'
]);
const Python3HighlightPlugin = () => {
const [editor] = useLexicalComposerContext();
const isPastingRef = useRef(false);
useEffect(() => {
return editor.registerCommand(
PASTE_COMMAND,
() => {
isPastingRef.current = true;
setTimeout(() => {
isPastingRef.current = false;
}, 100);
return false;
},
COMMAND_PRIORITY_LOW
);
}, [editor]);
useEffect(() => {
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
if (isPastingRef.current) return;
const text = textNode.getTextContent();
if (textNode.hasFormat('code')) return;
if (textNode.getStyle()) return;
if (!needsHighlight(text)) return;
const parent = textNode.getParent();
if (!parent) return;
const selection = $getSelection();
let selectionOffset = null;
if ($isRangeSelection(selection)) {
const anchor = selection.anchor;
if (anchor.getNode() === textNode) {
selectionOffset = anchor.offset;
}
}
const tokens = tokenizePython(text);
if (tokens.length <= 1) return;
const newNodes = tokens.map(token => {
const newNode = $createTextNode(token.text);
newNode.toggleFormat('code');
switch (token.type) {
case 'keyword':
newNode.setStyle('color: #d73a49; font-weight: 600;');
break;
case 'string':
newNode.setStyle('color: #032f62;');
break;
case 'comment':
newNode.setStyle('color: #6a737d; font-style: italic;');
break;
case 'number':
newNode.setStyle('color: #005cc5; font-weight: 500;');
break;
case 'function':
newNode.setStyle('color: #6f42c1; font-weight: 500;');
break;
}
return newNode;
});
if (newNodes.length > 1) {
textNode.replace(newNodes[0]);
for (let i = 1; i < newNodes.length; i++) {
newNodes[i - 1].insertAfter(newNodes[i]);
}
if (selectionOffset !== null && $isRangeSelection(selection)) {
let currentOffset = 0;
for (const node of newNodes) {
const nodeLength = node.getTextContent().length;
if (currentOffset + nodeLength >= selectionOffset) {
node.select(selectionOffset - currentOffset, selectionOffset - currentOffset);
break;
}
currentOffset += nodeLength;
}
}
}
});
}, [editor]);
return null;
};
function needsHighlight(text: string): boolean {
return /[a-zA-Z0-9_#"']/.test(text);
}
function tokenizePython(text: string): Array<{text: string, type: string}> {
const tokens: Array<{text: string, type: string}> = [];
let i = 0;
while (i < text.length) {
// Comments
if (text[i] === '#') {
let start = i;
while (i < text.length && text[i] !== '\n') i++;
tokens.push({ text: text.slice(start, i), type: 'comment' });
continue;
}
// Strings
if (text[i] === '"' || text[i] === "'") {
const quote = text[i];
let start = i++;
const isTriple = text.slice(start, start + 3) === quote.repeat(3);
if (isTriple) i += 2;
while (i < text.length) {
if (isTriple && text.slice(i, i + 3) === quote.repeat(3)) {
i += 3;
break;
} else if (!isTriple && text[i] === quote && text[i - 1] !== '\\') {
i++;
break;
}
i++;
}
tokens.push({ text: text.slice(start, i), type: 'string' });
continue;
}
// Numbers
if (/\d/.test(text[i])) {
let start = i;
while (i < text.length && /[\d.]/.test(text[i])) i++;
tokens.push({ text: text.slice(start, i), type: 'number' });
continue;
}
// Keywords and identifiers
if (/[a-zA-Z_]/.test(text[i])) {
let start = i;
while (i < text.length && /[a-zA-Z0-9_]/.test(text[i])) i++;
const word = text.slice(start, i);
if (PYTHON_KEYWORDS.has(word)) {
tokens.push({ text: word, type: 'keyword' });
} else if (i < text.length && text[i] === '(') {
tokens.push({ text: word, type: 'function' });
} else {
tokens.push({ text: word, type: 'text' });
}
continue;
}
// Other characters
let start = i;
while (i < text.length && !/[a-zA-Z0-9_#"']/.test(text[i])) i++;
if (start < i) {
tokens.push({ text: text.slice(start, i), type: 'text' });
}
}
return tokens;
}
export default Python3HighlightPlugin;

View File

@@ -5,8 +5,8 @@ import { Node } from '@antv/x6'
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
import MappingList from '../MappingList'
import Editor from '../../Editor'
import OutputList from './OutputList'
import CodeMirrorEditor from '@/components/CodeMirrorEditor';
interface MappingItem {
name?: string
@@ -110,7 +110,10 @@ const CodeExecution: FC<CodeExecutionProps> = ({ options }) => {
<Form.Item noStyle shouldUpdate={(prev, curr) => prev.language !== curr.language}>
{() => (
<Form.Item name="code" noStyle>
<Editor size="small" language={form.getFieldValue('language')} />
<CodeMirrorEditor
language={form.getFieldValue('language')}
size="small"
/>
</Form.Item>
)}
</Form.Item>

View File

@@ -126,7 +126,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
<div
className="rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/workflow/recall.svg')] rb:group-hover:bg-[url('@/assets/images/workflow/recall_hover.svg')]"
></div>
{t('workflow.config.knowledge-retrieval.recallConfig')}
{t('application.globalConfig')}
</Button>
</div>