Merge branch 'release/v0.2.3' into develop

# Conflicts:
#	api/app/core/agent/langchain_agent.py
#	api/app/core/memory/agent/langgraph_graph/write_graph.py
#	api/app/repositories/neo4j/graph_saver.py
#	api/app/services/draft_run_service.py
This commit is contained in:
Mark
2026-02-06 14:48:50 +08:00
45 changed files with 973 additions and 850 deletions

View File

@@ -196,6 +196,11 @@ def update_config(
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try: try:
svc = DataConfigService(db) svc = DataConfigService(db)

View File

@@ -52,6 +52,7 @@ from app.services.ontology_service import OntologyService
from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.validation.owl_validator import OWLValidator from app.core.memory.utils.validation.owl_validator import OWLValidator
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
from app.repositories.ontology_scene_repository import OntologySceneRepository
api_logger = get_api_logger() api_logger = get_api_logger()
@@ -116,27 +117,35 @@ def _get_ontology_service(
detail=f"找不到指定的LLM模型: {llm_id}" detail=f"找不到指定的LLM模型: {llm_id}"
) )
# 验证模型配置了API密钥 # 通过 Repository 获取可用的 API Key负载均衡逻辑由 Repository 处理)
if not model_config.api_keys: from app.repositories.model_repository import ModelApiKeyRepository
logger.error(f"Model {llm_id} has no API key configuration") api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
if not api_keys:
logger.error(f"Model {llm_id} has no active API key")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="指定的LLM模型没有配置API密钥" detail="指定的LLM模型没有可用的API密钥"
) )
api_key_config = api_keys[0]
api_key_config = model_config.api_keys[0] is_composite = getattr(model_config, 'is_composite', False)
logger.info( logger.info(
f"Using specified model - user: {current_user.id}, " f"Using specified model - user: {current_user.id}, "
f"model_id: {llm_id}, model_name: {api_key_config.model_name}" f"model_id: {llm_id}, model_name: {api_key_config.model_name}, "
f"is_composite: {is_composite}, api_key_id: {api_key_config.id}"
) )
# 创建模型配置对象 # 创建模型配置对象
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
# 对于组合模型,使用 API Key 的 provider否则使用 model_config 的 provider
actual_provider = api_key_config.provider if is_composite else (
getattr(model_config, 'provider', None) or "openai"
)
llm_model_config = RedBearModelConfig( llm_model_config = RedBearModelConfig(
model_name=api_key_config.model_name, model_name=api_key_config.model_name,
provider=model_config.provider if hasattr(model_config, 'provider') else "openai", provider=actual_provider,
api_key=api_key_config.api_key, api_key=api_key_config.api_key,
base_url=api_key_config.api_base, base_url=api_key_config.api_base,
max_retries=3, max_retries=3,
@@ -648,6 +657,46 @@ async def delete_scene(
return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e))
@router.get("/scenes/simple", response_model=ApiResponse)
async def get_scenes_simple(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取场景简单列表(轻量级,用于下拉选择)
仅返回 scene_id 和 scene_name不加载关联数据响应速度快。
适用于前端下拉选择场景的场景。
Args:
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含场景简单列表
Examples:
GET /scenes/simple
返回: {"data": [{"scene_id": "xxx", "scene_name": "场景1"}, ...]}
"""
api_logger.info(f"Simple scene list requested by user {current_user.id}")
try:
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
repo = OntologySceneRepository(db)
scenes = repo.get_simple_list(workspace_id)
api_logger.info(f"Simple scene list retrieved: {len(scenes)} scenes")
return success(data=scenes, msg="查询成功")
except Exception as e:
api_logger.error(f"Failed to get simple scene list: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
@router.get("/scenes", response_model=ApiResponse) @router.get("/scenes", response_model=ApiResponse)
async def get_scenes( async def get_scenes(
workspace_id: Optional[str] = None, workspace_id: Optional[str] = None,

View File

@@ -7,30 +7,21 @@ LangChain Agent 封装
- 支持流式输出 - 支持流式输出
- 使用 RedBearLLM 支持多提供商 - 使用 RedBearLLM 支持多提供商
""" """
import os
import time import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage
from app.db import get_db from app.db import get_db
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import ( from app.services.memory_agent_service import (
get_end_user_connected_config, get_end_user_connected_config,
) )
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from app.utils.config_utils import resolve_config_id
logger = get_business_logger() logger = get_business_logger()
@@ -289,105 +280,6 @@ class LangChainAgent:
return content_parts return content_parts
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db())
#TODO: 魔法数字
scope=6
try:
repo = LongTermMemoryRepository(db)
await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
# Handle case where no session exists in Redis (returns False)
if not result or result is False:
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
return
if type=="chunk" or type=="aggregate":
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'写入短长期:')
else:
# TODO: This branch handles type="time" strategy, currently unused.
# Will be activated when time-based long-term storage is implemented.
# TODO: 魔法数字 - extract 5 to a constant
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
return
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
finally:
db.close()
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
db = next(get_db())
try:
actual_config_id=resolve_config_id(actual_config_id, db)
if storage_type == "rag":
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if user_message:
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j"
user_rag_memory_id # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def chat( async def chat(
self, self,
message: str, message: str,
@@ -520,14 +412,7 @@ class LangChainAgent:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
if memory_flag: if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content) await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
response = { response = {
"content": content, "content": content,
"model": self.model_name, "model": self.model_name,
@@ -710,15 +595,7 @@ class LangChainAgent:
yield total_tokens yield total_tokens
break break
if memory_flag: if memory_flag:
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
except Exception as e: except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise raise

View File

@@ -1,8 +1,9 @@
import json
import os import os
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
@@ -10,46 +11,108 @@ from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context, get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
actual_config_id, long_term_messages=[]):
"""
写入记忆(支持结构化消息)
async def write_messages(end_user_id,langchain_messages,memory_config): Args:
''' storage_type: 存储类型 (neo4j/rag)
写入数据到neo4j
Args:
end_user_id: 终端用户ID end_user_id: 终端用户ID
memory_config: 内存配置对象 user_message: 用户消息内容
langchain_messages原始数据LIST ai_message: AI 回复内容
''' user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
db = next(get_db())
try: try:
actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if isinstance(user_message, str) and user_message.strip() != "":
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if isinstance(ai_message, str) and ai_message.strip() != "":
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果提供了 long_term_messages使用它替代 structured_messages
if long_term_messages and isinstance(long_term_messages, list):
structured_messages = long_term_messages
elif long_term_messages and isinstance(long_term_messages, str):
# 如果是 JSON 字符串,先解析
try:
structured_messages = json.loads(long_term_messages)
except json.JSONDecodeError:
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: JSON 字符串格式的消息列表
str(actual_config_id), # config_id: 配置ID字符串
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------')
else:
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config
}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
# TODO删除
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
print(contents)
except Exception as e:
import traceback
traceback.print_exc()
'''根据窗口''' '''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
''' '''
@@ -61,25 +124,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
scope窗口大小 scope窗口大小
''' '''
scope=scope scope=scope
redis_messages = []
is_end_user_id = count_store.get_sessions_count(end_user_id) is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False: if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0] is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1] redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope): if is_end_user_id and int(is_end_user_id) != int(scope):
print(is_end_user_id)
is_end_user_id += 1 is_end_user_id += 1
langchain_messages += redis_messages langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope): elif int(is_end_user_id) == int(scope):
print('写入长期记忆并且设置为0') logger.info('写入长期记忆NEO4J')
print(is_end_user_id) formatted_messages = (redis_messages)
formatted_messages = await chat_data_format(redis_messages) # 获取 config_id如果 memory_config 是对象,提取 config_id否则直接使用
print(100*'-') if hasattr(memory_config, 'config_id'):
print(formatted_messages) config_id = memory_config.config_id
print(100*'-') else:
await write_messages(end_user_id, formatted_messages, memory_config) config_id = memory_config
count_store.update_sessions_count(end_user_id, 0, '')
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
config_id, formatted_messages)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else: else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages) count_store.save_sessions_count(end_user_id, 1, langchain_messages)
@@ -93,12 +157,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
memory_config: 内存配置对象 memory_config: 内存配置对象
''' '''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time) long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
# Handle case where no session exists in Redis (returns False or empty) format_messages = (long_time_data)
if not long_time_data or long_time_data is False: messages=[]
return memory_config=memory_config.config_id
format_messages = await chat_data_format(long_time_data) for i in format_messages:
message=json.loads(i['Query'])
messages+= message
if format_messages!=[]: if format_messages!=[]:
await write_messages(end_user_id, format_messages, memory_config) await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
'''聚合判断''' '''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
""" """
@@ -109,13 +176,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: 内存配置对象 memory_config: 内存配置对象
""" """
try: try:
# 1. 获取历史会话数据(使用新方法) # 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id) result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
# Handle case where no session exists in Redis (returns False or empty) if not result:
if not result or result is False:
history = [] history = []
else: else:
history = await format_parsing(result) history = await format_parsing(result)
@@ -154,7 +220,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
} }
if not structured.is_same_event: if not structured.is_same_event:
logger.info(result_dict) logger.info(result_dict)
await write_messages(end_user_id, output_value, memory_config) await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value)
return result_dict return result_dict
except Exception as e: except Exception as e:

View File

