Merge branch 'develop' into fix/memory-enduser-config

This commit is contained in:
Ke Sun
2026-02-06 15:14:34 +08:00
85 changed files with 4465 additions and 865 deletions

View File

@@ -76,6 +76,7 @@ celery_app.conf.update(
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},

View File

@@ -9,13 +9,16 @@ from sqlalchemy import or_
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.rag.common import settings
from app.core.rag.integrations.feishu.client import FeishuAPIClient
from app.core.rag.integrations.yuque.client import YuqueAPIClient
from app.core.rag.llm.chat_model import Base
from app.core.rag.nlp import rag_tokenizer, search
from app.core.rag.prompts.generator import graph_entity_types
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success
from app.core.response_utils import success, fail
from app.db import get_db
from app.dependencies import get_current_user
from app.models import knowledge_model
@@ -484,3 +487,99 @@ async def rebuild_knowledge_graph(
except Exception as e:
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
raise
@router.get("/check/yuque/auth", response_model=ApiResponse)
async def check_yuque_auth(
yuque_user_id: str,
yuque_token: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
check yuque auth info
"""
api_logger.info(f"check yuque auth info, username: {current_user.username}")
try:
api_client = YuqueAPIClient(
user_id=yuque_user_id,
token=yuque_token
)
async with api_client as client:
repos = await client.get_user_repos()
if repos:
return success(data=repos, msg="Successfully auth yuque info")
return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"auth yuque info failed: {str(e)}")
raise
@router.get("/check/feishu/auth", response_model=ApiResponse)
async def check_yuque_auth(
feishu_app_id: str,
feishu_app_secret: str,
feishu_folder_token: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
check feishu auth info
"""
api_logger.info(f"check feishu auth info, username: {current_user.username}")
try:
api_client = FeishuAPIClient(
app_id=feishu_app_id,
app_secret=feishu_app_secret
)
async with api_client as client:
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
if files:
return success(data=files, msg="Successfully auth feishu info")
return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"auth feishu info failed: {str(e)}")
raise
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
async def sync_knowledge(
knowledge_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
sync knowledge base information based on knowledge_id
"""
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Query knowledge base information from the database
api_logger.debug(f"Query knowledge base: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 2. sync knowledge
# from app.tasks import sync_knowledge_for_kb
# sync_knowledge_for_kb(kb_id)
task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id])
result = {
"task_id": task.id
}
return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}")
raise

View File

@@ -191,6 +191,11 @@ def update_config(
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try:
svc = DataConfigService(db)

View File

@@ -52,6 +52,7 @@ from app.services.ontology_service import OntologyService
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.validation.owl_validator import OWLValidator
from app.services.model_service import ModelConfigService
from app.repositories.ontology_scene_repository import OntologySceneRepository
api_logger = get_api_logger()
@@ -116,27 +117,35 @@ def _get_ontology_service(
detail=f"找不到指定的LLM模型: {llm_id}"
)
# 验证模型配置了API密钥
if not model_config.api_keys:
logger.error(f"Model {llm_id} has no API key configuration")
# 通过 Repository 获取可用的 API Key负载均衡逻辑由 Repository 处理)
from app.repositories.model_repository import ModelApiKeyRepository
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
if not api_keys:
logger.error(f"Model {llm_id} has no active API key")
raise HTTPException(
status_code=400,
detail="指定的LLM模型没有配置API密钥"
detail="指定的LLM模型没有可用的API密钥"
)
api_key_config = api_keys[0]
api_key_config = model_config.api_keys[0]
is_composite = getattr(model_config, 'is_composite', False)
logger.info(
f"Using specified model - user: {current_user.id}, "
f"model_id: {llm_id}, model_name: {api_key_config.model_name}"
f"model_id: {llm_id}, model_name: {api_key_config.model_name}, "
f"is_composite: {is_composite}, api_key_id: {api_key_config.id}"
)
# 创建模型配置对象
from app.core.models.base import RedBearModelConfig
# 对于组合模型,使用 API Key 的 provider否则使用 model_config 的 provider
actual_provider = api_key_config.provider if is_composite else (
getattr(model_config, 'provider', None) or "openai"
)
llm_model_config = RedBearModelConfig(
model_name=api_key_config.model_name,
provider=model_config.provider if hasattr(model_config, 'provider') else "openai",
provider=actual_provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
max_retries=3,
@@ -648,6 +657,46 @@ async def delete_scene(
return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e))
@router.get("/scenes/simple", response_model=ApiResponse)
async def get_scenes_simple(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取场景简单列表(轻量级,用于下拉选择)
仅返回 scene_id 和 scene_name不加载关联数据响应速度快。
适用于前端下拉选择场景的场景。
Args:
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含场景简单列表
Examples:
GET /scenes/simple
返回: {"data": [{"scene_id": "xxx", "scene_name": "场景1"}, ...]}
"""
api_logger.info(f"Simple scene list requested by user {current_user.id}")
try:
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
repo = OntologySceneRepository(db)
scenes = repo.get_simple_list(workspace_id)
api_logger.info(f"Simple scene list retrieved: {len(scenes)} scenes")
return success(data=scenes, msg="查询成功")
except Exception as e:
api_logger.error(f"Failed to get simple scene list: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
@router.get("/scenes", response_model=ApiResponse)
async def get_scenes(
workspace_id: Optional[str] = None,

View File

@@ -7,30 +7,21 @@ LangChain Agent 封装
- 支持流式输出
- 使用 RedBearLLM 支持多提供商
"""
import os
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from app.utils.config_utils import resolve_config_id
logger = get_business_logger()
@@ -289,105 +280,6 @@ class LangChainAgent:
return content_parts
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db())
#TODO: 魔法数字
scope=6
try:
repo = LongTermMemoryRepository(db)
await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
# Handle case where no session exists in Redis (returns False)
if not result or result is False:
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
return
if type=="chunk" or type=="aggregate":
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'写入短长期:')
else:
# TODO: This branch handles type="time" strategy, currently unused.
# Will be activated when time-based long-term storage is implemented.
# TODO: 魔法数字 - extract 5 to a constant
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
return
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
finally:
db.close()
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
db = next(get_db())
try:
actual_config_id=resolve_config_id(actual_config_id, db)
if storage_type == "rag":
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if user_message:
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j"
user_rag_memory_id # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def chat(
self,
message: str,
@@ -520,14 +412,7 @@ class LangChainAgent:
elapsed_time = time.time() - start_time
if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
response = {
"content": content,
"model": self.model_name,
@@ -710,15 +595,7 @@ class LangChainAgent:
yield total_tokens
break
if memory_flag:
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise

View File

@@ -1,8 +1,9 @@
import json
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
@@ -10,46 +11,108 @@ from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.db import get_db_context, get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
actual_config_id, long_term_messages=[]):
"""
写入记忆(支持结构化消息)
async def write_messages(end_user_id,langchain_messages,memory_config):
'''
写入数据到neo4j
Args:
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
memory_config: 内存配置对象
langchain_messages原始数据LIST
'''
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
db = next(get_db())
try:
actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if isinstance(user_message, str) and user_message.strip() != "":
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if isinstance(ai_message, str) and ai_message.strip() != "":
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果提供了 long_term_messages使用它替代 structured_messages
if long_term_messages and isinstance(long_term_messages, list):
structured_messages = long_term_messages
elif long_term_messages and isinstance(long_term_messages, str):
# 如果是 JSON 字符串,先解析
try:
structured_messages = json.loads(long_term_messages)
except json.JSONDecodeError:
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: JSON 字符串格式的消息列表
str(actual_config_id), # config_id: 配置ID字符串
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------')
else:
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config
}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
# TODO删除
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
print(contents)
except Exception as e:
import traceback
traceback.print_exc()
'''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
'''
@@ -61,25 +124,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
scope窗口大小
'''
scope=scope
redis_messages = []
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope):
print(is_end_user_id)
is_end_user_id += 1
langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
print('写入长期记忆并且设置为0')
print(is_end_user_id)
formatted_messages = await chat_data_format(redis_messages)
print(100*'-')
print(formatted_messages)
print(100*'-')
await write_messages(end_user_id, formatted_messages, memory_config)
count_store.update_sessions_count(end_user_id, 0, '')
logger.info('写入长期记忆NEO4J')
formatted_messages = (redis_messages)
# 获取 config_id如果 memory_config 是对象,提取 config_id否则直接使用
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
else:
config_id = memory_config
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
config_id, formatted_messages)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
@@ -93,12 +157,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
memory_config: 内存配置对象
'''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
return
format_messages = await chat_data_format(long_time_data)
format_messages = (long_time_data)
messages=[]
memory_config=memory_config.config_id
for i in format_messages:
message=json.loads(i['Query'])
messages+= message
if format_messages!=[]:
await write_messages(end_user_id, format_messages, memory_config)
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
'''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
@@ -109,13 +176,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: 内存配置对象
"""
try:
# 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
# Handle case where no session exists in Redis (returns False or empty)
if not result or result is False:
history = await format_parsing(result)
if not result:
history = []
else:
history = await format_parsing(result)
@@ -154,7 +220,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
}
if not structured.is_same_event:
logger.info(result_dict)
await write_messages(end_user_id, output_value, memory_config)
await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value)
return result_dict
except Exception as e:

View File

@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
清理后的数据
"""
# 需要过滤的字段列表
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
}
if isinstance(data, dict):

View File

@@ -1,8 +1,6 @@
import json
from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'):
"""
格式化解析消息列表
@@ -26,13 +24,13 @@ async def format_parsing(messages: list,type:str='string'):
role = content['role']
content = content['content']
if type == "string":
if role == 'human':
if role == 'human' or role=="user":
content = '用户:' + content
else:
content = 'AI:' + content
result.append(content)
if type == "dict":
if role == 'human':
if type == "dict" :
if role == 'human' or role=="user":
user.append( content)
else:
ai.append(content)
@@ -57,33 +55,7 @@ async def messages_parse(messages: list | dict):
for key, values in zip(user, ai):
database.append({key, values})
return database
async def chat_data_format(messages: list | dict):
"""
将消息格式化为 LangChain 消息格式
Args:
messages: 消息列表或字典
Returns:
LangChain 消息列表
"""
langchain_messages = []
if isinstance(messages, list):
for msg in messages:
if 'role' in msg.keys():
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
if "Query" in msg.keys():
langchain_messages.append(HumanMessage(content=msg['Query']))
langchain_messages.append(AIMessage(content=msg['Answer']))
if isinstance(messages, dict):
if messages['type'] == 'human':
langchain_messages.append(HumanMessage(content=messages['content']))
elif messages['type'] == 'ai':
langchain_messages.append(AIMessage(content=messages['content']))
return langchain_messages
async def agent_chat_messages(user_content,ai_content):
messages = [

View File

@@ -1,13 +1,18 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
@@ -37,76 +42,61 @@ async def make_write_graph():
yield graph
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
"""Dispatch long-term memory storage to Celery background tasks.
Args:
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
langchain_messages: List of messages to store
memory_config: Memory configuration ID (string)
end_user_id: End user identifier
scope: Window size for 'chunk' strategy (default: 6)
"""
from app.tasks import (
long_term_storage_window_task,
# TODO: Uncomment when implemented
# long_term_storage_time_task,
# long_term_storage_aggregate_task,
)
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# Convert config to string if needed
config_id = str(memory_config) if memory_config else ''
if long_term_type == 'chunk':
# Strategy 1: Window-based batching (6 rounds of dialogue)
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
long_term_storage_window_task.delay(
end_user_id=end_user_id,
langchain_messages=langchain_messages,
config_id=config_id,
scope=scope
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, (langchain_messages))
# 获取数据库会话
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
# TODO: Uncomment when time-based strategy is fully implemented
# elif long_term_type == 'time':
# # Strategy 2: Time-based retrieval
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
# long_term_storage_time_task.delay(
# end_user_id=end_user_id,
# config_id=config_id,
# time_window=5
# )
# TODO: Uncomment when aggregate strategy is fully implemented
# elif long_term_type == 'aggregate':
# # Strategy 3: Aggregate judgment (deduplication)
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
# long_term_storage_aggregate_task.delay(
# end_user_id=end_user_id,
# langchain_messages=langchain_messages,
# config_id=config_id
# )
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
# async def main():
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五好开心啊"
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "你也这么觉得,我也是耶"
# "content": "耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":

View File

@@ -294,6 +294,7 @@ class RedisCountStore:
"""
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="count")
index_key = f'session:count:index:{end_user_id}' # 索引键
pipe = self.r.pipeline()
pipe.hset(key, mapping={
@@ -304,6 +305,10 @@ class RedisCountStore:
"starttime": get_current_timestamp()
})
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
# 创建索引end_user_id -> session_id 映射
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
@@ -320,31 +325,47 @@ class RedisCountStore:
list 或 False: 如果找到返回 [count, messages],否则返回 False
"""
try:
search_pattern = 'session:count:*'
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
for key in self.r.keys(search_pattern):
data = self.r.hgetall(key)
if not data:
continue
if data.get('end_user_id') == end_user_id:
count = data.get('count')
messages_str = data.get('messages')
if count is not None:
messages = deserialize_messages(messages_str)
return [int(count), messages]
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
self.r.delete(index_key)
return False
except Exception as type_error:
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
return False
# 直接获取数据
key = generate_session_key(session_id, key_type="count")
data = self.r.hgetall(key)
if not data:
# 索引存在但数据不存在,清理索引
self.r.delete(index_key)
return False
count = data.get('count')
messages_str = data.get('messages')
if count is not None:
messages = deserialize_messages(messages_str)
return [int(count), messages]
return False
except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}")
return False
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool:
"""
通过 end_user_id 修改访问次数统计
通过 end_user_id 修改访问次数统计(优化版:使用索引)
Args:
end_user_id: 终端用户ID
@@ -355,23 +376,39 @@ class RedisCountStore:
bool: 更新成功返回 True未找到记录返回 False
"""
try:
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
# 索引键类型错误,删除并返回 False
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
self.r.delete(index_key)
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as type_error:
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
# 直接更新数据
key = generate_session_key(session_id, key_type="count")
messages_str = serialize_messages(messages)
search_pattern = 'session:count:*'
for key in self.r.keys(search_pattern):
data = self.r.hgetall(key)
if not data:
continue
if data.get('end_user_id') == end_user_id:
self.r.hset(key, 'count', int(new_count))
self.r.hset(key, 'messages', messages_str)
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
pipe = self.r.pipeline()
pipe.hset(key, 'count', int(new_count))
pipe.hset(key, 'messages', messages_str)
result = pipe.execute()
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}")
return False

View File

@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
This module provides the main write function for executing the knowledge extraction
pipeline. Only MemoryConfig is needed - clients are constructed internally.
"""
import asyncio
import time
from datetime import datetime
@@ -124,23 +125,48 @@ async def write(
except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True)
# 添加死锁重试机制
max_retries = 3
retry_delay = 1 # 秒
for attempt in range(max_retries):
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
break
else:
logger.warning("Failed to save some data to Neo4j")
if attempt < max_retries - 1:
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
except Exception as e:
error_msg = str(e)
# 检查是否是死锁错误
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
if attempt < max_retries - 1:
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
else:
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
raise
else:
# 非死锁错误,直接抛出
raise
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
finally:
await neo4j_connector.close()
except Exception as e:
logger.error(f"Error closing Neo4j connector: {e}")
log_time("Neo4j Database Save", time.time() - step_start, log_file)

View File

@@ -413,7 +413,8 @@ class ExtractedEntityNode(Node):
description="Entity aliases - alternative names for this entity"
)
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")

View File

@@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
if len(desc_b) > len(desc_a):
canonical.description = desc_b
# 合并事实摘要:统一保留一个“实体: name”行来源行去重保序
fact_a = getattr(canonical, "fact_summary", "") or ""
fact_b = getattr(ent, "fact_summary", "") or ""
def _extract_sources(txt: str) -> List[str]:
sources: List[str] = []
if not txt:
return sources
for line in str(txt).splitlines():
ln = line.strip()
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_a = getattr(canonical, "fact_summary", "") or ""
# fact_b = getattr(ent, "fact_summary", "") or ""
# def _extract_sources(txt: str) -> List[str]:
# sources: List[str] = []
# if not txt:
# return sources
# for line in str(txt).splitlines():
# ln = line.strip()
# 支持“来源:”或“来源:”前缀
m = re.match(r"^来源[:]\s*(.+)$", ln)
if m:
content = m.group(1).strip()
if content:
sources.append(content)
# m = re.match(r"^来源[:]\s*(.+)$", ln)
# if m:
# content = m.group(1).strip()
# if content:
# sources.append(content)
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
if not sources and txt.strip():
sources.append(txt.strip())
return sources
# if not sources and txt.strip():
# sources.append(txt.strip())
# return sources
try:
src_a = _extract_sources(fact_a)
src_b = _extract_sources(fact_b)
seen = set()
merged_sources: List[str] = []
for s in src_a + src_b:
if s and s not in seen:
seen.add(s)
merged_sources.append(s)
if merged_sources:
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
elif fact_b and not fact_a:
canonical.fact_summary = fact_b
# src_a = _extract_sources(fact_a)
# src_b = _extract_sources(fact_b)
# seen = set()
# merged_sources: List[str] = []
# for s in src_a + src_b:
# if s and s not in seen:
# seen.add(s)
# merged_sources.append(s)
# if merged_sources:
# name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
# canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
# elif fact_b and not fact_a:
# canonical.fact_summary = fact_b
pass
except Exception:
# 兜底:若解析失败,保留较长文本
if len(fact_b) > len(fact_a):
canonical.fact_summary = fact_b
# if len(fact_b) > len(fact_a):
# canonical.fact_summary = fact_b
pass
except Exception:
pass

View File

@@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
desc_a = (getattr(a, "description", "") or "")
desc_b = (getattr(b, "description", "") or "")
fact_a = (getattr(a, "fact_summary", "") or "")
fact_b = (getattr(b, "fact_summary", "") or "")
score_a = len(desc_a) + len(fact_a)
score_b = len(desc_b) + len(fact_b)
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_a = (getattr(a, "fact_summary", "") or "")
# fact_b = (getattr(b, "fact_summary", "") or "")
# score_a = len(desc_a) + len(fact_a)
# score_b = len(desc_b) + len(fact_b)
score_a = len(desc_a)
score_b = len(desc_b)
if score_a != score_b:
return 0 if score_a >= score_b else 1
return 0
@@ -189,7 +192,8 @@ async def _judge_pair(
"entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None),
}
entity_b = {
@@ -197,7 +201,8 @@ async def _judge_pair(
"entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None),
}
# 5. 渲染LLM提示词用工具函数填充模板包含实体信息、上下文、输出格式
@@ -248,7 +253,8 @@ async def _judge_pair_disamb(
"entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None),
}
entity_b = {
@@ -256,7 +262,8 @@ async def _judge_pair_disamb(
"entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None),
}
prompt = render_entity_dedup_prompt(

View File

@@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
description=row.get("description") or "",
aliases=row.get("aliases") or [],
name_embedding=row.get("name_embedding") or [],
fact_summary=row.get("fact_summary") or "",
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=row.get("fact_summary") or "",
connect_strength=row.get("connect_strength") or "",
)

View File

@@ -1088,7 +1088,8 @@ class ExtractionOrchestrator:
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
example=getattr(entity, 'example', ''), # 新增:传递示例字段
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None),

View File

@@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
key=lambda eid: (
_strength_rank(eid),
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
0 # 临时占位
),
reverse=True
)

View File

@@ -9,7 +9,8 @@
- 类型: "{{ entity_a.entity_type | default('') }}"
- 描述: "{{ entity_a.description | default('') }}"
- 别名: {{ entity_a.aliases | default([]) }}
- 摘要: "{{ entity_a.fact_summary | default('') }}"
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #}
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
实体B:
@@ -17,7 +18,8 @@
- 类型: "{{ entity_b.entity_type | default('') }}"
- 描述: "{{ entity_b.description | default('') }}"
- 别名: {{ entity_b.aliases | default([]) }}
- 摘要: "{{ entity_b.fact_summary | default('') }}"
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #}
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
上下文:

View File

View File

@@ -0,0 +1,89 @@
"""Command-line interface for web crawler."""
import argparse
import logging
import sys
from app.core.rag.crawler.web_crawler import WebCrawler
def setup_logging(verbose: bool = False):
"""Set up logging configuration."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
def main(entry_url: str,
max_pages: int = 200,
delay_seconds: float = 1.0,
timeout_seconds: int = 10,
user_agent: str = "KnowledgeBaseCrawler/1.0"):
"""Main entry point for the crawler."""
# Create crawler
crawler = WebCrawler(
entry_url=entry_url,
max_pages=max_pages,
delay_seconds=delay_seconds,
timeout_seconds=timeout_seconds,
user_agent=user_agent
)
# Crawl and collect documents
documents = []
try:
for doc in crawler.crawl():
print(f"\n{'=' * 80}")
print(f"URL: {doc.url}")
print(f"Title: {doc.title}")
print(f"Content Length: {doc.content_length} characters")
print(f"Word Count: {doc.metadata.get('word_count', 0)} words")
print(f"{'=' * 80}\n")
documents.append({
'url': doc.url,
'title': doc.title,
'content': doc.content,
'content_length': doc.content_length,
'crawl_timestamp': doc.crawl_timestamp.isoformat(),
'http_status': doc.http_status,
'metadata': doc.metadata
})
except KeyboardInterrupt:
print("\n\nCrawl interrupted by user.")
except Exception as e:
print(f"\n\nError during crawl: {e}")
sys.exit(1)
# Get summary
summary = crawler.get_summary()
print(f"\n{'=' * 80}")
print("CRAWL SUMMARY")
print(f"{'=' * 80}")
print(f"Total Pages Processed: {summary.total_pages_processed}")
print(f"Total Errors: {summary.total_errors}")
print(f"Total Skipped: {summary.total_skipped}")
print(f"Total URLs Discovered: {summary.total_urls_discovered}")
print(f"Duration: {summary.duration_seconds:.2f} seconds")
print(f"documents: {documents}")
if summary.error_breakdown:
print(f"\nError Breakdown:")
for error_type, count in summary.error_breakdown.items():
print(f" {error_type}: {count}")
if __name__ == '__main__':
entry_url = "https://www.xxx.com"
max_pages = 20
delay_seconds = 1.0
timeout_seconds = 10
user_agent = "KnowledgeBaseCrawler/1.0"
main(entry_url, max_pages, delay_seconds, timeout_seconds, user_agent)

View File

@@ -0,0 +1,233 @@
"""Content extractor for web crawler."""
from bs4 import BeautifulSoup
import re
import logging
from app.core.rag.crawler.models import ExtractedContent
logger = logging.getLogger(__name__)
class ContentExtractor:
"""Extract clean, readable text from HTML pages."""
# Tags to remove completely
REMOVE_TAGS = ['script', 'style', 'nav', 'header', 'footer', 'aside']
# Tags that typically contain main content
MAIN_CONTENT_TAGS = ['article', 'main']
# Content extraction tags
CONTENT_TAGS = ['p', 'div', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'td', 'th', 'section']
def is_static_content(self, html: str) -> bool:
"""
Determine if the HTML represents static content.
Detects JavaScript-rendered content by checking for minimal body
with heavy script tag presence.
Args:
html: Raw HTML string
Returns:
bool: True if static, False if JavaScript-rendered
"""
try:
soup = BeautifulSoup(html, 'lxml')
# Count script tags
script_tags = soup.find_all('script')
script_count = len(script_tags)
# Get body content (excluding scripts and styles)
body = soup.find('body')
if not body:
return False
# Remove scripts and styles temporarily for text check
for tag in body.find_all(['script', 'style']):
tag.decompose()
# Get text content
text = body.get_text(strip=True)
text_length = len(text)
# If there's very little text but many scripts, likely JS-rendered
if script_count > 5 and text_length < 200:
logger.warning("Detected JavaScript-rendered content (many scripts, little text)")
return False
# If there's no meaningful text, likely JS-rendered
if text_length < 50:
logger.warning("Detected JavaScript-rendered content (minimal text)")
return False
return True
except Exception as e:
logger.error(f"Error checking if content is static: {e}")
return True # Assume static on error
def extract(self, html: str, url: str) -> ExtractedContent:
"""
Extract clean text content from HTML.
Args:
html: Raw HTML string
url: Source URL (for context)
Returns:
ExtractedContent: Contains title, text, metadata
"""
try:
soup = BeautifulSoup(html, 'lxml')
# Check if content is static
is_static = self.is_static_content(html)
# Extract title
title = self._extract_title(soup)
# Remove unwanted tags
for tag_name in self.REMOVE_TAGS:
for tag in soup.find_all(tag_name):
tag.decompose()
# Extract main content
text = self._extract_main_content(soup)
# Normalize whitespace
text = self._normalize_whitespace(text)
# Count words
word_count = len(text.split())
logger.info(f"Extracted {word_count} words from {url}")
return ExtractedContent(
title=title,
text=text,
is_static=is_static,
word_count=word_count,
metadata={'url': url}
)
except Exception as e:
logger.error(f"Error extracting content from {url}: {e}")
return ExtractedContent(
title=url,
text="",
is_static=False,
word_count=0,
metadata={'url': url, 'error': str(e)}
)
def _extract_title(self, soup: BeautifulSoup) -> str:
"""
Extract title from HTML.
Tries <title> tag first, then first <h1>.
Args:
soup: BeautifulSoup object
Returns:
str: Page title
"""
# Try <title> tag
title_tag = soup.find('title')
if title_tag and title_tag.string:
return title_tag.string.strip()
# Try first <h1>
h1_tag = soup.find('h1')
if h1_tag:
return h1_tag.get_text(strip=True)
# Default to empty string
return ""
def _extract_main_content(self, soup: BeautifulSoup) -> str:
"""
Extract main content from HTML.
Prioritizes semantic HTML5 elements like <article> and <main>.
Args:
soup: BeautifulSoup object
Returns:
str: Extracted text content
"""
# Try to find main content area
main_content = None
# Priority 1: <article> or <main> tags
for tag_name in self.MAIN_CONTENT_TAGS:
main_content = soup.find(tag_name)
if main_content:
logger.debug(f"Found main content in <{tag_name}> tag")
break
# Priority 2: div with role="main"
if not main_content:
main_content = soup.find('div', role='main')
if main_content:
logger.debug("Found main content in div[role='main']")
# Priority 3: Common class/id patterns
if not main_content:
for pattern in ['content', 'main', 'article', 'post']:
main_content = soup.find(['div', 'section'], class_=re.compile(pattern, re.I))
if main_content:
logger.debug(f"Found main content with class pattern '{pattern}'")
break
main_content = soup.find(['div', 'section'], id=re.compile(pattern, re.I))
if main_content:
logger.debug(f"Found main content with id pattern '{pattern}'")
break
# Fallback: use body
if not main_content:
main_content = soup.find('body')
logger.debug("Using <body> as main content (no specific content area found)")
# Extract text from content tags
if main_content:
text_parts = []
for tag in main_content.find_all(self.CONTENT_TAGS):
text = tag.get_text(strip=True)
if text:
text_parts.append(text)
return '\n'.join(text_parts)
return ""
def _normalize_whitespace(self, text: str) -> str:
"""
Normalize whitespace in text.
- Collapse multiple spaces to single space
- Reduce excessive newlines to maximum 2
- Strip leading/trailing whitespace
Args:
text: Text to normalize
Returns:
str: Normalized text
"""
# Collapse multiple spaces to single space
text = re.sub(r' +', ' ', text)
# Reduce excessive newlines to maximum 2
text = re.sub(r'\n{3,}', '\n\n', text)
# Strip leading/trailing whitespace
text = text.strip()
return text

View File

@@ -0,0 +1,302 @@
"""HTTP fetcher for web crawler."""
import requests
import time
import logging
import re
from typing import Optional, Dict
from app.core.rag.crawler.models import FetchResult
logger = logging.getLogger(__name__)
class HTTPFetcher:
"""Handle HTTP requests with retries, error handling, and response validation."""
def __init__(
self,
timeout: int = 10,
max_retries: int = 3,
user_agent: str = "KnowledgeBaseCrawler/1.0"
):
"""
Initialize HTTP fetcher.
Args:
timeout: Request timeout in seconds
max_retries: Maximum number of retry attempts
user_agent: User-Agent header value
"""
self.timeout = timeout
self.max_retries = max_retries
self.user_agent = user_agent
# Create session for connection pooling
self.session = requests.Session()
self.session.headers.update({
'User-Agent': user_agent
})
def fetch(self, url: str) -> FetchResult:
"""
Fetch a URL with retry logic and error handling.
Args:
url: URL to fetch
Returns:
FetchResult: Contains status_code, content, headers, error info
"""
last_error = None
for attempt in range(self.max_retries):
try:
# Calculate backoff delay for retries
if attempt > 0:
backoff_delay = 2 ** (attempt - 1) # 1s, 2s, 4s
logger.info(f"Retry attempt {attempt + 1}/{self.max_retries} for {url} after {backoff_delay}s")
time.sleep(backoff_delay)
# Make HTTP request
response = self.session.get(
url,
timeout=self.timeout,
allow_redirects=True
)
# Handle different status codes
if response.status_code == 429:
# Too Many Requests - backoff and retry
logger.warning(f"429 Too Many Requests for {url}, backing off")
if attempt < self.max_retries - 1:
continue
if response.status_code == 503:
# Service Unavailable - pause and retry
logger.warning(f"503 Service Unavailable for {url}")
if attempt < self.max_retries - 1:
time.sleep(5) # Longer pause for 503
continue
# Success or client error (don't retry 4xx except 429)
if 200 <= response.status_code < 300:
logger.info(f"Successfully fetched {url} (status: {response.status_code})")
# Get correctly encoded content
content = self._get_decoded_content(response)
return FetchResult(
url=url,
final_url=response.url,
status_code=response.status_code,
content=content,
headers=dict(response.headers),
error=None,
success=True
)
elif response.status_code == 404:
logger.info(f"404 Not Found: {url}")
return FetchResult(
url=url,
final_url=response.url,
status_code=response.status_code,
content=None,
headers=dict(response.headers),
error="Not Found",
success=False
)
elif 400 <= response.status_code < 500:
logger.warning(f"Client error {response.status_code} for {url}")
return FetchResult(
url=url,
final_url=response.url,
status_code=response.status_code,
content=None,
headers=dict(response.headers),
error=f"Client error: {response.status_code}",
success=False
)
elif 500 <= response.status_code < 600:
logger.error(f"Server error {response.status_code} for {url}")
last_error = f"Server error: {response.status_code}"
if attempt < self.max_retries - 1:
continue
return FetchResult(
url=url,
final_url=url,
status_code=response.status_code,
content=None,
headers={},
error=last_error,
success=False
)
except requests.exceptions.Timeout:
last_error = "Request timeout"
logger.warning(f"Timeout fetching {url} (attempt {attempt + 1}/{self.max_retries})")
if attempt >= self.max_retries - 1:
break
continue
except requests.exceptions.SSLError as e:
last_error = f"SSL/TLS error: {str(e)}"
logger.error(f"SSL/TLS error for {url}: {e}")
return FetchResult(
url=url,
final_url=url,
status_code=0,
content=None,
headers={},
error=last_error,
success=False
)
except requests.exceptions.ConnectionError as e:
last_error = f"Connection error: {str(e)}"
logger.warning(f"Connection error for {url} (attempt {attempt + 1}/{self.max_retries}): {e}")
if attempt >= self.max_retries - 1:
break
continue
except requests.exceptions.RequestException as e:
last_error = f"Request error: {str(e)}"
logger.error(f"Request error for {url}: {e}")
if attempt >= self.max_retries - 1:
break
continue
# All retries exhausted
logger.error(f"Failed to fetch {url} after {self.max_retries} attempts: {last_error}")
return FetchResult(
url=url,
final_url=url,
status_code=0,
content=None,
headers={},
error=last_error or "Unknown error",
success=False
)
def _get_decoded_content(self, response) -> str:
"""
Get correctly decoded content from response.
Handles encoding detection and fallback strategies:
1. Try encoding from HTML meta tags
2. Try response.encoding (from Content-Type header or detected)
3. Try UTF-8
4. Try common encodings (GB2312, GBK for Chinese, etc.)
5. Fall back to latin-1 with error replacement
Args:
response: requests.Response object
Returns:
str: Decoded content
"""
# Try to detect encoding from HTML meta tags
meta_encoding = self._detect_encoding_from_meta(response.content)
if meta_encoding:
try:
content = response.content.decode(meta_encoding)
logger.info(f"Successfully decoded with meta tag encoding: {meta_encoding}")
return content
except (UnicodeDecodeError, LookupError) as e:
logger.warning(f"Failed to decode with meta encoding {meta_encoding}: {e}")
# Try response.encoding (from Content-Type header or detected by requests)
if response.encoding and response.encoding.lower() != 'iso-8859-1':
# Note: requests defaults to ISO-8859-1 if no charset in Content-Type,
# so we skip it here and try UTF-8 first
try:
return response.text
except (UnicodeDecodeError, LookupError) as e:
logger.warning(f"Failed to decode with detected encoding {response.encoding}: {e}")
# Try UTF-8 first (most common)
try:
return response.content.decode('utf-8')
except UnicodeDecodeError:
logger.debug("UTF-8 decoding failed, trying other encodings")
# Try common encodings for different languages
encodings_to_try = [
'gbk', # Chinese (Simplified)
'gb2312', # Chinese (Simplified, older)
'gb18030', # Chinese (Simplified, extended)
'big5', # Chinese (Traditional)
'shift_jis', # Japanese
'euc-jp', # Japanese
'euc-kr', # Korean
'iso-8859-1', # Western European
'windows-1252', # Windows Western European
'windows-1251', # Cyrillic
]
for encoding in encodings_to_try:
try:
content = response.content.decode(encoding)
logger.info(f"Successfully decoded with {encoding}")
return content
except (UnicodeDecodeError, LookupError):
continue
# Last resort: use latin-1 with error replacement
logger.warning("All encoding attempts failed, using latin-1 with error replacement")
return response.content.decode('latin-1', errors='replace')
def _detect_encoding_from_meta(self, content: bytes) -> Optional[str]:
"""
Detect encoding from HTML meta tags.
Looks for:
- <meta charset="...">
- <meta http-equiv="Content-Type" content="...; charset=...">
Args:
content: Raw response content (bytes)
Returns:
Optional[str]: Detected encoding or None
"""
try:
# Only check first 2KB for performance
head = content[:2048]
# Try to decode as ASCII/Latin-1 to search for meta tags
try:
head_str = head.decode('ascii', errors='ignore')
except:
head_str = head.decode('latin-1', errors='ignore')
# Look for <meta charset="...">
charset_match = re.search(
r'<meta[^>]+charset=["\']?([a-zA-Z0-9_-]+)',
head_str,
re.IGNORECASE
)
if charset_match:
encoding = charset_match.group(1).lower()
logger.debug(f"Found charset in meta tag: {encoding}")
return encoding
# Look for <meta http-equiv="Content-Type" content="...; charset=...">
content_type_match = re.search(
r'<meta[^>]+http-equiv=["\']?content-type["\']?[^>]+content=["\']([^"\']+)',
head_str,
re.IGNORECASE
)
if content_type_match:
content_value = content_type_match.group(1)
charset_match = re.search(r'charset=([a-zA-Z0-9_-]+)', content_value, re.IGNORECASE)
if charset_match:
encoding = charset_match.group(1).lower()
logger.debug(f"Found charset in Content-Type meta: {encoding}")
return encoding
except Exception as e:
logger.debug(f"Error detecting encoding from meta tags: {e}")
return None

View File

@@ -0,0 +1,52 @@
"""Data models for web crawler."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, Any, Optional
@dataclass
class CrawledDocument:
"""Represents a successfully processed web page with extracted content."""
url: str
title: str
content: str
content_length: int
crawl_timestamp: datetime
http_status: int
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class FetchResult:
"""Represents the result of an HTTP fetch operation."""
url: str
final_url: str
status_code: int
content: Optional[str]
headers: Dict[str, str]
error: Optional[str]
success: bool
@dataclass
class ExtractedContent:
"""Represents content extracted from HTML."""
title: str
text: str
is_static: bool
word_count: int
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class CrawlSummary:
"""Represents statistics from a completed crawl."""
total_pages_processed: int
total_errors: int
total_skipped: int
total_urls_discovered: int
start_time: datetime
end_time: datetime
duration_seconds: float
error_breakdown: Dict[str, int] = field(default_factory=dict)

View File

@@ -0,0 +1,57 @@
"""Rate limiter for web crawler."""
import time
import logging
logger = logging.getLogger(__name__)
class RateLimiter:
"""Enforce delays between requests to be polite to servers."""
def __init__(self, delay_seconds: float = 1.0):
"""
Initialize rate limiter.
Args:
delay_seconds: Minimum delay between requests
"""
self.delay_seconds = delay_seconds
self.last_request_time = 0.0
self.max_delay = 60.0 # Cap maximum delay at 60 seconds
def wait(self):
"""
Block until enough time has passed since last request.
Respects the configured delay.
"""
current_time = time.time()
elapsed = current_time - self.last_request_time
if elapsed < self.delay_seconds:
sleep_time = self.delay_seconds - elapsed
logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds")
time.sleep(sleep_time)
self.last_request_time = time.time()
def set_delay(self, delay_seconds: float):
"""
Update the delay (useful for respecting Crawl-delay from robots.txt).
Args:
delay_seconds: New delay in seconds
"""
self.delay_seconds = min(delay_seconds, self.max_delay)
logger.info(f"Rate limiter delay updated to {self.delay_seconds} seconds")
def backoff(self, multiplier: float = 2.0):
"""
Increase delay exponentially for backoff scenarios (429, 503 responses).
Args:
multiplier: Factor to multiply current delay by
"""
old_delay = self.delay_seconds
self.delay_seconds = min(self.delay_seconds * multiplier, self.max_delay)
logger.warning(f"Rate limiter backing off: {old_delay:.2f}s -> {self.delay_seconds:.2f}s")

View File

@@ -0,0 +1,118 @@
"""Robots.txt parser for web crawler."""
from urllib.robotparser import RobotFileParser
from urllib.parse import urlparse, urljoin
from typing import Optional
import logging
logger = logging.getLogger(__name__)
class RobotsParser:
"""Parse and check robots.txt compliance for URLs."""
def __init__(self, user_agent: str, timeout: int = 10):
"""
Initialize robots.txt parser.
Args:
user_agent: User agent string to check permissions for
timeout: Timeout for fetching robots.txt
"""
self.user_agent = user_agent
self.timeout = timeout
self._parsers = {} # Cache parsers by domain
def _get_robots_url(self, url: str) -> str:
"""
Get the robots.txt URL for a given URL.
Args:
url: URL to get robots.txt for
Returns:
str: robots.txt URL
"""
parsed = urlparse(url)
robots_url = f"{parsed.scheme}://{parsed.netloc}/robots.txt"
return robots_url
def _get_parser(self, url: str) -> RobotFileParser:
"""
Get or create a RobotFileParser for the domain.
Args:
url: URL to get parser for
Returns:
RobotFileParser: Parser for the domain
"""
robots_url = self._get_robots_url(url)
# Return cached parser if available
if robots_url in self._parsers:
return self._parsers[robots_url]
# Create new parser
parser = RobotFileParser()
parser.set_url(robots_url)
try:
# Fetch and parse robots.txt
parser.read()
logger.info(f"Successfully fetched robots.txt from {robots_url}")
except Exception as e:
# If robots.txt cannot be fetched, assume all URLs are allowed
logger.warning(f"Could not fetch robots.txt from {robots_url}: {e}. Assuming all URLs allowed.")
# Create a permissive parser
parser = RobotFileParser()
parser.parse([]) # Empty robots.txt allows everything
# Cache the parser
self._parsers[robots_url] = parser
return parser
def can_fetch(self, url: str) -> bool:
"""
Check if the given URL can be fetched according to robots.txt.
Args:
url: URL to check
Returns:
bool: True if allowed, False if disallowed
"""
try:
parser = self._get_parser(url)
allowed = parser.can_fetch(self.user_agent, url)
if not allowed:
logger.info(f"URL disallowed by robots.txt: {url}")
return allowed
except Exception as e:
logger.error(f"Error checking robots.txt for {url}: {e}")
# On error, assume allowed
return True
def get_crawl_delay(self, url: str) -> Optional[float]:
"""
Get the Crawl-delay directive from robots.txt if present.
Args:
url: URL to get crawl delay for
Returns:
Optional[float]: Delay in seconds, or None if not specified
"""
try:
parser = self._get_parser(url)
delay = parser.crawl_delay(self.user_agent)
if delay is not None:
logger.info(f"Crawl-delay from robots.txt: {delay} seconds")
return delay
except Exception as e:
logger.error(f"Error getting crawl delay for {url}: {e}")
return None

View File

@@ -0,0 +1,171 @@
"""URL normalization and validation for web crawler."""
from typing import Optional, List
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode, urljoin
from bs4 import BeautifulSoup
class URLNormalizer:
"""Normalize and validate URLs for deduplication and domain checking."""
# Common tracking parameters to remove
TRACKING_PARAMS = {
'utm_source', 'utm_medium', 'utm_campaign', 'utm_term', 'utm_content',
'fbclid', 'gclid', 'msclkid', '_ga', 'mc_cid', 'mc_eid'
}
def __init__(self, base_domain: str):
"""
Initialize URL normalizer with base domain.
Args:
base_domain: The domain to use for same-domain checks
"""
parsed = urlparse(base_domain)
self.base_domain = parsed.netloc.lower() # example.com:8000
self.base_scheme = parsed.scheme or 'https' # https
def normalize(self, url: str) -> Optional[str]:
"""
Normalize a URL for deduplication.
Normalization rules:
1. Convert domain to lowercase
2. Remove fragments (#section)
3. Remove default ports (80 for http, 443 for https)
4. Remove trailing slashes (except for root)
5. Sort query parameters alphabetically
6. Remove common tracking parameters
Args:
url: URL to normalize
Returns:
Optional[str]: Normalized URL, or None if invalid
"""
try:
parsed = urlparse(url)
# Validate scheme
if parsed.scheme not in ('http', 'https'):
return None
# Normalize domain to lowercase
netloc = parsed.netloc.lower()
# Remove default ports
if ':' in netloc:
host, port = netloc.rsplit(':', 1)
if (parsed.scheme == 'http' and port == '80') or \
(parsed.scheme == 'https' and port == '443'):
netloc = host
# Normalize path
path = parsed.path
# Remove trailing slash except for root
if path != '/' and path.endswith('/'):
path = path.rstrip('/')
# Ensure path starts with /
if not path:
path = '/'
# Process query parameters
query = ''
if parsed.query:
# Parse query parameters
params = parse_qs(parsed.query, keep_blank_values=True)
# Remove tracking parameters
filtered_params = {
k: v for k, v in params.items()
if k not in self.TRACKING_PARAMS
}
# Sort parameters alphabetically
if filtered_params:
sorted_params = sorted(filtered_params.items())
query = urlencode(sorted_params, doseq=True)
# Reconstruct URL without fragment
normalized = urlunparse((
parsed.scheme,
netloc,
path,
parsed.params,
query,
'' # Remove fragment
))
return normalized
except Exception:
return None
def is_same_domain(self, url: str) -> bool:
"""
Check if URL belongs to the same domain as base_domain.
Args:
url: URL to check
Returns:
bool: True if same domain, False otherwise
"""
try:
parsed = urlparse(url)
domain = parsed.netloc.lower()
# Remove port if present
if ':' in domain:
domain = domain.split(':')[0]
# Check if domains match
return domain == self.base_domain or domain == self.base_domain.split(':')[0]
except Exception:
return False
def extract_links(self, html: str, base_url: str) -> List[str]:
"""
Extract and normalize all links from HTML.
Args:
html: HTML content
base_url: Base URL for resolving relative links
Returns:
List[str]: List of normalized absolute URLs
"""
links = []
try:
soup = BeautifulSoup(html, 'lxml')
# Find all anchor tags
for anchor in soup.find_all('a', href=True):
href = anchor['href']
# Skip empty hrefs
if not href or href.strip() == '':
continue
# Skip javascript: and mailto: links
if href.startswith(('javascript:', 'mailto:', 'tel:')):
continue
normalized_url = None
# Check if href starts with http/https (absolute URL)
if href.startswith(('http://', 'https://')):
if self.is_same_domain(href):
normalized_url = self.normalize(href)
else:
# Convert relative URL to absolute
absolute_url = urljoin(base_url, href)
# Normalize the URL
normalized_url = self.normalize(absolute_url)
if normalized_url:
links.append(normalized_url)
except Exception:
pass
return links

View File

@@ -0,0 +1,215 @@
"""Main web crawler orchestrator."""
from collections import deque
from datetime import datetime
from typing import Iterator, Optional, List, Set
from urllib.parse import urlparse
import logging
from app.core.rag.crawler.url_normalizer import URLNormalizer
from app.core.rag.crawler.robots_parser import RobotsParser
from app.core.rag.crawler.rate_limiter import RateLimiter
from app.core.rag.crawler.http_fetcher import HTTPFetcher
from app.core.rag.crawler.content_extractor import ContentExtractor
from app.core.rag.crawler.models import CrawledDocument, CrawlSummary
logger = logging.getLogger(__name__)
class WebCrawler:
"""Main orchestrator for web crawling."""
def __init__(
self,
entry_url: str,
max_pages: int = 200,
delay_seconds: float = 1.0,
timeout_seconds: int = 10,
user_agent: str = "KnowledgeBaseCrawler/1.0",
include_patterns: Optional[List[str]] = None,
exclude_patterns: Optional[List[str]] = None,
content_extractor: Optional[ContentExtractor] = None
):
"""
Initialize the web crawler.
Args:
entry_url: Starting URL for the crawl
max_pages: Maximum number of pages to crawl (default: 200)
delay_seconds: Delay between requests in seconds (default: 1.0)
timeout_seconds: HTTP request timeout (default: 10)
user_agent: User-Agent header string
include_patterns: List of regex patterns for URLs to include
exclude_patterns: List of regex patterns for URLs to exclude
content_extractor: Custom content extractor (optional)
"""
# Validate entry URL
parsed = urlparse(entry_url)
if not parsed.scheme or not parsed.netloc:
raise ValueError(f"Invalid entry URL: {entry_url}")
self.entry_url = entry_url
self.max_pages = max_pages
self.user_agent = user_agent
# Extract domain from entry URL
self.domain = parsed.netloc
# Initialize components
self.url_normalizer = URLNormalizer(entry_url)
self.robots_parser = RobotsParser(user_agent, timeout_seconds)
self.rate_limiter = RateLimiter(delay_seconds)
self.http_fetcher = HTTPFetcher(timeout_seconds, max_retries=3, user_agent=user_agent)
self.content_extractor = content_extractor or ContentExtractor()
# State management
self.url_queue: deque = deque()
self.visited_urls: Set[str] = set()
self.pages_processed = 0
# Statistics
self.stats = {
'success': 0,
'errors': 0,
'skipped': 0,
'urls_discovered': 0,
'error_breakdown': {}
}
self.start_time: Optional[datetime] = None
self.end_time: Optional[datetime] = None
def crawl(self) -> Iterator[CrawledDocument]:
"""
Execute the crawl and yield documents as they are processed.
Yields:
CrawledDocument: Structured document with extracted content
"""
logger.info(f"Starting crawl from {self.entry_url} (max_pages: {self.max_pages})")
self.start_time = datetime.now()
# Add entry URL to queue
normalized_entry = self.url_normalizer.normalize(self.entry_url)
if normalized_entry:
self.url_queue.append(normalized_entry)
self.stats['urls_discovered'] += 1
# Check robots.txt and update rate limiter if needed
crawl_delay = self.robots_parser.get_crawl_delay(self.entry_url)
if crawl_delay:
self.rate_limiter.set_delay(crawl_delay)
# Main crawl loop
while self.url_queue and self.pages_processed < self.max_pages:
url = self.url_queue.popleft()
# Skip if already visited
if url in self.visited_urls:
continue
# Mark as visited
self.visited_urls.add(url)
# Check robots.txt permission
if not self.robots_parser.can_fetch(url):
logger.info(f"Skipping {url} (disallowed by robots.txt)")
self.stats['skipped'] += 1
continue
# Apply rate limiting
self.rate_limiter.wait()
# Fetch URL
logger.info(f"Fetching {url} ({self.pages_processed + 1}/{self.max_pages})")
fetch_result = self.http_fetcher.fetch(url)
# Handle fetch errors
if not fetch_result.success:
self._record_error(fetch_result.error or "Unknown error")
continue
# Check Content-Type
content_type = fetch_result.headers.get('Content-Type', '').lower()
if not any(substring in content_type for substring in ['text/html', 'application/xhtml+xml']):
logger.warning(f"Skipping {url} (Content-Type: {content_type})")
self.stats['skipped'] += 1
continue
# Extract content
try:
extracted = self.content_extractor.extract(fetch_result.content, url)
# Check if static content
if not extracted.is_static:
logger.warning(f"Skipping {url} (JavaScript-rendered content)")
self.stats['skipped'] += 1
continue
# Create document
document = CrawledDocument(
url=url,
title=extracted.title,
content=extracted.text,
content_length=len(extracted.text),
crawl_timestamp=datetime.now(),
http_status=fetch_result.status_code,
metadata={
'word_count': extracted.word_count,
'final_url': fetch_result.final_url
}
)
# Update statistics
self.pages_processed += 1
self.stats['success'] += 1
# Extract and queue links
links = self.url_normalizer.extract_links(fetch_result.content, url)
for link in links:
if link not in self.visited_urls and self.url_normalizer.is_same_domain(link):
if link not in self.url_queue:
self.url_queue.append(link)
self.stats['urls_discovered'] += 1
# Yield document
yield document
except Exception as e:
logger.error(f"Error processing {url}: {e}")
self._record_error(f"Processing error: {str(e)}")
continue
self.end_time = datetime.now()
logger.info(f"Crawl completed. Processed {self.pages_processed} pages.")
def get_summary(self) -> CrawlSummary:
"""
Get summary statistics after crawl completion.
Returns:
CrawlSummary: Statistics including success/error/skip counts
"""
if not self.start_time:
self.start_time = datetime.now()
if not self.end_time:
self.end_time = datetime.now()
duration = (self.end_time - self.start_time).total_seconds()
return CrawlSummary(
total_pages_processed=self.stats['success'],
total_errors=self.stats['errors'],
total_skipped=self.stats['skipped'],
total_urls_discovered=self.stats['urls_discovered'],
start_time=self.start_time,
end_time=self.end_time,
duration_seconds=duration,
error_breakdown=self.stats['error_breakdown']
)
def _record_error(self, error: str):
"""Record an error in statistics."""
self.stats['errors'] += 1
error_type = error.split(':')[0] if ':' in error else error
self.stats['error_breakdown'][error_type] = \
self.stats['error_breakdown'].get(error_type, 0) + 1

View File

@@ -0,0 +1 @@
"""Integrations package for external services."""

View File

@@ -0,0 +1 @@
"""Feishu integration module for document synchronization."""

View File

@@ -0,0 +1,84 @@
"""Command-line interface for feishu integration."""
import asyncio
import sys
from app.core.rag.integrations.feishu.client import FeishuAPIClient
from app.core.rag.integrations.feishu.models import FileInfo
def main(feishu_app_id: str, # Feishu application ID
feishu_app_secret: str, # Feishu application secret
feishu_folder_token: str, # Feishu Folder Token
save_dir: str, # save file directory
feishu_api_base_url: str = "https://open.feishu.cn/open-apis", # Feishu API base URL
timeout: int = 30, # Request timeout in seconds
max_retries: int = 3, # Maximum number of retries
recursive: bool = True # recursive: Whether to sync subfolders recursively,
):
"""Main entry point for the feishuAPIClient."""
# Create feishuAPIClient
api_client = FeishuAPIClient(
app_id=feishu_app_id,
app_secret=feishu_app_secret,
api_base_url=feishu_api_base_url,
timeout=timeout,
max_retries=max_retries
)
# Get all files from folder
async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str):
async with api_client as client:
if recursive:
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
else:
all_files = []
page_token = None
while True:
files_page, page_token = await client.list_folder_files(
feishu_folder_token, page_token
)
all_files.extend(files_page)
if not page_token:
break
files = all_files
return files
files = asyncio.run(async_get_files(api_client,feishu_folder_token))
# Filter out folders, only sync documents
# documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file", "slides"]]
documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]]
try:
for doc in documents:
print(f"\n{'=' * 80}")
print(f"token: {doc.token}")
print(f"name: {doc.name}")
print(f"type: {doc.type}")
print(f"created_time: {doc.created_time}")
print(f"modified_time: {doc.modified_time}")
print(f"owner_id: {doc.owner_id}")
print(f"url: {doc.url}")
print(f"{'=' * 80}\n")
# download document from Feishu FileInfo
async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str):
async with api_client as client:
file_path = await client.download_document(document=doc, save_dir=save_dir)
return file_path
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
print(file_path)
except KeyboardInterrupt:
print("\n\nfeishu integration interrupted by user.")
except Exception as e:
print(f"\n\nError during feishu integration: {e}")
sys.exit(1)
if __name__ == '__main__':
feishu_app_id = ""
feishu_app_secret = ""
feishu_folder_token = ""
save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/"
main(feishu_app_id, feishu_app_secret, feishu_folder_token, save_dir)

View File

@@ -0,0 +1,452 @@
"""Feishu API client for document operations."""
import asyncio
import os
import re
from typing import Optional, Tuple, List
from datetime import datetime, timedelta
import httpx
from cachetools import TTLCache
import urllib.parse
from app.core.rag.integrations.feishu.exceptions import (
FeishuAuthError,
FeishuAPIError,
FeishuNotFoundError,
FeishuPermissionError,
FeishuRateLimitError,
FeishuNetworkError,
)
from app.core.rag.integrations.feishu.models import FileInfo
from app.core.rag.integrations.feishu.retry import with_retry
class FeishuAPIClient:
"""Feishu API client for document synchronization."""
def __init__(
self,
app_id: str,
app_secret: str,
api_base_url: str = "https://open.feishu.cn/open-apis",
timeout: int = 30,
max_retries: int = 3
):
"""
Initialize Feishu API client.
Args:
app_id: Feishu application ID
app_secret: Feishu application secret
api_base_url: Feishu API base URL
timeout: Request timeout in seconds
max_retries: Maximum number of retries
"""
self.app_id = app_id
self.app_secret = app_secret
self.api_base_url = api_base_url
self.timeout = timeout
self.max_retries = max_retries
self._http_client: Optional[httpx.AsyncClient] = None
self._token_cache: TTLCache = TTLCache(maxsize=1, ttl=7200 - 300) # 2 hours - 5 minutes
self._token_lock = asyncio.Lock()
async def __aenter__(self):
"""Async context manager entry."""
self._http_client = httpx.AsyncClient(
base_url=self.api_base_url,
timeout=self.timeout,
headers={"Content-Type": "application/json"}
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
if self._http_client:
await self._http_client.aclose()
async def get_tenant_access_token(self) -> str:
"""
Get tenant access token with caching.
Returns:
Access token string
Raises:
FeishuAuthError: If authentication fails
"""
# Check cache first
cached_token = self._token_cache.get("access_token")
if cached_token:
return cached_token
# Use lock to prevent concurrent token requests
async with self._token_lock:
# Double-check cache after acquiring lock
cached_token = self._token_cache.get("access_token")
if cached_token:
return cached_token
# Request new token
try:
if not self._http_client:
raise FeishuAuthError("HTTP client not initialized")
response = await self._http_client.post(
"/auth/v3/tenant_access_token/internal",
json={
"app_id": self.app_id,
"app_secret": self.app_secret
}
)
data = response.json()
if data.get("code") != 0:
error_msg = data.get("msg", "Unknown error")
raise FeishuAuthError(
f"Authentication failed: {error_msg}",
error_code=str(data.get("code")),
details=data
)
token = data.get("tenant_access_token")
if not token:
raise FeishuAuthError("No access token in response")
# Cache the token
self._token_cache["access_token"] = token
return token
except httpx.HTTPError as e:
raise FeishuAuthError(f"HTTP error during authentication: {str(e)}")
except Exception as e:
if isinstance(e, FeishuAuthError):
raise
raise FeishuAuthError(f"Unexpected error during authentication: {str(e)}")
@with_retry
async def list_folder_files(
self,
folder_token: str,
page_token: Optional[str] = None
) -> Tuple[List[FileInfo], Optional[str]]:
"""
Get list of files in a folder with pagination support.
Args:
folder_token: Folder token
page_token: Page token for pagination
Returns:
Tuple of (list of FileInfo, next page token)
Raises:
FeishuAPIError: If API call fails
FeishuNotFoundError: If folder not found
FeishuPermissionError: If permission denied
"""
try:
token = await self.get_tenant_access_token()
if not self._http_client:
raise FeishuAPIError("HTTP client not initialized")
# Build request parameters
params = {"page_size": 200, "folder_token": folder_token}
if page_token:
params["page_token"] = page_token
# Make API request
response = await self._http_client.get(
f"/drive/v1/files",
params=params,
headers={"Authorization": f"Bearer {token}"}
)
data = response.json()
# print(f"get files: {data}")
# Handle errors
if data.get("code") != 0:
error_code = data.get("code")
error_msg = data.get("msg", "Unknown error")
if error_code == 404 or error_code == 230005:
raise FeishuNotFoundError(
f"Folder not found: {error_msg}",
error_code=str(error_code),
details=data
)
elif error_code == 403 or error_code == 230003:
raise FeishuPermissionError(
f"Permission denied: {error_msg}",
error_code=str(error_code),
details=data
)
else:
raise FeishuAPIError(
f"API error: {error_msg}",
error_code=str(error_code),
details=data
)
# Parse response
files_data = data.get("data", {}).get("files", [])
next_page_token = data.get("data", {}).get("next_page_token", None)
# Convert to FileInfo objects
files = []
for file_data in files_data:
try:
file_info = FileInfo(
token=file_data.get("token", ""),
name=file_data.get("name", ""),
type=file_data.get("type", ""),
created_time=datetime.fromtimestamp(int(file_data.get("created_time", 0))),
modified_time=datetime.fromtimestamp(int(file_data.get("modified_time", 0))),
owner_id=file_data.get("owner_id", ""),
url=file_data.get("url", "")
)
files.append(file_info)
except (ValueError, TypeError) as e:
# Skip invalid file entries
continue
return files, next_page_token
except httpx.HTTPError as e:
raise FeishuAPIError(f"HTTP error: {str(e)}")
except Exception as e:
if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)):
raise
raise FeishuAPIError(f"Unexpected error: {str(e)}")
async def list_all_folder_files(
self,
folder_token: str,
recursive: bool = True
) -> List[FileInfo]:
"""
Get all files in a folder, handling pagination automatically.
Args:
folder_token: Folder token
recursive: Whether to recursively get files from subfolders
Returns:
List of all FileInfo objects
Raises:
FeishuAPIError: If API call fails
"""
all_files = []
page_token = None
# Get all files with pagination
while True:
files, page_token = await self.list_folder_files(folder_token, page_token)
all_files.extend(files)
if not page_token:
break
# Recursively get files from subfolders if requested
if recursive:
subfolders = [f for f in all_files if f.type == "folder"]
for subfolder in subfolders:
try:
subfolder_files = await self.list_all_folder_files(
subfolder.token,
recursive=True
)
all_files.extend(subfolder_files)
except Exception:
# Continue with other folders if one fails
continue
return all_files
@with_retry
async def download_document(
self,
document: FileInfo,
save_dir: str
) -> str:
"""
download document content.
Args:
document: Document FileInfo
save_dir: save dir
Returns:
file_full_path
Raises:
FeishuAPIError: If API call fails
FeishuNotFoundError: If document not found
FeishuPermissionError: If permission denied
"""
try:
token = await self.get_tenant_access_token()
if not self._http_client:
raise FeishuAPIError("HTTP client not initialized")
# Different API endpoints for different document types
if document.type == "doc" or document.type == "docx" or document.type == "sheet" or document.type == "bitable":
return await self._export_file(document, token, save_dir)
elif document.type == "file" or document.type == "slides":
return await self._download_file(document, token, save_dir)
else:
raise FeishuAPIError(f"Unsupported document type: {document.type}")
except Exception as e:
if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)):
raise
raise FeishuAPIError(f"Unexpected error: {str(e)}")
async def _export_file(self, document: FileInfo, access_token: str, save_dir: str) -> str:
"""export file for feishu online file type."""
try:
# 1.创建导出任务
file_extension = "pdf"
match document.type:
case "doc":
file_extension = "doc"
case "docx":
file_extension = "docx"
case "sheet":
file_extension = "xlsx"
case "bitable":
file_extension = "xlsx"
case _:
file_extension = "pdf"
response = await self._http_client.post(
"/drive/v1/export_tasks",
json={
"file_extension": file_extension,
"token": document.token,
"type": document.type
},
headers={"Authorization": f"Bearer {access_token}"}
)
data = response.json()
print(f"1.创建导出任务: {data}")
if data.get("code") != 0:
error_code = data.get("code")
error_msg = data.get("msg", "Unknown error")
raise FeishuAPIError(
f"API error: {error_msg}",
error_code=str(error_code),
details=data
)
ticket = data.get("data", {}).get("ticket", None)
if not ticket:
raise FeishuAuthError("No ticket in response")
# 2.轮序查询导出任务结果
max_retries = 10 # 最大轮询次数
poll_interval = 2 # 每次轮询间隔时间(秒)
file_token = None
for attempt in range(max_retries):
# 查询导出任务
response = await self._http_client.get(
f"/drive/v1/export_tasks/{ticket}",
params={"token": document.token},
headers={"Authorization": f"Bearer {access_token}"}
)
data = response.json()
print(f"2. 尝试查询导出任务结果 (第{attempt + 1}次): {data}")
if data.get("code") != 0:
error_code = data.get("code")
error_msg = data.get("msg", "Unknown error")
raise FeishuAPIError(
f"API error: {error_msg}",
error_code=str(error_code),
details=data,
)
# 检查导出任务结果
file_token = data.get("data", {}).get("result", {}).get("file_token", None)
if file_token:
# 如果导出任务成功生成 file_token则退出轮询
break
# 如果结果还没准备好,等待一段时间再进行下一次轮询
await asyncio.sleep(poll_interval)
if not file_token:
raise FeishuAPIError("Export task did not complete within the allowed time")
# 3.下载导出任务
response = await self._http_client.get(
f"/drive/v1/export_tasks/file/{file_token}/download",
headers={"Authorization": f"Bearer {access_token}"}
)
response.raise_for_status()
print(f'3.下载导出任务: {response.headers.get("Content-Disposition")}')
file_full_path = os.path.join(save_dir, document.name + "." + file_extension)
if os.path.exists(file_full_path):
os.remove(file_full_path) # Delete a single file
with open(file_full_path, "wb") as file:
file.write(response.content)
return file_full_path
except httpx.HTTPError as e:
raise FeishuAPIError(f"HTTP error: {str(e)}")
except Exception as e:
raise FeishuAPIError(f"Unexpected error during file download: {str(e)}")
async def _download_file(self, document: FileInfo, access_token: str, save_dir: str) -> str:
"""download file for file type."""
try:
response = await self._http_client.get(
f"/drive/v1/files/{document.token}/download",
headers={"Authorization": f"Bearer {access_token}"}
)
response.raise_for_status()
filename_header = response.headers.get("Content-Disposition")
# 最终的文件名(初始化为 None
filename = None
if filename_header:
# 优先解析 filename* 格式
match = re.search(r"filename\*=([^']*)''([^;]+)", filename_header)
if match:
# 使用 `filename*` 提取(已编码)
encoding = match.group(1) # 编码部分(如 UTF-8
encoded_filename = match.group(2) # 文件名部分
filename = urllib.parse.unquote(encoded_filename) # 解码 URL 编码的文件名
# 如果 `filename*` 不存在,回退到解析 `filename`
if not filename:
match = re.search(r'filename="([^"]+)"', filename_header)
if match:
filename = match.group(1)
# 如果文件名仍为 None则使用默认文件名
if not filename:
filename = f"{document.name}.pdf"
# 确保文件名合法,替换非法字符
filename = re.sub(r'[\/:*?"<>|]', '_', filename)
file_full_path = os.path.join(save_dir, filename)
if os.path.exists(file_full_path):
os.remove(file_full_path) # Delete a single file
with open(file_full_path, "wb") as file:
file.write(response.content)
return file_full_path
except httpx.HTTPError as e:
raise FeishuAPIError(f"HTTP error: {str(e)}")
except Exception as e:
raise FeishuAPIError(f"Unexpected error during file download: {str(e)}")

View File

@@ -0,0 +1,46 @@
"""Exception classes for Feishu integration."""
class FeishuError(Exception):
"""Base exception for all Feishu-related errors."""
def __init__(self, message: str, error_code: str = None, details: dict = None):
super().__init__(message)
self.message = message
self.error_code = error_code
self.details = details or {}
class FeishuAuthError(FeishuError):
"""Authentication error with Feishu API."""
pass
class FeishuAPIError(FeishuError):
"""General API error from Feishu."""
pass
class FeishuNotFoundError(FeishuError):
"""Resource not found error (404)."""
pass
class FeishuPermissionError(FeishuError):
"""Permission denied error (403)."""
pass
class FeishuRateLimitError(FeishuError):
"""Rate limit exceeded error (429)."""
pass
class FeishuNetworkError(FeishuError):
"""Network-related error (timeout, connection failure)."""
pass
class FeishuDataError(FeishuError):
"""Data parsing or validation error."""
pass

View File

@@ -0,0 +1,17 @@
"""Data models for Feishu integration."""
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, Any, List, Optional
@dataclass
class FileInfo:
"""File information from Feishu."""
token: str
name: str
type: str # doc/docx/sheet/bitable/file/slides/folder
created_time: datetime
modified_time: datetime
owner_id: str
url: str

View File

@@ -0,0 +1,137 @@
"""Retry strategy for Feishu API calls."""
import asyncio
import functools
from typing import Callable, TypeVar
import httpx
from app.core.rag.integrations.feishu.exceptions import (
FeishuAuthError,
FeishuPermissionError,
FeishuNotFoundError,
FeishuRateLimitError,
FeishuNetworkError,
FeishuDataError,
FeishuAPIError,
)
T = TypeVar('T')
class RetryStrategy:
"""Retry strategy for API calls."""
# Retryable error types
RETRYABLE_ERRORS = (
FeishuNetworkError,
FeishuRateLimitError,
httpx.TimeoutException,
httpx.ConnectError,
httpx.ReadError,
)
# Non-retryable error types
NON_RETRYABLE_ERRORS = (
FeishuAuthError,
FeishuPermissionError,
FeishuNotFoundError,
FeishuDataError,
)
# Retry configuration
MAX_RETRIES = 3
BACKOFF_DELAYS = [1, 2, 4] # seconds
@classmethod
def is_retryable(cls, error: Exception) -> bool:
"""Check if an error is retryable."""
# Check for specific retryable errors
if isinstance(error, cls.RETRYABLE_ERRORS):
return True
# Check for non-retryable errors
if isinstance(error, cls.NON_RETRYABLE_ERRORS):
return False
# Check for HTTP status codes
if isinstance(error, httpx.HTTPStatusError):
status_code = error.response.status_code
# Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway)
if status_code in [429, 502, 503]:
return True
# Don't retry on 4xx errors (except 429)
if 400 <= status_code < 500:
return False
# Retry on 5xx errors
if 500 <= status_code < 600:
return True
# Check for FeishuAPIError with specific codes
if isinstance(error, FeishuAPIError):
if error.error_code:
# Rate limit error codes
if error.error_code in ["99991400", "99991401"]:
return True
return False
@classmethod
async def execute_with_retry(
cls,
func: Callable[..., T],
*args,
**kwargs
) -> T:
"""
Execute a function with retry logic.
Args:
func: Async function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Function result
Raises:
Exception: The last exception if all retries fail
"""
last_exception = None
for attempt in range(cls.MAX_RETRIES + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
# Don't retry if not retryable
if not cls.is_retryable(e):
raise
# Don't retry if this was the last attempt
if attempt >= cls.MAX_RETRIES:
raise
# Wait before retrying
delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1]
await asyncio.sleep(delay)
# Should not reach here, but raise last exception if we do
if last_exception:
raise last_exception
def with_retry(func: Callable[..., T]) -> Callable[..., T]:
"""
Decorator to add retry logic to async functions.
Usage:
@with_retry
async def my_api_call():
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await RetryStrategy.execute_with_retry(func, *args, **kwargs)
return wrapper

View File

@@ -0,0 +1 @@
"""Yuque integration module for document synchronization."""

View File

@@ -0,0 +1,77 @@
"""Main entry point for Yuque integration testing."""
import asyncio
import sys
from app.core.rag.integrations.yuque.client import YuqueAPIClient
from app.core.rag.integrations.yuque.models import YuqueDocInfo
def main(yuque_user_id: str, # yuque User ID
yuque_token: str, # yuque Token
save_dir: str, # save file directory
):
"""Main entry point for the YuqueAPIClient."""
# Create feishuAPIClient
api_client = YuqueAPIClient(
user_id=yuque_user_id,
token=yuque_token
)
# Get all files from all repos
async def async_get_files(api_client: YuqueAPIClient):
async with api_client as client:
print("\n=== Fetching repositories ===")
repos = await client.get_user_repos()
print(f"Found {len(repos)} repositories:")
all_files = []
for repo in repos:
# Get documents from repository
print(f"\n=== Fetching documents from '{repo.name}' ===")
docs = await client.get_repo_docs(repo.id)
all_files.extend(docs)
return all_files
files = asyncio.run(async_get_files(api_client))
try:
for doc in files:
print(f"\n{'=' * 80}")
print(f"id: {doc.id}")
print(f"type: {doc.type}")
print(f"slug: {doc.slug}")
print(f"title: {doc.title}")
print(f"book_id: {doc.book_id}")
# print(f"format: {doc.format}")
# print(f"body: {doc.body}")
# print(f"body_draft: {doc.body_draft}")
# print(f"body_html: {doc.body_html}")
print(f"public: {doc.public}")
print(f"status: {doc.status}")
print(f"created_at: {doc.created_at}")
print(f"updated_at: {doc.updated_at}")
print(f"published_at: {doc.published_at}")
print(f"word_count: {doc.word_count}")
print(f"cover: {doc.cover}")
print(f"description: {doc.description}")
print(f"{'=' * 80}\n")
# download document from Feishu FileInfo
async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str):
async with api_client as client:
file_path = await client.download_document(doc, save_dir)
return file_path
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
print(file_path)
except KeyboardInterrupt:
print("\n\nfeishu integration interrupted by user.")
except Exception as e:
print(f"\n\nError during feishu integration: {e}")
sys.exit(1)
if __name__ == "__main__":
yuque_user_id = ""
yuque_token = ""
save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/"
main(yuque_user_id, yuque_token, save_dir)

View File

@@ -0,0 +1,544 @@
"""Yuque API client for document operations."""
import os
import re
from typing import Optional, List
from datetime import datetime, timedelta
import httpx
import urllib.parse
import json
from openpyxl import Workbook
from openpyxl.styles import Font, Alignment, PatternFill
from openpyxl.utils import get_column_letter
import zlib
from app.core.rag.integrations.yuque.exceptions import (
YuqueAuthError,
YuqueAPIError,
YuqueNotFoundError,
YuquePermissionError,
YuqueRateLimitError,
YuqueNetworkError,
)
from app.core.rag.integrations.yuque.models import YuqueDocInfo, YuqueRepoInfo
from app.core.rag.integrations.yuque.retry import with_retry
class YuqueAPIClient:
"""Yuque API client for document synchronization."""
def __init__(
self,
user_id: str,
token: str,
api_base_url: str = "https://www.yuque.com/api/v2",
timeout: int = 30,
max_retries: int = 3
):
"""
Initialize Yuque API client.
Args:
user_id: Yuque user ID or login name
token: Yuque personal access token
api_base_url: Yuque API base URL
timeout: Request timeout in seconds
max_retries: Maximum number of retries
"""
self.user_id = user_id
self.token = token
self.api_base_url = api_base_url
self.timeout = timeout
self.max_retries = max_retries
self._http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
"""Async context manager entry."""
self._http_client = httpx.AsyncClient(
base_url=self.api_base_url,
timeout=self.timeout,
headers={
"Content-Type": "application/json",
"X-Auth-Token": self.token,
"User-Agent": "Yuque-Integration-Client"
}
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
if self._http_client:
await self._http_client.aclose()
def _handle_api_error(self, response: httpx.Response):
"""Handle API error responses."""
try:
data = response.json()
except Exception:
data = {}
status_code = response.status_code
error_msg = data.get("message", "Unknown error")
# Rate limit errors
if status_code == 429:
raise YuqueRateLimitError(
f"Rate limit exceeded: {error_msg}",
error_code=str(status_code),
details=data
)
# Not found errors
elif status_code == 404:
raise YuqueNotFoundError(
f"Resource not found: {error_msg}",
error_code=str(status_code),
details=data
)
# Permission errors
elif status_code == 403:
raise YuquePermissionError(
f"Permission denied: {error_msg}",
error_code=str(status_code),
details=data
)
# Authentication errors
elif status_code == 401:
raise YuqueAuthError(
f"Authentication failed: {error_msg}",
error_code=str(status_code),
details=data
)
# Generic API error
else:
raise YuqueAPIError(
f"API error: {error_msg}",
error_code=str(status_code),
details=data
)
@with_retry
async def get_user_repos(self) -> List[YuqueRepoInfo]:
"""
Get all repositories (知识库) for the user.
Returns:
List of YuqueRepoInfo objects
Raises:
YuqueAPIError: If API call fails
"""
try:
if not self._http_client:
raise YuqueAPIError("HTTP client not initialized")
response = await self._http_client.get(f"/users/{self.user_id}/repos")
if response.status_code != 200:
self._handle_api_error(response)
data = response.json()
repos_data = data.get("data", [])
repos = []
for repo_data in repos_data:
try:
repo = YuqueRepoInfo(
id=repo_data.get("id"),
type=repo_data.get("type", ""),
name=repo_data.get("name", ""),
namespace=repo_data.get("namespace", ""),
slug=repo_data.get("slug", ""),
description=repo_data.get("description"),
public=repo_data.get("public", 0),
items_count=repo_data.get("items_count", 0),
created_at=datetime.fromisoformat(repo_data.get("created_at", "").replace("Z", "+00:00")),
updated_at=datetime.fromisoformat(repo_data.get("updated_at", "").replace("Z", "+00:00"))
)
repos.append(repo)
except (ValueError, TypeError, KeyError) as e:
# Skip invalid repo entries
continue
return repos
except httpx.HTTPError as e:
raise YuqueAPIError(f"HTTP error: {str(e)}")
except Exception as e:
if isinstance(e, (YuqueAPIError, YuqueAuthError)):
raise
raise YuqueAPIError(f"Unexpected error: {str(e)}")
@with_retry
async def get_repo_docs(self, book_id: int) -> List[YuqueDocInfo]:
"""
Get all documents in a repository.
Args:
book_id: repository id
Returns:
List of YuqueDocInfo objects (without body content)
Raises:
YuqueAPIError: If API call fails
"""
try:
if not self._http_client:
raise YuqueAPIError("HTTP client not initialized")
response = await self._http_client.get(f"/repos/{book_id}/docs")
if response.status_code != 200:
self._handle_api_error(response)
data = response.json()
docs_data = data.get("data", [])
docs = []
for doc_data in docs_data:
try:
published_at = doc_data.get("published_at")
doc = YuqueDocInfo(
id=doc_data.get("id"),
type=doc_data.get("type", ""),
slug=doc_data.get("slug", ""),
title=doc_data.get("title", ""),
book_id=doc_data.get("book_id"),
format=doc_data.get("format", "markdown"),
body=None, # Body not included in list API
body_draft=None,
body_html=None,
public=doc_data.get("public", 0),
status=doc_data.get("status", 0),
created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")),
updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")),
published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None,
word_count=doc_data.get("word_count", 0),
cover=doc_data.get("cover"),
description=doc_data.get("description")
)
docs.append(doc)
except (ValueError, TypeError, KeyError) as e:
# Skip invalid doc entries
continue
return docs
except httpx.HTTPError as e:
raise YuqueAPIError(f"HTTP error: {str(e)}")
except Exception as e:
if isinstance(e, (YuqueAPIError, YuqueNotFoundError)):
raise
raise YuqueAPIError(f"Unexpected error: {str(e)}")
@with_retry
async def get_doc_detail(self, id: int) -> YuqueDocInfo:
"""
Get detailed document information including content.
Args:
id: document ID
Returns:
YuqueDocInfo object with full content
Raises:
YuqueAPIError: If API call fails
"""
try:
if not self._http_client:
raise YuqueAPIError("HTTP client not initialized")
response = await self._http_client.get(
f"/repos/docs/{id}",
params={"raw": 1} # Get raw markdown content
)
if response.status_code != 200:
self._handle_api_error(response)
data = response.json()
doc_data = data.get("data", {})
published_at = doc_data.get("published_at")
doc = YuqueDocInfo(
id=doc_data.get("id"),
type=doc_data.get("type", ""),
slug=doc_data.get("slug", ""),
title=doc_data.get("title", ""),
book_id=doc_data.get("book_id"),
format=doc_data.get("format", "markdown"),
body=doc_data.get("body", ""),
body_draft=doc_data.get("body_draft"),
body_html=doc_data.get("body_html"),
public=doc_data.get("public", 0),
status=doc_data.get("status", 0),
created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")),
updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")),
published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None,
word_count=doc_data.get("word_count", 0),
cover=doc_data.get("cover"),
description=doc_data.get("description")
)
return doc
except httpx.HTTPError as e:
raise YuqueAPIError(f"HTTP error: {str(e)}")
except Exception as e:
if isinstance(e, (YuqueAPIError, YuqueNotFoundError)):
raise
raise YuqueAPIError(f"Unexpected error: {str(e)}")
async def download_document(
self,
doc: YuqueDocInfo,
save_dir: str
) -> str:
"""
Download document content to local file.
Args:
doc: Document info (can be without body)
save_dir: Directory to save the file
Returns:
Full path to the saved file
Raises:
YuqueAPIError: If download fails
"""
try:
# Get full document content if not already loaded
if not doc.body:
doc = await self.get_doc_detail(doc.id)
# Sanitize filename
filename = re.sub(r'[\/:*?"<>|]', '_', doc.title)
# Determine file extension based on format
content = doc.body or ""
if doc.format == "markdown":
file_extension = "md"
elif doc.format == "lake":
file_extension = "md" # Save lake format as markdown
elif doc.format == "html":
file_extension = "html"
elif doc.format == "lakesheet":
file_extension = "xlsx"
body_data = json.loads(doc.body)
sheet_data = body_data.get("sheet", "")
try:
sheet_raw = zlib.decompress(bytes(sheet_data, 'latin-1'))
except Exception as e:
print(f"Error decompressing sheet data: {e}")
raise ValueError("Invalid or unsupported sheet data format.")
try:
sheet_text = sheet_raw.decode("utf-8") # 假设是 UTF-8 编码
except UnicodeDecodeError:
sheet_text = sheet_raw.decode("gbk") # 如果 UTF-8 解码失败,尝试 GBK
file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}")
self.generate_excel_from_sheet(sheet_text, file_full_path)
return file_full_path
else:
file_extension = "txt"
file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}")
# Remove existing file if it exists
if os.path.exists(file_full_path):
os.remove(file_full_path)
# Write content to file
with open(file_full_path, "w", encoding="utf-8") as file:
file.write(content)
return file_full_path
except Exception as e:
if isinstance(e, YuqueAPIError):
raise
raise YuqueAPIError(f"Unexpected error during file download: {str(e)}")
def generate_excel_from_sheet(self, sheet_text: str, save_path: str):
"""
将解析的 sheet_text 数据转换为 Excel 文件。
Args:
sheet_text (str): JSON 格式的 sheet 数据。
save_path (str): Excel 文件的保存路径。
"""
try:
# 解析 JSON 数据
sheets = json.loads(sheet_text)
if not isinstance(sheets, list):
raise ValueError("sheet_text must be a JSON array of sheets.")
# 创建一个新的 Excel 工作簿
workbook = Workbook()
for sheet_index, sheet_data in enumerate(sheets):
sheet_name = sheet_data.get("name", f"Sheet{sheet_index + 1}")
row_data = sheet_data.get("data", {})
merge_cells = sheet_data.get("mergeCells", {})
rows_styles = sheet_data.get("rows", [])
cols_styles = sheet_data.get("columns", [])
# 创建 Sheet
if sheet_index == 0:
worksheet = workbook.active
worksheet.title = sheet_name
else:
worksheet = workbook.create_sheet(title=sheet_name)
# 设置列宽
for col_index, col_style in enumerate(cols_styles):
col_width = col_style.get("size", 82.125) / 7.0
col_letter = get_column_letter(col_index + 1) # Excel 列从1开始
worksheet.column_dimensions[col_letter].width = col_width
# 设置行高
for row_index, row_style in enumerate(rows_styles):
row_height = row_style.get("size", 24) / 1.5
worksheet.row_dimensions[row_index + 1].height = row_height
# 写入单元格数据
for r_index, row in row_data.items():
for c_index, cell in row.items():
# 防御性检查:确保行号和列号都是有效的整数
try:
row_number = int(r_index) + 1
col_number = int(c_index) + 1
except ValueError:
print(f"Invalid row or column index: r_index={r_index}, c_index={c_index}")
continue
if col_number < 1 or col_number > 16384: # Excel 最大列数支持到 XFD即 16384 列
print(f"Invalid column index: c_index={c_index}")
continue
cell_obj = worksheet.cell(row=row_number, column=col_number)
# 处理值和公式
cell_value = cell.get("value", "")
if isinstance(cell_value, dict):
# 检查是否为公式
if cell_value.get("class") == "formula" and "formula" in cell_value:
cell_obj.value = f"={cell_value['formula']}" # 写入公式
else:
cell_obj.value = cell_value.get("value", "") # 写入值
else:
cell_obj.value = cell_value # 写入简单值
# 应用样式
style = cell.get("style", {})
self.apply_cell_style(cell_obj, style)
# 合并单元格
for key, merge_def in merge_cells.items():
start_row = merge_def["row"] + 1
start_col = merge_def["col"] + 1
end_row = start_row + merge_def["rowCount"] - 1
end_col = start_col + merge_def["colCount"] - 1
worksheet.merge_cells(
start_row=start_row, start_column=start_col, end_row=end_row, end_column=end_col
)
# 保存 Excel 文件
workbook.save(save_path)
print(f"Excel file successfully saved to: {save_path}")
except Exception as e:
print(f"Error generating Excel file: {e}")
def apply_cell_style(self, cell, style):
"""
应用单元格样式,包括字体、对齐、背景颜色等。
Args:
cell: openpyxl 的单元格对象。
style: 字典格式的样式信息。
"""
# 定义允许的对齐值
allowed_horizontal_alignments = {"general", "left", "center", "centerContinuous", "right", "fill", "justify",
"distributed"}
allowed_vertical_alignments = {"top", "center", "justify", "distributed", "bottom"}
# 处理字体
font = Font(
size=style.get("fontSize", 11),
bold=style.get("fontWeight", False),
italic=style.get("fontStyle", "normal") == "italic",
underline="single" if style.get("underline", False) else None,
color=self.convert_color_to_hex(style.get("color", "#000000")),
)
cell.font = font
# 处理对齐方式
horizontal_alignment = style.get("hAlign", "left")
vertical_alignment = style.get("vAlign", "top")
# 如果对齐值无效,则使用默认值
if horizontal_alignment not in allowed_horizontal_alignments:
horizontal_alignment = "left"
if vertical_alignment not in allowed_vertical_alignments:
vertical_alignment = "top"
alignment = Alignment(
horizontal=horizontal_alignment,
vertical=vertical_alignment,
wrap_text=style.get("overflow") == "wrap",
)
cell.alignment = alignment
# 处理背景颜色
background_color = style.get("backColor", None)
if background_color:
hex_color = self.convert_color_to_hex(background_color)
if hex_color:
cell.fill = PatternFill(
start_color=hex_color,
end_color=hex_color,
fill_type="solid"
)
def convert_color_to_hex(self, color):
"""
将颜色从 `rgba(...)` 或 `rgb(...)` 转换为 aRGB 十六进制格式。
Args:
color (str): 原始颜色字符串,如 `rgba(255,255,0,1.00)` 或 `#FFFFFF`。
Returns:
str: 转换后的颜色字符串(符合 openpyxl 的格式),例如 `FFFF0000`。
"""
try:
if not color:
return None
# 如果是 `#RRGGBB` 或 `#AARRGGBB` 格式,直接返回
if color.startswith("#"):
return color.lstrip("#").upper()
# 如果是 `rgb(...)` 格式,例如 `rgb(255,255,0)`
if color.startswith("rgb("):
rgb_values = color.strip("rgb()").split(",")
red, green, blue = [int(v) for v in rgb_values]
return f"FF{red:02X}{green:02X}{blue:02X}"
# 如果是 `rgba(...)` 格式,例如 `rgba(255,255,0,1.00)`
if color.startswith("rgba("):
rgba_values = color.strip("rgba()").split(",")
red, green, blue = [int(v) for v in rgba_values[:3]]
alpha = float(rgba_values[3])
alpha_hex = int(alpha * 255) # 将透明度转换为 [00, FF]
return f"{alpha_hex:02X}{red:02X}{green:02X}{blue:02X}"
# 返回默认颜色
return None
except Exception as e:
print(f"Error parsing color '{color}': {e}")
return None

View File

@@ -0,0 +1,46 @@
"""Exception classes for Yuque integration."""
class YuqueError(Exception):
"""Base exception for all Yuque-related errors."""
def __init__(self, message: str, error_code: str = None, details: dict = None):
super().__init__(message)
self.message = message
self.error_code = error_code
self.details = details or {}
class YuqueAuthError(YuqueError):
"""Authentication error with Yuque API."""
pass
class YuqueAPIError(YuqueError):
"""General API error from Yuque."""
pass
class YuqueNotFoundError(YuqueError):
"""Resource not found error (404)."""
pass
class YuquePermissionError(YuqueError):
"""Permission denied error (403)."""
pass
class YuqueRateLimitError(YuqueError):
"""Rate limit exceeded error (429)."""
pass
class YuqueNetworkError(YuqueError):
"""Network-related error (timeout, connection failure)."""
pass
class YuqueDataError(YuqueError):
"""Data parsing or validation error."""
pass

View File

@@ -0,0 +1,42 @@
"""Data models for Yuque integration."""
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
@dataclass
class YuqueRepoInfo:
"""Repository (知识库) information from Yuque."""
id: int # 知识库 ID
type: str # 类型 (Book:文档, Design:图集, Sheet:表格, Resource:资源)
name: str # 名称
namespace: str # 完整路径: user/repo format
slug: str # 路径
description: Optional[str] # 简介
public: int # 公开性 (0:私密, 1:公开, 2:企业内公开)
items_count: int # 文档数量
created_at: datetime # 创建时间
updated_at: datetime # 更新时间
@dataclass
class YuqueDocInfo:
"""Document information from Yuque."""
id: int # 文档 ID
type: str # 文档类型 (Doc:普通文档, Sheet:表格, Thread:话题, Board:图集, Table:数据表)
slug: str # 路径
title: str # 标题
book_id: int # 归属知识库 ID
format: str # 内容格式 (markdown:Markdown 格式, lake:语雀 Lake 格式, html:HTML 标准格式, lakesheet:语雀表格)
body: Optional[str] # 正文原始内容
body_draft: Optional[str] # 正文草稿内容
body_html: Optional[str] # 正文 HTML 标准格式内容
public: int # 公开性 (0:私密, 1:公开, 2:企业内公开)
status: int # 状态 (0:草稿, 1:发布)
created_at: datetime # 创建时间
updated_at: datetime # 更新时间
published_at: Optional[datetime] # 发布时间
word_count: int # 内容字数
cover: Optional[str] # 封面
description: Optional[str] # 摘要

View File

@@ -0,0 +1,134 @@
"""Retry strategy for Yuque API calls."""
import asyncio
import functools
from typing import Callable, TypeVar
import httpx
from app.core.rag.integrations.yuque.exceptions import (
YuqueAuthError,
YuquePermissionError,
YuqueNotFoundError,
YuqueRateLimitError,
YuqueNetworkError,
YuqueDataError,
YuqueAPIError,
)
T = TypeVar('T')
class RetryStrategy:
"""Retry strategy for API calls."""
# Retryable error types
RETRYABLE_ERRORS = (
YuqueNetworkError,
YuqueRateLimitError,
httpx.TimeoutException,
httpx.ConnectError,
httpx.ReadError,
)
# Non-retryable error types
NON_RETRYABLE_ERRORS = (
YuqueAuthError,
YuquePermissionError,
YuqueNotFoundError,
YuqueDataError,
)
# Retry configuration
MAX_RETRIES = 3
BACKOFF_DELAYS = [1, 2, 4] # seconds
@classmethod
def is_retryable(cls, error: Exception) -> bool:
"""Check if an error is retryable."""
# Check for specific retryable errors
if isinstance(error, cls.RETRYABLE_ERRORS):
return True
# Check for non-retryable errors
if isinstance(error, cls.NON_RETRYABLE_ERRORS):
return False
# Check for HTTP status codes
if isinstance(error, httpx.HTTPStatusError):
status_code = error.response.status_code
# Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway)
if status_code in [429, 502, 503]:
return True
# Don't retry on 4xx errors (except 429)
if 400 <= status_code < 500:
return False
# Retry on 5xx errors
if 500 <= status_code < 600:
return True
# Check for YuqueRateLimitError
if isinstance(error, YuqueRateLimitError):
return True
return False
@classmethod
async def execute_with_retry(
cls,
func: Callable[..., T],
*args,
**kwargs
) -> T:
"""
Execute a function with retry logic.
Args:
func: Async function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Function result
Raises:
Exception: The last exception if all retries fail
"""
last_exception = None
for attempt in range(cls.MAX_RETRIES + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
# Don't retry if not retryable
if not cls.is_retryable(e):
raise
# Don't retry if this was the last attempt
if attempt >= cls.MAX_RETRIES:
raise
# Wait before retrying
delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1]
await asyncio.sleep(delay)
# Should not reach here, but raise last exception if we do
if last_exception:
raise last_exception
def with_retry(func: Callable[..., T]) -> Callable[..., T]:
"""
Decorator to add retry logic to async functions.
Usage:
@with_retry
async def my_api_call():
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await RetryStrategy.execute_with_retry(func, *args, **kwargs)
return wrapper

View File

@@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
import logging
logger = logging.getLogger(__name__)
def knowledge_retrieval(
query: str,
@@ -62,7 +64,15 @@ def knowledge_retrieval(
merge_strategy = config.get("merge_strategy", "weight")
reranker_id = config.get("reranker_id")
reranker_top_k = config.get("reranker_top_k", 1024)
use_graph = config.get("use_graph", "false").lower() == "true"
# use_graph = config.get("use_graph", "false").lower() == "true"
use_graph_value = config.get("use_graph", False)
if isinstance(use_graph_value, bool):
use_graph = use_graph_value
elif isinstance(use_graph_value, str):
use_graph = use_graph_value.lower() in ("true", "1", "yes")
else:
use_graph = False
file_names_filter = []
if user_ids:
@@ -159,13 +169,29 @@ def knowledge_retrieval(
# Use the specified reranker for re-ranking
if reranker_id:
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
# use graph
try:
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
except Exception as rerank_error:
# If reranker fails, log warning and continue with original results
logger.warning(
"Reranker failed, falling back to original results",
extra={
"reranker_id": reranker_id,
"query": query,
"doc_count": len(all_results),
"error": str(rerank_error),
},
)
if use_graph:
from app.core.rag.common.settings import kg_retriever
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
all_results.insert(0, doc)
try:
from app.core.rag.common.settings import kg_retriever
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
all_results.insert(0, doc)
except Exception as graph_error:
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
return all_results
except Exception as e:

View File

@@ -25,6 +25,18 @@ class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
if self.response_metadata:
usage = self.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
def _output_types(self) -> dict[str, VariableType]:
outputs = {}
@@ -180,6 +192,7 @@ class ParameterExtractorNode(BaseNode):
])
model_resp = await llm.ainvoke(messages)
self.response_metadata = model_resp.response_metadata
result = json_repair.repair_json(model_resp.content, return_objects=True)
logger.info(f"node: {self.node_id} get params:{result}")

View File

@@ -25,6 +25,18 @@ class QuestionClassifierNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {}
self.response_metadata = {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
if self.response_metadata:
usage = self.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
def _output_types(self) -> dict[str, VariableType]:
return {
@@ -120,6 +132,7 @@ class QuestionClassifierNode(BaseNode):
response = await llm.ainvoke(messages)
result = response.content.strip()
self.response_metadata = response.response_metadata
if result in category_names:
category = result

View File

@@ -14,4 +14,5 @@ class File(Base):
file_name = Column(String, index=True, nullable=False, comment="file name or folder name,default folder name is /")
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
file_size = Column(Integer, default=0, comment="file size(byte)")
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
created_at = Column(DateTime, default=datetime.datetime.now)

View File

@@ -57,6 +57,17 @@ class Knowledge(Base):
parser_id = Column(String, index=True, default="naive", comment="default parser ID")
parser_config = Column(JSON, nullable=False,
default={
"entry_url": "https://ai.redbearai.com",
"max_pages": 20,
"delay_seconds": 1.0,
"timeout_seconds": 10,
"user_agent": "KnowledgeBaseCrawler/1.0",
"yuque_user_id": "User ID",
"yuque_token": "Token",
"feishu_app_id": "App ID",
"feishu_app_secret": "App Secret",
"feishu_folder_token": "Folder Token",
"sync_cron": "30 7 * * 1-5",
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",

View File

@@ -86,7 +86,8 @@ class MemoryConfigRepository:
n.description AS description,
n.entity_type AS entity_type,
n.name AS name,
COALESCE(n.fact_summary, '') AS fact_summary,
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
// COALESCE(n.fact_summary, '') AS fact_summary,
n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
@@ -279,6 +280,9 @@ class MemoryConfigRepository:
if update.config_desc is not None:
db_config.config_desc = update.config_desc
has_update = True
if update.scene_id is not None:
db_config.scene_id = update.scene_id
has_update = True
if not has_update:
raise ValueError("No fields to update")
@@ -650,28 +654,32 @@ class MemoryConfigRepository:
raise
@staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
"""获取所有配置参数
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]:
"""获取所有配置参数,包含关联的场景名称
Args:
db: 数据库会话
workspace_id: 工作空间ID用于过滤查询结果
Returns:
List[MemoryConfig]: 配置列表
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
"""
from app.models.ontology_scene import OntologyScene
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
try:
query = db.query(MemoryConfig)
query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin(
OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id
)
if workspace_id:
query = query.filter(MemoryConfig.workspace_id == workspace_id)
configs = query.order_by(desc(MemoryConfig.updated_at)).all()
results = query.order_by(desc(MemoryConfig.updated_at)).all()
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
return configs
db_logger.debug(f"配置列表查询成功: 数量={len(results)}")
return results
except Exception as e:
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")

View File

@@ -79,7 +79,8 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
try:
edges: List[dict] = []
for s in summaries:
for chunk_id in getattr(s, "chunk_ids", []) or []:
chunk_ids = getattr(s, "chunk_ids", []) or []
for chunk_id in chunk_ids:
edges.append({
"summary_id": s.id,
"chunk_id": chunk_id,
@@ -91,12 +92,11 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
if not edges:
return []
result = await connector.execute_query(
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
edges=edges
)
created = [record.get("uuid") for record in result] if result else []
return created
except Exception:
except Exception as e:
return None

View File

@@ -217,8 +217,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
summaries=flattened
)
created_ids = [record.get("uuid") for record in result]
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
return created_ids
except Exception:
except Exception as e:
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
return None

View File

@@ -101,10 +101,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
e.name_embedding = CASE
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
ELSE e.name_embedding END,
e.fact_summary = CASE
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
THEN entity.fact_summary ELSE e.fact_summary END,
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
// e.fact_summary = CASE
// WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
// AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
// THEN entity.fact_summary ELSE e.fact_summary END,
e.connect_strength = CASE
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
ELSE CASE
@@ -321,7 +322,8 @@ RETURN e.id AS id,
e.description AS description,
e.aliases AS aliases,
e.name_embedding AS name_embedding,
COALESCE(e.fact_summary, '') AS fact_summary,
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
// COALESCE(e.fact_summary, '') AS fact_summary,
e.connect_strength AS connect_strength,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT c.id) AS chunk_ids,
@@ -1002,3 +1004,58 @@ RETURN DISTINCT
x.statement as statement,x.created_at as created_at
"""
Graph_Node_query = """
MATCH (n:MemorySummary)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
0 AS priority
LIMIT $limit
UNION ALL
MATCH (n:Dialogue)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
1 AS priority
LIMIT 1
UNION ALL
MATCH (n:Statement)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
1 AS priority
LIMIT $limit
UNION ALL
MATCH (n:ExtractedEntity)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
2 AS priority
LIMIT $limit
UNION ALL
MATCH (n:Chunk)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
3 AS priority
LIMIT $limit
"""

View File

@@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import (
ExtractedEntityNode,
EntityEntityEdge,
)
import logging
logger = logging.getLogger(__name__)
async def save_entities_and_relationships(
entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge],
@@ -41,8 +42,8 @@ async def save_entities_and_relationships(
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat(),
'expired_at': edge.expired_at.isoformat(),
'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id,
'end_user_id': edge.end_user_id,
}
@@ -147,14 +148,14 @@ async def save_statement_entity_edges(
async def save_dialog_and_statements_to_neo4j(
dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
@@ -171,40 +172,126 @@ async def save_dialog_and_statements_to_neo4j(
Returns:
bool: True if successful, False otherwise
"""
try:
# Save all dialogue nodes in batch
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector)
if dialogue_uuids:
# 定义事务函数,将所有写操作放在一个事务中
async def _save_all_in_transaction(tx):
"""在单个事务中执行所有保存操作,避免死锁"""
results = {}
# 1. Save all dialogue nodes in batch
if dialogue_nodes:
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE
dialogue_data = [node.model_dump() for node in dialogue_nodes]
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
dialogue_uuids = [record["uuid"] async for record in result]
results['dialogues'] = dialogue_uuids
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
else:
print("Failed to save dialogues to Neo4j")
return False
# Save all chunk nodes in batch
await save_chunk_nodes(chunk_nodes, connector)
# 2. Save all chunk nodes in batch
if chunk_nodes:
from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE
chunk_data = [node.model_dump() for node in chunk_nodes]
result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data)
chunk_uuids = [record["uuid"] async for record in result]
results['chunks'] = chunk_uuids
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
# Save all statement nodes in batch
# 3. Save all statement nodes in batch
if statement_nodes:
statement_uuids = await add_statement_nodes(statement_nodes, connector)
if statement_uuids:
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
else:
print("Failed to save statement nodes to Neo4j")
return False
else:
print("No statement nodes to save")
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
statement_data = [node.model_dump() for node in statement_nodes]
result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data)
statement_uuids = [record["uuid"] async for record in result]
results['statements'] = statement_uuids
logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
# Save entities and relationships
await save_entities_and_relationships(entity_nodes, entity_edges, connector)
print("Successfully saved entities and relationships to Neo4j")
# 4. Save entities
if entity_nodes:
from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
entity_data = [entity.model_dump() for entity in entity_nodes]
result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data)
entity_uuids = [record["uuid"] async for record in result]
results['entities'] = entity_uuids
logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
# Save new edges
await save_statement_chunk_edges(statement_chunk_edges, connector)
await save_statement_entity_edges(statement_entity_edges, connector)
# 5. Create entity relationships
if entity_edges:
from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE
relationship_data = []
for edge in entity_edges:
relationship_data.append({
'source_id': edge.source,
'target_id': edge.target,
'predicate': edge.relation_type,
'statement_id': edge.source_statement_id,
'value': edge.relation_value,
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id,
'end_user_id': edge.end_user_id,
})
result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data)
rel_uuids = [record["uuid"] async for record in result]
results['entity_relationships'] = rel_uuids
logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j")
# 6. Save statement-chunk edges
if statement_chunk_edges:
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE
sc_edge_data = []
for edge in statement_chunk_edges:
sc_edge_data.append({
"id": edge.id,
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
})
result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data)
sc_uuids = [record["uuid"] async for record in result]
results['statement_chunk_edges'] = sc_uuids
logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j")
# 7. Save statement-entity edges
if statement_entity_edges:
from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE
se_edge_data = []
for edge in statement_entity_edges:
se_edge_data.append({
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
"connect_strength": getattr(edge, "connect_strength", "strong"),
})
result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data)
se_uuids = [record["uuid"] async for record in result]
results['statement_entity_edges'] = se_uuids
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
return results
try:
# 使用显式写事务执行所有操作,避免死锁
results = await connector.execute_write_transaction(_save_all_in_transaction)
summary = {
key: len(value)
for key, value in results.items()
if isinstance(value, (list, tuple, set))
}
logger.info("Transaction completed. Summary: %s", summary)
logger.debug("Full transaction results: %r", results)
return True
except Exception as e:
logger.error(f"Neo4j integration error: {e}", exc_info=True)
print(f"Neo4j integration error: {e}")
print("Continuing without database storage...")
return False
return False

View File

@@ -392,3 +392,48 @@ class OntologySceneRepository:
exc_info=True
)
raise
def get_simple_list(self, workspace_id: UUID) -> List[dict]:
"""获取场景简单列表仅包含scene_id和scene_name用于下拉选择
这是一个轻量级查询不加载关联的classes响应速度快。
Args:
workspace_id: 工作空间ID
Returns:
List[dict]: 场景简单列表每项包含scene_id和scene_name
Examples:
>>> repo = OntologySceneRepository(db)
>>> scenes = repo.get_simple_list(workspace_id)
>>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...]
"""
try:
logger.debug(f"Getting simple scene list for workspace: {workspace_id}")
# 只查询需要的字段,不加载关联数据
results = self.db.query(
OntologyScene.scene_id,
OntologyScene.scene_name
).filter(
OntologyScene.workspace_id == workspace_id
).order_by(
OntologyScene.updated_at.desc()
).all()
scenes = [
{"scene_id": str(r.scene_id), "scene_name": r.scene_name}
for r in results
]
logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}")
return scenes
except Exception as e:
logger.error(
f"Failed to get simple scene list: {str(e)}",
exc_info=True
)
raise

View File

@@ -10,6 +10,8 @@ class FileBase(BaseModel):
file_name: str
file_ext: str
file_size: int
file_url: str | None = None
created_at: datetime.datetime | None = None
class FileCreate(FileBase):
@@ -26,6 +28,7 @@ class FileUpdate(BaseModel):
file_name: str | None = Field(None)
file_ext: str | None = Field(None)
file_size: str | None = Field(None)
file_url: str | None = Field(None)
class File(FileBase):

View File

@@ -1,3 +1,4 @@
from abc import ABC
from typing import Optional
from pydantic import BaseModel
@@ -14,4 +15,15 @@ class UserInput(BaseModel):
class Write_UserInput(BaseModel):
messages: list[dict]
end_user_id: str
config_id: Optional[str] = None
config_id: Optional[str] = None
class AgentMemory_Long_Term(ABC):
"""长期记忆配置常量"""
STORAGE_NEO4J = "neo4j"
STORAGE_RAG = "rag"
STRATEGY_AGGREGATE = "aggregate"
STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6

View File

@@ -248,8 +248,9 @@ class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id: Union[uuid.UUID, int, str] = None
config_name: str = Field("配置名称", description="配置名称(字符串)")
config_desc: str = Field("配置描述", description="配置描述(字符串)")
config_name: Optional[str] = Field(None, description="配置名称(字符串)")
config_desc: Optional[str] = Field(None, description="配置描述(字符串)")
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID")
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型

View File

@@ -964,8 +964,15 @@ class AppService:
).order_by(
AgentConfig.updated_at.desc()
)
config = self.db.scalars(stmt).first()
try:
config_memory=config.memory
if 'memory_content' in config_memory:
config.memory['memory_config_id'] = config.memory.pop('memory_content')
except:
logger.debug("记忆配置不存在")
if config:
return config

View File

@@ -114,6 +114,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
result = task_service.get_task_memory_read_result(task.id)
status = result.get("status")
logger.info(f"读取任务状态:{status}")
if memory_content:
memory_content = memory_content['answer']
finally:
db.close()
@@ -127,11 +129,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"content_length": len(str(memory_content))
}
)
# 检查是否有有效内容
if not memory_content or str(memory_content).strip() == "" or "answer" in str(memory_content) and str(memory_content).count("''") > 0:
return "未找到相关的历史记忆。请直接回答用户的问题,不要再次调用此工具。"
return f"检索到以下历史记忆:\n\n{memory_content}"
except Exception as e:
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})

View File

@@ -183,11 +183,11 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
configs = MemoryConfigRepository.get_all(self.db, workspace_id)
results = MemoryConfigRepository.get_all(self.db, workspace_id)
# 将 ORM 对象转换为字典列表
data_list = []
for config in configs:
for config, scene_name in results:
# 安全地转换 user_id 为 int
config_id_old = None
if config.config_id_old:
@@ -209,7 +209,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"end_user_id": config.end_user_id,
"config_id_old": config_id_old,
"apply_id": config.apply_id,
"scene_id": config.scene_id,
"scene_id": str(config.scene_id) if config.scene_id else None,
"scene_name": scene_name, # 新增:场景名称
"llm_id": config.llm_id,
"embedding_id": config.embedding_id,
"rerank_id": config.rerank_id,
@@ -637,10 +638,9 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
if m < 1:
latest_relative = "刚刚"
elif m < 60:
latest_relative = f"{m}分钟"
latest_relative = "一会"
else:
h = int(m // 60)
latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前"
latest_relative = "较早前"
except Exception:
pass

View File

@@ -15,6 +15,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.conversation_repository import ConversationRepository
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.cypher_queries import Graph_Node_query
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService
@@ -1525,7 +1526,6 @@ async def analytics_graph_data(
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
return {
@@ -1579,21 +1579,11 @@ async def analytics_graph_data(
}
else:
# 查询所有节点
node_query = """
MATCH (n)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) as id,
labels(n)[0] as label,
properties(n) as properties
LIMIT $limit
"""
node_query=Graph_Node_query
node_params = {
"end_user_id": end_user_id,
"limit": limit
}
# 执行节点查询
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
@@ -1604,9 +1594,9 @@ async def analytics_graph_data(
for record in node_results:
node_id = record["id"]
node_label = record["label"]
node_labels = record.get("labels", [])
node_label = node_labels[0] if node_labels else "Unknown"
node_props = record["properties"]
# 根据节点类型提取需要的属性字段
filtered_props = await _extract_node_properties(node_label, node_props,node_id)

View File

@@ -7,6 +7,8 @@ import uuid
from uuid import UUID
from datetime import datetime, timezone
from math import ceil
from pathlib import Path
import shutil
from typing import Any, Dict, List, Optional
import redis
@@ -16,8 +18,13 @@ import trio
# Import a unified Celery instance
from app.celery_app import celery_app
from app.core.config import settings
from app.core.rag.crawler.web_crawler import WebCrawler
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
from app.core.rag.integrations.feishu.client import FeishuAPIClient
from app.core.rag.integrations.feishu.models import FileInfo
from app.core.rag.integrations.yuque.client import YuqueAPIClient
from app.core.rag.integrations.yuque.models import YuqueDocInfo
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.embedding_model import OpenAIEmbed
@@ -29,7 +36,9 @@ from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
)
from app.db import get_db, get_db_context
from app.models.document_model import Document
from app.models.file_model import File
from app.models.knowledge_model import Knowledge
from app.schemas import file_schema, document_schema
from app.services.memory_agent_service import MemoryAgentService
@@ -382,6 +391,480 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
db.close()
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
def sync_knowledge_for_kb(kb_id: uuid.UUID):
"""
sync knowledge document and Document parsing, vectorization, and storage
"""
db = next(get_db()) # Manually call the generator
db_knowledge = None
try:
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
# 1. get vector_service
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# 2. sync data
match db_knowledge.type:
case "Web": # Crawl webpages in batches through a web crawler
entry_url = db_knowledge.parser_config.get("entry_url", "")
max_pages = db_knowledge.parser_config.get("max_pages", 20)
delay_seconds = db_knowledge.parser_config.get("delay_seconds", 1.0)
timeout_seconds = db_knowledge.parser_config.get("timeout_seconds", 10)
user_agent = db_knowledge.parser_config.get("user_agent", "KnowledgeBaseCrawler/1.0")
# Create crawler
crawler = WebCrawler(
entry_url=entry_url,
max_pages=max_pages,
delay_seconds=delay_seconds,
timeout_seconds=timeout_seconds,
user_agent=user_agent
)
try:
# 初始化存储已爬取 URLs 的集合
file_urls = set()
# crawl entry_url by yield
for crawled_document in crawler.crawl():
file_urls.add(crawled_document.url)
db_file = db.query(File).filter(File.kb_id == db_knowledge.id,
File.file_url == crawled_document.url).first()
if db_file:
if db_file.file_size == crawled_document.content_length: # same
continue
else: # --update
if crawled_document.content_length:
# 1. update file
db_file.file_name = f"{crawled_document.title}.txt"
db_file.file_ext=".txt"
db_file.file_size=crawled_document.content_length
db.commit()
db.refresh(db_file)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
# update file
if os.path.exists(save_path):
os.remove(save_path) # Delete a single file
content_bytes = crawled_document.content.encode('utf-8')
with open(save_path, "wb") as f:
f.write(content_bytes)
# 2. update a document
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
Document.file_id == db_file.id).first()
if db_document:
db_document.file_name = db_file.file_name
db_document.file_ext = db_file.file_ext
db_document.file_size = db_file.file_size
db_document.updated_at = datetime.now()
db.commit()
db.refresh(db_document)
# 3. Document parsing, vectorization, and storage
parse_document(file_path=save_path, document_id=db_document.id)
else: # --add
if crawled_document.content_length:
# 1. upload file
upload_file = file_schema.FileCreate(
kb_id=db_knowledge.id,
created_by=db_knowledge.created_by,
parent_id=db_knowledge.id,
file_name=f"{crawled_document.title}.txt",
file_ext=".txt",
file_size=crawled_document.content_length,
file_url=crawled_document.url,
)
db_file = File(**upload_file.model_dump())
db.add(db_file)
db.commit()
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
# Save file
content_bytes = crawled_document.content.encode('utf-8')
with open(save_path, "wb") as f:
f.write(content_bytes)
# 2. Create a document
create_document_data = document_schema.DocumentCreate(
kb_id=db_knowledge.id,
created_by=db_knowledge.created_by,
file_id=db_file.id,
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
)
db_document = Document(**create_document_data.model_dump())
db.add(db_document)
db.commit()
# 3. Document parsing, vectorization, and storage
parse_document(file_path=save_path, document_id=db_document.id)
db_files = db.query(File).filter(File.kb_id == db_knowledge.id, File.file_url.notin_(file_urls)).all()
if db_files: # --delete
for db_file in db_files:
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
Document.file_id == db_file.id).first()
if db_document:
# 1. Delete vector index
vector_service.delete_by_metadata_field(key="document_id", value=str(db_document.id))
# 2. Delete document
db.delete(db_document)
# 3. Delete file
file_path = Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
if file_path.exists():
file_path.unlink() # Delete a single file
db.delete(db_file)
# commit transaction
db.commit()
except Exception as e:
print(f"\n\nError during crawl: {e}")
case "Third-party": # Integration of knowledge bases from three parties
yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "")
feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "")
if yuque_user_id: # Yuque Knowledge Base
yuque_token = db_knowledge.parser_config.get("yuque_token", "")
# Create yuqueAPIClient
api_client = YuqueAPIClient(
user_id=yuque_user_id,
token=yuque_token
)
try:
# 初始化存储获取语雀 URLs 的集合
file_urls = set()
# Get all files from all repos
async def async_get_files(api_client: YuqueAPIClient):
async with api_client as client:
print("\n=== Fetching repositories ===")
repos = await client.get_user_repos()
print(f"Found {len(repos)} repositories:")
all_files = []
for repo in repos:
# Get documents from repository
print(f"\n=== Fetching documents from '{repo.name}' ===")
docs = await client.get_repo_docs(repo.id)
all_files.extend(docs)
return all_files
files = asyncio.run(async_get_files(api_client))
for doc in files:
file_urls.add(doc.slug)
db_file = db.query(File).filter(File.kb_id == db_knowledge.id,
File.file_url == doc.slug).first()
if db_file:
if db_file.created_at == doc.updated_at: # same
continue
else: # --update
# 1. update file
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
# download document from Feishu FileInfo
async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str):
async with api_client as client:
file_path = await client.download_document(doc, save_dir)
return file_path
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
# update file
if os.path.exists(save_path):
os.remove(save_path) # Delete a single file
shutil.copyfile(file_path, save_path)
# update db_file
file_name = os.path.basename(file_path)
_, file_extension = os.path.splitext(file_name)
file_size = os.path.getsize(file_path)
db_file.file_name = file_name
db_file.file_ext = file_extension.lower()
db_file.file_size = file_size
db_file.created_at = doc.updated_at
db.commit()
db.refresh(db_file)
# 2. update a document
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
Document.file_id == db_file.id).first()
if db_document:
db_document.file_name = db_file.file_name
db_document.file_ext = db_file.file_ext
db_document.file_size = db_file.file_size
db_document.created_at = db_file.created_at
db_document.updated_at = datetime.now()
db.commit()
db.refresh(db_document)
# 3. Document parsing, vectorization, and storage
parse_document(file_path=save_path, document_id=db_document.id)
else: # --add
# 1. update file
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
# download document from Feishu FileInfo
async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str):
async with api_client as client:
file_path = await client.download_document(doc, save_dir)
return file_path
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
# add db_file
file_name = os.path.basename(file_path)
_, file_extension = os.path.splitext(file_name)
file_size = os.path.getsize(file_path)
upload_file = file_schema.FileCreate(
kb_id=db_knowledge.id,
created_by=db_knowledge.created_by,
parent_id=db_knowledge.id,
file_name=file_name,
file_ext=file_extension.lower(),
file_size=file_size,
file_url=doc.slug,
created_at=doc.updated_at
)
db_file = File(**upload_file.model_dump())
db.add(db_file)
db.commit()
# Save file
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
# update file
if os.path.exists(save_path):
os.remove(save_path) # Delete a single file
shutil.copyfile(file_path, save_path)
# 2. Create a document
create_document_data = document_schema.DocumentCreate(
kb_id=db_knowledge.id,
created_by=db_knowledge.created_by,
file_id=db_file.id,
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
)
db_document = Document(**create_document_data.model_dump())
db.add(db_document)
db.commit()
# 3. Document parsing, vectorization, and storage
parse_document(file_path=save_path, document_id=db_document.id)
db_files = db.query(File).filter(File.kb_id == db_knowledge.id,
File.file_url.notin_(file_urls)).all()
if db_files: # --delete
for db_file in db_files:
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
Document.file_id == db_file.id).first()
if db_document:
# 1. Delete vector index
vector_service.delete_by_metadata_field(key="document_id",
value=str(db_document.id))
# 2. Delete document
db.delete(db_document)
# 3. Delete file
file_path = Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
if file_path.exists():
file_path.unlink() # Delete a single file
db.delete(db_file)
# commit transaction
db.commit()
except Exception as e:
print(f"\n\nError during fetch feishu: {e}")
if feishu_app_id: # Feishu Knowledge Base
feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "")
feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "")
# Create feishuAPIClient
api_client = FeishuAPIClient(
app_id=feishu_app_id,
app_secret=feishu_app_secret
)
try:
# 初始化存储获取飞书 URLs 的集合
file_urls = set()
# Get all files from folder
async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str):
async with api_client as client:
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
return files
files = asyncio.run(async_get_files(api_client, feishu_folder_token))
# Filter out folders, only sync documents
documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]]
for doc in documents:
file_urls.add(doc.url)
db_file = db.query(File).filter(File.kb_id == db_knowledge.id,
File.file_url == doc.url).first()
if db_file:
if db_file.created_at == doc.modified_time: # same
continue
else: # --update
# 1. update file
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id),
str(db_knowledge.parent_id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
# download document from Feishu FileInfo
async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str):
async with api_client as client:
file_path = await client.download_document(document=doc, save_dir=save_dir)
return file_path
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
# update file
if os.path.exists(save_path):
os.remove(save_path) # Delete a single file
shutil.copyfile(file_path, save_path)
# update db_file
file_name = os.path.basename(file_path)
_, file_extension = os.path.splitext(file_name)
file_size = os.path.getsize(file_path)
db_file.file_name = file_name
db_file.file_ext = file_extension.lower()
db_file.file_size = file_size
db_file.created_at = doc.modified_time
db.commit()
db.refresh(db_file)
# 2. update a document
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
Document.file_id == db_file.id).first()
if db_document:
db_document.file_name = db_file.file_name
db_document.file_ext = db_file.file_ext
db_document.file_size = db_file.file_size
db_document.created_at = db_file.created_at
db_document.updated_at = datetime.now()
db.commit()
db.refresh(db_document)
# 3. Document parsing, vectorization, and storage
parse_document(file_path=save_path, document_id=db_document.id)
else: # --add
# 1. update file
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id),
str(db_knowledge.parent_id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
# download document from Feishu FileInfo
async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str):
async with api_client as client:
file_path = await client.download_document(document=doc, save_dir=save_dir)
return file_path
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
# add db_file
file_name = os.path.basename(file_path)
_, file_extension = os.path.splitext(file_name)
file_size = os.path.getsize(file_path)
upload_file = file_schema.FileCreate(
kb_id=db_knowledge.id,
created_by=db_knowledge.created_by,
parent_id=db_knowledge.id,
file_name=file_name,
file_ext=file_extension.lower(),
file_size=file_size,
file_url=doc.url,
created_at = doc.modified_time
)
db_file = File(**upload_file.model_dump())
db.add(db_file)
db.commit()
# Save file
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
# update file
if os.path.exists(save_path):
os.remove(save_path) # Delete a single file
shutil.copyfile(file_path, save_path)
# 2. Create a document
create_document_data = document_schema.DocumentCreate(
kb_id=db_knowledge.id,
created_by=db_knowledge.created_by,
file_id=db_file.id,
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
)
db_document = Document(**create_document_data.model_dump())
db.add(db_document)
db.commit()
# 3. Document parsing, vectorization, and storage
parse_document(file_path=save_path, document_id=db_document.id)
db_files = db.query(File).filter(File.kb_id == db_knowledge.id,
File.file_url.notin_(file_urls)).all()
if db_files: # --delete
for db_file in db_files:
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
Document.file_id == db_file.id).first()
if db_document:
# 1. Delete vector index
vector_service.delete_by_metadata_field(key="document_id",
value=str(db_document.id))
# 2. Delete document
db.delete(db_document)
# 3. Delete file
file_path = Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
if file_path.exists():
file_path.unlink() # Delete a single file
db.delete(db_file)
# commit transaction
db.commit()
except Exception as e:
print(f"\n\nError during fetch feishu: {e}")
case _: # General
print(f"General: No synchronization needed\n")
result = f"sync knowledge '{db_knowledge.name}' processed successfully."
return result
except Exception as e:
if 'db_knowledge' in locals():
print(f"Failed to sync knowledge:{str(e)}\n")
result = f"sync knowledge '{db_knowledge.name}' failed."
return result
finally:
db.close()
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]:

View File

@@ -5,42 +5,68 @@ Shared utilities for configuration handling to avoid circular imports.
"""
from uuid import UUID
from sqlalchemy.orm import Session
import uuid as uuid_module
def resolve_config_id(config_id: UUID | int|str, db: Session) -> UUID:
def resolve_config_id(config_id: UUID | int | str, db: Session) -> UUID:
"""
解析 config_id如果是整数则通过 config_id_old 查找对应的 UUID
解析 config_id支持 UUID、UUID字符串、整数等多种格式
Args:
config_id: 配置IDUUID 或整数)
config_id: 配置IDUUID、UUID字符串 整数)
db: 数据库会话
Returns:
UUID: 解析后的配置ID
Raises:
ValueError: 当找不到对应的配置时
ValueError: 当找不到对应的配置时或格式无效时
"""
from app.models.memory_config_model import MemoryConfig
if isinstance(config_id, UUID):
# 1. 如果已经是 UUID 类型,直接返回
if isinstance(config_id, UUID):
return config_id
if isinstance(config_id, str) and len(config_id)<=6:
memory_config = db.query(MemoryConfig).filter(
MemoryConfig.config_id_old == int(config_id)
).first()
print(memory_config)
if not memory_config:
raise ValueError(f"STR 未找到 config_id_old={config_id} 对应的配置")
return memory_config.config_id
# 2. 如果是字符串类型
if isinstance(config_id, str):
config_id_stripped = config_id.strip()
# 2.1 尝试解析为 UUID标准 UUID 字符串长度为 36
try:
return uuid_module.UUID(config_id_stripped)
except ValueError:
pass
# 2.2 尝试解析为整数(用于查询 config_id_old
try:
old_id = int(config_id_stripped)
if old_id > 0:
memory_config = db.query(MemoryConfig).filter(
MemoryConfig.config_id_old == old_id
).first()
if not memory_config:
raise ValueError(f"未找到 config_id_old={old_id} 对应的配置")
return memory_config.config_id
except ValueError:
pass
# 2.3 无法解析的字符串格式
raise ValueError(f"无效的 config_id 格式: '{config_id}'(必须是 UUID 或正整数)")
# 3. 如果是整数类型,通过 config_id_old 查找
if isinstance(config_id, int):
if config_id <= 0:
raise ValueError(f"config_id 必须是正整数: {config_id}")
memory_config = db.query(MemoryConfig).filter(
MemoryConfig.config_id_old == config_id
).first()
if not memory_config:
raise ValueError(f"INT 未找到 config_id_old={config_id} 对应的配置")
raise ValueError(f"未找到 config_id_old={config_id} 对应的配置")
return memory_config.config_id
return config_id
# 4. 不支持的类型
raise ValueError(f"不支持的 config_id 类型: {type(config_id).__name__}")

View File

@@ -0,0 +1,32 @@
"""202602061233
Revision ID: ef0787b85c35
Revises: 9b28b66cf8e8
Create Date: 2026-02-06 12:33:26.114673
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'ef0787b85c35'
down_revision: Union[str, None] = '9b28b66cf8e8'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('files', sa.Column('file_url', sa.String(), nullable=True, comment='file comes from a website url'))
op.create_index(op.f('ix_files_file_url'), 'files', ['file_url'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_files_file_url'), table_name='files')
op.drop_column('files', 'file_url')
# ### end Alembic commands ###

View File

@@ -141,6 +141,8 @@ dependencies = [
"flower>=2.0.1",
"aiofiles>=23.0.0",
"owlready2>=0.46",
"lxml>=4.9.0",
"httpx>=0.28.0",
]
[tool.pytest.ini_options]

View File

@@ -134,3 +134,5 @@ xlrd==2.0.2
oss2>=2.18.0
boto3>=1.28.0
aiofiles>=23.0.0
lxml>=4.9.0
httpx>=0.28.0

4
api/uv.lock generated
View File

@@ -3224,6 +3224,7 @@ dependencies = [
{ name = "hanziconv" },
{ name = "html5lib" },
{ name = "httptools" },
{ name = "httpx" },
{ name = "huggingface-hub" },
{ name = "idna" },
{ name = "jieba" },
@@ -3237,6 +3238,7 @@ dependencies = [
{ name = "langchain-ollama" },
{ name = "langchain-openai" },
{ name = "langfuse" },
{ name = "lxml" },
{ name = "mako" },
{ name = "mammoth" },
{ name = "markdown" },
@@ -3361,6 +3363,7 @@ requires-dist = [
{ name = "hanziconv", specifier = "==0.3.2" },
{ name = "html5lib", specifier = "==1.1" },
{ name = "httptools", specifier = "==0.7.1" },
{ name = "httpx", specifier = ">=0.28.0" },
{ name = "huggingface-hub", specifier = "==0.25.2" },
{ name = "idna", specifier = "==3.11" },
{ name = "jieba", specifier = ">=0.42.1" },
@@ -3375,6 +3378,7 @@ requires-dist = [
{ name = "langchain-ollama" },
{ name = "langchain-openai", specifier = ">=1.0.2" },
{ name = "langfuse", specifier = ">=3.10.0" },
{ name = "lxml", specifier = ">=4.9.0" },
{ name = "mako", specifier = "==1.3.10" },
{ name = "mammoth", specifier = "==1.11.0" },
{ name = "markdown", specifier = "==3.8" },

View File

@@ -13,6 +13,14 @@
"@antv/layout": "^1.2.14-beta.8",
"@antv/x6": "^3.0.1",
"@antv/x6-react-shape": "^3.0.1",
"@codemirror/lang-cpp": "^6.0.3",
"@codemirror/lang-java": "^6.0.2",
"@codemirror/lang-javascript": "^6.2.4",
"@codemirror/lang-python": "^6.2.1",
"@codemirror/lang-rust": "^6.0.2",
"@codemirror/state": "^6.5.4",
"@codemirror/theme-one-dark": "^6.1.3",
"@codemirror/view": "^6.39.12",
"@dnd-kit/core": "^6.3.1",
"@dnd-kit/modifiers": "^9.0.0",
"@dnd-kit/sortable": "^10.0.0",
@@ -25,6 +33,7 @@
"antd": "^5.27.4",
"axios": "^1.12.2",
"clsx": "^2.1.1",
"codemirror": "^6.0.2",
"copy-to-clipboard": "^3.3.3",
"crypto-js": "^4.2.0",
"dayjs": "^1.11.18",
@@ -55,6 +64,7 @@
"@tailwindcss/postcss": "^4.1.14",
"@tailwindcss/typography": "^0.5.19",
"@tailwindcss/vite": "^4.1.14",
"@types/codemirror": "^5.60.17",
"@types/crypto-js": "^4.2.2",
"@types/js-yaml": "^4.0.9",
"@types/node": "^24.6.0",

View File

@@ -256,7 +256,7 @@ export const updateMemoryExtractionConfig = (values: ExtractionConfigForm) => {
return request.post('/memory-storage/update_config_extracted', values)
}
// Memory Extraction Engine - Pilot run
export const pilotRunMemoryExtractionConfig = (values: { config_id: number | string; dialogue_text: string; }, onMessage?: (data: SSEMessage[]) => void) => {
export const pilotRunMemoryExtractionConfig = (values: { config_id: number | string; dialogue_text: string; custom_text?: string; }, onMessage?: (data: SSEMessage[]) => void) => {
return handleSSE('/memory-storage/pilot_run', values, onMessage)
}
// Emotion Engine - Get configuration

View File

@@ -8,6 +8,7 @@ import { request } from '@/utils/request'
import type { Query, OntologyModalData, OntologyClassModalData, OntologyClassExtractModalData, OntologyExportModalData } from '@/views/Ontology/types'
// Scene list
export const getOntologyScenesSimpleUrl = '/memory/ontology/scenes/simple'
export const getOntologyScenesUrl = '/memory/ontology/scenes'
export const getOntologyScenesList = (data: Query) => {
return request.get(getOntologyScenesUrl, data)

View File

@@ -0,0 +1,150 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-04 17:20:52
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-04 17:20:52
*/
import { useEffect, useRef, useMemo } from 'react';
import { EditorView, basicSetup } from 'codemirror';
import { EditorState } from '@codemirror/state';
import { python } from '@codemirror/lang-python';
import { javascript } from '@codemirror/lang-javascript';
import { java } from '@codemirror/lang-java';
import { cpp } from '@codemirror/lang-cpp';
import { rust } from '@codemirror/lang-rust';
import { oneDark } from '@codemirror/theme-one-dark';
/**
* Props for the CodeMirrorEditor component
* @property {string} value - The initial code content to display in the editor
* @property {string} language - Programming language for syntax highlighting (python, python3, javascript, typescript, java, cpp, c, rust)
* @property {function} onChange - Callback function triggered when editor content changes, receives the new code value
* @property {string} theme - Editor theme, either 'light' or 'dark'
* @property {boolean} readOnly - Whether the editor is read-only
* @property {string} height - Custom height for the editor
* @property {string} size - Predefined size preset: 'default' (120px min-height, 14px font) or 'small' (60px min-height, 12px font)
*/
interface CodeMirrorEditorProps {
value?: string;
language?: 'python' | 'python3' | 'javascript' | 'typescript' | 'java' | 'cpp' | 'c' | 'rust';
onChange?: (value: string) => void;
theme?: 'light' | 'dark';
readOnly?: boolean;
height?: string;
size?: 'default' | 'small';
}
/**
* Map of language identifiers to their corresponding CodeMirror language extensions
* Supports multiple programming languages with syntax highlighting
*/
const languageExtensions: Record<string, any> = {
python: python(),
python3: python(),
javascript: javascript(),
typescript: javascript({ typescript: true }),
java: java(),
cpp: cpp(),
c: cpp(),
rust: rust(),
};
/**
* CodeMirrorEditor - A React wrapper component for CodeMirror 6 editor
* Provides a code editor with syntax highlighting, theme support, and customizable sizing
* Used in workflow code execution nodes for editing Python and JavaScript code
*/
const CodeMirrorEditor = ({
value = '',
language = 'javascript',
onChange,
theme = 'light',
readOnly = false,
size,
}: CodeMirrorEditorProps) => {
// Reference to the DOM element that will contain the editor
const editorRef = useRef<HTMLDivElement>(null);
// Reference to the CodeMirror EditorView instance
const viewRef = useRef<EditorView | null>(null);
/**
* Initialize CodeMirror editor when component mounts or when language/theme/readOnly changes
* Sets up extensions for syntax highlighting, change listeners, and theme
*/
useEffect(() => {
if (!editorRef.current) return;
// Get the appropriate language extension, fallback to JavaScript if not found
const langExtension = languageExtensions[language] || languageExtensions.javascript;
// Configure editor extensions
const extensions = [
basicSetup, // Basic editor features (line numbers, bracket matching, etc.)
langExtension, // Language-specific syntax highlighting
// Listen for document changes and trigger onChange callback
EditorView.updateListener.of((update) => {
if (update.docChanged && onChange) {
onChange(update.state.doc.toString());
}
}),
EditorState.readOnly.of(readOnly), // Set read-only mode
];
// Apply dark theme if specified
if (theme === 'dark') {
extensions.push(oneDark);
}
// Create editor state with initial value and extensions
const state = EditorState.create({
doc: value,
extensions,
});
// Create and mount the editor view
viewRef.current = new EditorView({
state,
parent: editorRef.current,
});
// Cleanup: destroy editor instance when component unmounts or dependencies change
return () => {
viewRef.current?.destroy();
};
}, [language, theme, readOnly]);
/**
* Update editor content when the value prop changes externally
* Only updates if the new value differs from current editor content
*/
useEffect(() => {
if (viewRef.current && value !== viewRef.current.state.doc.toString()) {
viewRef.current.dispatch({
changes: {
from: 0,
to: viewRef.current.state.doc.length,
insert: value,
},
});
}
}, [value]);
// Calculate minimum height based on size prop: small (60px) or default (120px)
const minHeight = useMemo(() => {
return `${size === 'small' ? 60 : 120}px`
}, [size])
// Calculate font size based on size prop: small (12px) or default (14px)
const fontSize = useMemo(() => {
return `${size === 'small' ? 12 : 14}px`
}, [size])
// Calculate line height based on size prop: small (16px) or default (20px)
const lineHeight = useMemo(() => {
return `${size === 'small' ? 16 : 20}px`
}, [size])
return <div ref={editorRef} style={{ minHeight, fontSize, lineHeight }} />;
};
export default CodeMirrorEditor;

View File

@@ -81,7 +81,7 @@ const components = {
audio: ({ src, ...props }: any) => <AudioBlock node={{ children: [{ properties: { src: src || '' } }] }} {...props} />,
a: ({ href, children, ...props }: any) => <Link href={href || '#'} {...props}>{children}</Link>,
button: ({ children }: any) => <RbButton node={{ children }}>{[children]}</RbButton>,
table: ({ children, ...props }: any) => <table className="rb:border rb:border-[#D9D9D9] rb:mb-2" {...props}>{children}</table>,
table: ({ children, ...props }: any) => <div className="rb:overflow-x-auto rb:max-w-full"><table className="rb:border rb:border-[#D9D9D9] rb:mb-2" {...props}>{children}</table></div>,
tr: ({ children, ...props }: any) => <tr className="rb:border rb:border-[#D9D9D9]" {...props}>{children}</tr>,
th: ({ children, ...props }: any) => <th className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left rb:font-bold" {...props}>{children}</th>,
td: ({ children, ...props }: any) => <td className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left" {...props}>{children}</td>,

View File

@@ -1543,7 +1543,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
text_preprocessing_desc: 'Text split into {{count}} semantic fragments',
knowledge_extraction_desc: 'Knowledge extraction completed, identified {{entities}} entities, {{statements}} statements, {{temporal_ranges_count}} temporal extractions, {{triplets}} triplets',
creating_nodes_edges_desc: 'Entity relationship creation completed, {{num}} relationships in total',
deduplication_desc: 'Deduplication and disambiguation completed, {{count}} unique entities in total'
deduplication_desc: 'Deduplication and disambiguation completed, {{count}} unique entities in total',
custom_text: 'Debug Text',
},
memoryConversation: {
searchPlaceholder: 'Enter user ID...',

View File

@@ -1617,7 +1617,8 @@ export const zh = {
text_preprocessing_desc: '文本切分为{{count}}个语义片段',
knowledge_extraction_desc: '知识抽取完成,共识别{{entities}}个实体,{{statements}}个句子, {{temporal_ranges_count}}个时间提取, {{triplets}}个三元组',
creating_nodes_edges_desc: '实体关系创建完成,共{{num}}条关系',
deduplication_desc: '去重消歧完成,最终{{count}}个唯一实体'
deduplication_desc: '去重消歧完成,最终{{count}}个唯一实体',
custom_text: '调试文本',
},
memoryConversation: {
chatEmpty:'有什么我可以帮您的吗?',

View File

@@ -180,4 +180,9 @@ body {
.x6-node foreignObject > body {
min-height: 100%;
max-height: 100%;
}
.ͼ2 .cm-gutters {
background-color: #FFFFFF;
border: none;
}

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 16:29:21
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-04 20:16:45
* @Last Modified time: 2026-02-06 11:20:14
*/
import { type FC, type ReactNode, useEffect, useRef, useState, forwardRef, useImperativeHandle } from 'react';
import clsx from 'clsx'
@@ -38,8 +38,8 @@ import CustomSelect from '@/components/CustomSelect'
import aiPrompt from '@/assets/images/application/aiPrompt.png'
import AiPromptModal from './components/AiPromptModal'
import ToolList from './components/ToolList/ToolList'
import ChatVariableConfigModal from './components/ChatVariableConfigModal';
import SkillList from './components/Skill'
import ChatVariableConfigModal from './components/ChatVariableConfigModal';
import type { Skill } from '@/views/Skills/types'
/**
@@ -169,7 +169,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
const { skills } = response
let allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : []
let allTools = Array.isArray(response.tools) ? response.tools : []
const memoryContent = response.memory?.memory_content
const memoryContent = response.memory?.memory_config_id
const parsedMemoryContent = memoryContent === null || memoryContent === ''
? undefined
: !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent
@@ -178,7 +178,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
tools: allTools,
memory: {
...response.memory,
memory_content: parsedMemoryContent
memory_config_id: parsedMemoryContent
},
skills: {
...skills,
@@ -262,7 +262,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
if (!isSave || !data) return Promise.resolve()
const { memory, knowledge_retrieval, tools, skills, ...rest } = values
const { knowledge_bases = [], ...knowledgeRest } = knowledge_retrieval || {}
const { memory_content } = memory || {}
const { memory_config_id } = memory || {}
// Get other necessary properties of memory from original data
const originalMemory = data.memory || ({} as MemoryConfig)
@@ -272,7 +272,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
memory: {
...originalMemory,
...memory,
memory_content: memory_content ? String(memory_content) : '',
memory_config_id: memory_config_id ? String(memory_config_id) : '',
},
knowledge_retrieval: knowledge_bases.length > 0 ? {
...data.knowledge_retrieval,
@@ -444,7 +444,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
<SelectWrapper
title="selectMemoryContent"
desc="selectMemoryContentDesc"
name={['memory', 'memory_content']}
name={['memory', 'memory_config_id']}
url={memoryConfigListUrl}
/>
</Space>

View File

@@ -140,7 +140,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
title={t('application.knowledgeBaseAssociation')}
extra={
<Space>
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleKnowledgeConfig}>{t('workflow.config.knowledge-retrieval.recallConfig')}</Button>
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleKnowledgeConfig}>{t('application.globalConfig')}</Button>
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleAddKnowledge}>+</Button>
</Space>
}

View File

@@ -39,7 +39,7 @@ const processObj = [
* @param value - Current skill configuration values
* @param onChange - Callback function when configuration changes
*/
const Skill: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) => void}> = () => {
const SkillList: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) => void}> = () => {
const { t } = useTranslation()
const form = Form.useFormInstance()
const skillConfig = Form.useWatch(['skills'], form)
@@ -148,4 +148,4 @@ const Skill: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) =
</Card>
)
}
export default Skill
export default SkillList

View File

@@ -43,7 +43,7 @@ export interface MemoryConfig {
/** Whether memory is enabled */
enabled: boolean;
/** Memory content */
memory_content?: string;
memory_config_id?: string;
/** Maximum history length */
max_history?: number | string;
}

View File

@@ -13,7 +13,7 @@
import { type FC, useState } from 'react'
import { useParams } from 'react-router-dom'
import { useTranslation } from 'react-i18next'
import { Space, Button, Progress } from 'antd'
import { Space, Button, Progress, Form, Input } from 'antd'
import { ExclamationCircleFilled, CheckCircleFilled, ClockCircleOutlined, LoadingOutlined } from '@ant-design/icons'
import clsx from 'clsx'
import type { AnyObject } from 'antd/es/_util/type';
@@ -79,6 +79,8 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
const [creatingNodesEdges, setCreatingNodesEdges] = useState<ModuleItem>(initObj as ModuleItem)
const [deduplication, setDeduplication] = useState<ModuleItem>(initObj as ModuleItem)
const [runForm] = Form.useForm()
/** Run pilot test */
const handleRun = () => {
if(!id) return
@@ -187,6 +189,7 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
pilotRunMemoryExtractionConfig({
config_id: id,
dialogue_text: t('memoryExtractionEngine.exampleText'),
custom_text: runForm.getFieldValue('custom_text')
}, handleStreamMessage)
.finally(() => {
setRunLoading(false)
@@ -222,6 +225,14 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
headerClassName="rb:pb-0! rb:pt-4!"
bodyClassName="rb:min-h-[calc(100vh-388px)] rb:p-[16px_20px]!"
>
<Form form={runForm} layout="vertical">
<Form.Item
name="custom_text"
label={t('memoryExtractionEngine.custom_text')}
>
<Input.TextArea placeholder={t('common.pleaseEnter')} />
</Form.Item>
</Form>
<div className="rb:min-h-[calc(100vh-480px)] rb:overflow-y-auto">
{runLoading
? <>

View File

@@ -16,7 +16,7 @@ import { useTranslation } from 'react-i18next';
import type { MemoryFormData, Memory, MemoryFormRef } from '../types';
import RbModal from '@/components/RbModal'
import { createMemoryConfig, updateMemoryConfig } from '@/api/memory'
import { getOntologyScenesUrl } from '@/api/ontology'
import { getOntologyScenesSimpleUrl } from '@/api/ontology'
import CustomSelect from '@/components/CustomSelect';
const FormItem = Form.Item;
@@ -129,8 +129,7 @@ const MemoryForm = forwardRef<MemoryFormRef, MemoryFormProps>(({
>
<CustomSelect
placeholder={t('common.pleaseSelect')}
url={getOntologyScenesUrl}
params={{ pagesize: 100, page: 1 }}
url={getOntologyScenesSimpleUrl}
hasAll={false}
valueKey='scene_id'
labelKey="scene_name"

View File

@@ -112,7 +112,7 @@ const MemoryManagement: React.FC = () => {
title={item.config_name}
>
<Tooltip title={item.config_desc}>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1">{item.config_desc}</div>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-[17px]">{item.config_desc}</div>
</Tooltip>
<RbAlert className="rb:mt-3 ">
<div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}>

View File

@@ -103,9 +103,9 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
{model.api_keys && model.api_keys.length > 0 && (
<div className="rb:mb-4">
{model.api_keys.map((key) => (
<div key={key.id} className="rb:flex rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2">
<div>
<div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium">{key.api_key}</div>
<div key={key.id} className="rb:flex rb:gap-3 rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2">
<div className="rb:flex-1">
<div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium rb:break-all">{key.api_key}</div>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:mt-1">{key.api_base}</div>
</div>
<Button type="primary" danger ghost onClick={() => handleDelete(key.id)}>{t('common.remove')}</Button>

View File

@@ -15,8 +15,6 @@ import CharacterCountPlugin from './plugin/CharacterCountPlugin'
import InitialValuePlugin from './plugin/InitialValuePlugin';
import CommandPlugin from './plugin/CommandPlugin';
import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin';
import Python3HighlightPlugin from './plugin/Python3HighlightPlugin';
import JavaScriptHighlightPlugin from './plugin/JavaScriptHighlightPlugin';
import LineNumberPlugin from './plugin/LineNumberPlugin';
import BlurPlugin from './plugin/BlurPlugin';
import { VariableNode } from './nodes/VariableNode'
@@ -32,7 +30,7 @@ export interface LexicalEditorProps {
lineHeight?: number;
size?: 'default' | 'small';
type?: 'input' | 'textarea',
language?: 'string' | 'jinja2' | 'python3' | 'javascript'
language?: 'string' | 'jinja2'
}
const theme = {
@@ -67,7 +65,7 @@ const Editor: FC<LexicalEditorProps> =({
const [enableLineNumbers, setEnableLineNumbers] = useState(false)
useEffect(() => {
const needsLineNumbers = language === 'jinja2' || language === 'python3' || language === 'javascript';
const needsLineNumbers = language === 'jinja2';
setEnableJinja2(language === 'jinja2');
setEnableLineNumbers(needsLineNumbers);
@@ -237,13 +235,11 @@ const Editor: FC<LexicalEditorProps> =({
<HistoryPlugin />
<CommandPlugin />
{language === 'jinja2' && <Jinja2HighlightPlugin />}
{language === 'python3' && <Python3HighlightPlugin />}
{language === 'javascript' && <JavaScriptHighlightPlugin />}
{enableLineNumbers && <LineNumberPlugin />}
<AutocompletePlugin options={options} enableJinja2={enableJinja2} />
<CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} />
<InitialValuePlugin key={language} value={value} options={options} enableLineNumbers={enableLineNumbers} />
{enableLineNumbers && <BlurPlugin />}
<InitialValuePlugin value={value} options={options} enableLineNumbers={enableLineNumbers} />
{enableJinja2 && <BlurPlugin />}
</div>
</LexicalComposer>
);

View File

@@ -1,182 +0,0 @@
import { useEffect, useRef } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
const JS_KEYWORDS = new Set([
'async', 'await', 'break', 'case', 'catch', 'class', 'const', 'continue', 'debugger', 'default',
'delete', 'do', 'else', 'export', 'extends', 'finally', 'for', 'function', 'if', 'import',
'in', 'instanceof', 'let', 'new', 'return', 'super', 'switch', 'this', 'throw', 'try',
'typeof', 'var', 'void', 'while', 'with', 'yield', 'true', 'false', 'null', 'undefined'
]);
const JavaScriptHighlightPlugin = () => {
const [editor] = useLexicalComposerContext();
const isPastingRef = useRef(false);
useEffect(() => {
return editor.registerCommand(
PASTE_COMMAND,
() => {
isPastingRef.current = true;
setTimeout(() => {
isPastingRef.current = false;
}, 100);
return false;
},
COMMAND_PRIORITY_LOW
);
}, [editor]);
useEffect(() => {
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
if (isPastingRef.current) return;
const text = textNode.getTextContent();
if (textNode.hasFormat('code')) return;
if (!needsHighlight(text)) return;
if (textNode.getStyle()) return;
const parent = textNode.getParent();
if (!parent) return;
const selection = $getSelection();
let selectionOffset = null;
if ($isRangeSelection(selection)) {
const anchor = selection.anchor;
if (anchor.getNode() === textNode) {
selectionOffset = anchor.offset;
}
}
const tokens = tokenizeJavaScript(text);
if (tokens.length <= 1) return;
const newNodes = tokens.map(token => {
const newNode = $createTextNode(token.text);
newNode.toggleFormat('code');
switch (token.type) {
case 'keyword':
newNode.setStyle('color: #d73a49; font-weight: 600;');
break;
case 'string':
newNode.setStyle('color: #032f62;');
break;
case 'comment':
newNode.setStyle('color: #6a737d; font-style: italic;');
break;
case 'number':
newNode.setStyle('color: #005cc5; font-weight: 500;');
break;
case 'function':
newNode.setStyle('color: #6f42c1; font-weight: 500;');
break;
}
return newNode;
});
if (newNodes.length > 1) {
textNode.replace(newNodes[0]);
for (let i = 1; i < newNodes.length; i++) {
newNodes[i - 1].insertAfter(newNodes[i]);
}
if (selectionOffset !== null && $isRangeSelection(selection)) {
let currentOffset = 0;
for (const node of newNodes) {
const nodeLength = node.getTextContent().length;
if (currentOffset + nodeLength >= selectionOffset) {
node.select(selectionOffset - currentOffset, selectionOffset - currentOffset);
break;
}
currentOffset += nodeLength;
}
}
}
});
}, [editor]);
return null;
};
function needsHighlight(text: string): boolean {
return /[a-zA-Z0-9_/"'`]/.test(text);
}
function tokenizeJavaScript(text: string): Array<{text: string, type: string}> {
const tokens: Array<{text: string, type: string}> = [];
let i = 0;
while (i < text.length) {
// Single-line comments
if (text.slice(i, i + 2) === '//') {
let start = i;
while (i < text.length && text[i] !== '\n') i++;
tokens.push({ text: text.slice(start, i), type: 'comment' });
continue;
}
// Multi-line comments
if (text.slice(i, i + 2) === '/*') {
let start = i;
i += 2;
while (i < text.length && text.slice(i, i + 2) !== '*/') i++;
if (i < text.length) i += 2;
tokens.push({ text: text.slice(start, i), type: 'comment' });
continue;
}
// Strings
if (text[i] === '"' || text[i] === "'" || text[i] === '`') {
const quote = text[i];
let start = i++;
while (i < text.length) {
if (text[i] === quote && text[i - 1] !== '\\') {
i++;
break;
}
i++;
}
tokens.push({ text: text.slice(start, i), type: 'string' });
continue;
}
// Numbers
if (/\d/.test(text[i])) {
let start = i;
while (i < text.length && /[\d.]/.test(text[i])) i++;
tokens.push({ text: text.slice(start, i), type: 'number' });
continue;
}
// Keywords and identifiers
if (/[a-zA-Z_$]/.test(text[i])) {
let start = i;
while (i < text.length && /[a-zA-Z0-9_$]/.test(text[i])) i++;
const word = text.slice(start, i);
if (JS_KEYWORDS.has(word)) {
tokens.push({ text: word, type: 'keyword' });
} else if (i < text.length && text[i] === '(') {
tokens.push({ text: word, type: 'function' });
} else {
tokens.push({ text: word, type: 'text' });
}
continue;
}
// Other characters
let start = i;
while (i < text.length && !/[a-zA-Z0-9_$/"'`]/.test(text[i])) i++;
if (start < i) {
tokens.push({ text: text.slice(start, i), type: 'text' });
}
}
return tokens;
}
export default JavaScriptHighlightPlugin;

View File

@@ -1,177 +0,0 @@
import { useEffect, useRef } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
const PYTHON_KEYWORDS = new Set([
'False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue',
'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import',
'in', 'is', 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while',
'with', 'yield'
]);
const Python3HighlightPlugin = () => {
const [editor] = useLexicalComposerContext();
const isPastingRef = useRef(false);
useEffect(() => {
return editor.registerCommand(
PASTE_COMMAND,
() => {
isPastingRef.current = true;
setTimeout(() => {
isPastingRef.current = false;
}, 100);
return false;
},
COMMAND_PRIORITY_LOW
);
}, [editor]);
useEffect(() => {
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
if (isPastingRef.current) return;
const text = textNode.getTextContent();
if (textNode.hasFormat('code')) return;
if (textNode.getStyle()) return;
if (!needsHighlight(text)) return;
const parent = textNode.getParent();
if (!parent) return;
const selection = $getSelection();
let selectionOffset = null;
if ($isRangeSelection(selection)) {
const anchor = selection.anchor;
if (anchor.getNode() === textNode) {
selectionOffset = anchor.offset;
}
}
const tokens = tokenizePython(text);
if (tokens.length <= 1) return;
const newNodes = tokens.map(token => {
const newNode = $createTextNode(token.text);
newNode.toggleFormat('code');
switch (token.type) {
case 'keyword':
newNode.setStyle('color: #d73a49; font-weight: 600;');
break;
case 'string':
newNode.setStyle('color: #032f62;');
break;
case 'comment':
newNode.setStyle('color: #6a737d; font-style: italic;');
break;
case 'number':
newNode.setStyle('color: #005cc5; font-weight: 500;');
break;
case 'function':
newNode.setStyle('color: #6f42c1; font-weight: 500;');
break;
}
return newNode;
});
if (newNodes.length > 1) {
textNode.replace(newNodes[0]);
for (let i = 1; i < newNodes.length; i++) {
newNodes[i - 1].insertAfter(newNodes[i]);
}
if (selectionOffset !== null && $isRangeSelection(selection)) {
let currentOffset = 0;
for (const node of newNodes) {
const nodeLength = node.getTextContent().length;
if (currentOffset + nodeLength >= selectionOffset) {
node.select(selectionOffset - currentOffset, selectionOffset - currentOffset);
break;
}
currentOffset += nodeLength;
}
}
}
});
}, [editor]);
return null;
};
function needsHighlight(text: string): boolean {
return /[a-zA-Z0-9_#"']/.test(text);
}
function tokenizePython(text: string): Array<{text: string, type: string}> {
const tokens: Array<{text: string, type: string}> = [];
let i = 0;
while (i < text.length) {
// Comments
if (text[i] === '#') {
let start = i;
while (i < text.length && text[i] !== '\n') i++;
tokens.push({ text: text.slice(start, i), type: 'comment' });
continue;
}
// Strings
if (text[i] === '"' || text[i] === "'") {
const quote = text[i];
let start = i++;
const isTriple = text.slice(start, start + 3) === quote.repeat(3);
if (isTriple) i += 2;
while (i < text.length) {
if (isTriple && text.slice(i, i + 3) === quote.repeat(3)) {
i += 3;
break;
} else if (!isTriple && text[i] === quote && text[i - 1] !== '\\') {
i++;
break;
}
i++;
}
tokens.push({ text: text.slice(start, i), type: 'string' });
continue;
}
// Numbers
if (/\d/.test(text[i])) {
let start = i;
while (i < text.length && /[\d.]/.test(text[i])) i++;
tokens.push({ text: text.slice(start, i), type: 'number' });
continue;
}
// Keywords and identifiers
if (/[a-zA-Z_]/.test(text[i])) {
let start = i;
while (i < text.length && /[a-zA-Z0-9_]/.test(text[i])) i++;
const word = text.slice(start, i);
if (PYTHON_KEYWORDS.has(word)) {
tokens.push({ text: word, type: 'keyword' });
} else if (i < text.length && text[i] === '(') {
tokens.push({ text: word, type: 'function' });
} else {
tokens.push({ text: word, type: 'text' });
}
continue;
}
// Other characters
let start = i;
while (i < text.length && !/[a-zA-Z0-9_#"']/.test(text[i])) i++;
if (start < i) {
tokens.push({ text: text.slice(start, i), type: 'text' });
}
}
return tokens;
}
export default Python3HighlightPlugin;

View File

@@ -5,8 +5,8 @@ import { Node } from '@antv/x6'
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
import MappingList from '../MappingList'
import Editor from '../../Editor'
import OutputList from './OutputList'
import CodeMirrorEditor from '@/components/CodeMirrorEditor';
interface MappingItem {
name?: string
@@ -110,7 +110,10 @@ const CodeExecution: FC<CodeExecutionProps> = ({ options }) => {
<Form.Item noStyle shouldUpdate={(prev, curr) => prev.language !== curr.language}>
{() => (
<Form.Item name="code" noStyle>
<Editor size="small" language={form.getFieldValue('language')} />
<CodeMirrorEditor
language={form.getFieldValue('language')}
size="small"
/>
</Form.Item>
)}
</Form.Item>

View File

@@ -126,7 +126,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
<div
className="rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/workflow/recall.svg')] rb:group-hover:bg-[url('@/assets/images/workflow/recall_hover.svg')]"
></div>
{t('workflow.config.knowledge-retrieval.recallConfig')}
{t('application.globalConfig')}
</Button>
</div>