@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
清理后的数据 清理后的数据
""" """
# 需要过滤的字段列表 # 需要过滤的字段列表
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
fields_to_remove = { fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary" 'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
} }
if isinstance(data, dict): if isinstance(data, dict):

View File

@@ -1,8 +1,6 @@
import json import json
from langchain_core.messages import HumanMessage, AIMessage from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'): async def format_parsing(messages: list,type:str='string'):
""" """
格式化解析消息列表 格式化解析消息列表
@@ -26,13 +24,13 @@ async def format_parsing(messages: list,type:str='string'):
role = content['role'] role = content['role']
content = content['content'] content = content['content']
if type == "string": if type == "string":
if role == 'human': if role == 'human' or role=="user":
content = '用户:' + content content = '用户:' + content
else: else:
content = 'AI:' + content content = 'AI:' + content
result.append(content) result.append(content)
if type == "dict": if type == "dict" :
if role == 'human': if role == 'human' or role=="user":
user.append( content) user.append( content)
else: else:
ai.append(content) ai.append(content)
@@ -57,33 +55,7 @@ async def messages_parse(messages: list | dict):
for key, values in zip(user, ai): for key, values in zip(user, ai):
database.append({key, values}) database.append({key, values})
return database return database
async def chat_data_format(messages: list | dict):
"""
将消息格式化为 LangChain 消息格式
Args:
messages: 消息列表或字典
Returns:
LangChain 消息列表
"""
langchain_messages = []
if isinstance(messages, list):
for msg in messages:
if 'role' in msg.keys():
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
if "Query" in msg.keys():
langchain_messages.append(HumanMessage(content=msg['Query']))
langchain_messages.append(AIMessage(content=msg['Answer']))
if isinstance(messages, dict):
if messages['type'] == 'human':
langchain_messages.append(HumanMessage(content=messages['content']))
elif messages['type'] == 'ai':
langchain_messages.append(AIMessage(content=messages['content']))
return langchain_messages
async def agent_chat_messages(user_content,ai_content): async def agent_chat_messages(user_content,ai_content):
messages = [ messages = [

View File

@@ -1,13 +1,18 @@
import asyncio import asyncio
import json
import sys import sys
import warnings import warnings
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from langgraph.constants import END, START from langgraph.constants import END, START
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -37,76 +42,61 @@ async def make_write_graph():
yield graph yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '', from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
end_user_id: str = '', scope: int = 6): from app.core.memory.agent.utils.redis_tool import write_store
"""Dispatch long-term memory storage to Celery background tasks. write_store.save_session_write(end_user_id, (langchain_messages))
# 获取数据库会话
Args: with get_db_context() as db_session:
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate' config_service = MemoryConfigService(db_session)
langchain_messages: List of messages to store memory_config = config_service.load_memory_config(
memory_config: Memory configuration ID (string) config_id=memory_config, # 改为整数
end_user_id: End user identifier service_name="MemoryAgentService"
scope: Window size for 'chunk' strategy (default: 6)
"""
from app.tasks import (
long_term_storage_window_task,
# TODO: Uncomment when implemented
# long_term_storage_time_task,
# long_term_storage_aggregate_task,
)
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# Convert config to string if needed
config_id = str(memory_config) if memory_config else ''
if long_term_type == 'chunk':
# Strategy 1: Window-based batching (6 rounds of dialogue)
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
long_term_storage_window_task.delay(
end_user_id=end_user_id,
langchain_messages=langchain_messages,
config_id=config_id,
scope=scope
) )
# TODO: Uncomment when time-based strategy is fully implemented if long_term_type=='chunk':
# elif long_term_type == 'time': '''方案一:对话窗口6轮对话'''
# # Strategy 2: Time-based retrieval await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}") if long_term_type=='time':
# long_term_storage_time_task.delay( """时间"""
# end_user_id=end_user_id, await memory_long_term_storage(end_user_id, memory_config,5)
# config_id=config_id, if long_term_type=='aggregate':
# time_window=5 """方案三:聚合判断"""
# ) await aggregate_judgment(end_user_id, langchain_messages, memory_config)
# TODO: Uncomment when aggregate strategy is fully implemented
# elif long_term_type == 'aggregate':
# # Strategy 3: Aggregate judgment (deduplication)
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}") async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
# long_term_storage_aggregate_task.delay( from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
# end_user_id=end_user_id, from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
# langchain_messages=langchain_messages, from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
# config_id=config_id if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
# ) await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
# async def main(): # async def main():
# """主函数 - 运行工作流""" # """主函数 - 运行工作流"""
# langchain_messages = [ # langchain_messages = [
# { # {
# "role": "user", # "role": "user",
# "content": "今天周五好开心啊" # "content": "今天周五去爬山"
# }, # },
# { # {
# "role": "assistant", # "role": "assistant",
# "content": "你也这么觉得,我也是耶" # "content": "耶"
# } # }
# #
# ] # ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID # end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" # memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) #
# #
# #
# if __name__ == "__main__": # if __name__ == "__main__":

View File

@@ -294,6 +294,7 @@ class RedisCountStore:
""" """
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="count") key = generate_session_key(session_id, key_type="count")
index_key = f'session:count:index:{end_user_id}' # 索引键
pipe = self.r.pipeline() pipe = self.r.pipeline()
pipe.hset(key, mapping={ pipe.hset(key, mapping={
@@ -304,6 +305,10 @@ class RedisCountStore:
"starttime": get_current_timestamp() "starttime": get_current_timestamp()
}) })
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
# 创建索引end_user_id -> session_id 映射
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
result = pipe.execute() result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
@@ -320,31 +325,47 @@ class RedisCountStore:
list 或 False: 如果找到返回 [count, messages],否则返回 False list 或 False: 如果找到返回 [count, messages],否则返回 False
""" """
try: try:
search_pattern = 'session:count:*' # 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
for key in self.r.keys(search_pattern): # 检查索引键类型,避免 WRONGTYPE 错误
data = self.r.hgetall(key) try:
key_type = self.r.type(index_key)
if not data: if key_type != 'string' and key_type != 'none':
continue self.r.delete(index_key)
return False
if data.get('end_user_id') == end_user_id: except Exception as type_error:
count = data.get('count') print(f"[get_sessions_count] 检查键类型失败: {type_error}")
messages_str = data.get('messages')
session_id = self.r.get(index_key)
if count is not None:
messages = deserialize_messages(messages_str) if not session_id:
return [int(count), messages] return False
# 直接获取数据
key = generate_session_key(session_id, key_type="count")
data = self.r.hgetall(key)
if not data:
# 索引存在但数据不存在,清理索引
self.r.delete(index_key)
return False
count = data.get('count')
messages_str = data.get('messages')
if count is not None:
messages = deserialize_messages(messages_str)
return [int(count), messages]
return False return False
except Exception as e: except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}") print(f"[get_sessions_count] 查询失败: {e}")
return False return False
def update_sessions_count(self, end_user_id: str, new_count: int, def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool: messages: Any) -> bool:
""" """
通过 end_user_id 修改访问次数统计 通过 end_user_id 修改访问次数统计(优化版:使用索引)
Args: Args:
end_user_id: 终端用户ID end_user_id: 终端用户ID
@@ -355,23 +376,39 @@ class RedisCountStore:
bool: 更新成功返回 True未找到记录返回 False bool: 更新成功返回 True未找到记录返回 False
""" """
try: try:
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
# 索引键类型错误,删除并返回 False
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
self.r.delete(index_key)
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as type_error:
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
# 直接更新数据
key = generate_session_key(session_id, key_type="count")
messages_str = serialize_messages(messages) messages_str = serialize_messages(messages)
search_pattern = 'session:count:*'
for key in self.r.keys(search_pattern): pipe = self.r.pipeline()
data = self.r.hgetall(key) pipe.hset(key, 'count', int(new_count))
pipe.hset(key, 'messages', messages_str)
if not data: result = pipe.execute()
continue
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
if data.get('end_user_id') == end_user_id: return True
self.r.hset(key, 'count', int(new_count))
self.r.hset(key, 'messages', messages_str)
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as e: except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}") print(f"[update_sessions_count] 更新失败: {e}")
return False return False

View File

@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
This module provides the main write function for executing the knowledge extraction This module provides the main write function for executing the knowledge extraction
pipeline. Only MemoryConfig is needed - clients are constructed internally. pipeline. Only MemoryConfig is needed - clients are constructed internally.
""" """
import asyncio
import time import time
from datetime import datetime from datetime import datetime
@@ -124,23 +125,48 @@ async def write(
except Exception as e: except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True) logger.error(f"Error creating indexes: {e}", exc_info=True)
# 添加死锁重试机制
max_retries = 3
retry_delay = 1 # 秒
for attempt in range(max_retries):
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
break
else:
logger.warning("Failed to save some data to Neo4j")
if attempt < max_retries - 1:
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
except Exception as e:
error_msg = str(e)
# 检查是否是死锁错误
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
if attempt < max_retries - 1:
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
else:
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
raise
else:
# 非死锁错误,直接抛出
raise
try: try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
finally:
await neo4j_connector.close() await neo4j_connector.close()
except Exception as e:
logger.error(f"Error closing Neo4j connector: {e}")
log_time("Neo4j Database Save", time.time() - step_start, log_file) log_time("Neo4j Database Save", time.time() - step_start, log_file)

View File

@@ -413,7 +413,8 @@ class ExtractedEntityNode(Node):
description="Entity aliases - alternative names for this entity" description="Entity aliases - alternative names for this entity"
) )
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
fact_summary: str = Field(default="", description="Summary of the fact about this entity") # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")

View File

@@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
if len(desc_b) > len(desc_a): if len(desc_b) > len(desc_a):
canonical.description = desc_b canonical.description = desc_b
# 合并事实摘要:统一保留一个“实体: name”行来源行去重保序 # 合并事实摘要:统一保留一个“实体: name”行来源行去重保序
fact_a = getattr(canonical, "fact_summary", "") or "" # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
fact_b = getattr(ent, "fact_summary", "") or "" # fact_a = getattr(canonical, "fact_summary", "") or ""
def _extract_sources(txt: str) -> List[str]: # fact_b = getattr(ent, "fact_summary", "") or ""
sources: List[str] = [] # def _extract_sources(txt: str) -> List[str]:
if not txt: # sources: List[str] = []
return sources # if not txt:
for line in str(txt).splitlines(): # return sources
ln = line.strip() # for line in str(txt).splitlines():
# ln = line.strip()
# 支持“来源:”或“来源:”前缀 # 支持“来源:”或“来源:”前缀
m = re.match(r"^来源[:]\s*(.+)$", ln) # m = re.match(r"^来源[:]\s*(.+)$", ln)
if m: # if m:
content = m.group(1).strip() # content = m.group(1).strip()
if content: # if content:
sources.append(content) # sources.append(content)
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失 # 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
if not sources and txt.strip(): # if not sources and txt.strip():
sources.append(txt.strip()) # sources.append(txt.strip())
return sources # return sources
try: try:
src_a = _extract_sources(fact_a) # src_a = _extract_sources(fact_a)
src_b = _extract_sources(fact_b) # src_b = _extract_sources(fact_b)
seen = set() # seen = set()
merged_sources: List[str] = [] # merged_sources: List[str] = []
for s in src_a + src_b: # for s in src_a + src_b:
if s and s not in seen: # if s and s not in seen:
seen.add(s) # seen.add(s)
merged_sources.append(s) # merged_sources.append(s)
if merged_sources: # if merged_sources:
name_line = f"实体: {getattr(canonical, 'name', '')}".strip() # name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources]) # canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
elif fact_b and not fact_a: # elif fact_b and not fact_a:
canonical.fact_summary = fact_b # canonical.fact_summary = fact_b
pass
except Exception: except Exception:
# 兜底:若解析失败,保留较长文本 # 兜底:若解析失败,保留较长文本
if len(fact_b) > len(fact_a): # if len(fact_b) > len(fact_a):
canonical.fact_summary = fact_b # canonical.fact_summary = fact_b
pass
except Exception: except Exception:
pass pass

View File

@@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整) # 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
desc_a = (getattr(a, "description", "") or "") desc_a = (getattr(a, "description", "") or "")
desc_b = (getattr(b, "description", "") or "") desc_b = (getattr(b, "description", "") or "")
fact_a = (getattr(a, "fact_summary", "") or "") # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
fact_b = (getattr(b, "fact_summary", "") or "") # fact_a = (getattr(a, "fact_summary", "") or "")
score_a = len(desc_a) + len(fact_a) # fact_b = (getattr(b, "fact_summary", "") or "")
score_b = len(desc_b) + len(fact_b) # score_a = len(desc_a) + len(fact_a)
# score_b = len(desc_b) + len(fact_b)
score_a = len(desc_a)
score_b = len(desc_b)
if score_a != score_b: if score_a != score_b:
return 0 if score_a >= score_b else 1 return 0 if score_a >= score_b else 1
return 0 return 0
@@ -189,7 +192,8 @@ async def _judge_pair(
"entity_type": getattr(a, "entity_type", None), "entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None), "description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [], "aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None), # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None), "connect_strength": getattr(a, "connect_strength", None),
} }
entity_b = { entity_b = {
@@ -197,7 +201,8 @@ async def _judge_pair(
"entity_type": getattr(b, "entity_type", None), "entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None), "description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [], "aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None), # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None), "connect_strength": getattr(b, "connect_strength", None),
} }
# 5. 渲染LLM提示词用工具函数填充模板包含实体信息、上下文、输出格式 # 5. 渲染LLM提示词用工具函数填充模板包含实体信息、上下文、输出格式
@@ -248,7 +253,8 @@ async def _judge_pair_disamb(
"entity_type": getattr(a, "entity_type", None), "entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None), "description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [], "aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None), # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None), "connect_strength": getattr(a, "connect_strength", None),
} }
entity_b = { entity_b = {
@@ -256,7 +262,8 @@ async def _judge_pair_disamb(
"entity_type": getattr(b, "entity_type", None), "entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None), "description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [], "aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None), # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None), "connect_strength": getattr(b, "connect_strength", None),
} }
prompt = render_entity_dedup_prompt( prompt = render_entity_dedup_prompt(

View File

@@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
description=row.get("description") or "", description=row.get("description") or "",
aliases=row.get("aliases") or [], aliases=row.get("aliases") or [],
name_embedding=row.get("name_embedding") or [], name_embedding=row.get("name_embedding") or [],
fact_summary=row.get("fact_summary") or "", # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=row.get("fact_summary") or "",
connect_strength=row.get("connect_strength") or "", connect_strength=row.get("connect_strength") or "",
) )

View File

@@ -1088,7 +1088,8 @@ class ExtractionOrchestrator:
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
description=getattr(entity, 'description', ''), # 添加必需的 description 字段 description=getattr(entity, 'description', ''), # 添加必需的 description 字段
example=getattr(entity, 'example', ''), # 新增:传递示例字段 example=getattr(entity, 'example', ''), # 新增:传递示例字段
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None), name_embedding=getattr(entity, 'name_embedding', None),

View File

@@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
key=lambda eid: ( key=lambda eid: (
_strength_rank(eid), _strength_rank(eid),
len(getattr(entity_by_id.get(eid), 'description', '') or ''), len(getattr(entity_by_id.get(eid), 'description', '') or ''),
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '') # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
0 # 临时占位
), ),
reverse=True reverse=True
) )

View File

@@ -9,7 +9,8 @@
- 类型: "{{ entity_a.entity_type | default('') }}" - 类型: "{{ entity_a.entity_type | default('') }}"
- 描述: "{{ entity_a.description | default('') }}" - 描述: "{{ entity_a.description | default('') }}"
- 别名: {{ entity_a.aliases | default([]) }} - 别名: {{ entity_a.aliases | default([]) }}
- 摘要: "{{ entity_a.fact_summary | default('') }}" {# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #}
- 连接强弱: "{{ entity_a.connect_strength | default('') }}" - 连接强弱: "{{ entity_a.connect_strength | default('') }}"
实体B: 实体B:
@@ -17,7 +18,8 @@
- 类型: "{{ entity_b.entity_type | default('') }}" - 类型: "{{ entity_b.entity_type | default('') }}"
- 描述: "{{ entity_b.description | default('') }}" - 描述: "{{ entity_b.description | default('') }}"
- 别名: {{ entity_b.aliases | default([]) }} - 别名: {{ entity_b.aliases | default([]) }}
- 摘要: "{{ entity_b.fact_summary | default('') }}" {# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #}
- 连接强弱: "{{ entity_b.connect_strength | default('') }}" - 连接强弱: "{{ entity_b.connect_strength | default('') }}"
上下文: 上下文:

View File

@@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
from app.core.rag.llm.chat_model import Base from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed from app.core.rag.llm.embedding_model import OpenAIEmbed
import logging
logger = logging.getLogger(__name__)
def knowledge_retrieval( def knowledge_retrieval(
query: str, query: str,
@@ -62,7 +64,15 @@ def knowledge_retrieval(
merge_strategy = config.get("merge_strategy", "weight") merge_strategy = config.get("merge_strategy", "weight")
reranker_id = config.get("reranker_id") reranker_id = config.get("reranker_id")
reranker_top_k = config.get("reranker_top_k", 1024) reranker_top_k = config.get("reranker_top_k", 1024)
use_graph = config.get("use_graph", "false").lower() == "true" # use_graph = config.get("use_graph", "false").lower() == "true"
use_graph_value = config.get("use_graph", False)
if isinstance(use_graph_value, bool):
use_graph = use_graph_value
elif isinstance(use_graph_value, str):
use_graph = use_graph_value.lower() in ("true", "1", "yes")
else:
use_graph = False
file_names_filter = [] file_names_filter = []
if user_ids: if user_ids:
@@ -159,13 +169,29 @@ def knowledge_retrieval(
# Use the specified reranker for re-ranking # Use the specified reranker for re-ranking
if reranker_id: if reranker_id:
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) try:
# use graph return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
except Exception as rerank_error:
# If reranker fails, log warning and continue with original results
logger.warning(
"Reranker failed, falling back to original results",
extra={
"reranker_id": reranker_id,
"query": query,
"doc_count": len(all_results),
"error": str(rerank_error),
},
)
if use_graph: if use_graph:
from app.core.rag.common.settings import kg_retriever try:
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) from app.core.rag.common.settings import kg_retriever
if doc: doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
all_results.insert(0, doc) if doc:
all_results.insert(0, doc)
except Exception as graph_error:
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
return all_results return all_results
except Exception as e: except Exception as e:

View File

@@ -25,6 +25,18 @@ class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)
self.typed_config: ParameterExtractorNodeConfig | None = None self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
if self.response_metadata:
usage = self.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
outputs = {} outputs = {}
@@ -180,6 +192,7 @@ class ParameterExtractorNode(BaseNode):
]) ])
model_resp = await llm.ainvoke(messages) model_resp = await llm.ainvoke(messages)
self.response_metadata = model_resp.response_metadata
result = json_repair.repair_json(model_resp.content, return_objects=True) result = json_repair.repair_json(model_resp.content, return_objects=True)
logger.info(f"node: {self.node_id} get params:{result}") logger.info(f"node: {self.node_id} get params:{result}")

View File

@@ -25,6 +25,18 @@ class QuestionClassifierNode(BaseNode):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)
self.typed_config: QuestionClassifierNodeConfig | None = None self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {} self.category_to_case_map = {}
self.response_metadata = {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
if self.response_metadata:
usage = self.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return { return {
@@ -120,6 +132,7 @@ class QuestionClassifierNode(BaseNode):
response = await llm.ainvoke(messages) response = await llm.ainvoke(messages)
result = response.content.strip() result = response.content.strip()
self.response_metadata = response.response_metadata
if result in category_names: if result in category_names:
category = result category = result

View File

@@ -86,7 +86,8 @@ class MemoryConfigRepository:
n.description AS description, n.description AS description,
n.entity_type AS entity_type, n.entity_type AS entity_type,
n.name AS name, n.name AS name,
COALESCE(n.fact_summary, '') AS fact_summary, // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
// COALESCE(n.fact_summary, '') AS fact_summary,
n.end_user_id AS end_user_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id, n.apply_id AS apply_id,
n.user_id AS user_id, n.user_id AS user_id,
@@ -279,6 +280,9 @@ class MemoryConfigRepository:
if update.config_desc is not None: if update.config_desc is not None:
db_config.config_desc = update.config_desc db_config.config_desc = update.config_desc
has_update = True has_update = True
if update.scene_id is not None:
db_config.scene_id = update.scene_id
has_update = True
if not has_update: if not has_update:
raise ValueError("No fields to update") raise ValueError("No fields to update")
@@ -650,28 +654,32 @@ class MemoryConfigRepository:
raise raise
@staticmethod @staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]: def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]:
"""获取所有配置参数 """获取所有配置参数,包含关联的场景名称
Args: Args:
db: 数据库会话 db: 数据库会话
workspace_id: 工作空间ID用于过滤查询结果 workspace_id: 工作空间ID用于过滤查询结果
Returns: Returns:
List[MemoryConfig]: 配置列表 List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
""" """
from app.models.ontology_scene import OntologyScene
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
try: try:
query = db.query(MemoryConfig) query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin(
OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id
)
if workspace_id: if workspace_id:
query = query.filter(MemoryConfig.workspace_id == workspace_id) query = query.filter(MemoryConfig.workspace_id == workspace_id)
configs = query.order_by(desc(MemoryConfig.updated_at)).all() results = query.order_by(desc(MemoryConfig.updated_at)).all()
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}") db_logger.debug(f"配置列表查询成功: 数量={len(results)}")
return configs return results
except Exception as e: except Exception as e:
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}") db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")

View File

@@ -79,7 +79,8 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
try: try:
edges: List[dict] = [] edges: List[dict] = []
for s in summaries: for s in summaries:
for chunk_id in getattr(s, "chunk_ids", []) or []: chunk_ids = getattr(s, "chunk_ids", []) or []
for chunk_id in chunk_ids:
edges.append({ edges.append({
"summary_id": s.id, "summary_id": s.id,
"chunk_id": chunk_id, "chunk_id": chunk_id,
@@ -91,12 +92,11 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
if not edges: if not edges:
return [] return []
result = await connector.execute_query( result = await connector.execute_query(
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE, MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
edges=edges edges=edges
) )
created = [record.get("uuid") for record in result] if result else [] created = [record.get("uuid") for record in result] if result else []
return created return created
except Exception: except Exception as e:
return None return None

View File

@@ -217,8 +217,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
summaries=flattened summaries=flattened
) )
created_ids = [record.get("uuid") for record in result] created_ids = [record.get("uuid") for record in result]
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
return created_ids return created_ids
except Exception: except Exception as e:
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
return None return None

View File

@@ -101,10 +101,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
e.name_embedding = CASE e.name_embedding = CASE
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
ELSE e.name_embedding END, ELSE e.name_embedding END,
e.fact_summary = CASE // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> '' // e.fact_summary = CASE
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary)) // WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
THEN entity.fact_summary ELSE e.fact_summary END, // AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
// THEN entity.fact_summary ELSE e.fact_summary END,
e.connect_strength = CASE e.connect_strength = CASE
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
ELSE CASE ELSE CASE
@@ -321,7 +322,8 @@ RETURN e.id AS id,
e.description AS description, e.description AS description,
e.aliases AS aliases, e.aliases AS aliases,
e.name_embedding AS name_embedding, e.name_embedding AS name_embedding,
COALESCE(e.fact_summary, '') AS fact_summary, // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
// COALESCE(e.fact_summary, '') AS fact_summary,
e.connect_strength AS connect_strength, e.connect_strength AS connect_strength,
collect(DISTINCT s.id) AS statement_ids, collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT c.id) AS chunk_ids, collect(DISTINCT c.id) AS chunk_ids,
@@ -1002,3 +1004,58 @@ RETURN DISTINCT
x.statement as statement,x.created_at as created_at x.statement as statement,x.created_at as created_at
""" """
Graph_Node_query = """
MATCH (n:MemorySummary)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
0 AS priority
LIMIT $limit
UNION ALL
MATCH (n:Dialogue)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
1 AS priority
LIMIT 1
UNION ALL
MATCH (n:Statement)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
1 AS priority
LIMIT $limit
UNION ALL
MATCH (n:ExtractedEntity)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
2 AS priority
LIMIT $limit
UNION ALL
MATCH (n:Chunk)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
3 AS priority
LIMIT $limit
"""

View File

@@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import (
ExtractedEntityNode, ExtractedEntityNode,
EntityEntityEdge, EntityEntityEdge,
) )
import logging
logger = logging.getLogger(__name__)
async def save_entities_and_relationships( async def save_entities_and_relationships(
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
@@ -41,8 +42,8 @@ async def save_entities_and_relationships(
'statement': edge.statement, 'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat(), 'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat(), 'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id, 'run_id': edge.run_id,
'end_user_id': edge.end_user_id, 'end_user_id': edge.end_user_id,
} }
@@ -147,14 +148,14 @@ async def save_statement_entity_edges(
async def save_dialog_and_statements_to_neo4j( async def save_dialog_and_statements_to_neo4j(
dialogue_nodes: List[DialogueNode], dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode], statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
entity_edges: List[EntityEntityEdge], entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector connector: Neo4jConnector
) -> bool: ) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
@@ -171,40 +172,126 @@ async def save_dialog_and_statements_to_neo4j(
Returns: Returns:
bool: True if successful, False otherwise bool: True if successful, False otherwise
""" """
try:
# Save all dialogue nodes in batch # 定义事务函数,将所有写操作放在一个事务中
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector) async def _save_all_in_transaction(tx):
if dialogue_uuids: """在单个事务中执行所有保存操作,避免死锁"""
results = {}
# 1. Save all dialogue nodes in batch
if dialogue_nodes:
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE
dialogue_data = [node.model_dump() for node in dialogue_nodes]
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
dialogue_uuids = [record["uuid"] async for record in result]
results['dialogues'] = dialogue_uuids
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
else:
print("Failed to save dialogues to Neo4j")
return False
# Save all chunk nodes in batch # 2. Save all chunk nodes in batch
await save_chunk_nodes(chunk_nodes, connector) if chunk_nodes:
from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE
chunk_data = [node.model_dump() for node in chunk_nodes]
result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data)
chunk_uuids = [record["uuid"] async for record in result]
results['chunks'] = chunk_uuids
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
# Save all statement nodes in batch # 3. Save all statement nodes in batch
if statement_nodes: if statement_nodes:
statement_uuids = await add_statement_nodes(statement_nodes, connector) from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
if statement_uuids: statement_data = [node.model_dump() for node in statement_nodes]
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data)
else: statement_uuids = [record["uuid"] async for record in result]
print("Failed to save statement nodes to Neo4j") results['statements'] = statement_uuids
return False logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
else:
print("No statement nodes to save")
# Save entities and relationships # 4. Save entities
await save_entities_and_relationships(entity_nodes, entity_edges, connector) if entity_nodes:
print("Successfully saved entities and relationships to Neo4j") from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
entity_data = [entity.model_dump() for entity in entity_nodes]
result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data)
entity_uuids = [record["uuid"] async for record in result]
results['entities'] = entity_uuids
logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
# Save new edges # 5. Create entity relationships
await save_statement_chunk_edges(statement_chunk_edges, connector) if entity_edges:
await save_statement_entity_edges(statement_entity_edges, connector) from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE
relationship_data = []
for edge in entity_edges:
relationship_data.append({
'source_id': edge.source,
'target_id': edge.target,
'predicate': edge.relation_type,
'statement_id': edge.source_statement_id,
'value': edge.relation_value,
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id,
'end_user_id': edge.end_user_id,
})
result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data)
rel_uuids = [record["uuid"] async for record in result]
results['entity_relationships'] = rel_uuids
logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j")
# 6. Save statement-chunk edges
if statement_chunk_edges:
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE
sc_edge_data = []
for edge in statement_chunk_edges:
sc_edge_data.append({
"id": edge.id,
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
})
result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data)
sc_uuids = [record["uuid"] async for record in result]
results['statement_chunk_edges'] = sc_uuids
logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j")
# 7. Save statement-entity edges
if statement_entity_edges:
from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE
se_edge_data = []
for edge in statement_entity_edges:
se_edge_data.append({
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
"connect_strength": getattr(edge, "connect_strength", "strong"),
})
result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data)
se_uuids = [record["uuid"] async for record in result]
results['statement_entity_edges'] = se_uuids
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
return results
try:
# 使用显式写事务执行所有操作,避免死锁
results = await connector.execute_write_transaction(_save_all_in_transaction)
summary = {
key: len(value)
for key, value in results.items()
if isinstance(value, (list, tuple, set))
}
logger.info("Transaction completed. Summary: %s", summary)
logger.debug("Full transaction results: %r", results)
return True return True
except Exception as e: except Exception as e:
logger.error(f"Neo4j integration error: {e}", exc_info=True)
print(f"Neo4j integration error: {e}") print(f"Neo4j integration error: {e}")
print("Continuing without database storage...") print("Continuing without database storage...")
return False return False

View File

@@ -392,3 +392,48 @@ class OntologySceneRepository:
exc_info=True exc_info=True
) )
raise raise
def get_simple_list(self, workspace_id: UUID) -> List[dict]:
"""获取场景简单列表仅包含scene_id和scene_name用于下拉选择
这是一个轻量级查询不加载关联的classes响应速度快。
Args:
workspace_id: 工作空间ID
Returns:
List[dict]: 场景简单列表每项包含scene_id和scene_name
Examples:
>>> repo = OntologySceneRepository(db)
>>> scenes = repo.get_simple_list(workspace_id)
>>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...]
"""
try:
logger.debug(f"Getting simple scene list for workspace: {workspace_id}")
# 只查询需要的字段,不加载关联数据
results = self.db.query(
OntologyScene.scene_id,
OntologyScene.scene_name
).filter(
OntologyScene.workspace_id == workspace_id
).order_by(
OntologyScene.updated_at.desc()
).all()
scenes = [
{"scene_id": str(r.scene_id), "scene_name": r.scene_name}
for r in results
]
logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}")
return scenes
except Exception as e:
logger.error(
f"Failed to get simple scene list: {str(e)}",
exc_info=True
)
raise

View File

@@ -1,3 +1,4 @@
from abc import ABC
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
@@ -14,4 +15,15 @@ class UserInput(BaseModel):
class Write_UserInput(BaseModel): class Write_UserInput(BaseModel):
messages: list[dict] messages: list[dict]
end_user_id: str end_user_id: str
config_id: Optional[str] = None config_id: Optional[str] = None
class AgentMemory_Long_Term(ABC):
"""长期记忆配置常量"""
STORAGE_NEO4J = "neo4j"
STORAGE_RAG = "rag"
STRATEGY_AGGREGATE = "aggregate"
STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6

View File

@@ -248,8 +248,9 @@ class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id: Union[uuid.UUID, int, str] = None config_id: Union[uuid.UUID, int, str] = None
config_name: str = Field("配置名称", description="配置名称(字符串)") config_name: Optional[str] = Field(None, description="配置名称(字符串)")
config_desc: str = Field("配置描述", description="配置描述(字符串)") config_desc: Optional[str] = Field(None, description="配置描述(字符串)")
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID")
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型

View File

@@ -114,6 +114,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
result = task_service.get_task_memory_read_result(task.id) result = task_service.get_task_memory_read_result(task.id)
status = result.get("status") status = result.get("status")
logger.info(f"读取任务状态:{status}") logger.info(f"读取任务状态:{status}")
if memory_content:
memory_content = memory_content['answer']
finally: finally:
db.close() db.close()
@@ -127,11 +129,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"content_length": len(str(memory_content)) "content_length": len(str(memory_content))
} }
) )
# 检查是否有有效内容
if not memory_content or str(memory_content).strip() == "" or "answer" in str(memory_content) and str(memory_content).count("''") > 0:
return "未找到相关的历史记忆。请直接回答用户的问题,不要再次调用此工具。"
return f"检索到以下历史记忆:\n\n{memory_content}" return f"检索到以下历史记忆:\n\n{memory_content}"
except Exception as e: except Exception as e:
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__}) logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})

View File

@@ -183,11 +183,11 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# --- Read All --- # --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
configs = MemoryConfigRepository.get_all(self.db, workspace_id) results = MemoryConfigRepository.get_all(self.db, workspace_id)
# 将 ORM 对象转换为字典列表 # 将 ORM 对象转换为字典列表
data_list = [] data_list = []
for config in configs: for config, scene_name in results:
# 安全地转换 user_id 为 int # 安全地转换 user_id 为 int
config_id_old = None config_id_old = None
if config.config_id_old: if config.config_id_old:
@@ -209,7 +209,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"end_user_id": config.end_user_id, "end_user_id": config.end_user_id,
"config_id_old": config_id_old, "config_id_old": config_id_old,
"apply_id": config.apply_id, "apply_id": config.apply_id,
"scene_id": config.scene_id, "scene_id": str(config.scene_id) if config.scene_id else None,
"scene_name": scene_name, # 新增:场景名称
"llm_id": config.llm_id, "llm_id": config.llm_id,
"embedding_id": config.embedding_id, "embedding_id": config.embedding_id,
"rerank_id": config.rerank_id, "rerank_id": config.rerank_id,
@@ -637,10 +638,9 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
if m < 1: if m < 1:
latest_relative = "刚刚" latest_relative = "刚刚"
elif m < 60: elif m < 60:
latest_relative = f"{m}分钟" latest_relative = "一会"
else: else:
h = int(m // 60) latest_relative = "较早前"
latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前"
except Exception: except Exception:
pass pass

View File

@@ -15,6 +15,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.repositories.conversation_repository import ConversationRepository from app.repositories.conversation_repository import ConversationRepository
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.cypher_queries import Graph_Node_query
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService from app.services.implicit_memory_service import ImplicitMemoryService
@@ -1521,7 +1522,6 @@ async def analytics_graph_data(
user_uuid = uuid.UUID(end_user_id) user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db) repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid) end_user = repo.get_by_id(user_uuid)
if not end_user: if not end_user:
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
return { return {
@@ -1575,21 +1575,11 @@ async def analytics_graph_data(
} }
else: else:
# 查询所有节点 # 查询所有节点
node_query = """ node_query=Graph_Node_query
MATCH (n)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) as id,
labels(n)[0] as label,
properties(n) as properties
LIMIT $limit
"""
node_params = { node_params = {
"end_user_id": end_user_id, "end_user_id": end_user_id,
"limit": limit "limit": limit
} }
# 执行节点查询 # 执行节点查询
node_results = await _neo4j_connector.execute_query(node_query, **node_params) node_results = await _neo4j_connector.execute_query(node_query, **node_params)
@@ -1600,9 +1590,9 @@ async def analytics_graph_data(
for record in node_results: for record in node_results:
node_id = record["id"] node_id = record["id"]
node_label = record["label"] node_labels = record.get("labels", [])
node_label = node_labels[0] if node_labels else "Unknown"
node_props = record["properties"] node_props = record["properties"]
# 根据节点类型提取需要的属性字段 # 根据节点类型提取需要的属性字段
filtered_props = await _extract_node_properties(node_label, node_props,node_id) filtered_props = await _extract_node_properties(node_label, node_props,node_id)

View File

@@ -5,42 +5,68 @@ Shared utilities for configuration handling to avoid circular imports.
""" """
from uuid import UUID from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
import uuid as uuid_module
def resolve_config_id(config_id: UUID | int|str, db: Session) -> UUID: def resolve_config_id(config_id: UUID | int | str, db: Session) -> UUID:
""" """
解析 config_id如果是整数则通过 config_id_old 查找对应的 UUID 解析 config_id支持 UUID、UUID字符串、整数等多种格式
Args: Args:
config_id: 配置IDUUID 或整数) config_id: 配置IDUUID、UUID字符串 整数)
db: 数据库会话 db: 数据库会话
Returns: Returns:
UUID: 解析后的配置ID UUID: 解析后的配置ID
Raises: Raises:
ValueError: 当找不到对应的配置时 ValueError: 当找不到对应的配置时或格式无效时
""" """
from app.models.memory_config_model import MemoryConfig from app.models.memory_config_model import MemoryConfig
if isinstance(config_id, UUID):
# 1. 如果已经是 UUID 类型,直接返回
if isinstance(config_id, UUID):
return config_id return config_id
if isinstance(config_id, str) and len(config_id)<=6:
memory_config = db.query(MemoryConfig).filter( # 2. 如果是字符串类型
MemoryConfig.config_id_old == int(config_id) if isinstance(config_id, str):
).first() config_id_stripped = config_id.strip()
print(memory_config)
if not memory_config: # 2.1 尝试解析为 UUID标准 UUID 字符串长度为 36
raise ValueError(f"STR 未找到 config_id_old={config_id} 对应的配置") try:
return memory_config.config_id return uuid_module.UUID(config_id_stripped)
except ValueError:
pass
# 2.2 尝试解析为整数(用于查询 config_id_old
try:
old_id = int(config_id_stripped)
if old_id > 0:
memory_config = db.query(MemoryConfig).filter(
MemoryConfig.config_id_old == old_id
).first()
if not memory_config:
raise ValueError(f"未找到 config_id_old={old_id} 对应的配置")
return memory_config.config_id
except ValueError:
pass
# 2.3 无法解析的字符串格式
raise ValueError(f"无效的 config_id 格式: '{config_id}'(必须是 UUID 或正整数)")
# 3. 如果是整数类型,通过 config_id_old 查找
if isinstance(config_id, int): if isinstance(config_id, int):
if config_id <= 0:
raise ValueError(f"config_id 必须是正整数: {config_id}")
memory_config = db.query(MemoryConfig).filter( memory_config = db.query(MemoryConfig).filter(
MemoryConfig.config_id_old == config_id MemoryConfig.config_id_old == config_id
).first() ).first()
if not memory_config: if not memory_config:
raise ValueError(f"INT 未找到 config_id_old={config_id} 对应的配置") raise ValueError(f"未找到 config_id_old={config_id} 对应的配置")
return memory_config.config_id return memory_config.config_id
return config_id # 4. 不支持的类型
raise ValueError(f"不支持的 config_id 类型: {type(config_id).__name__}")

View File

@@ -13,6 +13,14 @@
"@antv/layout": "^1.2.14-beta.8", "@antv/layout": "^1.2.14-beta.8",
"@antv/x6": "^3.0.1", "@antv/x6": "^3.0.1",
"@antv/x6-react-shape": "^3.0.1", "@antv/x6-react-shape": "^3.0.1",
"@codemirror/lang-cpp": "^6.0.3",
"@codemirror/lang-java": "^6.0.2",
"@codemirror/lang-javascript": "^6.2.4",
"@codemirror/lang-python": "^6.2.1",
"@codemirror/lang-rust": "^6.0.2",
"@codemirror/state": "^6.5.4",
"@codemirror/theme-one-dark": "^6.1.3",
"@codemirror/view": "^6.39.12",
"@dnd-kit/core": "^6.3.1", "@dnd-kit/core": "^6.3.1",
"@dnd-kit/modifiers": "^9.0.0", "@dnd-kit/modifiers": "^9.0.0",
"@dnd-kit/sortable": "^10.0.0", "@dnd-kit/sortable": "^10.0.0",
@@ -25,6 +33,7 @@
"antd": "^5.27.4", "antd": "^5.27.4",
"axios": "^1.12.2", "axios": "^1.12.2",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"codemirror": "^6.0.2",
"copy-to-clipboard": "^3.3.3", "copy-to-clipboard": "^3.3.3",
"crypto-js": "^4.2.0", "crypto-js": "^4.2.0",
"dayjs": "^1.11.18", "dayjs": "^1.11.18",
@@ -55,6 +64,7 @@
"@tailwindcss/postcss": "^4.1.14", "@tailwindcss/postcss": "^4.1.14",
"@tailwindcss/typography": "^0.5.19", "@tailwindcss/typography": "^0.5.19",
"@tailwindcss/vite": "^4.1.14", "@tailwindcss/vite": "^4.1.14",
"@types/codemirror": "^5.60.17",
"@types/crypto-js": "^4.2.2", "@types/crypto-js": "^4.2.2",
"@types/js-yaml": "^4.0.9", "@types/js-yaml": "^4.0.9",
"@types/node": "^24.6.0", "@types/node": "^24.6.0",

View File

@@ -8,6 +8,7 @@ import { request } from '@/utils/request'
import type { Query, OntologyModalData, OntologyClassModalData, OntologyClassExtractModalData, OntologyExportModalData } from '@/views/Ontology/types' import type { Query, OntologyModalData, OntologyClassModalData, OntologyClassExtractModalData, OntologyExportModalData } from '@/views/Ontology/types'
// Scene list // Scene list
export const getOntologyScenesSimpleUrl = '/memory/ontology/scenes/simple'
export const getOntologyScenesUrl = '/memory/ontology/scenes' export const getOntologyScenesUrl = '/memory/ontology/scenes'
export const getOntologyScenesList = (data: Query) => { export const getOntologyScenesList = (data: Query) => {
return request.get(getOntologyScenesUrl, data) return request.get(getOntologyScenesUrl, data)

View File

@@ -0,0 +1,150 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-04 17:20:52
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-04 17:20:52
*/
import { useEffect, useRef, useMemo } from 'react';
import { EditorView, basicSetup } from 'codemirror';
import { EditorState } from '@codemirror/state';
import { python } from '@codemirror/lang-python';
import { javascript } from '@codemirror/lang-javascript';
import { java } from '@codemirror/lang-java';
import { cpp } from '@codemirror/lang-cpp';
import { rust } from '@codemirror/lang-rust';
import { oneDark } from '@codemirror/theme-one-dark';
/**
* Props for the CodeMirrorEditor component
* @property {string} value - The initial code content to display in the editor
* @property {string} language - Programming language for syntax highlighting (python, python3, javascript, typescript, java, cpp, c, rust)
* @property {function} onChange - Callback function triggered when editor content changes, receives the new code value
* @property {string} theme - Editor theme, either 'light' or 'dark'
* @property {boolean} readOnly - Whether the editor is read-only
* @property {string} height - Custom height for the editor
* @property {string} size - Predefined size preset: 'default' (120px min-height, 14px font) or 'small' (60px min-height, 12px font)
*/
interface CodeMirrorEditorProps {
value?: string;
language?: 'python' | 'python3' | 'javascript' | 'typescript' | 'java' | 'cpp' | 'c' | 'rust';
onChange?: (value: string) => void;
theme?: 'light' | 'dark';
readOnly?: boolean;
height?: string;
size?: 'default' | 'small';
}
/**
* Map of language identifiers to their corresponding CodeMirror language extensions
* Supports multiple programming languages with syntax highlighting
*/
const languageExtensions: Record<string, any> = {
python: python(),
python3: python(),
javascript: javascript(),
typescript: javascript({ typescript: true }),
java: java(),
cpp: cpp(),
c: cpp(),
rust: rust(),
};
/**
* CodeMirrorEditor - A React wrapper component for CodeMirror 6 editor
* Provides a code editor with syntax highlighting, theme support, and customizable sizing
* Used in workflow code execution nodes for editing Python and JavaScript code
*/
const CodeMirrorEditor = ({
value = '',
language = 'javascript',
onChange,
theme = 'light',
readOnly = false,
size,
}: CodeMirrorEditorProps) => {
// Reference to the DOM element that will contain the editor
const editorRef = useRef<HTMLDivElement>(null);
// Reference to the CodeMirror EditorView instance
const viewRef = useRef<EditorView | null>(null);
/**
* Initialize CodeMirror editor when component mounts or when language/theme/readOnly changes
* Sets up extensions for syntax highlighting, change listeners, and theme
*/
useEffect(() => {
if (!editorRef.current) return;
// Get the appropriate language extension, fallback to JavaScript if not found
const langExtension = languageExtensions[language] || languageExtensions.javascript;
// Configure editor extensions
const extensions = [
basicSetup, // Basic editor features (line numbers, bracket matching, etc.)
langExtension, // Language-specific syntax highlighting
// Listen for document changes and trigger onChange callback
EditorView.updateListener.of((update) => {
if (update.docChanged && onChange) {
onChange(update.state.doc.toString());
}
}),
EditorState.readOnly.of(readOnly), // Set read-only mode
];
// Apply dark theme if specified
if (theme === 'dark') {
extensions.push(oneDark);
}
// Create editor state with initial value and extensions
const state = EditorState.create({
doc: value,
extensions,
});
// Create and mount the editor view
viewRef.current = new EditorView({
state,
parent: editorRef.current,
});
// Cleanup: destroy editor instance when component unmounts or dependencies change
return () => {
viewRef.current?.destroy();
};
}, [language, theme, readOnly]);
/**
* Update editor content when the value prop changes externally
* Only updates if the new value differs from current editor content
*/
useEffect(() => {
if (viewRef.current && value !== viewRef.current.state.doc.toString()) {
viewRef.current.dispatch({
changes: {
from: 0,
to: viewRef.current.state.doc.length,
insert: value,
},
});
}
}, [value]);
// Calculate minimum height based on size prop: small (60px) or default (120px)
const minHeight = useMemo(() => {
return `${size === 'small' ? 60 : 120}px`
}, [size])
// Calculate font size based on size prop: small (12px) or default (14px)
const fontSize = useMemo(() => {
return `${size === 'small' ? 12 : 14}px`
}, [size])
// Calculate line height based on size prop: small (16px) or default (20px)
const lineHeight = useMemo(() => {
return `${size === 'small' ? 16 : 20}px`
}, [size])
return <div ref={editorRef} style={{ minHeight, fontSize, lineHeight }} />;
};
export default CodeMirrorEditor;

View File

@@ -81,7 +81,7 @@ const components = {
audio: ({ src, ...props }: any) => <AudioBlock node={{ children: [{ properties: { src: src || '' } }] }} {...props} />, audio: ({ src, ...props }: any) => <AudioBlock node={{ children: [{ properties: { src: src || '' } }] }} {...props} />,
a: ({ href, children, ...props }: any) => <Link href={href || '#'} {...props}>{children}</Link>, a: ({ href, children, ...props }: any) => <Link href={href || '#'} {...props}>{children}</Link>,
button: ({ children }: any) => <RbButton node={{ children }}>{[children]}</RbButton>, button: ({ children }: any) => <RbButton node={{ children }}>{[children]}</RbButton>,
table: ({ children, ...props }: any) => <table className="rb:border rb:border-[#D9D9D9] rb:mb-2" {...props}>{children}</table>, table: ({ children, ...props }: any) => <div className="rb:overflow-x-auto rb:max-w-full"><table className="rb:border rb:border-[#D9D9D9] rb:mb-2" {...props}>{children}</table></div>,
tr: ({ children, ...props }: any) => <tr className="rb:border rb:border-[#D9D9D9]" {...props}>{children}</tr>, tr: ({ children, ...props }: any) => <tr className="rb:border rb:border-[#D9D9D9]" {...props}>{children}</tr>,
th: ({ children, ...props }: any) => <th className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left rb:font-bold" {...props}>{children}</th>, th: ({ children, ...props }: any) => <th className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left rb:font-bold" {...props}>{children}</th>,
td: ({ children, ...props }: any) => <td className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left" {...props}>{children}</td>, td: ({ children, ...props }: any) => <td className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left" {...props}>{children}</td>,

View File

@@ -180,4 +180,9 @@ body {
.x6-node foreignObject > body { .x6-node foreignObject > body {
min-height: 100%; min-height: 100%;
max-height: 100%; max-height: 100%;
}
.ͼ2 .cm-gutters {
background-color: #FFFFFF;
border: none;
} }

View File

@@ -140,7 +140,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
title={t('application.knowledgeBaseAssociation')} title={t('application.knowledgeBaseAssociation')}
extra={ extra={
<Space> <Space>
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleKnowledgeConfig}>{t('workflow.config.knowledge-retrieval.recallConfig')}</Button> <Button style={{ padding: '0 8px', height: '24px' }} onClick={handleKnowledgeConfig}>{t('application.globalConfig')}</Button>
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleAddKnowledge}>+</Button> <Button style={{ padding: '0 8px', height: '24px' }} onClick={handleAddKnowledge}>+</Button>
</Space> </Space>
} }

View File

@@ -16,7 +16,7 @@ import { useTranslation } from 'react-i18next';
import type { MemoryFormData, Memory, MemoryFormRef } from '../types'; import type { MemoryFormData, Memory, MemoryFormRef } from '../types';
import RbModal from '@/components/RbModal' import RbModal from '@/components/RbModal'
import { createMemoryConfig, updateMemoryConfig } from '@/api/memory' import { createMemoryConfig, updateMemoryConfig } from '@/api/memory'
import { getOntologyScenesUrl } from '@/api/ontology' import { getOntologyScenesSimpleUrl } from '@/api/ontology'
import CustomSelect from '@/components/CustomSelect'; import CustomSelect from '@/components/CustomSelect';
const FormItem = Form.Item; const FormItem = Form.Item;
@@ -129,8 +129,7 @@ const MemoryForm = forwardRef<MemoryFormRef, MemoryFormProps>(({
> >
<CustomSelect <CustomSelect
placeholder={t('common.pleaseSelect')} placeholder={t('common.pleaseSelect')}
url={getOntologyScenesUrl} url={getOntologyScenesSimpleUrl}
params={{ pagesize: 100, page: 1 }}
hasAll={false} hasAll={false}
valueKey='scene_id' valueKey='scene_id'
labelKey="scene_name" labelKey="scene_name"

View File

@@ -112,7 +112,7 @@ const MemoryManagement: React.FC = () => {
title={item.config_name} title={item.config_name}
> >
<Tooltip title={item.config_desc}> <Tooltip title={item.config_desc}>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1">{item.config_desc}</div> <div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-[17px]">{item.config_desc}</div>
</Tooltip> </Tooltip>
<RbAlert className="rb:mt-3 "> <RbAlert className="rb:mt-3 ">
<div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}> <div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}>

View File

@@ -103,9 +103,9 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
{model.api_keys && model.api_keys.length > 0 && ( {model.api_keys && model.api_keys.length > 0 && (
<div className="rb:mb-4"> <div className="rb:mb-4">
{model.api_keys.map((key) => ( {model.api_keys.map((key) => (
<div key={key.id} className="rb:flex rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2"> <div key={key.id} className="rb:flex rb:gap-3 rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2">
<div> <div className="rb:flex-1">
<div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium">{key.api_key}</div> <div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium rb:break-all">{key.api_key}</div>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:mt-1">{key.api_base}</div> <div className="rb:text-[#5B6167] rb:text-[12px] rb:mt-1">{key.api_base}</div>
</div> </div>
<Button type="primary" danger ghost onClick={() => handleDelete(key.id)}>{t('common.remove')}</Button> <Button type="primary" danger ghost onClick={() => handleDelete(key.id)}>{t('common.remove')}</Button>

View File

@@ -15,8 +15,6 @@ import CharacterCountPlugin from './plugin/CharacterCountPlugin'
import InitialValuePlugin from './plugin/InitialValuePlugin'; import InitialValuePlugin from './plugin/InitialValuePlugin';
import CommandPlugin from './plugin/CommandPlugin'; import CommandPlugin from './plugin/CommandPlugin';
import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin'; import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin';
import Python3HighlightPlugin from './plugin/Python3HighlightPlugin';
import JavaScriptHighlightPlugin from './plugin/JavaScriptHighlightPlugin';
import LineNumberPlugin from './plugin/LineNumberPlugin'; import LineNumberPlugin from './plugin/LineNumberPlugin';
import BlurPlugin from './plugin/BlurPlugin'; import BlurPlugin from './plugin/BlurPlugin';
import { VariableNode } from './nodes/VariableNode' import { VariableNode } from './nodes/VariableNode'
@@ -32,7 +30,7 @@ export interface LexicalEditorProps {
lineHeight?: number; lineHeight?: number;
size?: 'default' | 'small'; size?: 'default' | 'small';
type?: 'input' | 'textarea', type?: 'input' | 'textarea',
language?: 'string' | 'jinja2' | 'python3' | 'javascript' language?: 'string' | 'jinja2'
} }
const theme = { const theme = {
@@ -67,7 +65,7 @@ const Editor: FC<LexicalEditorProps> =({
const [enableLineNumbers, setEnableLineNumbers] = useState(false) const [enableLineNumbers, setEnableLineNumbers] = useState(false)
useEffect(() => { useEffect(() => {
const needsLineNumbers = language === 'jinja2' || language === 'python3' || language === 'javascript'; const needsLineNumbers = language === 'jinja2';
setEnableJinja2(language === 'jinja2'); setEnableJinja2(language === 'jinja2');
setEnableLineNumbers(needsLineNumbers); setEnableLineNumbers(needsLineNumbers);
@@ -237,13 +235,11 @@ const Editor: FC<LexicalEditorProps> =({
<HistoryPlugin /> <HistoryPlugin />
<CommandPlugin /> <CommandPlugin />
{language === 'jinja2' && <Jinja2HighlightPlugin />} {language === 'jinja2' && <Jinja2HighlightPlugin />}
{language === 'python3' && <Python3HighlightPlugin />}
{language === 'javascript' && <JavaScriptHighlightPlugin />}
{enableLineNumbers && <LineNumberPlugin />} {enableLineNumbers && <LineNumberPlugin />}
<AutocompletePlugin options={options} enableJinja2={enableJinja2} /> <AutocompletePlugin options={options} enableJinja2={enableJinja2} />
<CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} /> <CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} />
<InitialValuePlugin key={language} value={value} options={options} enableLineNumbers={enableLineNumbers} /> <InitialValuePlugin value={value} options={options} enableLineNumbers={enableLineNumbers} />
{enableLineNumbers && <BlurPlugin />} {enableJinja2 && <BlurPlugin />}
</div> </div>
</LexicalComposer> </LexicalComposer>
); );

View File

@@ -1,182 +0,0 @@
import { useEffect, useRef } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
const JS_KEYWORDS = new Set([
'async', 'await', 'break', 'case', 'catch', 'class', 'const', 'continue', 'debugger', 'default',
'delete', 'do', 'else', 'export', 'extends', 'finally', 'for', 'function', 'if', 'import',
'in', 'instanceof', 'let', 'new', 'return', 'super', 'switch', 'this', 'throw', 'try',
'typeof', 'var', 'void', 'while', 'with', 'yield', 'true', 'false', 'null', 'undefined'
]);
const JavaScriptHighlightPlugin = () => {
const [editor] = useLexicalComposerContext();
const isPastingRef = useRef(false);
useEffect(() => {
return editor.registerCommand(
PASTE_COMMAND,
() => {
isPastingRef.current = true;
setTimeout(() => {
isPastingRef.current = false;
}, 100);
return false;
},
COMMAND_PRIORITY_LOW
);
}, [editor]);
useEffect(() => {
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
if (isPastingRef.current) return;
const text = textNode.getTextContent();
if (textNode.hasFormat('code')) return;
if (!needsHighlight(text)) return;
if (textNode.getStyle()) return;
const parent = textNode.getParent();
if (!parent) return;
const selection = $getSelection();
let selectionOffset = null;
if ($isRangeSelection(selection)) {
const anchor = selection.anchor;
if (anchor.getNode() === textNode) {
selectionOffset = anchor.offset;
}
}
const tokens = tokenizeJavaScript(text);
if (tokens.length <= 1) return;
const newNodes = tokens.map(token => {
const newNode = $createTextNode(token.text);
newNode.toggleFormat('code');
switch (token.type) {
case 'keyword':
newNode.setStyle('color: #d73a49; font-weight: 600;');
break;
case 'string':
newNode.setStyle('color: #032f62;');
break;
case 'comment':
newNode.setStyle('color: #6a737d; font-style: italic;');
break;
case 'number':
newNode.setStyle('color: #005cc5; font-weight: 500;');
break;
case 'function':
newNode.setStyle('color: #6f42c1; font-weight: 500;');
break;
}
return newNode;
});
if (newNodes.length > 1) {
textNode.replace(newNodes[0]);
for (let i = 1; i < newNodes.length; i++) {
newNodes[i - 1].insertAfter(newNodes[i]);
}
if (selectionOffset !== null && $isRangeSelection(selection)) {
let currentOffset = 0;
for (const node of newNodes) {
const nodeLength = node.getTextContent().length;
if (currentOffset + nodeLength >= selectionOffset) {
node.select(selectionOffset - currentOffset, selectionOffset - currentOffset);
break;
}
currentOffset += nodeLength;
}
}
}
});
}, [editor]);
return null;
};
function needsHighlight(text: string): boolean {
return /[a-zA-Z0-9_/"'`]/.test(text);
}
function tokenizeJavaScript(text: string): Array<{text: string, type: string}> {
const tokens: Array<{text: string, type: string}> = [];
let i = 0;
while (i < text.length) {
// Single-line comments
if (text.slice(i, i + 2) === '//') {
let start = i;
while (i < text.length && text[i] !== '\n') i++;
tokens.push({ text: text.slice(start, i), type: 'comment' });
continue;
}
// Multi-line comments
if (text.slice(i, i + 2) === '/*') {
let start = i;
i += 2;
while (i < text.length && text.slice(i, i + 2) !== '*/') i++;
if (i < text.length) i += 2;
tokens.push({ text: text.slice(start, i), type: 'comment' });
continue;
}
// Strings
if (text[i] === '"' || text[i] === "'" || text[i] === '`') {
const quote = text[i];
let start = i++;
while (i < text.length) {
if (text[i] === quote && text[i - 1] !== '\\') {
i++;
break;
}
i++;
}
tokens.push({ text: text.slice(start, i), type: 'string' });
continue;
}
// Numbers
if (/\d/.test(text[i])) {
let start = i;
while (i < text.length && /[\d.]/.test(text[i])) i++;
tokens.push({ text: text.slice(start, i), type: 'number' });
continue;
}
// Keywords and identifiers
if (/[a-zA-Z_$]/.test(text[i])) {
let start = i;
while (i < text.length && /[a-zA-Z0-9_$]/.test(text[i])) i++;
const word = text.slice(start, i);
if (JS_KEYWORDS.has(word)) {
tokens.push({ text: word, type: 'keyword' });
} else if (i < text.length && text[i] === '(') {
tokens.push({ text: word, type: 'function' });
} else {
tokens.push({ text: word, type: 'text' });
}
continue;
}
// Other characters
let start = i;
while (i < text.length && !/[a-zA-Z0-9_$/"'`]/.test(text[i])) i++;
if (start < i) {
tokens.push({ text: text.slice(start, i), type: 'text' });
}
}
return tokens;
}
export default JavaScriptHighlightPlugin;

View File

@@ -1,177 +0,0 @@
import { useEffect, useRef } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
const PYTHON_KEYWORDS = new Set([
'False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue',
'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import',
'in', 'is', 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while',
'with', 'yield'
]);
const Python3HighlightPlugin = () => {
const [editor] = useLexicalComposerContext();
const isPastingRef = useRef(false);
useEffect(() => {
return editor.registerCommand(
PASTE_COMMAND,
() => {
isPastingRef.current = true;
setTimeout(() => {
isPastingRef.current = false;
}, 100);
return false;
},
COMMAND_PRIORITY_LOW
);
}, [editor]);
useEffect(() => {
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
if (isPastingRef.current) return;
const text = textNode.getTextContent();
if (textNode.hasFormat('code')) return;
if (textNode.getStyle()) return;
if (!needsHighlight(text)) return;
const parent = textNode.getParent();
if (!parent) return;
const selection = $getSelection();
let selectionOffset = null;
if ($isRangeSelection(selection)) {
const anchor = selection.anchor;
if (anchor.getNode() === textNode) {
selectionOffset = anchor.offset;
}
}
const tokens = tokenizePython(text);
if (tokens.length <= 1) return;
const newNodes = tokens.map(token => {
const newNode = $createTextNode(token.text);
newNode.toggleFormat('code');
switch (token.type) {
case 'keyword':
newNode.setStyle('color: #d73a49; font-weight: 600;');
break;
case 'string':
newNode.setStyle('color: #032f62;');
break;
case 'comment':
newNode.setStyle('color: #6a737d; font-style: italic;');
break;
case 'number':
newNode.setStyle('color: #005cc5; font-weight: 500;');
break;
case 'function':
newNode.setStyle('color: #6f42c1; font-weight: 500;');
break;
}
return newNode;
});
if (newNodes.length > 1) {
textNode.replace(newNodes[0]);
for (let i = 1; i < newNodes.length; i++) {
newNodes[i - 1].insertAfter(newNodes[i]);
}
if (selectionOffset !== null && $isRangeSelection(selection)) {
let currentOffset = 0;
for (const node of newNodes) {
const nodeLength = node.getTextContent().length;
if (currentOffset + nodeLength >= selectionOffset) {
node.select(selectionOffset - currentOffset, selectionOffset - currentOffset);
break;
}
currentOffset += nodeLength;
}
}
}
});
}, [editor]);
return null;
};
function needsHighlight(text: string): boolean {
return /[a-zA-Z0-9_#"']/.test(text);
}
function tokenizePython(text: string): Array<{text: string, type: string}> {
const tokens: Array<{text: string, type: string}> = [];
let i = 0;
while (i < text.length) {
// Comments
if (text[i] === '#') {
let start = i;
while (i < text.length && text[i] !== '\n') i++;
tokens.push({ text: text.slice(start, i), type: 'comment' });
continue;
}
// Strings
if (text[i] === '"' || text[i] === "'") {
const quote = text[i];
let start = i++;
const isTriple = text.slice(start, start + 3) === quote.repeat(3);
if (isTriple) i += 2;
while (i < text.length) {
if (isTriple && text.slice(i, i + 3) === quote.repeat(3)) {
i += 3;
break;
} else if (!isTriple && text[i] === quote && text[i - 1] !== '\\') {
i++;
break;
}
i++;
}
tokens.push({ text: text.slice(start, i), type: 'string' });
continue;
}
// Numbers
if (/\d/.test(text[i])) {
let start = i;
while (i < text.length && /[\d.]/.test(text[i])) i++;
tokens.push({ text: text.slice(start, i), type: 'number' });
continue;
}
// Keywords and identifiers
if (/[a-zA-Z_]/.test(text[i])) {
let start = i;
while (i < text.length && /[a-zA-Z0-9_]/.test(text[i])) i++;
const word = text.slice(start, i);
if (PYTHON_KEYWORDS.has(word)) {
tokens.push({ text: word, type: 'keyword' });
} else if (i < text.length && text[i] === '(') {
tokens.push({ text: word, type: 'function' });
} else {
tokens.push({ text: word, type: 'text' });
}
continue;
}
// Other characters
let start = i;
while (i < text.length && !/[a-zA-Z0-9_#"']/.test(text[i])) i++;
if (start < i) {
tokens.push({ text: text.slice(start, i), type: 'text' });
}
}
return tokens;
}
export default Python3HighlightPlugin;

View File

@@ -5,8 +5,8 @@ import { Node } from '@antv/x6'
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
import MappingList from '../MappingList' import MappingList from '../MappingList'
import Editor from '../../Editor'
import OutputList from './OutputList' import OutputList from './OutputList'
import CodeMirrorEditor from '@/components/CodeMirrorEditor';
interface MappingItem { interface MappingItem {
name?: string name?: string
@@ -110,7 +110,10 @@ const CodeExecution: FC<CodeExecutionProps> = ({ options }) => {
<Form.Item noStyle shouldUpdate={(prev, curr) => prev.language !== curr.language}> <Form.Item noStyle shouldUpdate={(prev, curr) => prev.language !== curr.language}>
{() => ( {() => (
<Form.Item name="code" noStyle> <Form.Item name="code" noStyle>
<Editor size="small" language={form.getFieldValue('language')} /> <CodeMirrorEditor
language={form.getFieldValue('language')}
size="small"
/>
</Form.Item> </Form.Item>
)} )}
</Form.Item> </Form.Item>

View File

@@ -126,7 +126,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
<div <div
className="rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/workflow/recall.svg')] rb:group-hover:bg-[url('@/assets/images/workflow/recall_hover.svg')]" className="rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/workflow/recall.svg')] rb:group-hover:bg-[url('@/assets/images/workflow/recall_hover.svg')]"
></div> ></div>
{t('workflow.config.knowledge-retrieval.recallConfig')} {t('application.globalConfig')}
</Button> </Button>
</div> </div>