Merge branch 'develop' into fix/memory-enduser-config
This commit is contained in:
@@ -76,6 +76,7 @@ celery_app.conf.update(
|
|||||||
# Document tasks → document_tasks queue (prefork worker)
|
# Document tasks → document_tasks queue (prefork worker)
|
||||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||||
|
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||||
|
|
||||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||||
|
|||||||
@@ -9,13 +9,16 @@ from sqlalchemy import or_
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.rag.common import settings
|
from app.core.rag.common import settings
|
||||||
|
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||||
|
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||||
from app.core.rag.llm.chat_model import Base
|
from app.core.rag.llm.chat_model import Base
|
||||||
from app.core.rag.nlp import rag_tokenizer, search
|
from app.core.rag.nlp import rag_tokenizer, search
|
||||||
from app.core.rag.prompts.generator import graph_entity_types
|
from app.core.rag.prompts.generator import graph_entity_types
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success, fail
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models import knowledge_model
|
from app.models import knowledge_model
|
||||||
@@ -484,3 +487,99 @@ async def rebuild_knowledge_graph(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
|
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/check/yuque/auth", response_model=ApiResponse)
|
||||||
|
async def check_yuque_auth(
|
||||||
|
yuque_user_id: str,
|
||||||
|
yuque_token: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
check yuque auth info
|
||||||
|
"""
|
||||||
|
api_logger.info(f"check yuque auth info, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_client = YuqueAPIClient(
|
||||||
|
user_id=yuque_user_id,
|
||||||
|
token=yuque_token
|
||||||
|
)
|
||||||
|
async with api_client as client:
|
||||||
|
repos = await client.get_user_repos()
|
||||||
|
if repos:
|
||||||
|
return success(data=repos, msg="Successfully auth yuque info")
|
||||||
|
return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"auth yuque info failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/check/feishu/auth", response_model=ApiResponse)
|
||||||
|
async def check_yuque_auth(
|
||||||
|
feishu_app_id: str,
|
||||||
|
feishu_app_secret: str,
|
||||||
|
feishu_folder_token: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
check feishu auth info
|
||||||
|
"""
|
||||||
|
api_logger.info(f"check feishu auth info, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_client = FeishuAPIClient(
|
||||||
|
app_id=feishu_app_id,
|
||||||
|
app_secret=feishu_app_secret
|
||||||
|
)
|
||||||
|
async with api_client as client:
|
||||||
|
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
|
||||||
|
if files:
|
||||||
|
return success(data=files, msg="Successfully auth feishu info")
|
||||||
|
return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"auth feishu info failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
|
||||||
|
async def sync_knowledge(
|
||||||
|
knowledge_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
sync knowledge base information based on knowledge_id
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Query knowledge base information from the database
|
||||||
|
api_logger.debug(f"Query knowledge base: {knowledge_id}")
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. sync knowledge
|
||||||
|
# from app.tasks import sync_knowledge_for_kb
|
||||||
|
# sync_knowledge_for_kb(kb_id)
|
||||||
|
task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id])
|
||||||
|
result = {
|
||||||
|
"task_id": task.id
|
||||||
|
}
|
||||||
|
return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|||||||
@@ -191,6 +191,11 @@ def update_config(
|
|||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
# 校验至少有一个字段需要更新
|
||||||
|
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||||
try:
|
try:
|
||||||
svc = DataConfigService(db)
|
svc = DataConfigService(db)
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from app.services.ontology_service import OntologyService
|
|||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
from app.core.memory.utils.validation.owl_validator import OWLValidator
|
from app.core.memory.utils.validation.owl_validator import OWLValidator
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||||
|
|
||||||
|
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -116,27 +117,35 @@ def _get_ontology_service(
|
|||||||
detail=f"找不到指定的LLM模型: {llm_id}"
|
detail=f"找不到指定的LLM模型: {llm_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 验证模型配置了API密钥
|
# 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理)
|
||||||
if not model_config.api_keys:
|
from app.repositories.model_repository import ModelApiKeyRepository
|
||||||
logger.error(f"Model {llm_id} has no API key configuration")
|
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
|
||||||
|
if not api_keys:
|
||||||
|
logger.error(f"Model {llm_id} has no active API key")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail="指定的LLM模型没有配置API密钥"
|
detail="指定的LLM模型没有可用的API密钥"
|
||||||
)
|
)
|
||||||
|
api_key_config = api_keys[0]
|
||||||
|
|
||||||
api_key_config = model_config.api_keys[0]
|
is_composite = getattr(model_config, 'is_composite', False)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using specified model - user: {current_user.id}, "
|
f"Using specified model - user: {current_user.id}, "
|
||||||
f"model_id: {llm_id}, model_name: {api_key_config.model_name}"
|
f"model_id: {llm_id}, model_name: {api_key_config.model_name}, "
|
||||||
|
f"is_composite: {is_composite}, api_key_id: {api_key_config.id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建模型配置对象
|
# 创建模型配置对象
|
||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
|
|
||||||
|
# 对于组合模型,使用 API Key 的 provider;否则使用 model_config 的 provider
|
||||||
|
actual_provider = api_key_config.provider if is_composite else (
|
||||||
|
getattr(model_config, 'provider', None) or "openai"
|
||||||
|
)
|
||||||
|
|
||||||
llm_model_config = RedBearModelConfig(
|
llm_model_config = RedBearModelConfig(
|
||||||
model_name=api_key_config.model_name,
|
model_name=api_key_config.model_name,
|
||||||
provider=model_config.provider if hasattr(model_config, 'provider') else "openai",
|
provider=actual_provider,
|
||||||
api_key=api_key_config.api_key,
|
api_key=api_key_config.api_key,
|
||||||
base_url=api_key_config.api_base,
|
base_url=api_key_config.api_base,
|
||||||
max_retries=3,
|
max_retries=3,
|
||||||
@@ -648,6 +657,46 @@ async def delete_scene(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/scenes/simple", response_model=ApiResponse)
|
||||||
|
async def get_scenes_simple(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取场景简单列表(轻量级,用于下拉选择)
|
||||||
|
|
||||||
|
仅返回 scene_id 和 scene_name,不加载关联数据,响应速度快。
|
||||||
|
适用于前端下拉选择场景的场景。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含场景简单列表
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
GET /scenes/simple
|
||||||
|
返回: {"data": [{"scene_id": "xxx", "scene_name": "场景1"}, ...]}
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Simple scene list requested by user {current_user.id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
if not workspace_id:
|
||||||
|
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||||
|
|
||||||
|
repo = OntologySceneRepository(db)
|
||||||
|
scenes = repo.get_simple_list(workspace_id)
|
||||||
|
|
||||||
|
api_logger.info(f"Simple scene list retrieved: {len(scenes)} scenes")
|
||||||
|
return success(data=scenes, msg="查询成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to get simple scene list: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/scenes", response_model=ApiResponse)
|
@router.get("/scenes", response_model=ApiResponse)
|
||||||
async def get_scenes(
|
async def get_scenes(
|
||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None,
|
||||||
|
|||||||
@@ -7,30 +7,21 @@ LangChain Agent 封装
|
|||||||
- 支持流式输出
|
- 支持流式输出
|
||||||
- 使用 RedBearLLM 支持多提供商
|
- 使用 RedBearLLM 支持多提供商
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse
|
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.memory.agent.utils.redis_tool import store
|
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
from app.models.models_model import ModelType
|
from app.models.models_model import ModelType
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
|
||||||
from app.services.memory_agent_service import (
|
from app.services.memory_agent_service import (
|
||||||
get_end_user_connected_config,
|
get_end_user_connected_config,
|
||||||
)
|
)
|
||||||
from app.services.memory_konwledges_server import write_rag
|
|
||||||
from app.services.task_service import get_task_memory_write_result
|
|
||||||
from app.tasks import write_message_task
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -289,105 +280,6 @@ class LangChainAgent:
|
|||||||
|
|
||||||
return content_parts
|
return content_parts
|
||||||
|
|
||||||
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
|
||||||
db = next(get_db())
|
|
||||||
#TODO: 魔法数字
|
|
||||||
scope=6
|
|
||||||
|
|
||||||
try:
|
|
||||||
repo = LongTermMemoryRepository(db)
|
|
||||||
await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages,
|
|
||||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
|
|
||||||
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
result = write_store.get_session_by_userid(end_user_id)
|
|
||||||
|
|
||||||
# Handle case where no session exists in Redis (returns False)
|
|
||||||
if not result or result is False:
|
|
||||||
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
|
|
||||||
return
|
|
||||||
|
|
||||||
if type=="chunk" or type=="aggregate":
|
|
||||||
data = await format_parsing(result, "dict")
|
|
||||||
chunk_data = data[:scope]
|
|
||||||
if len(chunk_data)==scope:
|
|
||||||
repo.upsert(end_user_id, chunk_data)
|
|
||||||
logger.info(f'写入短长期:')
|
|
||||||
else:
|
|
||||||
# TODO: This branch handles type="time" strategy, currently unused.
|
|
||||||
# Will be activated when time-based long-term storage is implemented.
|
|
||||||
# TODO: 魔法数字 - extract 5 to a constant
|
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
|
||||||
# Handle case where no session exists in Redis (returns False or empty)
|
|
||||||
if not long_time_data or long_time_data is False:
|
|
||||||
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
|
|
||||||
return
|
|
||||||
long_messages = await messages_parse(long_time_data)
|
|
||||||
repo.upsert(end_user_id, long_messages)
|
|
||||||
logger.info(f'写入短长期:')
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
|
||||||
"""
|
|
||||||
写入记忆(支持结构化消息)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_type: 存储类型 (neo4j/rag)
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
user_message: 用户消息内容
|
|
||||||
ai_message: AI 回复内容
|
|
||||||
user_rag_memory_id: RAG 记忆ID
|
|
||||||
actual_end_user_id: 实际用户ID
|
|
||||||
actual_config_id: 配置ID
|
|
||||||
|
|
||||||
逻辑说明:
|
|
||||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
|
||||||
- Neo4j 模式:使用结构化消息列表
|
|
||||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
|
||||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
|
||||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
|
||||||
"""
|
|
||||||
|
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
actual_config_id=resolve_config_id(actual_config_id, db)
|
|
||||||
|
|
||||||
if storage_type == "rag":
|
|
||||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
|
||||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
|
||||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
|
||||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
|
||||||
else:
|
|
||||||
# Neo4j 模式:使用结构化消息列表
|
|
||||||
structured_messages = []
|
|
||||||
|
|
||||||
# 始终添加用户消息(如果不为空)
|
|
||||||
if user_message:
|
|
||||||
structured_messages.append({"role": "user", "content": user_message})
|
|
||||||
|
|
||||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
|
||||||
if ai_message:
|
|
||||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
|
||||||
|
|
||||||
# 如果没有消息,直接返回
|
|
||||||
if not structured_messages:
|
|
||||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
|
||||||
write_id = write_message_task.delay(
|
|
||||||
actual_end_user_id, # end_user_id: 用户ID
|
|
||||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
|
||||||
actual_config_id, # config_id: 配置ID
|
|
||||||
storage_type, # storage_type: "neo4j"
|
|
||||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
|
||||||
)
|
|
||||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
|
||||||
write_status = get_task_memory_write_result(str(write_id))
|
|
||||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
@@ -520,14 +412,7 @@ class LangChainAgent:
|
|||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
long_term_messages=await agent_chat_messages(message_chat,content)
|
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
|
||||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
|
||||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
|
||||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
|
||||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
|
||||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
|
||||||
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
|
|
||||||
response = {
|
response = {
|
||||||
"content": content,
|
"content": content,
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
@@ -710,15 +595,7 @@ class LangChainAgent:
|
|||||||
yield total_tokens
|
yield total_tokens
|
||||||
break
|
break
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
|
||||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
|
||||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
|
||||||
long_term_messages = await agent_chat_messages(message_chat, full_content)
|
|
||||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
|
||||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
|
||||||
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||||
|
|
||||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||||
@@ -10,46 +11,108 @@ from app.core.memory.agent.utils.redis_tool import write_store
|
|||||||
from app.core.memory.agent.utils.redis_tool import count_store
|
from app.core.memory.agent.utils.redis_tool import count_store
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context, get_db
|
||||||
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
|
from app.services.memory_konwledges_server import write_rag
|
||||||
|
from app.services.task_service import get_task_memory_write_result
|
||||||
|
from app.tasks import write_message_task
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
|
|
||||||
|
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||||
|
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||||
|
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||||
|
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||||
|
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||||
|
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||||
|
actual_config_id, long_term_messages=[]):
|
||||||
|
"""
|
||||||
|
写入记忆(支持结构化消息)
|
||||||
|
|
||||||
async def write_messages(end_user_id,langchain_messages,memory_config):
|
Args:
|
||||||
'''
|
storage_type: 存储类型 (neo4j/rag)
|
||||||
写入数据到neo4j:
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
end_user_id: 终端用户ID
|
||||||
memory_config: 内存配置对象
|
user_message: 用户消息内容
|
||||||
langchain_messages:原始数据LIST
|
ai_message: AI 回复内容
|
||||||
'''
|
user_rag_memory_id: RAG 记忆ID
|
||||||
|
actual_end_user_id: 实际用户ID
|
||||||
|
actual_config_id: 配置ID
|
||||||
|
|
||||||
|
逻辑说明:
|
||||||
|
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||||
|
- Neo4j 模式:使用结构化消息列表
|
||||||
|
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||||
|
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||||
|
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||||
|
"""
|
||||||
|
|
||||||
|
db = next(get_db())
|
||||||
try:
|
try:
|
||||||
|
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||||
|
# Neo4j 模式:使用结构化消息列表
|
||||||
|
structured_messages = []
|
||||||
|
|
||||||
|
# 始终添加用户消息(如果不为空)
|
||||||
|
if isinstance(user_message, str) and user_message.strip() != "":
|
||||||
|
structured_messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
|
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||||
|
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||||
|
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||||
|
|
||||||
|
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||||
|
if long_term_messages and isinstance(long_term_messages, list):
|
||||||
|
structured_messages = long_term_messages
|
||||||
|
elif long_term_messages and isinstance(long_term_messages, str):
|
||||||
|
# 如果是 JSON 字符串,先解析
|
||||||
|
try:
|
||||||
|
structured_messages = json.loads(long_term_messages)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||||
|
|
||||||
|
# 如果没有消息,直接返回
|
||||||
|
if not structured_messages:
|
||||||
|
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
|
write_id = write_message_task.delay(
|
||||||
|
actual_end_user_id, # end_user_id: 用户ID
|
||||||
|
structured_messages, # message: JSON 字符串格式的消息列表
|
||||||
|
str(actual_config_id), # config_id: 配置ID字符串
|
||||||
|
storage_type, # storage_type: "neo4j"
|
||||||
|
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||||
|
)
|
||||||
|
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||||
|
with get_db_context() as db_session:
|
||||||
|
repo = LongTermMemoryRepository(db_session)
|
||||||
|
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
result = write_store.get_session_by_userid(end_user_id)
|
||||||
|
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||||
|
data = await format_parsing(result, "dict")
|
||||||
|
chunk_data = data[:scope]
|
||||||
|
if len(chunk_data)==scope:
|
||||||
|
repo.upsert(end_user_id, chunk_data)
|
||||||
|
logger.info(f'---------写入短长期-----------')
|
||||||
|
else:
|
||||||
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||||
|
long_messages = await messages_parse(long_time_data)
|
||||||
|
repo.upsert(end_user_id, long_messages)
|
||||||
|
logger.info(f'写入短长期:')
|
||||||
|
|
||||||
|
|
||||||
async with make_write_graph() as graph:
|
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
|
||||||
# 初始状态 - 包含所有必要字段
|
|
||||||
initial_state = {
|
|
||||||
"messages": langchain_messages,
|
|
||||||
"end_user_id": end_user_id,
|
|
||||||
"memory_config": memory_config
|
|
||||||
}
|
|
||||||
|
|
||||||
# 获取节点更新信息
|
|
||||||
async for update_event in graph.astream(
|
|
||||||
initial_state,
|
|
||||||
stream_mode="updates",
|
|
||||||
config=config
|
|
||||||
):
|
|
||||||
for node_name, node_data in update_event.items():
|
|
||||||
if 'save_neo4j' == node_name:
|
|
||||||
massages = node_data
|
|
||||||
# TODO:删除
|
|
||||||
massagesstatus = massages.get('write_result')['status']
|
|
||||||
contents = massages.get('write_result')
|
|
||||||
print(contents)
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
'''根据窗口'''
|
'''根据窗口'''
|
||||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||||
'''
|
'''
|
||||||
@@ -61,25 +124,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
|||||||
scope:窗口大小
|
scope:窗口大小
|
||||||
'''
|
'''
|
||||||
scope=scope
|
scope=scope
|
||||||
redis_messages = []
|
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||||
if is_end_user_id is not False:
|
if is_end_user_id is not False:
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||||
print(is_end_user_id)
|
|
||||||
is_end_user_id += 1
|
is_end_user_id += 1
|
||||||
langchain_messages += redis_messages
|
langchain_messages += redis_messages
|
||||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||||
elif int(is_end_user_id) == int(scope):
|
elif int(is_end_user_id) == int(scope):
|
||||||
print('写入长期记忆,并且设置为0')
|
logger.info('写入长期记忆NEO4J')
|
||||||
print(is_end_user_id)
|
formatted_messages = (redis_messages)
|
||||||
formatted_messages = await chat_data_format(redis_messages)
|
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||||
print(100*'-')
|
if hasattr(memory_config, 'config_id'):
|
||||||
print(formatted_messages)
|
config_id = memory_config.config_id
|
||||||
print(100*'-')
|
else:
|
||||||
await write_messages(end_user_id, formatted_messages, memory_config)
|
config_id = memory_config
|
||||||
count_store.update_sessions_count(end_user_id, 0, '')
|
|
||||||
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
|
config_id, formatted_messages)
|
||||||
|
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
else:
|
else:
|
||||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
|
|
||||||
@@ -93,12 +157,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
|
|||||||
memory_config: 内存配置对象
|
memory_config: 内存配置对象
|
||||||
'''
|
'''
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||||
# Handle case where no session exists in Redis (returns False or empty)
|
format_messages = (long_time_data)
|
||||||
if not long_time_data or long_time_data is False:
|
messages=[]
|
||||||
return
|
memory_config=memory_config.config_id
|
||||||
format_messages = await chat_data_format(long_time_data)
|
for i in format_messages:
|
||||||
|
message=json.loads(i['Query'])
|
||||||
|
messages+= message
|
||||||
if format_messages!=[]:
|
if format_messages!=[]:
|
||||||
await write_messages(end_user_id, format_messages, memory_config)
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
|
memory_config, messages)
|
||||||
'''聚合判断'''
|
'''聚合判断'''
|
||||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -109,13 +176,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
|||||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||||
memory_config: 内存配置对象
|
memory_config: 内存配置对象
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 获取历史会话数据(使用新方法)
|
# 1. 获取历史会话数据(使用新方法)
|
||||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||||
|
history = await format_parsing(result)
|
||||||
# Handle case where no session exists in Redis (returns False or empty)
|
if not result:
|
||||||
if not result or result is False:
|
|
||||||
history = []
|
history = []
|
||||||
else:
|
else:
|
||||||
history = await format_parsing(result)
|
history = await format_parsing(result)
|
||||||
@@ -154,7 +220,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
|||||||
}
|
}
|
||||||
if not structured.is_same_event:
|
if not structured.is_same_event:
|
||||||
logger.info(result_dict)
|
logger.info(result_dict)
|
||||||
await write_messages(end_user_id, output_value, memory_config)
|
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||||
|
memory_config.config_id, output_value)
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
清理后的数据
|
清理后的数据
|
||||||
"""
|
"""
|
||||||
# 需要过滤的字段列表
|
# 需要过滤的字段列表
|
||||||
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
fields_to_remove = {
|
fields_to_remove = {
|
||||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
|
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
|
||||||
|
|
||||||
async def format_parsing(messages: list,type:str='string'):
|
async def format_parsing(messages: list,type:str='string'):
|
||||||
"""
|
"""
|
||||||
格式化解析消息列表
|
格式化解析消息列表
|
||||||
@@ -26,13 +24,13 @@ async def format_parsing(messages: list,type:str='string'):
|
|||||||
role = content['role']
|
role = content['role']
|
||||||
content = content['content']
|
content = content['content']
|
||||||
if type == "string":
|
if type == "string":
|
||||||
if role == 'human':
|
if role == 'human' or role=="user":
|
||||||
content = '用户:' + content
|
content = '用户:' + content
|
||||||
else:
|
else:
|
||||||
content = 'AI:' + content
|
content = 'AI:' + content
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if type == "dict":
|
if type == "dict" :
|
||||||
if role == 'human':
|
if role == 'human' or role=="user":
|
||||||
user.append( content)
|
user.append( content)
|
||||||
else:
|
else:
|
||||||
ai.append(content)
|
ai.append(content)
|
||||||
@@ -57,33 +55,7 @@ async def messages_parse(messages: list | dict):
|
|||||||
for key, values in zip(user, ai):
|
for key, values in zip(user, ai):
|
||||||
database.append({key, values})
|
database.append({key, values})
|
||||||
return database
|
return database
|
||||||
async def chat_data_format(messages: list | dict):
|
|
||||||
"""
|
|
||||||
将消息格式化为 LangChain 消息格式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: 消息列表或字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LangChain 消息列表
|
|
||||||
"""
|
|
||||||
langchain_messages = []
|
|
||||||
if isinstance(messages, list):
|
|
||||||
for msg in messages:
|
|
||||||
if 'role' in msg.keys():
|
|
||||||
if msg['role'] == 'user':
|
|
||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
|
||||||
elif msg['role'] == 'assistant':
|
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
|
||||||
if "Query" in msg.keys():
|
|
||||||
langchain_messages.append(HumanMessage(content=msg['Query']))
|
|
||||||
langchain_messages.append(AIMessage(content=msg['Answer']))
|
|
||||||
if isinstance(messages, dict):
|
|
||||||
if messages['type'] == 'human':
|
|
||||||
langchain_messages.append(HumanMessage(content=messages['content']))
|
|
||||||
elif messages['type'] == 'ai':
|
|
||||||
langchain_messages.append(AIMessage(content=messages['content']))
|
|
||||||
return langchain_messages
|
|
||||||
|
|
||||||
async def agent_chat_messages(user_content,ai_content):
|
async def agent_chat_messages(user_content,ai_content):
|
||||||
messages = [
|
messages = [
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from langgraph.constants import END, START
|
from langgraph.constants import END, START
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
|
from app.db import get_db, get_db_context
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||||
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
@@ -37,76 +42,61 @@ async def make_write_graph():
|
|||||||
|
|
||||||
yield graph
|
yield graph
|
||||||
|
|
||||||
|
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||||
end_user_id: str = '', scope: int = 6):
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
"""Dispatch long-term memory storage to Celery background tasks.
|
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||||
|
# 获取数据库会话
|
||||||
Args:
|
with get_db_context() as db_session:
|
||||||
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
|
config_service = MemoryConfigService(db_session)
|
||||||
langchain_messages: List of messages to store
|
memory_config = config_service.load_memory_config(
|
||||||
memory_config: Memory configuration ID (string)
|
config_id=memory_config, # 改为整数
|
||||||
end_user_id: End user identifier
|
service_name="MemoryAgentService"
|
||||||
scope: Window size for 'chunk' strategy (default: 6)
|
|
||||||
"""
|
|
||||||
from app.tasks import (
|
|
||||||
long_term_storage_window_task,
|
|
||||||
# TODO: Uncomment when implemented
|
|
||||||
# long_term_storage_time_task,
|
|
||||||
# long_term_storage_aggregate_task,
|
|
||||||
)
|
|
||||||
from app.core.logging_config import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
# Convert config to string if needed
|
|
||||||
config_id = str(memory_config) if memory_config else ''
|
|
||||||
|
|
||||||
if long_term_type == 'chunk':
|
|
||||||
# Strategy 1: Window-based batching (6 rounds of dialogue)
|
|
||||||
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
|
|
||||||
long_term_storage_window_task.delay(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
langchain_messages=langchain_messages,
|
|
||||||
config_id=config_id,
|
|
||||||
scope=scope
|
|
||||||
)
|
)
|
||||||
# TODO: Uncomment when time-based strategy is fully implemented
|
if long_term_type=='chunk':
|
||||||
# elif long_term_type == 'time':
|
'''方案一:对话窗口6轮对话'''
|
||||||
# # Strategy 2: Time-based retrieval
|
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||||
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
|
if long_term_type=='time':
|
||||||
# long_term_storage_time_task.delay(
|
"""时间"""
|
||||||
# end_user_id=end_user_id,
|
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||||
# config_id=config_id,
|
if long_term_type=='aggregate':
|
||||||
# time_window=5
|
"""方案三:聚合判断"""
|
||||||
# )
|
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||||
# TODO: Uncomment when aggregate strategy is fully implemented
|
|
||||||
# elif long_term_type == 'aggregate':
|
|
||||||
# # Strategy 3: Aggregate judgment (deduplication)
|
|
||||||
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
|
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||||
# long_term_storage_aggregate_task.delay(
|
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||||
# end_user_id=end_user_id,
|
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||||
# langchain_messages=langchain_messages,
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||||
# config_id=config_id
|
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||||
# )
|
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||||
|
else:
|
||||||
|
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||||
|
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||||
|
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||||
|
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||||
|
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||||
|
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||||
|
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||||
|
|
||||||
# async def main():
|
# async def main():
|
||||||
# """主函数 - 运行工作流"""
|
# """主函数 - 运行工作流"""
|
||||||
# langchain_messages = [
|
# langchain_messages = [
|
||||||
# {
|
# {
|
||||||
# "role": "user",
|
# "role": "user",
|
||||||
# "content": "今天周五好开心啊"
|
# "content": "今天周五去爬山"
|
||||||
# },
|
# },
|
||||||
# {
|
# {
|
||||||
# "role": "assistant",
|
# "role": "assistant",
|
||||||
# "content": "你也这么觉得,我也是耶"
|
# "content": "好耶"
|
||||||
# }
|
# }
|
||||||
#
|
#
|
||||||
# ]
|
# ]
|
||||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||||
# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||||
# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
#
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
|
|||||||
@@ -294,6 +294,7 @@ class RedisCountStore:
|
|||||||
"""
|
"""
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
key = generate_session_key(session_id, key_type="count")
|
key = generate_session_key(session_id, key_type="count")
|
||||||
|
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||||
|
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
pipe.hset(key, mapping={
|
pipe.hset(key, mapping={
|
||||||
@@ -304,6 +305,10 @@ class RedisCountStore:
|
|||||||
"starttime": get_current_timestamp()
|
"starttime": get_current_timestamp()
|
||||||
})
|
})
|
||||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||||
|
|
||||||
|
# 创建索引:end_user_id -> session_id 映射
|
||||||
|
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||||
|
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||||
@@ -320,31 +325,47 @@ class RedisCountStore:
|
|||||||
list 或 False: 如果找到返回 [count, messages],否则返回 False
|
list 或 False: 如果找到返回 [count, messages],否则返回 False
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
search_pattern = 'session:count:*'
|
# 使用索引键快速查找
|
||||||
|
index_key = f'session:count:index:{end_user_id}'
|
||||||
|
|
||||||
for key in self.r.keys(search_pattern):
|
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||||
data = self.r.hgetall(key)
|
try:
|
||||||
|
key_type = self.r.type(index_key)
|
||||||
if not data:
|
if key_type != 'string' and key_type != 'none':
|
||||||
continue
|
self.r.delete(index_key)
|
||||||
|
return False
|
||||||
if data.get('end_user_id') == end_user_id:
|
except Exception as type_error:
|
||||||
count = data.get('count')
|
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||||
messages_str = data.get('messages')
|
|
||||||
|
session_id = self.r.get(index_key)
|
||||||
if count is not None:
|
|
||||||
messages = deserialize_messages(messages_str)
|
if not session_id:
|
||||||
return [int(count), messages]
|
return False
|
||||||
|
|
||||||
|
# 直接获取数据
|
||||||
|
key = generate_session_key(session_id, key_type="count")
|
||||||
|
data = self.r.hgetall(key)
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
# 索引存在但数据不存在,清理索引
|
||||||
|
self.r.delete(index_key)
|
||||||
|
return False
|
||||||
|
|
||||||
|
count = data.get('count')
|
||||||
|
messages_str = data.get('messages')
|
||||||
|
|
||||||
|
if count is not None:
|
||||||
|
messages = deserialize_messages(messages_str)
|
||||||
|
return [int(count), messages]
|
||||||
|
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[get_sessions_count] 查询失败: {e}")
|
print(f"[get_sessions_count] 查询失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||||
messages: Any) -> bool:
|
messages: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
通过 end_user_id 修改访问次数统计
|
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 终端用户ID
|
end_user_id: 终端用户ID
|
||||||
@@ -355,23 +376,39 @@ class RedisCountStore:
|
|||||||
bool: 更新成功返回 True,未找到记录返回 False
|
bool: 更新成功返回 True,未找到记录返回 False
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 使用索引键快速查找
|
||||||
|
index_key = f'session:count:index:{end_user_id}'
|
||||||
|
|
||||||
|
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||||
|
try:
|
||||||
|
key_type = self.r.type(index_key)
|
||||||
|
if key_type != 'string' and key_type != 'none':
|
||||||
|
# 索引键类型错误,删除并返回 False
|
||||||
|
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||||
|
self.r.delete(index_key)
|
||||||
|
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
|
return False
|
||||||
|
except Exception as type_error:
|
||||||
|
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||||
|
|
||||||
|
session_id = self.r.get(index_key)
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 直接更新数据
|
||||||
|
key = generate_session_key(session_id, key_type="count")
|
||||||
messages_str = serialize_messages(messages)
|
messages_str = serialize_messages(messages)
|
||||||
search_pattern = 'session:count:*'
|
|
||||||
|
|
||||||
for key in self.r.keys(search_pattern):
|
pipe = self.r.pipeline()
|
||||||
data = self.r.hgetall(key)
|
pipe.hset(key, 'count', int(new_count))
|
||||||
|
pipe.hset(key, 'messages', messages_str)
|
||||||
if not data:
|
result = pipe.execute()
|
||||||
continue
|
|
||||||
|
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||||
if data.get('end_user_id') == end_user_id:
|
return True
|
||||||
self.r.hset(key, 'count', int(new_count))
|
|
||||||
self.r.hset(key, 'messages', messages_str)
|
|
||||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[update_sessions_count] 更新失败: {e}")
|
print(f"[update_sessions_count] 更新失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
|
|||||||
This module provides the main write function for executing the knowledge extraction
|
This module provides the main write function for executing the knowledge extraction
|
||||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -124,23 +125,48 @@ async def write(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# 添加死锁重试机制
|
||||||
|
max_retries = 3
|
||||||
|
retry_delay = 1 # 秒
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
success = await save_dialog_and_statements_to_neo4j(
|
||||||
|
dialogue_nodes=all_dialogue_nodes,
|
||||||
|
chunk_nodes=all_chunk_nodes,
|
||||||
|
statement_nodes=all_statement_nodes,
|
||||||
|
entity_nodes=all_entity_nodes,
|
||||||
|
statement_chunk_edges=all_statement_chunk_edges,
|
||||||
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
|
entity_edges=all_entity_entity_edges,
|
||||||
|
connector=neo4j_connector
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to save some data to Neo4j")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
|
||||||
|
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
# 检查是否是死锁错误
|
||||||
|
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
|
||||||
|
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# 非死锁错误,直接抛出
|
||||||
|
raise
|
||||||
|
|
||||||
try:
|
try:
|
||||||
success = await save_dialog_and_statements_to_neo4j(
|
|
||||||
dialogue_nodes=all_dialogue_nodes,
|
|
||||||
chunk_nodes=all_chunk_nodes,
|
|
||||||
statement_nodes=all_statement_nodes,
|
|
||||||
entity_nodes=all_entity_nodes,
|
|
||||||
statement_chunk_edges=all_statement_chunk_edges,
|
|
||||||
statement_entity_edges=all_statement_entity_edges,
|
|
||||||
entity_edges=all_entity_entity_edges,
|
|
||||||
connector=neo4j_connector
|
|
||||||
)
|
|
||||||
if success:
|
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
|
||||||
else:
|
|
||||||
logger.warning("Failed to save some data to Neo4j")
|
|
||||||
finally:
|
|
||||||
await neo4j_connector.close()
|
await neo4j_connector.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing Neo4j connector: {e}")
|
||||||
|
|
||||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
|||||||
@@ -413,7 +413,8 @@ class ExtractedEntityNode(Node):
|
|||||||
description="Entity aliases - alternative names for this entity"
|
description="Entity aliases - alternative names for this entity"
|
||||||
)
|
)
|
||||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||||
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||||
|
|
||||||
|
|||||||
@@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
|||||||
if len(desc_b) > len(desc_a):
|
if len(desc_b) > len(desc_a):
|
||||||
canonical.description = desc_b
|
canonical.description = desc_b
|
||||||
# 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序
|
# 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序
|
||||||
fact_a = getattr(canonical, "fact_summary", "") or ""
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
fact_b = getattr(ent, "fact_summary", "") or ""
|
# fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||||
def _extract_sources(txt: str) -> List[str]:
|
# fact_b = getattr(ent, "fact_summary", "") or ""
|
||||||
sources: List[str] = []
|
# def _extract_sources(txt: str) -> List[str]:
|
||||||
if not txt:
|
# sources: List[str] = []
|
||||||
return sources
|
# if not txt:
|
||||||
for line in str(txt).splitlines():
|
# return sources
|
||||||
ln = line.strip()
|
# for line in str(txt).splitlines():
|
||||||
|
# ln = line.strip()
|
||||||
# 支持“来源:”或“来源:”前缀
|
# 支持“来源:”或“来源:”前缀
|
||||||
m = re.match(r"^来源[::]\s*(.+)$", ln)
|
# m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||||
if m:
|
# if m:
|
||||||
content = m.group(1).strip()
|
# content = m.group(1).strip()
|
||||||
if content:
|
# if content:
|
||||||
sources.append(content)
|
# sources.append(content)
|
||||||
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
|
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
|
||||||
if not sources and txt.strip():
|
# if not sources and txt.strip():
|
||||||
sources.append(txt.strip())
|
# sources.append(txt.strip())
|
||||||
return sources
|
# return sources
|
||||||
try:
|
try:
|
||||||
src_a = _extract_sources(fact_a)
|
# src_a = _extract_sources(fact_a)
|
||||||
src_b = _extract_sources(fact_b)
|
# src_b = _extract_sources(fact_b)
|
||||||
seen = set()
|
# seen = set()
|
||||||
merged_sources: List[str] = []
|
# merged_sources: List[str] = []
|
||||||
for s in src_a + src_b:
|
# for s in src_a + src_b:
|
||||||
if s and s not in seen:
|
# if s and s not in seen:
|
||||||
seen.add(s)
|
# seen.add(s)
|
||||||
merged_sources.append(s)
|
# merged_sources.append(s)
|
||||||
if merged_sources:
|
# if merged_sources:
|
||||||
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
# name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||||
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
# canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||||
elif fact_b and not fact_a:
|
# elif fact_b and not fact_a:
|
||||||
canonical.fact_summary = fact_b
|
# canonical.fact_summary = fact_b
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
# 兜底:若解析失败,保留较长文本
|
# 兜底:若解析失败,保留较长文本
|
||||||
if len(fact_b) > len(fact_a):
|
# if len(fact_b) > len(fact_a):
|
||||||
canonical.fact_summary = fact_b
|
# canonical.fact_summary = fact_b
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
|
|||||||
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
|
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
|
||||||
desc_a = (getattr(a, "description", "") or "")
|
desc_a = (getattr(a, "description", "") or "")
|
||||||
desc_b = (getattr(b, "description", "") or "")
|
desc_b = (getattr(b, "description", "") or "")
|
||||||
fact_a = (getattr(a, "fact_summary", "") or "")
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
fact_b = (getattr(b, "fact_summary", "") or "")
|
# fact_a = (getattr(a, "fact_summary", "") or "")
|
||||||
score_a = len(desc_a) + len(fact_a)
|
# fact_b = (getattr(b, "fact_summary", "") or "")
|
||||||
score_b = len(desc_b) + len(fact_b)
|
# score_a = len(desc_a) + len(fact_a)
|
||||||
|
# score_b = len(desc_b) + len(fact_b)
|
||||||
|
score_a = len(desc_a)
|
||||||
|
score_b = len(desc_b)
|
||||||
if score_a != score_b:
|
if score_a != score_b:
|
||||||
return 0 if score_a >= score_b else 1
|
return 0 if score_a >= score_b else 1
|
||||||
return 0
|
return 0
|
||||||
@@ -189,7 +192,8 @@ async def _judge_pair(
|
|||||||
"entity_type": getattr(a, "entity_type", None),
|
"entity_type": getattr(a, "entity_type", None),
|
||||||
"description": getattr(a, "description", None),
|
"description": getattr(a, "description", None),
|
||||||
"aliases": getattr(a, "aliases", None) or [],
|
"aliases": getattr(a, "aliases", None) or [],
|
||||||
"fact_summary": getattr(a, "fact_summary", None),
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# "fact_summary": getattr(a, "fact_summary", None),
|
||||||
"connect_strength": getattr(a, "connect_strength", None),
|
"connect_strength": getattr(a, "connect_strength", None),
|
||||||
}
|
}
|
||||||
entity_b = {
|
entity_b = {
|
||||||
@@ -197,7 +201,8 @@ async def _judge_pair(
|
|||||||
"entity_type": getattr(b, "entity_type", None),
|
"entity_type": getattr(b, "entity_type", None),
|
||||||
"description": getattr(b, "description", None),
|
"description": getattr(b, "description", None),
|
||||||
"aliases": getattr(b, "aliases", None) or [],
|
"aliases": getattr(b, "aliases", None) or [],
|
||||||
"fact_summary": getattr(b, "fact_summary", None),
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# "fact_summary": getattr(b, "fact_summary", None),
|
||||||
"connect_strength": getattr(b, "connect_strength", None),
|
"connect_strength": getattr(b, "connect_strength", None),
|
||||||
}
|
}
|
||||||
# 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式)
|
# 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式)
|
||||||
@@ -248,7 +253,8 @@ async def _judge_pair_disamb(
|
|||||||
"entity_type": getattr(a, "entity_type", None),
|
"entity_type": getattr(a, "entity_type", None),
|
||||||
"description": getattr(a, "description", None),
|
"description": getattr(a, "description", None),
|
||||||
"aliases": getattr(a, "aliases", None) or [],
|
"aliases": getattr(a, "aliases", None) or [],
|
||||||
"fact_summary": getattr(a, "fact_summary", None),
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# "fact_summary": getattr(a, "fact_summary", None),
|
||||||
"connect_strength": getattr(a, "connect_strength", None),
|
"connect_strength": getattr(a, "connect_strength", None),
|
||||||
}
|
}
|
||||||
entity_b = {
|
entity_b = {
|
||||||
@@ -256,7 +262,8 @@ async def _judge_pair_disamb(
|
|||||||
"entity_type": getattr(b, "entity_type", None),
|
"entity_type": getattr(b, "entity_type", None),
|
||||||
"description": getattr(b, "description", None),
|
"description": getattr(b, "description", None),
|
||||||
"aliases": getattr(b, "aliases", None) or [],
|
"aliases": getattr(b, "aliases", None) or [],
|
||||||
"fact_summary": getattr(b, "fact_summary", None),
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# "fact_summary": getattr(b, "fact_summary", None),
|
||||||
"connect_strength": getattr(b, "connect_strength", None),
|
"connect_strength": getattr(b, "connect_strength", None),
|
||||||
}
|
}
|
||||||
prompt = render_entity_dedup_prompt(
|
prompt = render_entity_dedup_prompt(
|
||||||
|
|||||||
@@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
|||||||
description=row.get("description") or "",
|
description=row.get("description") or "",
|
||||||
aliases=row.get("aliases") or [],
|
aliases=row.get("aliases") or [],
|
||||||
name_embedding=row.get("name_embedding") or [],
|
name_embedding=row.get("name_embedding") or [],
|
||||||
fact_summary=row.get("fact_summary") or "",
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# fact_summary=row.get("fact_summary") or "",
|
||||||
connect_strength=row.get("connect_strength") or "",
|
connect_strength=row.get("connect_strength") or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1088,7 +1088,8 @@ class ExtractionOrchestrator:
|
|||||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||||
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||||
name_embedding=getattr(entity, 'name_embedding', None),
|
name_embedding=getattr(entity, 'name_embedding', None),
|
||||||
|
|||||||
@@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
|
|||||||
key=lambda eid: (
|
key=lambda eid: (
|
||||||
_strength_rank(eid),
|
_strength_rank(eid),
|
||||||
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
|
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
|
||||||
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||||
|
0 # 临时占位
|
||||||
),
|
),
|
||||||
reverse=True
|
reverse=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,7 +9,8 @@
|
|||||||
- 类型: "{{ entity_a.entity_type | default('') }}"
|
- 类型: "{{ entity_a.entity_type | default('') }}"
|
||||||
- 描述: "{{ entity_a.description | default('') }}"
|
- 描述: "{{ entity_a.description | default('') }}"
|
||||||
- 别名: {{ entity_a.aliases | default([]) }}
|
- 别名: {{ entity_a.aliases | default([]) }}
|
||||||
- 摘要: "{{ entity_a.fact_summary | default('') }}"
|
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
||||||
|
{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #}
|
||||||
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
|
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
|
||||||
|
|
||||||
实体B:
|
实体B:
|
||||||
@@ -17,7 +18,8 @@
|
|||||||
- 类型: "{{ entity_b.entity_type | default('') }}"
|
- 类型: "{{ entity_b.entity_type | default('') }}"
|
||||||
- 描述: "{{ entity_b.description | default('') }}"
|
- 描述: "{{ entity_b.description | default('') }}"
|
||||||
- 别名: {{ entity_b.aliases | default([]) }}
|
- 别名: {{ entity_b.aliases | default([]) }}
|
||||||
- 摘要: "{{ entity_b.fact_summary | default('') }}"
|
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
||||||
|
{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #}
|
||||||
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
|
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
|
||||||
|
|
||||||
上下文:
|
上下文:
|
||||||
|
|||||||
0
api/app/core/rag/crawler/__init__.py
Normal file
0
api/app/core/rag/crawler/__init__.py
Normal file
89
api/app/core/rag/crawler/__main__.py
Normal file
89
api/app/core/rag/crawler/__main__.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""Command-line interface for web crawler."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from app.core.rag.crawler.web_crawler import WebCrawler
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(verbose: bool = False):
|
||||||
|
"""Set up logging configuration."""
|
||||||
|
level = logging.DEBUG if verbose else logging.INFO
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(sys.stdout)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(entry_url: str,
|
||||||
|
max_pages: int = 200,
|
||||||
|
delay_seconds: float = 1.0,
|
||||||
|
timeout_seconds: int = 10,
|
||||||
|
user_agent: str = "KnowledgeBaseCrawler/1.0"):
|
||||||
|
"""Main entry point for the crawler."""
|
||||||
|
# Create crawler
|
||||||
|
crawler = WebCrawler(
|
||||||
|
entry_url=entry_url,
|
||||||
|
max_pages=max_pages,
|
||||||
|
delay_seconds=delay_seconds,
|
||||||
|
timeout_seconds=timeout_seconds,
|
||||||
|
user_agent=user_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
# Crawl and collect documents
|
||||||
|
documents = []
|
||||||
|
try:
|
||||||
|
for doc in crawler.crawl():
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"URL: {doc.url}")
|
||||||
|
print(f"Title: {doc.title}")
|
||||||
|
print(f"Content Length: {doc.content_length} characters")
|
||||||
|
print(f"Word Count: {doc.metadata.get('word_count', 0)} words")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
documents.append({
|
||||||
|
'url': doc.url,
|
||||||
|
'title': doc.title,
|
||||||
|
'content': doc.content,
|
||||||
|
'content_length': doc.content_length,
|
||||||
|
'crawl_timestamp': doc.crawl_timestamp.isoformat(),
|
||||||
|
'http_status': doc.http_status,
|
||||||
|
'metadata': doc.metadata
|
||||||
|
})
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nCrawl interrupted by user.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\nError during crawl: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Get summary
|
||||||
|
summary = crawler.get_summary()
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("CRAWL SUMMARY")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
print(f"Total Pages Processed: {summary.total_pages_processed}")
|
||||||
|
print(f"Total Errors: {summary.total_errors}")
|
||||||
|
print(f"Total Skipped: {summary.total_skipped}")
|
||||||
|
print(f"Total URLs Discovered: {summary.total_urls_discovered}")
|
||||||
|
print(f"Duration: {summary.duration_seconds:.2f} seconds")
|
||||||
|
print(f"documents: {documents}")
|
||||||
|
|
||||||
|
if summary.error_breakdown:
|
||||||
|
print(f"\nError Breakdown:")
|
||||||
|
for error_type, count in summary.error_breakdown.items():
|
||||||
|
print(f" {error_type}: {count}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
entry_url = "https://www.xxx.com"
|
||||||
|
max_pages = 20
|
||||||
|
delay_seconds = 1.0
|
||||||
|
timeout_seconds = 10
|
||||||
|
user_agent = "KnowledgeBaseCrawler/1.0"
|
||||||
|
|
||||||
|
main(entry_url, max_pages, delay_seconds, timeout_seconds, user_agent)
|
||||||
233
api/app/core/rag/crawler/content_extractor.py
Normal file
233
api/app/core/rag/crawler/content_extractor.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""Content extractor for web crawler."""
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from app.core.rag.crawler.models import ExtractedContent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContentExtractor:
|
||||||
|
"""Extract clean, readable text from HTML pages."""
|
||||||
|
|
||||||
|
# Tags to remove completely
|
||||||
|
REMOVE_TAGS = ['script', 'style', 'nav', 'header', 'footer', 'aside']
|
||||||
|
|
||||||
|
# Tags that typically contain main content
|
||||||
|
MAIN_CONTENT_TAGS = ['article', 'main']
|
||||||
|
|
||||||
|
# Content extraction tags
|
||||||
|
CONTENT_TAGS = ['p', 'div', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'td', 'th', 'section']
|
||||||
|
|
||||||
|
def is_static_content(self, html: str) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if the HTML represents static content.
|
||||||
|
|
||||||
|
Detects JavaScript-rendered content by checking for minimal body
|
||||||
|
with heavy script tag presence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
html: Raw HTML string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if static, False if JavaScript-rendered
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
soup = BeautifulSoup(html, 'lxml')
|
||||||
|
|
||||||
|
# Count script tags
|
||||||
|
script_tags = soup.find_all('script')
|
||||||
|
script_count = len(script_tags)
|
||||||
|
|
||||||
|
# Get body content (excluding scripts and styles)
|
||||||
|
body = soup.find('body')
|
||||||
|
if not body:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Remove scripts and styles temporarily for text check
|
||||||
|
for tag in body.find_all(['script', 'style']):
|
||||||
|
tag.decompose()
|
||||||
|
|
||||||
|
# Get text content
|
||||||
|
text = body.get_text(strip=True)
|
||||||
|
text_length = len(text)
|
||||||
|
|
||||||
|
# If there's very little text but many scripts, likely JS-rendered
|
||||||
|
if script_count > 5 and text_length < 200:
|
||||||
|
logger.warning("Detected JavaScript-rendered content (many scripts, little text)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If there's no meaningful text, likely JS-rendered
|
||||||
|
if text_length < 50:
|
||||||
|
logger.warning("Detected JavaScript-rendered content (minimal text)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking if content is static: {e}")
|
||||||
|
return True # Assume static on error
|
||||||
|
|
||||||
|
def extract(self, html: str, url: str) -> ExtractedContent:
|
||||||
|
"""
|
||||||
|
Extract clean text content from HTML.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
html: Raw HTML string
|
||||||
|
url: Source URL (for context)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExtractedContent: Contains title, text, metadata
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
soup = BeautifulSoup(html, 'lxml')
|
||||||
|
|
||||||
|
# Check if content is static
|
||||||
|
is_static = self.is_static_content(html)
|
||||||
|
|
||||||
|
# Extract title
|
||||||
|
title = self._extract_title(soup)
|
||||||
|
|
||||||
|
# Remove unwanted tags
|
||||||
|
for tag_name in self.REMOVE_TAGS:
|
||||||
|
for tag in soup.find_all(tag_name):
|
||||||
|
tag.decompose()
|
||||||
|
|
||||||
|
# Extract main content
|
||||||
|
text = self._extract_main_content(soup)
|
||||||
|
|
||||||
|
# Normalize whitespace
|
||||||
|
text = self._normalize_whitespace(text)
|
||||||
|
|
||||||
|
# Count words
|
||||||
|
word_count = len(text.split())
|
||||||
|
|
||||||
|
logger.info(f"Extracted {word_count} words from {url}")
|
||||||
|
|
||||||
|
return ExtractedContent(
|
||||||
|
title=title,
|
||||||
|
text=text,
|
||||||
|
is_static=is_static,
|
||||||
|
word_count=word_count,
|
||||||
|
metadata={'url': url}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting content from {url}: {e}")
|
||||||
|
return ExtractedContent(
|
||||||
|
title=url,
|
||||||
|
text="",
|
||||||
|
is_static=False,
|
||||||
|
word_count=0,
|
||||||
|
metadata={'url': url, 'error': str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_title(self, soup: BeautifulSoup) -> str:
|
||||||
|
"""
|
||||||
|
Extract title from HTML.
|
||||||
|
|
||||||
|
Tries <title> tag first, then first <h1>.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
soup: BeautifulSoup object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Page title
|
||||||
|
"""
|
||||||
|
# Try <title> tag
|
||||||
|
title_tag = soup.find('title')
|
||||||
|
if title_tag and title_tag.string:
|
||||||
|
return title_tag.string.strip()
|
||||||
|
|
||||||
|
# Try first <h1>
|
||||||
|
h1_tag = soup.find('h1')
|
||||||
|
if h1_tag:
|
||||||
|
return h1_tag.get_text(strip=True)
|
||||||
|
|
||||||
|
# Default to empty string
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _extract_main_content(self, soup: BeautifulSoup) -> str:
|
||||||
|
"""
|
||||||
|
Extract main content from HTML.
|
||||||
|
|
||||||
|
Prioritizes semantic HTML5 elements like <article> and <main>.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
soup: BeautifulSoup object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Extracted text content
|
||||||
|
"""
|
||||||
|
# Try to find main content area
|
||||||
|
main_content = None
|
||||||
|
|
||||||
|
# Priority 1: <article> or <main> tags
|
||||||
|
for tag_name in self.MAIN_CONTENT_TAGS:
|
||||||
|
main_content = soup.find(tag_name)
|
||||||
|
if main_content:
|
||||||
|
logger.debug(f"Found main content in <{tag_name}> tag")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Priority 2: div with role="main"
|
||||||
|
if not main_content:
|
||||||
|
main_content = soup.find('div', role='main')
|
||||||
|
if main_content:
|
||||||
|
logger.debug("Found main content in div[role='main']")
|
||||||
|
|
||||||
|
# Priority 3: Common class/id patterns
|
||||||
|
if not main_content:
|
||||||
|
for pattern in ['content', 'main', 'article', 'post']:
|
||||||
|
main_content = soup.find(['div', 'section'], class_=re.compile(pattern, re.I))
|
||||||
|
if main_content:
|
||||||
|
logger.debug(f"Found main content with class pattern '{pattern}'")
|
||||||
|
break
|
||||||
|
|
||||||
|
main_content = soup.find(['div', 'section'], id=re.compile(pattern, re.I))
|
||||||
|
if main_content:
|
||||||
|
logger.debug(f"Found main content with id pattern '{pattern}'")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Fallback: use body
|
||||||
|
if not main_content:
|
||||||
|
main_content = soup.find('body')
|
||||||
|
logger.debug("Using <body> as main content (no specific content area found)")
|
||||||
|
|
||||||
|
# Extract text from content tags
|
||||||
|
if main_content:
|
||||||
|
text_parts = []
|
||||||
|
for tag in main_content.find_all(self.CONTENT_TAGS):
|
||||||
|
text = tag.get_text(strip=True)
|
||||||
|
if text:
|
||||||
|
text_parts.append(text)
|
||||||
|
|
||||||
|
return '\n'.join(text_parts)
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _normalize_whitespace(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Normalize whitespace in text.
|
||||||
|
|
||||||
|
- Collapse multiple spaces to single space
|
||||||
|
- Reduce excessive newlines to maximum 2
|
||||||
|
- Strip leading/trailing whitespace
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Normalized text
|
||||||
|
"""
|
||||||
|
# Collapse multiple spaces to single space
|
||||||
|
text = re.sub(r' +', ' ', text)
|
||||||
|
|
||||||
|
# Reduce excessive newlines to maximum 2
|
||||||
|
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||||
|
|
||||||
|
# Strip leading/trailing whitespace
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
return text
|
||||||
302
api/app/core/rag/crawler/http_fetcher.py
Normal file
302
api/app/core/rag/crawler/http_fetcher.py
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
"""HTTP fetcher for web crawler."""
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
|
||||||
|
from app.core.rag.crawler.models import FetchResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPFetcher:
|
||||||
|
"""Handle HTTP requests with retries, error handling, and response validation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
timeout: int = 10,
|
||||||
|
max_retries: int = 3,
|
||||||
|
user_agent: str = "KnowledgeBaseCrawler/1.0"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize HTTP fetcher.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
user_agent: User-Agent header value
|
||||||
|
"""
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.user_agent = user_agent
|
||||||
|
|
||||||
|
# Create session for connection pooling
|
||||||
|
self.session = requests.Session()
|
||||||
|
self.session.headers.update({
|
||||||
|
'User-Agent': user_agent
|
||||||
|
})
|
||||||
|
|
||||||
|
def fetch(self, url: str) -> FetchResult:
|
||||||
|
"""
|
||||||
|
Fetch a URL with retry logic and error handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FetchResult: Contains status_code, content, headers, error info
|
||||||
|
"""
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
# Calculate backoff delay for retries
|
||||||
|
if attempt > 0:
|
||||||
|
backoff_delay = 2 ** (attempt - 1) # 1s, 2s, 4s
|
||||||
|
logger.info(f"Retry attempt {attempt + 1}/{self.max_retries} for {url} after {backoff_delay}s")
|
||||||
|
time.sleep(backoff_delay)
|
||||||
|
|
||||||
|
# Make HTTP request
|
||||||
|
response = self.session.get(
|
||||||
|
url,
|
||||||
|
timeout=self.timeout,
|
||||||
|
allow_redirects=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle different status codes
|
||||||
|
if response.status_code == 429:
|
||||||
|
# Too Many Requests - backoff and retry
|
||||||
|
logger.warning(f"429 Too Many Requests for {url}, backing off")
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if response.status_code == 503:
|
||||||
|
# Service Unavailable - pause and retry
|
||||||
|
logger.warning(f"503 Service Unavailable for {url}")
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
time.sleep(5) # Longer pause for 503
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Success or client error (don't retry 4xx except 429)
|
||||||
|
if 200 <= response.status_code < 300:
|
||||||
|
logger.info(f"Successfully fetched {url} (status: {response.status_code})")
|
||||||
|
|
||||||
|
# Get correctly encoded content
|
||||||
|
content = self._get_decoded_content(response)
|
||||||
|
|
||||||
|
return FetchResult(
|
||||||
|
url=url,
|
||||||
|
final_url=response.url,
|
||||||
|
status_code=response.status_code,
|
||||||
|
content=content,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
error=None,
|
||||||
|
success=True
|
||||||
|
)
|
||||||
|
elif response.status_code == 404:
|
||||||
|
logger.info(f"404 Not Found: {url}")
|
||||||
|
return FetchResult(
|
||||||
|
url=url,
|
||||||
|
final_url=response.url,
|
||||||
|
status_code=response.status_code,
|
||||||
|
content=None,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
error="Not Found",
|
||||||
|
success=False
|
||||||
|
)
|
||||||
|
elif 400 <= response.status_code < 500:
|
||||||
|
logger.warning(f"Client error {response.status_code} for {url}")
|
||||||
|
return FetchResult(
|
||||||
|
url=url,
|
||||||
|
final_url=response.url,
|
||||||
|
status_code=response.status_code,
|
||||||
|
content=None,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
error=f"Client error: {response.status_code}",
|
||||||
|
success=False
|
||||||
|
)
|
||||||
|
elif 500 <= response.status_code < 600:
|
||||||
|
logger.error(f"Server error {response.status_code} for {url}")
|
||||||
|
last_error = f"Server error: {response.status_code}"
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
continue
|
||||||
|
return FetchResult(
|
||||||
|
url=url,
|
||||||
|
final_url=url,
|
||||||
|
status_code=response.status_code,
|
||||||
|
content=None,
|
||||||
|
headers={},
|
||||||
|
error=last_error,
|
||||||
|
success=False
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
last_error = "Request timeout"
|
||||||
|
logger.warning(f"Timeout fetching {url} (attempt {attempt + 1}/{self.max_retries})")
|
||||||
|
if attempt >= self.max_retries - 1:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
except requests.exceptions.SSLError as e:
|
||||||
|
last_error = f"SSL/TLS error: {str(e)}"
|
||||||
|
logger.error(f"SSL/TLS error for {url}: {e}")
|
||||||
|
return FetchResult(
|
||||||
|
url=url,
|
||||||
|
final_url=url,
|
||||||
|
status_code=0,
|
||||||
|
content=None,
|
||||||
|
headers={},
|
||||||
|
error=last_error,
|
||||||
|
success=False
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.exceptions.ConnectionError as e:
|
||||||
|
last_error = f"Connection error: {str(e)}"
|
||||||
|
logger.warning(f"Connection error for {url} (attempt {attempt + 1}/{self.max_retries}): {e}")
|
||||||
|
if attempt >= self.max_retries - 1:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
last_error = f"Request error: {str(e)}"
|
||||||
|
logger.error(f"Request error for {url}: {e}")
|
||||||
|
if attempt >= self.max_retries - 1:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
# All retries exhausted
|
||||||
|
logger.error(f"Failed to fetch {url} after {self.max_retries} attempts: {last_error}")
|
||||||
|
return FetchResult(
|
||||||
|
url=url,
|
||||||
|
final_url=url,
|
||||||
|
status_code=0,
|
||||||
|
content=None,
|
||||||
|
headers={},
|
||||||
|
error=last_error or "Unknown error",
|
||||||
|
success=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_decoded_content(self, response) -> str:
|
||||||
|
"""
|
||||||
|
Get correctly decoded content from response.
|
||||||
|
|
||||||
|
Handles encoding detection and fallback strategies:
|
||||||
|
1. Try encoding from HTML meta tags
|
||||||
|
2. Try response.encoding (from Content-Type header or detected)
|
||||||
|
3. Try UTF-8
|
||||||
|
4. Try common encodings (GB2312, GBK for Chinese, etc.)
|
||||||
|
5. Fall back to latin-1 with error replacement
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: requests.Response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Decoded content
|
||||||
|
"""
|
||||||
|
# Try to detect encoding from HTML meta tags
|
||||||
|
meta_encoding = self._detect_encoding_from_meta(response.content)
|
||||||
|
if meta_encoding:
|
||||||
|
try:
|
||||||
|
content = response.content.decode(meta_encoding)
|
||||||
|
logger.info(f"Successfully decoded with meta tag encoding: {meta_encoding}")
|
||||||
|
return content
|
||||||
|
except (UnicodeDecodeError, LookupError) as e:
|
||||||
|
logger.warning(f"Failed to decode with meta encoding {meta_encoding}: {e}")
|
||||||
|
|
||||||
|
# Try response.encoding (from Content-Type header or detected by requests)
|
||||||
|
if response.encoding and response.encoding.lower() != 'iso-8859-1':
|
||||||
|
# Note: requests defaults to ISO-8859-1 if no charset in Content-Type,
|
||||||
|
# so we skip it here and try UTF-8 first
|
||||||
|
try:
|
||||||
|
return response.text
|
||||||
|
except (UnicodeDecodeError, LookupError) as e:
|
||||||
|
logger.warning(f"Failed to decode with detected encoding {response.encoding}: {e}")
|
||||||
|
|
||||||
|
# Try UTF-8 first (most common)
|
||||||
|
try:
|
||||||
|
return response.content.decode('utf-8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logger.debug("UTF-8 decoding failed, trying other encodings")
|
||||||
|
|
||||||
|
# Try common encodings for different languages
|
||||||
|
encodings_to_try = [
|
||||||
|
'gbk', # Chinese (Simplified)
|
||||||
|
'gb2312', # Chinese (Simplified, older)
|
||||||
|
'gb18030', # Chinese (Simplified, extended)
|
||||||
|
'big5', # Chinese (Traditional)
|
||||||
|
'shift_jis', # Japanese
|
||||||
|
'euc-jp', # Japanese
|
||||||
|
'euc-kr', # Korean
|
||||||
|
'iso-8859-1', # Western European
|
||||||
|
'windows-1252', # Windows Western European
|
||||||
|
'windows-1251', # Cyrillic
|
||||||
|
]
|
||||||
|
|
||||||
|
for encoding in encodings_to_try:
|
||||||
|
try:
|
||||||
|
content = response.content.decode(encoding)
|
||||||
|
logger.info(f"Successfully decoded with {encoding}")
|
||||||
|
return content
|
||||||
|
except (UnicodeDecodeError, LookupError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Last resort: use latin-1 with error replacement
|
||||||
|
logger.warning("All encoding attempts failed, using latin-1 with error replacement")
|
||||||
|
return response.content.decode('latin-1', errors='replace')
|
||||||
|
|
||||||
|
def _detect_encoding_from_meta(self, content: bytes) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Detect encoding from HTML meta tags.
|
||||||
|
|
||||||
|
Looks for:
|
||||||
|
- <meta charset="...">
|
||||||
|
- <meta http-equiv="Content-Type" content="...; charset=...">
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Raw response content (bytes)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: Detected encoding or None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Only check first 2KB for performance
|
||||||
|
head = content[:2048]
|
||||||
|
|
||||||
|
# Try to decode as ASCII/Latin-1 to search for meta tags
|
||||||
|
try:
|
||||||
|
head_str = head.decode('ascii', errors='ignore')
|
||||||
|
except:
|
||||||
|
head_str = head.decode('latin-1', errors='ignore')
|
||||||
|
|
||||||
|
# Look for <meta charset="...">
|
||||||
|
charset_match = re.search(
|
||||||
|
r'<meta[^>]+charset=["\']?([a-zA-Z0-9_-]+)',
|
||||||
|
head_str,
|
||||||
|
re.IGNORECASE
|
||||||
|
)
|
||||||
|
if charset_match:
|
||||||
|
encoding = charset_match.group(1).lower()
|
||||||
|
logger.debug(f"Found charset in meta tag: {encoding}")
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
# Look for <meta http-equiv="Content-Type" content="...; charset=...">
|
||||||
|
content_type_match = re.search(
|
||||||
|
r'<meta[^>]+http-equiv=["\']?content-type["\']?[^>]+content=["\']([^"\']+)',
|
||||||
|
head_str,
|
||||||
|
re.IGNORECASE
|
||||||
|
)
|
||||||
|
if content_type_match:
|
||||||
|
content_value = content_type_match.group(1)
|
||||||
|
charset_match = re.search(r'charset=([a-zA-Z0-9_-]+)', content_value, re.IGNORECASE)
|
||||||
|
if charset_match:
|
||||||
|
encoding = charset_match.group(1).lower()
|
||||||
|
logger.debug(f"Found charset in Content-Type meta: {encoding}")
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error detecting encoding from meta tags: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
52
api/app/core/rag/crawler/models.py
Normal file
52
api/app/core/rag/crawler/models.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Data models for web crawler."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CrawledDocument:
|
||||||
|
"""Represents a successfully processed web page with extracted content."""
|
||||||
|
url: str
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
content_length: int
|
||||||
|
crawl_timestamp: datetime
|
||||||
|
http_status: int
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FetchResult:
|
||||||
|
"""Represents the result of an HTTP fetch operation."""
|
||||||
|
url: str
|
||||||
|
final_url: str
|
||||||
|
status_code: int
|
||||||
|
content: Optional[str]
|
||||||
|
headers: Dict[str, str]
|
||||||
|
error: Optional[str]
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtractedContent:
|
||||||
|
"""Represents content extracted from HTML."""
|
||||||
|
title: str
|
||||||
|
text: str
|
||||||
|
is_static: bool
|
||||||
|
word_count: int
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CrawlSummary:
|
||||||
|
"""Represents statistics from a completed crawl."""
|
||||||
|
total_pages_processed: int
|
||||||
|
total_errors: int
|
||||||
|
total_skipped: int
|
||||||
|
total_urls_discovered: int
|
||||||
|
start_time: datetime
|
||||||
|
end_time: datetime
|
||||||
|
duration_seconds: float
|
||||||
|
error_breakdown: Dict[str, int] = field(default_factory=dict)
|
||||||
57
api/app/core/rag/crawler/rate_limiter.py
Normal file
57
api/app/core/rag/crawler/rate_limiter.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Rate limiter for web crawler."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Enforce delays between requests to be polite to servers."""
|
||||||
|
|
||||||
|
def __init__(self, delay_seconds: float = 1.0):
|
||||||
|
"""
|
||||||
|
Initialize rate limiter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delay_seconds: Minimum delay between requests
|
||||||
|
"""
|
||||||
|
self.delay_seconds = delay_seconds
|
||||||
|
self.last_request_time = 0.0
|
||||||
|
self.max_delay = 60.0 # Cap maximum delay at 60 seconds
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
"""
|
||||||
|
Block until enough time has passed since last request.
|
||||||
|
Respects the configured delay.
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
elapsed = current_time - self.last_request_time
|
||||||
|
|
||||||
|
if elapsed < self.delay_seconds:
|
||||||
|
sleep_time = self.delay_seconds - elapsed
|
||||||
|
logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds")
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
self.last_request_time = time.time()
|
||||||
|
|
||||||
|
def set_delay(self, delay_seconds: float):
|
||||||
|
"""
|
||||||
|
Update the delay (useful for respecting Crawl-delay from robots.txt).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delay_seconds: New delay in seconds
|
||||||
|
"""
|
||||||
|
self.delay_seconds = min(delay_seconds, self.max_delay)
|
||||||
|
logger.info(f"Rate limiter delay updated to {self.delay_seconds} seconds")
|
||||||
|
|
||||||
|
def backoff(self, multiplier: float = 2.0):
|
||||||
|
"""
|
||||||
|
Increase delay exponentially for backoff scenarios (429, 503 responses).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
multiplier: Factor to multiply current delay by
|
||||||
|
"""
|
||||||
|
old_delay = self.delay_seconds
|
||||||
|
self.delay_seconds = min(self.delay_seconds * multiplier, self.max_delay)
|
||||||
|
logger.warning(f"Rate limiter backing off: {old_delay:.2f}s -> {self.delay_seconds:.2f}s")
|
||||||
118
api/app/core/rag/crawler/robots_parser.py
Normal file
118
api/app/core/rag/crawler/robots_parser.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""Robots.txt parser for web crawler."""
|
||||||
|
|
||||||
|
from urllib.robotparser import RobotFileParser
|
||||||
|
from urllib.parse import urlparse, urljoin
|
||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RobotsParser:
|
||||||
|
"""Parse and check robots.txt compliance for URLs."""
|
||||||
|
|
||||||
|
def __init__(self, user_agent: str, timeout: int = 10):
|
||||||
|
"""
|
||||||
|
Initialize robots.txt parser.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_agent: User agent string to check permissions for
|
||||||
|
timeout: Timeout for fetching robots.txt
|
||||||
|
"""
|
||||||
|
self.user_agent = user_agent
|
||||||
|
self.timeout = timeout
|
||||||
|
self._parsers = {} # Cache parsers by domain
|
||||||
|
|
||||||
|
def _get_robots_url(self, url: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the robots.txt URL for a given URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to get robots.txt for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: robots.txt URL
|
||||||
|
"""
|
||||||
|
parsed = urlparse(url)
|
||||||
|
robots_url = f"{parsed.scheme}://{parsed.netloc}/robots.txt"
|
||||||
|
return robots_url
|
||||||
|
|
||||||
|
def _get_parser(self, url: str) -> RobotFileParser:
|
||||||
|
"""
|
||||||
|
Get or create a RobotFileParser for the domain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to get parser for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RobotFileParser: Parser for the domain
|
||||||
|
"""
|
||||||
|
robots_url = self._get_robots_url(url)
|
||||||
|
|
||||||
|
# Return cached parser if available
|
||||||
|
if robots_url in self._parsers:
|
||||||
|
return self._parsers[robots_url]
|
||||||
|
|
||||||
|
# Create new parser
|
||||||
|
parser = RobotFileParser()
|
||||||
|
parser.set_url(robots_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fetch and parse robots.txt
|
||||||
|
parser.read()
|
||||||
|
logger.info(f"Successfully fetched robots.txt from {robots_url}")
|
||||||
|
except Exception as e:
|
||||||
|
# If robots.txt cannot be fetched, assume all URLs are allowed
|
||||||
|
logger.warning(f"Could not fetch robots.txt from {robots_url}: {e}. Assuming all URLs allowed.")
|
||||||
|
# Create a permissive parser
|
||||||
|
parser = RobotFileParser()
|
||||||
|
parser.parse([]) # Empty robots.txt allows everything
|
||||||
|
|
||||||
|
# Cache the parser
|
||||||
|
self._parsers[robots_url] = parser
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def can_fetch(self, url: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given URL can be fetched according to robots.txt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if allowed, False if disallowed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parser = self._get_parser(url)
|
||||||
|
allowed = parser.can_fetch(self.user_agent, url)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
logger.info(f"URL disallowed by robots.txt: {url}")
|
||||||
|
|
||||||
|
return allowed
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking robots.txt for {url}: {e}")
|
||||||
|
# On error, assume allowed
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_crawl_delay(self, url: str) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Get the Crawl-delay directive from robots.txt if present.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to get crawl delay for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[float]: Delay in seconds, or None if not specified
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parser = self._get_parser(url)
|
||||||
|
delay = parser.crawl_delay(self.user_agent)
|
||||||
|
|
||||||
|
if delay is not None:
|
||||||
|
logger.info(f"Crawl-delay from robots.txt: {delay} seconds")
|
||||||
|
|
||||||
|
return delay
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting crawl delay for {url}: {e}")
|
||||||
|
return None
|
||||||
171
api/app/core/rag/crawler/url_normalizer.py
Normal file
171
api/app/core/rag/crawler/url_normalizer.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""URL normalization and validation for web crawler."""
|
||||||
|
|
||||||
|
from typing import Optional, List
|
||||||
|
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode, urljoin
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
|
||||||
|
class URLNormalizer:
|
||||||
|
"""Normalize and validate URLs for deduplication and domain checking."""
|
||||||
|
|
||||||
|
# Common tracking parameters to remove
|
||||||
|
TRACKING_PARAMS = {
|
||||||
|
'utm_source', 'utm_medium', 'utm_campaign', 'utm_term', 'utm_content',
|
||||||
|
'fbclid', 'gclid', 'msclkid', '_ga', 'mc_cid', 'mc_eid'
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, base_domain: str):
|
||||||
|
"""
|
||||||
|
Initialize URL normalizer with base domain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_domain: The domain to use for same-domain checks
|
||||||
|
"""
|
||||||
|
parsed = urlparse(base_domain)
|
||||||
|
self.base_domain = parsed.netloc.lower() # example.com:8000
|
||||||
|
self.base_scheme = parsed.scheme or 'https' # https
|
||||||
|
|
||||||
|
def normalize(self, url: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Normalize a URL for deduplication.
|
||||||
|
|
||||||
|
Normalization rules:
|
||||||
|
1. Convert domain to lowercase
|
||||||
|
2. Remove fragments (#section)
|
||||||
|
3. Remove default ports (80 for http, 443 for https)
|
||||||
|
4. Remove trailing slashes (except for root)
|
||||||
|
5. Sort query parameters alphabetically
|
||||||
|
6. Remove common tracking parameters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: Normalized URL, or None if invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
# Validate scheme
|
||||||
|
if parsed.scheme not in ('http', 'https'):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Normalize domain to lowercase
|
||||||
|
netloc = parsed.netloc.lower()
|
||||||
|
|
||||||
|
# Remove default ports
|
||||||
|
if ':' in netloc:
|
||||||
|
host, port = netloc.rsplit(':', 1)
|
||||||
|
if (parsed.scheme == 'http' and port == '80') or \
|
||||||
|
(parsed.scheme == 'https' and port == '443'):
|
||||||
|
netloc = host
|
||||||
|
|
||||||
|
# Normalize path
|
||||||
|
path = parsed.path
|
||||||
|
# Remove trailing slash except for root
|
||||||
|
if path != '/' and path.endswith('/'):
|
||||||
|
path = path.rstrip('/')
|
||||||
|
# Ensure path starts with /
|
||||||
|
if not path:
|
||||||
|
path = '/'
|
||||||
|
|
||||||
|
# Process query parameters
|
||||||
|
query = ''
|
||||||
|
if parsed.query:
|
||||||
|
# Parse query parameters
|
||||||
|
params = parse_qs(parsed.query, keep_blank_values=True)
|
||||||
|
# Remove tracking parameters
|
||||||
|
filtered_params = {
|
||||||
|
k: v for k, v in params.items()
|
||||||
|
if k not in self.TRACKING_PARAMS
|
||||||
|
}
|
||||||
|
# Sort parameters alphabetically
|
||||||
|
if filtered_params:
|
||||||
|
sorted_params = sorted(filtered_params.items())
|
||||||
|
query = urlencode(sorted_params, doseq=True)
|
||||||
|
|
||||||
|
# Reconstruct URL without fragment
|
||||||
|
normalized = urlunparse((
|
||||||
|
parsed.scheme,
|
||||||
|
netloc,
|
||||||
|
path,
|
||||||
|
parsed.params,
|
||||||
|
query,
|
||||||
|
'' # Remove fragment
|
||||||
|
))
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_same_domain(self, url: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if URL belongs to the same domain as base_domain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if same domain, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
domain = parsed.netloc.lower()
|
||||||
|
|
||||||
|
# Remove port if present
|
||||||
|
if ':' in domain:
|
||||||
|
domain = domain.split(':')[0]
|
||||||
|
|
||||||
|
# Check if domains match
|
||||||
|
return domain == self.base_domain or domain == self.base_domain.split(':')[0]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def extract_links(self, html: str, base_url: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Extract and normalize all links from HTML.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
html: HTML content
|
||||||
|
base_url: Base URL for resolving relative links
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of normalized absolute URLs
|
||||||
|
"""
|
||||||
|
links = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
soup = BeautifulSoup(html, 'lxml')
|
||||||
|
|
||||||
|
# Find all anchor tags
|
||||||
|
for anchor in soup.find_all('a', href=True):
|
||||||
|
href = anchor['href']
|
||||||
|
|
||||||
|
# Skip empty hrefs
|
||||||
|
if not href or href.strip() == '':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip javascript: and mailto: links
|
||||||
|
if href.startswith(('javascript:', 'mailto:', 'tel:')):
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_url = None
|
||||||
|
# Check if href starts with http/https (absolute URL)
|
||||||
|
if href.startswith(('http://', 'https://')):
|
||||||
|
if self.is_same_domain(href):
|
||||||
|
normalized_url = self.normalize(href)
|
||||||
|
else:
|
||||||
|
# Convert relative URL to absolute
|
||||||
|
absolute_url = urljoin(base_url, href)
|
||||||
|
# Normalize the URL
|
||||||
|
normalized_url = self.normalize(absolute_url)
|
||||||
|
|
||||||
|
if normalized_url:
|
||||||
|
links.append(normalized_url)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return links
|
||||||
215
api/app/core/rag/crawler/web_crawler.py
Normal file
215
api/app/core/rag/crawler/web_crawler.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
"""Main web crawler orchestrator."""
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Iterator, Optional, List, Set
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from app.core.rag.crawler.url_normalizer import URLNormalizer
|
||||||
|
from app.core.rag.crawler.robots_parser import RobotsParser
|
||||||
|
from app.core.rag.crawler.rate_limiter import RateLimiter
|
||||||
|
from app.core.rag.crawler.http_fetcher import HTTPFetcher
|
||||||
|
from app.core.rag.crawler.content_extractor import ContentExtractor
|
||||||
|
from app.core.rag.crawler.models import CrawledDocument, CrawlSummary
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WebCrawler:
|
||||||
|
"""Main orchestrator for web crawling."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entry_url: str,
|
||||||
|
max_pages: int = 200,
|
||||||
|
delay_seconds: float = 1.0,
|
||||||
|
timeout_seconds: int = 10,
|
||||||
|
user_agent: str = "KnowledgeBaseCrawler/1.0",
|
||||||
|
include_patterns: Optional[List[str]] = None,
|
||||||
|
exclude_patterns: Optional[List[str]] = None,
|
||||||
|
content_extractor: Optional[ContentExtractor] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the web crawler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry_url: Starting URL for the crawl
|
||||||
|
max_pages: Maximum number of pages to crawl (default: 200)
|
||||||
|
delay_seconds: Delay between requests in seconds (default: 1.0)
|
||||||
|
timeout_seconds: HTTP request timeout (default: 10)
|
||||||
|
user_agent: User-Agent header string
|
||||||
|
include_patterns: List of regex patterns for URLs to include
|
||||||
|
exclude_patterns: List of regex patterns for URLs to exclude
|
||||||
|
content_extractor: Custom content extractor (optional)
|
||||||
|
"""
|
||||||
|
# Validate entry URL
|
||||||
|
parsed = urlparse(entry_url)
|
||||||
|
if not parsed.scheme or not parsed.netloc:
|
||||||
|
raise ValueError(f"Invalid entry URL: {entry_url}")
|
||||||
|
|
||||||
|
self.entry_url = entry_url
|
||||||
|
self.max_pages = max_pages
|
||||||
|
self.user_agent = user_agent
|
||||||
|
|
||||||
|
# Extract domain from entry URL
|
||||||
|
self.domain = parsed.netloc
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self.url_normalizer = URLNormalizer(entry_url)
|
||||||
|
self.robots_parser = RobotsParser(user_agent, timeout_seconds)
|
||||||
|
self.rate_limiter = RateLimiter(delay_seconds)
|
||||||
|
self.http_fetcher = HTTPFetcher(timeout_seconds, max_retries=3, user_agent=user_agent)
|
||||||
|
self.content_extractor = content_extractor or ContentExtractor()
|
||||||
|
|
||||||
|
# State management
|
||||||
|
self.url_queue: deque = deque()
|
||||||
|
self.visited_urls: Set[str] = set()
|
||||||
|
self.pages_processed = 0
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self.stats = {
|
||||||
|
'success': 0,
|
||||||
|
'errors': 0,
|
||||||
|
'skipped': 0,
|
||||||
|
'urls_discovered': 0,
|
||||||
|
'error_breakdown': {}
|
||||||
|
}
|
||||||
|
self.start_time: Optional[datetime] = None
|
||||||
|
self.end_time: Optional[datetime] = None
|
||||||
|
|
||||||
|
def crawl(self) -> Iterator[CrawledDocument]:
|
||||||
|
"""
|
||||||
|
Execute the crawl and yield documents as they are processed.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
CrawledDocument: Structured document with extracted content
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting crawl from {self.entry_url} (max_pages: {self.max_pages})")
|
||||||
|
self.start_time = datetime.now()
|
||||||
|
|
||||||
|
# Add entry URL to queue
|
||||||
|
normalized_entry = self.url_normalizer.normalize(self.entry_url)
|
||||||
|
if normalized_entry:
|
||||||
|
self.url_queue.append(normalized_entry)
|
||||||
|
self.stats['urls_discovered'] += 1
|
||||||
|
|
||||||
|
# Check robots.txt and update rate limiter if needed
|
||||||
|
crawl_delay = self.robots_parser.get_crawl_delay(self.entry_url)
|
||||||
|
if crawl_delay:
|
||||||
|
self.rate_limiter.set_delay(crawl_delay)
|
||||||
|
|
||||||
|
# Main crawl loop
|
||||||
|
while self.url_queue and self.pages_processed < self.max_pages:
|
||||||
|
url = self.url_queue.popleft()
|
||||||
|
|
||||||
|
# Skip if already visited
|
||||||
|
if url in self.visited_urls:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Mark as visited
|
||||||
|
self.visited_urls.add(url)
|
||||||
|
|
||||||
|
# Check robots.txt permission
|
||||||
|
if not self.robots_parser.can_fetch(url):
|
||||||
|
logger.info(f"Skipping {url} (disallowed by robots.txt)")
|
||||||
|
self.stats['skipped'] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Apply rate limiting
|
||||||
|
self.rate_limiter.wait()
|
||||||
|
|
||||||
|
# Fetch URL
|
||||||
|
logger.info(f"Fetching {url} ({self.pages_processed + 1}/{self.max_pages})")
|
||||||
|
fetch_result = self.http_fetcher.fetch(url)
|
||||||
|
|
||||||
|
# Handle fetch errors
|
||||||
|
if not fetch_result.success:
|
||||||
|
self._record_error(fetch_result.error or "Unknown error")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check Content-Type
|
||||||
|
content_type = fetch_result.headers.get('Content-Type', '').lower()
|
||||||
|
if not any(substring in content_type for substring in ['text/html', 'application/xhtml+xml']):
|
||||||
|
logger.warning(f"Skipping {url} (Content-Type: {content_type})")
|
||||||
|
self.stats['skipped'] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract content
|
||||||
|
try:
|
||||||
|
extracted = self.content_extractor.extract(fetch_result.content, url)
|
||||||
|
|
||||||
|
# Check if static content
|
||||||
|
if not extracted.is_static:
|
||||||
|
logger.warning(f"Skipping {url} (JavaScript-rendered content)")
|
||||||
|
self.stats['skipped'] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create document
|
||||||
|
document = CrawledDocument(
|
||||||
|
url=url,
|
||||||
|
title=extracted.title,
|
||||||
|
content=extracted.text,
|
||||||
|
content_length=len(extracted.text),
|
||||||
|
crawl_timestamp=datetime.now(),
|
||||||
|
http_status=fetch_result.status_code,
|
||||||
|
metadata={
|
||||||
|
'word_count': extracted.word_count,
|
||||||
|
'final_url': fetch_result.final_url
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
self.pages_processed += 1
|
||||||
|
self.stats['success'] += 1
|
||||||
|
|
||||||
|
# Extract and queue links
|
||||||
|
links = self.url_normalizer.extract_links(fetch_result.content, url)
|
||||||
|
for link in links:
|
||||||
|
if link not in self.visited_urls and self.url_normalizer.is_same_domain(link):
|
||||||
|
if link not in self.url_queue:
|
||||||
|
self.url_queue.append(link)
|
||||||
|
self.stats['urls_discovered'] += 1
|
||||||
|
|
||||||
|
# Yield document
|
||||||
|
yield document
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing {url}: {e}")
|
||||||
|
self._record_error(f"Processing error: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.end_time = datetime.now()
|
||||||
|
logger.info(f"Crawl completed. Processed {self.pages_processed} pages.")
|
||||||
|
|
||||||
|
def get_summary(self) -> CrawlSummary:
|
||||||
|
"""
|
||||||
|
Get summary statistics after crawl completion.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CrawlSummary: Statistics including success/error/skip counts
|
||||||
|
"""
|
||||||
|
if not self.start_time:
|
||||||
|
self.start_time = datetime.now()
|
||||||
|
if not self.end_time:
|
||||||
|
self.end_time = datetime.now()
|
||||||
|
|
||||||
|
duration = (self.end_time - self.start_time).total_seconds()
|
||||||
|
|
||||||
|
return CrawlSummary(
|
||||||
|
total_pages_processed=self.stats['success'],
|
||||||
|
total_errors=self.stats['errors'],
|
||||||
|
total_skipped=self.stats['skipped'],
|
||||||
|
total_urls_discovered=self.stats['urls_discovered'],
|
||||||
|
start_time=self.start_time,
|
||||||
|
end_time=self.end_time,
|
||||||
|
duration_seconds=duration,
|
||||||
|
error_breakdown=self.stats['error_breakdown']
|
||||||
|
)
|
||||||
|
|
||||||
|
def _record_error(self, error: str):
|
||||||
|
"""Record an error in statistics."""
|
||||||
|
self.stats['errors'] += 1
|
||||||
|
error_type = error.split(':')[0] if ':' in error else error
|
||||||
|
self.stats['error_breakdown'][error_type] = \
|
||||||
|
self.stats['error_breakdown'].get(error_type, 0) + 1
|
||||||
1
api/app/core/rag/integrations/__init__.py
Normal file
1
api/app/core/rag/integrations/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Integrations package for external services."""
|
||||||
1
api/app/core/rag/integrations/feishu/__init__.py
Normal file
1
api/app/core/rag/integrations/feishu/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Feishu integration module for document synchronization."""
|
||||||
84
api/app/core/rag/integrations/feishu/__main__.py
Normal file
84
api/app/core/rag/integrations/feishu/__main__.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Command-line interface for feishu integration."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||||
|
from app.core.rag.integrations.feishu.models import FileInfo
|
||||||
|
|
||||||
|
|
||||||
|
def main(feishu_app_id: str, # Feishu application ID
|
||||||
|
feishu_app_secret: str, # Feishu application secret
|
||||||
|
feishu_folder_token: str, # Feishu Folder Token
|
||||||
|
save_dir: str, # save file directory
|
||||||
|
feishu_api_base_url: str = "https://open.feishu.cn/open-apis", # Feishu API base URL
|
||||||
|
timeout: int = 30, # Request timeout in seconds
|
||||||
|
max_retries: int = 3, # Maximum number of retries
|
||||||
|
recursive: bool = True # recursive: Whether to sync subfolders recursively,
|
||||||
|
):
|
||||||
|
"""Main entry point for the feishuAPIClient."""
|
||||||
|
# Create feishuAPIClient
|
||||||
|
api_client = FeishuAPIClient(
|
||||||
|
app_id=feishu_app_id,
|
||||||
|
app_secret=feishu_app_secret,
|
||||||
|
api_base_url=feishu_api_base_url,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all files from folder
|
||||||
|
async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str):
|
||||||
|
async with api_client as client:
|
||||||
|
if recursive:
|
||||||
|
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
|
||||||
|
else:
|
||||||
|
all_files = []
|
||||||
|
page_token = None
|
||||||
|
while True:
|
||||||
|
files_page, page_token = await client.list_folder_files(
|
||||||
|
feishu_folder_token, page_token
|
||||||
|
)
|
||||||
|
all_files.extend(files_page)
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
files = all_files
|
||||||
|
return files
|
||||||
|
files = asyncio.run(async_get_files(api_client,feishu_folder_token))
|
||||||
|
|
||||||
|
# Filter out folders, only sync documents
|
||||||
|
# documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file", "slides"]]
|
||||||
|
documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]]
|
||||||
|
|
||||||
|
try:
|
||||||
|
for doc in documents:
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"token: {doc.token}")
|
||||||
|
print(f"name: {doc.name}")
|
||||||
|
print(f"type: {doc.type}")
|
||||||
|
print(f"created_time: {doc.created_time}")
|
||||||
|
print(f"modified_time: {doc.modified_time}")
|
||||||
|
print(f"owner_id: {doc.owner_id}")
|
||||||
|
print(f"url: {doc.url}")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
# download document from Feishu FileInfo
|
||||||
|
async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str):
|
||||||
|
async with api_client as client:
|
||||||
|
file_path = await client.download_document(document=doc, save_dir=save_dir)
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
|
||||||
|
print(file_path)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nfeishu integration interrupted by user.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\nError during feishu integration: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
feishu_app_id = ""
|
||||||
|
feishu_app_secret = ""
|
||||||
|
feishu_folder_token = ""
|
||||||
|
save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/"
|
||||||
|
main(feishu_app_id, feishu_app_secret, feishu_folder_token, save_dir)
|
||||||
452
api/app/core/rag/integrations/feishu/client.py
Normal file
452
api/app/core/rag/integrations/feishu/client.py
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
"""Feishu API client for document operations."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import httpx
|
||||||
|
from cachetools import TTLCache
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
from app.core.rag.integrations.feishu.exceptions import (
|
||||||
|
FeishuAuthError,
|
||||||
|
FeishuAPIError,
|
||||||
|
FeishuNotFoundError,
|
||||||
|
FeishuPermissionError,
|
||||||
|
FeishuRateLimitError,
|
||||||
|
FeishuNetworkError,
|
||||||
|
)
|
||||||
|
from app.core.rag.integrations.feishu.models import FileInfo
|
||||||
|
from app.core.rag.integrations.feishu.retry import with_retry
|
||||||
|
|
||||||
|
|
||||||
|
class FeishuAPIClient:
|
||||||
|
"""Feishu API client for document synchronization."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app_id: str,
|
||||||
|
app_secret: str,
|
||||||
|
api_base_url: str = "https://open.feishu.cn/open-apis",
|
||||||
|
timeout: int = 30,
|
||||||
|
max_retries: int = 3
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Feishu API client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: Feishu application ID
|
||||||
|
app_secret: Feishu application secret
|
||||||
|
api_base_url: Feishu API base URL
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
max_retries: Maximum number of retries
|
||||||
|
"""
|
||||||
|
self.app_id = app_id
|
||||||
|
self.app_secret = app_secret
|
||||||
|
self.api_base_url = api_base_url
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self._http_client: Optional[httpx.AsyncClient] = None
|
||||||
|
self._token_cache: TTLCache = TTLCache(maxsize=1, ttl=7200 - 300) # 2 hours - 5 minutes
|
||||||
|
self._token_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Async context manager entry."""
|
||||||
|
self._http_client = httpx.AsyncClient(
|
||||||
|
base_url=self.api_base_url,
|
||||||
|
timeout=self.timeout,
|
||||||
|
headers={"Content-Type": "application/json"}
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Async context manager exit."""
|
||||||
|
if self._http_client:
|
||||||
|
await self._http_client.aclose()
|
||||||
|
|
||||||
|
async def get_tenant_access_token(self) -> str:
|
||||||
|
"""
|
||||||
|
Get tenant access token with caching.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Access token string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FeishuAuthError: If authentication fails
|
||||||
|
"""
|
||||||
|
# Check cache first
|
||||||
|
cached_token = self._token_cache.get("access_token")
|
||||||
|
if cached_token:
|
||||||
|
return cached_token
|
||||||
|
|
||||||
|
# Use lock to prevent concurrent token requests
|
||||||
|
async with self._token_lock:
|
||||||
|
# Double-check cache after acquiring lock
|
||||||
|
cached_token = self._token_cache.get("access_token")
|
||||||
|
if cached_token:
|
||||||
|
return cached_token
|
||||||
|
|
||||||
|
# Request new token
|
||||||
|
try:
|
||||||
|
if not self._http_client:
|
||||||
|
raise FeishuAuthError("HTTP client not initialized")
|
||||||
|
|
||||||
|
response = await self._http_client.post(
|
||||||
|
"/auth/v3/tenant_access_token/internal",
|
||||||
|
json={
|
||||||
|
"app_id": self.app_id,
|
||||||
|
"app_secret": self.app_secret
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if data.get("code") != 0:
|
||||||
|
error_msg = data.get("msg", "Unknown error")
|
||||||
|
raise FeishuAuthError(
|
||||||
|
f"Authentication failed: {error_msg}",
|
||||||
|
error_code=str(data.get("code")),
|
||||||
|
details=data
|
||||||
|
)
|
||||||
|
|
||||||
|
token = data.get("tenant_access_token")
|
||||||
|
if not token:
|
||||||
|
raise FeishuAuthError("No access token in response")
|
||||||
|
|
||||||
|
# Cache the token
|
||||||
|
self._token_cache["access_token"] = token
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise FeishuAuthError(f"HTTP error during authentication: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, FeishuAuthError):
|
||||||
|
raise
|
||||||
|
raise FeishuAuthError(f"Unexpected error during authentication: {str(e)}")
|
||||||
|
|
||||||
|
@with_retry
|
||||||
|
async def list_folder_files(
|
||||||
|
self,
|
||||||
|
folder_token: str,
|
||||||
|
page_token: Optional[str] = None
|
||||||
|
) -> Tuple[List[FileInfo], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Get list of files in a folder with pagination support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_token: Folder token
|
||||||
|
page_token: Page token for pagination
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (list of FileInfo, next page token)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FeishuAPIError: If API call fails
|
||||||
|
FeishuNotFoundError: If folder not found
|
||||||
|
FeishuPermissionError: If permission denied
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
token = await self.get_tenant_access_token()
|
||||||
|
|
||||||
|
if not self._http_client:
|
||||||
|
raise FeishuAPIError("HTTP client not initialized")
|
||||||
|
|
||||||
|
# Build request parameters
|
||||||
|
params = {"page_size": 200, "folder_token": folder_token}
|
||||||
|
if page_token:
|
||||||
|
params["page_token"] = page_token
|
||||||
|
|
||||||
|
# Make API request
|
||||||
|
response = await self._http_client.get(
|
||||||
|
f"/drive/v1/files",
|
||||||
|
params=params,
|
||||||
|
headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
# print(f"get files: {data}")
|
||||||
|
|
||||||
|
# Handle errors
|
||||||
|
if data.get("code") != 0:
|
||||||
|
error_code = data.get("code")
|
||||||
|
error_msg = data.get("msg", "Unknown error")
|
||||||
|
|
||||||
|
if error_code == 404 or error_code == 230005:
|
||||||
|
raise FeishuNotFoundError(
|
||||||
|
f"Folder not found: {error_msg}",
|
||||||
|
error_code=str(error_code),
|
||||||
|
details=data
|
||||||
|
)
|
||||||
|
elif error_code == 403 or error_code == 230003:
|
||||||
|
raise FeishuPermissionError(
|
||||||
|
f"Permission denied: {error_msg}",
|
||||||
|
error_code=str(error_code),
|
||||||
|
details=data
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise FeishuAPIError(
|
||||||
|
f"API error: {error_msg}",
|
||||||
|
error_code=str(error_code),
|
||||||
|
details=data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
files_data = data.get("data", {}).get("files", [])
|
||||||
|
next_page_token = data.get("data", {}).get("next_page_token", None)
|
||||||
|
|
||||||
|
# Convert to FileInfo objects
|
||||||
|
files = []
|
||||||
|
for file_data in files_data:
|
||||||
|
try:
|
||||||
|
file_info = FileInfo(
|
||||||
|
token=file_data.get("token", ""),
|
||||||
|
name=file_data.get("name", ""),
|
||||||
|
type=file_data.get("type", ""),
|
||||||
|
created_time=datetime.fromtimestamp(int(file_data.get("created_time", 0))),
|
||||||
|
modified_time=datetime.fromtimestamp(int(file_data.get("modified_time", 0))),
|
||||||
|
owner_id=file_data.get("owner_id", ""),
|
||||||
|
url=file_data.get("url", "")
|
||||||
|
)
|
||||||
|
files.append(file_info)
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
# Skip invalid file entries
|
||||||
|
continue
|
||||||
|
|
||||||
|
return files, next_page_token
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise FeishuAPIError(f"HTTP error: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)):
|
||||||
|
raise
|
||||||
|
raise FeishuAPIError(f"Unexpected error: {str(e)}")
|
||||||
|
|
||||||
|
async def list_all_folder_files(
|
||||||
|
self,
|
||||||
|
folder_token: str,
|
||||||
|
recursive: bool = True
|
||||||
|
) -> List[FileInfo]:
|
||||||
|
"""
|
||||||
|
Get all files in a folder, handling pagination automatically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_token: Folder token
|
||||||
|
recursive: Whether to recursively get files from subfolders
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all FileInfo objects
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FeishuAPIError: If API call fails
|
||||||
|
"""
|
||||||
|
all_files = []
|
||||||
|
page_token = None
|
||||||
|
|
||||||
|
# Get all files with pagination
|
||||||
|
while True:
|
||||||
|
files, page_token = await self.list_folder_files(folder_token, page_token)
|
||||||
|
all_files.extend(files)
|
||||||
|
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Recursively get files from subfolders if requested
|
||||||
|
if recursive:
|
||||||
|
subfolders = [f for f in all_files if f.type == "folder"]
|
||||||
|
for subfolder in subfolders:
|
||||||
|
try:
|
||||||
|
subfolder_files = await self.list_all_folder_files(
|
||||||
|
subfolder.token,
|
||||||
|
recursive=True
|
||||||
|
)
|
||||||
|
all_files.extend(subfolder_files)
|
||||||
|
except Exception:
|
||||||
|
# Continue with other folders if one fails
|
||||||
|
continue
|
||||||
|
|
||||||
|
return all_files
|
||||||
|
|
||||||
|
@with_retry
|
||||||
|
async def download_document(
|
||||||
|
self,
|
||||||
|
document: FileInfo,
|
||||||
|
save_dir: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
download document content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document: Document FileInfo
|
||||||
|
save_dir: save dir
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
file_full_path
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FeishuAPIError: If API call fails
|
||||||
|
FeishuNotFoundError: If document not found
|
||||||
|
FeishuPermissionError: If permission denied
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
token = await self.get_tenant_access_token()
|
||||||
|
|
||||||
|
if not self._http_client:
|
||||||
|
raise FeishuAPIError("HTTP client not initialized")
|
||||||
|
|
||||||
|
# Different API endpoints for different document types
|
||||||
|
if document.type == "doc" or document.type == "docx" or document.type == "sheet" or document.type == "bitable":
|
||||||
|
return await self._export_file(document, token, save_dir)
|
||||||
|
elif document.type == "file" or document.type == "slides":
|
||||||
|
return await self._download_file(document, token, save_dir)
|
||||||
|
else:
|
||||||
|
raise FeishuAPIError(f"Unsupported document type: {document.type}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)):
|
||||||
|
raise
|
||||||
|
raise FeishuAPIError(f"Unexpected error: {str(e)}")
|
||||||
|
|
||||||
|
async def _export_file(self, document: FileInfo, access_token: str, save_dir: str) -> str:
|
||||||
|
"""export file for feishu online file type."""
|
||||||
|
try:
|
||||||
|
# 1.创建导出任务
|
||||||
|
file_extension = "pdf"
|
||||||
|
match document.type:
|
||||||
|
case "doc":
|
||||||
|
file_extension = "doc"
|
||||||
|
case "docx":
|
||||||
|
file_extension = "docx"
|
||||||
|
case "sheet":
|
||||||
|
file_extension = "xlsx"
|
||||||
|
case "bitable":
|
||||||
|
file_extension = "xlsx"
|
||||||
|
case _:
|
||||||
|
file_extension = "pdf"
|
||||||
|
response = await self._http_client.post(
|
||||||
|
"/drive/v1/export_tasks",
|
||||||
|
json={
|
||||||
|
"file_extension": file_extension,
|
||||||
|
"token": document.token,
|
||||||
|
"type": document.type
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"}
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
print(f"1.创建导出任务: {data}")
|
||||||
|
|
||||||
|
if data.get("code") != 0:
|
||||||
|
error_code = data.get("code")
|
||||||
|
error_msg = data.get("msg", "Unknown error")
|
||||||
|
raise FeishuAPIError(
|
||||||
|
f"API error: {error_msg}",
|
||||||
|
error_code=str(error_code),
|
||||||
|
details=data
|
||||||
|
)
|
||||||
|
|
||||||
|
ticket = data.get("data", {}).get("ticket", None)
|
||||||
|
if not ticket:
|
||||||
|
raise FeishuAuthError("No ticket in response")
|
||||||
|
|
||||||
|
# 2.轮序查询导出任务结果
|
||||||
|
max_retries = 10 # 最大轮询次数
|
||||||
|
poll_interval = 2 # 每次轮询间隔时间(秒)
|
||||||
|
file_token = None
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
# 查询导出任务
|
||||||
|
response = await self._http_client.get(
|
||||||
|
f"/drive/v1/export_tasks/{ticket}",
|
||||||
|
params={"token": document.token},
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"}
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
print(f"2. 尝试查询导出任务结果 (第{attempt + 1}次): {data}")
|
||||||
|
|
||||||
|
if data.get("code") != 0:
|
||||||
|
error_code = data.get("code")
|
||||||
|
error_msg = data.get("msg", "Unknown error")
|
||||||
|
raise FeishuAPIError(
|
||||||
|
f"API error: {error_msg}",
|
||||||
|
error_code=str(error_code),
|
||||||
|
details=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查导出任务结果
|
||||||
|
file_token = data.get("data", {}).get("result", {}).get("file_token", None)
|
||||||
|
if file_token:
|
||||||
|
# 如果导出任务成功生成 file_token,则退出轮询
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果结果还没准备好,等待一段时间再进行下一次轮询
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
if not file_token:
|
||||||
|
raise FeishuAPIError("Export task did not complete within the allowed time")
|
||||||
|
|
||||||
|
# 3.下载导出任务
|
||||||
|
response = await self._http_client.get(
|
||||||
|
f"/drive/v1/export_tasks/file/{file_token}/download",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
print(f'3.下载导出任务: {response.headers.get("Content-Disposition")}')
|
||||||
|
|
||||||
|
file_full_path = os.path.join(save_dir, document.name + "." + file_extension)
|
||||||
|
if os.path.exists(file_full_path):
|
||||||
|
os.remove(file_full_path) # Delete a single file
|
||||||
|
with open(file_full_path, "wb") as file:
|
||||||
|
file.write(response.content)
|
||||||
|
|
||||||
|
return file_full_path
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise FeishuAPIError(f"HTTP error: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
raise FeishuAPIError(f"Unexpected error during file download: {str(e)}")
|
||||||
|
|
||||||
|
async def _download_file(self, document: FileInfo, access_token: str, save_dir: str) -> str:
|
||||||
|
"""download file for file type."""
|
||||||
|
try:
|
||||||
|
response = await self._http_client.get(
|
||||||
|
f"/drive/v1/files/{document.token}/download",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
filename_header = response.headers.get("Content-Disposition")
|
||||||
|
|
||||||
|
# 最终的文件名(初始化为 None)
|
||||||
|
filename = None
|
||||||
|
if filename_header:
|
||||||
|
# 优先解析 filename* 格式
|
||||||
|
match = re.search(r"filename\*=([^']*)''([^;]+)", filename_header)
|
||||||
|
if match:
|
||||||
|
# 使用 `filename*` 提取(已编码)
|
||||||
|
encoding = match.group(1) # 编码部分(如 UTF-8)
|
||||||
|
encoded_filename = match.group(2) # 文件名部分
|
||||||
|
filename = urllib.parse.unquote(encoded_filename) # 解码 URL 编码的文件名
|
||||||
|
|
||||||
|
# 如果 `filename*` 不存在,回退到解析 `filename`
|
||||||
|
if not filename:
|
||||||
|
match = re.search(r'filename="([^"]+)"', filename_header)
|
||||||
|
if match:
|
||||||
|
filename = match.group(1)
|
||||||
|
# 如果文件名仍为 None,则使用默认文件名
|
||||||
|
if not filename:
|
||||||
|
filename = f"{document.name}.pdf"
|
||||||
|
# 确保文件名合法,替换非法字符
|
||||||
|
filename = re.sub(r'[\/:*?"<>|]', '_', filename)
|
||||||
|
|
||||||
|
file_full_path = os.path.join(save_dir, filename)
|
||||||
|
if os.path.exists(file_full_path):
|
||||||
|
os.remove(file_full_path) # Delete a single file
|
||||||
|
with open(file_full_path, "wb") as file:
|
||||||
|
file.write(response.content)
|
||||||
|
|
||||||
|
return file_full_path
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise FeishuAPIError(f"HTTP error: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
raise FeishuAPIError(f"Unexpected error during file download: {str(e)}")
|
||||||
46
api/app/core/rag/integrations/feishu/exceptions.py
Normal file
46
api/app/core/rag/integrations/feishu/exceptions.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Exception classes for Feishu integration."""
|
||||||
|
|
||||||
|
|
||||||
|
class FeishuError(Exception):
|
||||||
|
"""Base exception for all Feishu-related errors."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, error_code: str = None, details: dict = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.error_code = error_code
|
||||||
|
self.details = details or {}
|
||||||
|
|
||||||
|
|
||||||
|
class FeishuAuthError(FeishuError):
|
||||||
|
"""Authentication error with Feishu API."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FeishuAPIError(FeishuError):
|
||||||
|
"""General API error from Feishu."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FeishuNotFoundError(FeishuError):
|
||||||
|
"""Resource not found error (404)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FeishuPermissionError(FeishuError):
|
||||||
|
"""Permission denied error (403)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FeishuRateLimitError(FeishuError):
|
||||||
|
"""Rate limit exceeded error (429)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FeishuNetworkError(FeishuError):
|
||||||
|
"""Network-related error (timeout, connection failure)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FeishuDataError(FeishuError):
|
||||||
|
"""Data parsing or validation error."""
|
||||||
|
pass
|
||||||
17
api/app/core/rag/integrations/feishu/models.py
Normal file
17
api/app/core/rag/integrations/feishu/models.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Data models for Feishu integration."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileInfo:
|
||||||
|
"""File information from Feishu."""
|
||||||
|
token: str
|
||||||
|
name: str
|
||||||
|
type: str # doc/docx/sheet/bitable/file/slides/folder
|
||||||
|
created_time: datetime
|
||||||
|
modified_time: datetime
|
||||||
|
owner_id: str
|
||||||
|
url: str
|
||||||
137
api/app/core/rag/integrations/feishu/retry.py
Normal file
137
api/app/core/rag/integrations/feishu/retry.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""Retry strategy for Feishu API calls."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.core.rag.integrations.feishu.exceptions import (
|
||||||
|
FeishuAuthError,
|
||||||
|
FeishuPermissionError,
|
||||||
|
FeishuNotFoundError,
|
||||||
|
FeishuRateLimitError,
|
||||||
|
FeishuNetworkError,
|
||||||
|
FeishuDataError,
|
||||||
|
FeishuAPIError,
|
||||||
|
)
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
class RetryStrategy:
|
||||||
|
"""Retry strategy for API calls."""
|
||||||
|
|
||||||
|
# Retryable error types
|
||||||
|
RETRYABLE_ERRORS = (
|
||||||
|
FeishuNetworkError,
|
||||||
|
FeishuRateLimitError,
|
||||||
|
httpx.TimeoutException,
|
||||||
|
httpx.ConnectError,
|
||||||
|
httpx.ReadError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-retryable error types
|
||||||
|
NON_RETRYABLE_ERRORS = (
|
||||||
|
FeishuAuthError,
|
||||||
|
FeishuPermissionError,
|
||||||
|
FeishuNotFoundError,
|
||||||
|
FeishuDataError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retry configuration
|
||||||
|
MAX_RETRIES = 3
|
||||||
|
BACKOFF_DELAYS = [1, 2, 4] # seconds
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_retryable(cls, error: Exception) -> bool:
|
||||||
|
"""Check if an error is retryable."""
|
||||||
|
# Check for specific retryable errors
|
||||||
|
if isinstance(error, cls.RETRYABLE_ERRORS):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for non-retryable errors
|
||||||
|
if isinstance(error, cls.NON_RETRYABLE_ERRORS):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for HTTP status codes
|
||||||
|
if isinstance(error, httpx.HTTPStatusError):
|
||||||
|
status_code = error.response.status_code
|
||||||
|
# Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway)
|
||||||
|
if status_code in [429, 502, 503]:
|
||||||
|
return True
|
||||||
|
# Don't retry on 4xx errors (except 429)
|
||||||
|
if 400 <= status_code < 500:
|
||||||
|
return False
|
||||||
|
# Retry on 5xx errors
|
||||||
|
if 500 <= status_code < 600:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for FeishuAPIError with specific codes
|
||||||
|
if isinstance(error, FeishuAPIError):
|
||||||
|
if error.error_code:
|
||||||
|
# Rate limit error codes
|
||||||
|
if error.error_code in ["99991400", "99991401"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute_with_retry(
|
||||||
|
cls,
|
||||||
|
func: Callable[..., T],
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
Execute a function with retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Async function to execute
|
||||||
|
*args: Positional arguments for the function
|
||||||
|
**kwargs: Keyword arguments for the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: The last exception if all retries fail
|
||||||
|
"""
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
for attempt in range(cls.MAX_RETRIES + 1):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
# Don't retry if not retryable
|
||||||
|
if not cls.is_retryable(e):
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Don't retry if this was the last attempt
|
||||||
|
if attempt >= cls.MAX_RETRIES:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Wait before retrying
|
||||||
|
delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1]
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
# Should not reach here, but raise last exception if we do
|
||||||
|
if last_exception:
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
|
||||||
|
def with_retry(func: Callable[..., T]) -> Callable[..., T]:
|
||||||
|
"""
|
||||||
|
Decorator to add retry logic to async functions.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@with_retry
|
||||||
|
async def my_api_call():
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
return await RetryStrategy.execute_with_retry(func, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
1
api/app/core/rag/integrations/yuque/__init__.py
Normal file
1
api/app/core/rag/integrations/yuque/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Yuque integration module for document synchronization."""
|
||||||
77
api/app/core/rag/integrations/yuque/__main__.py
Normal file
77
api/app/core/rag/integrations/yuque/__main__.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Main entry point for Yuque integration testing."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||||
|
from app.core.rag.integrations.yuque.models import YuqueDocInfo
|
||||||
|
|
||||||
|
|
||||||
|
def main(yuque_user_id: str, # yuque User ID
|
||||||
|
yuque_token: str, # yuque Token
|
||||||
|
save_dir: str, # save file directory
|
||||||
|
):
|
||||||
|
"""Main entry point for the YuqueAPIClient."""
|
||||||
|
# Create feishuAPIClient
|
||||||
|
api_client = YuqueAPIClient(
|
||||||
|
user_id=yuque_user_id,
|
||||||
|
token=yuque_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all files from all repos
|
||||||
|
async def async_get_files(api_client: YuqueAPIClient):
|
||||||
|
async with api_client as client:
|
||||||
|
print("\n=== Fetching repositories ===")
|
||||||
|
repos = await client.get_user_repos()
|
||||||
|
print(f"Found {len(repos)} repositories:")
|
||||||
|
all_files = []
|
||||||
|
for repo in repos:
|
||||||
|
# Get documents from repository
|
||||||
|
print(f"\n=== Fetching documents from '{repo.name}' ===")
|
||||||
|
docs = await client.get_repo_docs(repo.id)
|
||||||
|
all_files.extend(docs)
|
||||||
|
return all_files
|
||||||
|
files = asyncio.run(async_get_files(api_client))
|
||||||
|
|
||||||
|
try:
|
||||||
|
for doc in files:
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"id: {doc.id}")
|
||||||
|
print(f"type: {doc.type}")
|
||||||
|
print(f"slug: {doc.slug}")
|
||||||
|
print(f"title: {doc.title}")
|
||||||
|
print(f"book_id: {doc.book_id}")
|
||||||
|
# print(f"format: {doc.format}")
|
||||||
|
# print(f"body: {doc.body}")
|
||||||
|
# print(f"body_draft: {doc.body_draft}")
|
||||||
|
# print(f"body_html: {doc.body_html}")
|
||||||
|
print(f"public: {doc.public}")
|
||||||
|
print(f"status: {doc.status}")
|
||||||
|
print(f"created_at: {doc.created_at}")
|
||||||
|
print(f"updated_at: {doc.updated_at}")
|
||||||
|
print(f"published_at: {doc.published_at}")
|
||||||
|
print(f"word_count: {doc.word_count}")
|
||||||
|
print(f"cover: {doc.cover}")
|
||||||
|
print(f"description: {doc.description}")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
# download document from Feishu FileInfo
|
||||||
|
async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str):
|
||||||
|
async with api_client as client:
|
||||||
|
file_path = await client.download_document(doc, save_dir)
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
|
||||||
|
print(file_path)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nfeishu integration interrupted by user.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\nError during feishu integration: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
yuque_user_id = ""
|
||||||
|
yuque_token = ""
|
||||||
|
save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/"
|
||||||
|
main(yuque_user_id, yuque_token, save_dir)
|
||||||
544
api/app/core/rag/integrations/yuque/client.py
Normal file
544
api/app/core/rag/integrations/yuque/client.py
Normal file
@@ -0,0 +1,544 @@
|
|||||||
|
"""Yuque API client for document operations."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Optional, List
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import httpx
|
||||||
|
import urllib.parse
|
||||||
|
import json
|
||||||
|
from openpyxl import Workbook
|
||||||
|
from openpyxl.styles import Font, Alignment, PatternFill
|
||||||
|
from openpyxl.utils import get_column_letter
|
||||||
|
import zlib
|
||||||
|
|
||||||
|
from app.core.rag.integrations.yuque.exceptions import (
|
||||||
|
YuqueAuthError,
|
||||||
|
YuqueAPIError,
|
||||||
|
YuqueNotFoundError,
|
||||||
|
YuquePermissionError,
|
||||||
|
YuqueRateLimitError,
|
||||||
|
YuqueNetworkError,
|
||||||
|
)
|
||||||
|
from app.core.rag.integrations.yuque.models import YuqueDocInfo, YuqueRepoInfo
|
||||||
|
from app.core.rag.integrations.yuque.retry import with_retry
|
||||||
|
|
||||||
|
|
||||||
|
class YuqueAPIClient:
|
||||||
|
"""Yuque API client for document synchronization."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
token: str,
|
||||||
|
api_base_url: str = "https://www.yuque.com/api/v2",
|
||||||
|
timeout: int = 30,
|
||||||
|
max_retries: int = 3
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Yuque API client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Yuque user ID or login name
|
||||||
|
token: Yuque personal access token
|
||||||
|
api_base_url: Yuque API base URL
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
max_retries: Maximum number of retries
|
||||||
|
"""
|
||||||
|
self.user_id = user_id
|
||||||
|
self.token = token
|
||||||
|
self.api_base_url = api_base_url
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self._http_client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Async context manager entry."""
|
||||||
|
self._http_client = httpx.AsyncClient(
|
||||||
|
base_url=self.api_base_url,
|
||||||
|
timeout=self.timeout,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-Auth-Token": self.token,
|
||||||
|
"User-Agent": "Yuque-Integration-Client"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Async context manager exit."""
|
||||||
|
if self._http_client:
|
||||||
|
await self._http_client.aclose()
|
||||||
|
|
||||||
|
def _handle_api_error(self, response: httpx.Response):
|
||||||
|
"""Handle API error responses."""
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except Exception:
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
status_code = response.status_code
|
||||||
|
error_msg = data.get("message", "Unknown error")
|
||||||
|
|
||||||
|
# Rate limit errors
|
||||||
|
if status_code == 429:
|
||||||
|
raise YuqueRateLimitError(
|
||||||
|
f"Rate limit exceeded: {error_msg}",
|
||||||
|
error_code=str(status_code),
|
||||||
|
details=data
|
||||||
|
)
|
||||||
|
# Not found errors
|
||||||
|
elif status_code == 404:
|
||||||
|
raise YuqueNotFoundError(
|
||||||
|
f"Resource not found: {error_msg}",
|
||||||
|
error_code=str(status_code),
|
||||||
|
details=data
|
||||||
|
)
|
||||||
|
# Permission errors
|
||||||
|
elif status_code == 403:
|
||||||
|
raise YuquePermissionError(
|
||||||
|
f"Permission denied: {error_msg}",
|
||||||
|
error_code=str(status_code),
|
||||||
|
details=data
|
||||||
|
)
|
||||||
|
# Authentication errors
|
||||||
|
elif status_code == 401:
|
||||||
|
raise YuqueAuthError(
|
||||||
|
f"Authentication failed: {error_msg}",
|
||||||
|
error_code=str(status_code),
|
||||||
|
details=data
|
||||||
|
)
|
||||||
|
# Generic API error
|
||||||
|
else:
|
||||||
|
raise YuqueAPIError(
|
||||||
|
f"API error: {error_msg}",
|
||||||
|
error_code=str(status_code),
|
||||||
|
details=data
|
||||||
|
)
|
||||||
|
|
||||||
|
@with_retry
|
||||||
|
async def get_user_repos(self) -> List[YuqueRepoInfo]:
|
||||||
|
"""
|
||||||
|
Get all repositories (知识库) for the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of YuqueRepoInfo objects
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
YuqueAPIError: If API call fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self._http_client:
|
||||||
|
raise YuqueAPIError("HTTP client not initialized")
|
||||||
|
|
||||||
|
response = await self._http_client.get(f"/users/{self.user_id}/repos")
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
self._handle_api_error(response)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
repos_data = data.get("data", [])
|
||||||
|
|
||||||
|
repos = []
|
||||||
|
for repo_data in repos_data:
|
||||||
|
try:
|
||||||
|
repo = YuqueRepoInfo(
|
||||||
|
id=repo_data.get("id"),
|
||||||
|
type=repo_data.get("type", ""),
|
||||||
|
name=repo_data.get("name", ""),
|
||||||
|
namespace=repo_data.get("namespace", ""),
|
||||||
|
slug=repo_data.get("slug", ""),
|
||||||
|
description=repo_data.get("description"),
|
||||||
|
public=repo_data.get("public", 0),
|
||||||
|
items_count=repo_data.get("items_count", 0),
|
||||||
|
created_at=datetime.fromisoformat(repo_data.get("created_at", "").replace("Z", "+00:00")),
|
||||||
|
updated_at=datetime.fromisoformat(repo_data.get("updated_at", "").replace("Z", "+00:00"))
|
||||||
|
)
|
||||||
|
repos.append(repo)
|
||||||
|
except (ValueError, TypeError, KeyError) as e:
|
||||||
|
# Skip invalid repo entries
|
||||||
|
continue
|
||||||
|
|
||||||
|
return repos
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise YuqueAPIError(f"HTTP error: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, (YuqueAPIError, YuqueAuthError)):
|
||||||
|
raise
|
||||||
|
raise YuqueAPIError(f"Unexpected error: {str(e)}")
|
||||||
|
|
||||||
|
@with_retry
|
||||||
|
async def get_repo_docs(self, book_id: int) -> List[YuqueDocInfo]:
|
||||||
|
"""
|
||||||
|
Get all documents in a repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
book_id: repository id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of YuqueDocInfo objects (without body content)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
YuqueAPIError: If API call fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self._http_client:
|
||||||
|
raise YuqueAPIError("HTTP client not initialized")
|
||||||
|
|
||||||
|
response = await self._http_client.get(f"/repos/{book_id}/docs")
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
self._handle_api_error(response)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
docs_data = data.get("data", [])
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
for doc_data in docs_data:
|
||||||
|
try:
|
||||||
|
published_at = doc_data.get("published_at")
|
||||||
|
doc = YuqueDocInfo(
|
||||||
|
id=doc_data.get("id"),
|
||||||
|
type=doc_data.get("type", ""),
|
||||||
|
slug=doc_data.get("slug", ""),
|
||||||
|
title=doc_data.get("title", ""),
|
||||||
|
book_id=doc_data.get("book_id"),
|
||||||
|
format=doc_data.get("format", "markdown"),
|
||||||
|
body=None, # Body not included in list API
|
||||||
|
body_draft=None,
|
||||||
|
body_html=None,
|
||||||
|
public=doc_data.get("public", 0),
|
||||||
|
status=doc_data.get("status", 0),
|
||||||
|
created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")),
|
||||||
|
updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")),
|
||||||
|
published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None,
|
||||||
|
word_count=doc_data.get("word_count", 0),
|
||||||
|
cover=doc_data.get("cover"),
|
||||||
|
description=doc_data.get("description")
|
||||||
|
)
|
||||||
|
docs.append(doc)
|
||||||
|
except (ValueError, TypeError, KeyError) as e:
|
||||||
|
# Skip invalid doc entries
|
||||||
|
continue
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise YuqueAPIError(f"HTTP error: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, (YuqueAPIError, YuqueNotFoundError)):
|
||||||
|
raise
|
||||||
|
raise YuqueAPIError(f"Unexpected error: {str(e)}")
|
||||||
|
|
||||||
|
@with_retry
|
||||||
|
async def get_doc_detail(self, id: int) -> YuqueDocInfo:
|
||||||
|
"""
|
||||||
|
Get detailed document information including content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: document ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
YuqueDocInfo object with full content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
YuqueAPIError: If API call fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self._http_client:
|
||||||
|
raise YuqueAPIError("HTTP client not initialized")
|
||||||
|
|
||||||
|
response = await self._http_client.get(
|
||||||
|
f"/repos/docs/{id}",
|
||||||
|
params={"raw": 1} # Get raw markdown content
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
self._handle_api_error(response)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
doc_data = data.get("data", {})
|
||||||
|
|
||||||
|
published_at = doc_data.get("published_at")
|
||||||
|
doc = YuqueDocInfo(
|
||||||
|
id=doc_data.get("id"),
|
||||||
|
type=doc_data.get("type", ""),
|
||||||
|
slug=doc_data.get("slug", ""),
|
||||||
|
title=doc_data.get("title", ""),
|
||||||
|
book_id=doc_data.get("book_id"),
|
||||||
|
format=doc_data.get("format", "markdown"),
|
||||||
|
body=doc_data.get("body", ""),
|
||||||
|
body_draft=doc_data.get("body_draft"),
|
||||||
|
body_html=doc_data.get("body_html"),
|
||||||
|
public=doc_data.get("public", 0),
|
||||||
|
status=doc_data.get("status", 0),
|
||||||
|
created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")),
|
||||||
|
updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")),
|
||||||
|
published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None,
|
||||||
|
word_count=doc_data.get("word_count", 0),
|
||||||
|
cover=doc_data.get("cover"),
|
||||||
|
description=doc_data.get("description")
|
||||||
|
)
|
||||||
|
|
||||||
|
return doc
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise YuqueAPIError(f"HTTP error: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, (YuqueAPIError, YuqueNotFoundError)):
|
||||||
|
raise
|
||||||
|
raise YuqueAPIError(f"Unexpected error: {str(e)}")
|
||||||
|
|
||||||
|
async def download_document(
|
||||||
|
self,
|
||||||
|
doc: YuqueDocInfo,
|
||||||
|
save_dir: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Download document content to local file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc: Document info (can be without body)
|
||||||
|
save_dir: Directory to save the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full path to the saved file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
YuqueAPIError: If download fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get full document content if not already loaded
|
||||||
|
if not doc.body:
|
||||||
|
doc = await self.get_doc_detail(doc.id)
|
||||||
|
|
||||||
|
# Sanitize filename
|
||||||
|
filename = re.sub(r'[\/:*?"<>|]', '_', doc.title)
|
||||||
|
|
||||||
|
# Determine file extension based on format
|
||||||
|
content = doc.body or ""
|
||||||
|
if doc.format == "markdown":
|
||||||
|
file_extension = "md"
|
||||||
|
elif doc.format == "lake":
|
||||||
|
file_extension = "md" # Save lake format as markdown
|
||||||
|
elif doc.format == "html":
|
||||||
|
file_extension = "html"
|
||||||
|
elif doc.format == "lakesheet":
|
||||||
|
file_extension = "xlsx"
|
||||||
|
|
||||||
|
body_data = json.loads(doc.body)
|
||||||
|
sheet_data = body_data.get("sheet", "")
|
||||||
|
try:
|
||||||
|
sheet_raw = zlib.decompress(bytes(sheet_data, 'latin-1'))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decompressing sheet data: {e}")
|
||||||
|
raise ValueError("Invalid or unsupported sheet data format.")
|
||||||
|
try:
|
||||||
|
sheet_text = sheet_raw.decode("utf-8") # 假设是 UTF-8 编码
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
sheet_text = sheet_raw.decode("gbk") # 如果 UTF-8 解码失败,尝试 GBK
|
||||||
|
|
||||||
|
file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}")
|
||||||
|
self.generate_excel_from_sheet(sheet_text, file_full_path)
|
||||||
|
return file_full_path
|
||||||
|
else:
|
||||||
|
file_extension = "txt"
|
||||||
|
|
||||||
|
file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}")
|
||||||
|
# Remove existing file if it exists
|
||||||
|
if os.path.exists(file_full_path):
|
||||||
|
os.remove(file_full_path)
|
||||||
|
|
||||||
|
# Write content to file
|
||||||
|
with open(file_full_path, "w", encoding="utf-8") as file:
|
||||||
|
file.write(content)
|
||||||
|
|
||||||
|
return file_full_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, YuqueAPIError):
|
||||||
|
raise
|
||||||
|
raise YuqueAPIError(f"Unexpected error during file download: {str(e)}")
|
||||||
|
|
||||||
|
def generate_excel_from_sheet(self, sheet_text: str, save_path: str):
|
||||||
|
"""
|
||||||
|
将解析的 sheet_text 数据转换为 Excel 文件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sheet_text (str): JSON 格式的 sheet 数据。
|
||||||
|
save_path (str): Excel 文件的保存路径。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 解析 JSON 数据
|
||||||
|
sheets = json.loads(sheet_text)
|
||||||
|
|
||||||
|
if not isinstance(sheets, list):
|
||||||
|
raise ValueError("sheet_text must be a JSON array of sheets.")
|
||||||
|
|
||||||
|
# 创建一个新的 Excel 工作簿
|
||||||
|
workbook = Workbook()
|
||||||
|
|
||||||
|
for sheet_index, sheet_data in enumerate(sheets):
|
||||||
|
sheet_name = sheet_data.get("name", f"Sheet{sheet_index + 1}")
|
||||||
|
row_data = sheet_data.get("data", {})
|
||||||
|
merge_cells = sheet_data.get("mergeCells", {})
|
||||||
|
rows_styles = sheet_data.get("rows", [])
|
||||||
|
cols_styles = sheet_data.get("columns", [])
|
||||||
|
|
||||||
|
# 创建 Sheet
|
||||||
|
if sheet_index == 0:
|
||||||
|
worksheet = workbook.active
|
||||||
|
worksheet.title = sheet_name
|
||||||
|
else:
|
||||||
|
worksheet = workbook.create_sheet(title=sheet_name)
|
||||||
|
|
||||||
|
# 设置列宽
|
||||||
|
for col_index, col_style in enumerate(cols_styles):
|
||||||
|
col_width = col_style.get("size", 82.125) / 7.0
|
||||||
|
col_letter = get_column_letter(col_index + 1) # Excel 列从1开始
|
||||||
|
worksheet.column_dimensions[col_letter].width = col_width
|
||||||
|
|
||||||
|
# 设置行高
|
||||||
|
for row_index, row_style in enumerate(rows_styles):
|
||||||
|
row_height = row_style.get("size", 24) / 1.5
|
||||||
|
worksheet.row_dimensions[row_index + 1].height = row_height
|
||||||
|
|
||||||
|
# 写入单元格数据
|
||||||
|
for r_index, row in row_data.items():
|
||||||
|
for c_index, cell in row.items():
|
||||||
|
# 防御性检查:确保行号和列号都是有效的整数
|
||||||
|
try:
|
||||||
|
row_number = int(r_index) + 1
|
||||||
|
col_number = int(c_index) + 1
|
||||||
|
except ValueError:
|
||||||
|
print(f"Invalid row or column index: r_index={r_index}, c_index={c_index}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if col_number < 1 or col_number > 16384: # Excel 最大列数支持到 XFD,即 16384 列
|
||||||
|
print(f"Invalid column index: c_index={c_index}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
cell_obj = worksheet.cell(row=row_number, column=col_number)
|
||||||
|
|
||||||
|
# 处理值和公式
|
||||||
|
cell_value = cell.get("value", "")
|
||||||
|
if isinstance(cell_value, dict):
|
||||||
|
# 检查是否为公式
|
||||||
|
if cell_value.get("class") == "formula" and "formula" in cell_value:
|
||||||
|
cell_obj.value = f"={cell_value['formula']}" # 写入公式
|
||||||
|
else:
|
||||||
|
cell_obj.value = cell_value.get("value", "") # 写入值
|
||||||
|
else:
|
||||||
|
cell_obj.value = cell_value # 写入简单值
|
||||||
|
|
||||||
|
# 应用样式
|
||||||
|
style = cell.get("style", {})
|
||||||
|
self.apply_cell_style(cell_obj, style)
|
||||||
|
|
||||||
|
# 合并单元格
|
||||||
|
for key, merge_def in merge_cells.items():
|
||||||
|
start_row = merge_def["row"] + 1
|
||||||
|
start_col = merge_def["col"] + 1
|
||||||
|
end_row = start_row + merge_def["rowCount"] - 1
|
||||||
|
end_col = start_col + merge_def["colCount"] - 1
|
||||||
|
worksheet.merge_cells(
|
||||||
|
start_row=start_row, start_column=start_col, end_row=end_row, end_column=end_col
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存 Excel 文件
|
||||||
|
workbook.save(save_path)
|
||||||
|
print(f"Excel file successfully saved to: {save_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating Excel file: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_cell_style(self, cell, style):
|
||||||
|
"""
|
||||||
|
应用单元格样式,包括字体、对齐、背景颜色等。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: openpyxl 的单元格对象。
|
||||||
|
style: 字典格式的样式信息。
|
||||||
|
"""
|
||||||
|
# 定义允许的对齐值
|
||||||
|
allowed_horizontal_alignments = {"general", "left", "center", "centerContinuous", "right", "fill", "justify",
|
||||||
|
"distributed"}
|
||||||
|
allowed_vertical_alignments = {"top", "center", "justify", "distributed", "bottom"}
|
||||||
|
|
||||||
|
# 处理字体
|
||||||
|
font = Font(
|
||||||
|
size=style.get("fontSize", 11),
|
||||||
|
bold=style.get("fontWeight", False),
|
||||||
|
italic=style.get("fontStyle", "normal") == "italic",
|
||||||
|
underline="single" if style.get("underline", False) else None,
|
||||||
|
color=self.convert_color_to_hex(style.get("color", "#000000")),
|
||||||
|
)
|
||||||
|
cell.font = font
|
||||||
|
|
||||||
|
# 处理对齐方式
|
||||||
|
horizontal_alignment = style.get("hAlign", "left")
|
||||||
|
vertical_alignment = style.get("vAlign", "top")
|
||||||
|
|
||||||
|
# 如果对齐值无效,则使用默认值
|
||||||
|
if horizontal_alignment not in allowed_horizontal_alignments:
|
||||||
|
horizontal_alignment = "left"
|
||||||
|
if vertical_alignment not in allowed_vertical_alignments:
|
||||||
|
vertical_alignment = "top"
|
||||||
|
|
||||||
|
alignment = Alignment(
|
||||||
|
horizontal=horizontal_alignment,
|
||||||
|
vertical=vertical_alignment,
|
||||||
|
wrap_text=style.get("overflow") == "wrap",
|
||||||
|
)
|
||||||
|
cell.alignment = alignment
|
||||||
|
|
||||||
|
# 处理背景颜色
|
||||||
|
background_color = style.get("backColor", None)
|
||||||
|
if background_color:
|
||||||
|
hex_color = self.convert_color_to_hex(background_color)
|
||||||
|
if hex_color:
|
||||||
|
cell.fill = PatternFill(
|
||||||
|
start_color=hex_color,
|
||||||
|
end_color=hex_color,
|
||||||
|
fill_type="solid"
|
||||||
|
)
|
||||||
|
|
||||||
|
def convert_color_to_hex(self, color):
|
||||||
|
"""
|
||||||
|
将颜色从 `rgba(...)` 或 `rgb(...)` 转换为 aRGB 十六进制格式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
color (str): 原始颜色字符串,如 `rgba(255,255,0,1.00)` 或 `#FFFFFF`。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 转换后的颜色字符串(符合 openpyxl 的格式),例如 `FFFF0000`。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not color:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果是 `#RRGGBB` 或 `#AARRGGBB` 格式,直接返回
|
||||||
|
if color.startswith("#"):
|
||||||
|
return color.lstrip("#").upper()
|
||||||
|
|
||||||
|
# 如果是 `rgb(...)` 格式,例如 `rgb(255,255,0)`
|
||||||
|
if color.startswith("rgb("):
|
||||||
|
rgb_values = color.strip("rgb()").split(",")
|
||||||
|
red, green, blue = [int(v) for v in rgb_values]
|
||||||
|
return f"FF{red:02X}{green:02X}{blue:02X}"
|
||||||
|
|
||||||
|
# 如果是 `rgba(...)` 格式,例如 `rgba(255,255,0,1.00)`
|
||||||
|
if color.startswith("rgba("):
|
||||||
|
rgba_values = color.strip("rgba()").split(",")
|
||||||
|
red, green, blue = [int(v) for v in rgba_values[:3]]
|
||||||
|
alpha = float(rgba_values[3])
|
||||||
|
alpha_hex = int(alpha * 255) # 将透明度转换为 [00, FF]
|
||||||
|
return f"{alpha_hex:02X}{red:02X}{green:02X}{blue:02X}"
|
||||||
|
|
||||||
|
# 返回默认颜色
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing color '{color}': {e}")
|
||||||
|
return None
|
||||||
46
api/app/core/rag/integrations/yuque/exceptions.py
Normal file
46
api/app/core/rag/integrations/yuque/exceptions.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Exception classes for Yuque integration."""
|
||||||
|
|
||||||
|
|
||||||
|
class YuqueError(Exception):
|
||||||
|
"""Base exception for all Yuque-related errors."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, error_code: str = None, details: dict = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.error_code = error_code
|
||||||
|
self.details = details or {}
|
||||||
|
|
||||||
|
|
||||||
|
class YuqueAuthError(YuqueError):
|
||||||
|
"""Authentication error with Yuque API."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class YuqueAPIError(YuqueError):
|
||||||
|
"""General API error from Yuque."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class YuqueNotFoundError(YuqueError):
|
||||||
|
"""Resource not found error (404)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class YuquePermissionError(YuqueError):
|
||||||
|
"""Permission denied error (403)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class YuqueRateLimitError(YuqueError):
|
||||||
|
"""Rate limit exceeded error (429)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class YuqueNetworkError(YuqueError):
|
||||||
|
"""Network-related error (timeout, connection failure)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class YuqueDataError(YuqueError):
|
||||||
|
"""Data parsing or validation error."""
|
||||||
|
pass
|
||||||
42
api/app/core/rag/integrations/yuque/models.py
Normal file
42
api/app/core/rag/integrations/yuque/models.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""Data models for Yuque integration."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class YuqueRepoInfo:
|
||||||
|
"""Repository (知识库) information from Yuque."""
|
||||||
|
id: int # 知识库 ID
|
||||||
|
type: str # 类型 (Book:文档, Design:图集, Sheet:表格, Resource:资源)
|
||||||
|
name: str # 名称
|
||||||
|
namespace: str # 完整路径: user/repo format
|
||||||
|
slug: str # 路径
|
||||||
|
description: Optional[str] # 简介
|
||||||
|
public: int # 公开性 (0:私密, 1:公开, 2:企业内公开)
|
||||||
|
items_count: int # 文档数量
|
||||||
|
created_at: datetime # 创建时间
|
||||||
|
updated_at: datetime # 更新时间
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class YuqueDocInfo:
|
||||||
|
"""Document information from Yuque."""
|
||||||
|
id: int # 文档 ID
|
||||||
|
type: str # 文档类型 (Doc:普通文档, Sheet:表格, Thread:话题, Board:图集, Table:数据表)
|
||||||
|
slug: str # 路径
|
||||||
|
title: str # 标题
|
||||||
|
book_id: int # 归属知识库 ID
|
||||||
|
format: str # 内容格式 (markdown:Markdown 格式, lake:语雀 Lake 格式, html:HTML 标准格式, lakesheet:语雀表格)
|
||||||
|
body: Optional[str] # 正文原始内容
|
||||||
|
body_draft: Optional[str] # 正文草稿内容
|
||||||
|
body_html: Optional[str] # 正文 HTML 标准格式内容
|
||||||
|
public: int # 公开性 (0:私密, 1:公开, 2:企业内公开)
|
||||||
|
status: int # 状态 (0:草稿, 1:发布)
|
||||||
|
created_at: datetime # 创建时间
|
||||||
|
updated_at: datetime # 更新时间
|
||||||
|
published_at: Optional[datetime] # 发布时间
|
||||||
|
word_count: int # 内容字数
|
||||||
|
cover: Optional[str] # 封面
|
||||||
|
description: Optional[str] # 摘要
|
||||||
134
api/app/core/rag/integrations/yuque/retry.py
Normal file
134
api/app/core/rag/integrations/yuque/retry.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""Retry strategy for Yuque API calls."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.core.rag.integrations.yuque.exceptions import (
|
||||||
|
YuqueAuthError,
|
||||||
|
YuquePermissionError,
|
||||||
|
YuqueNotFoundError,
|
||||||
|
YuqueRateLimitError,
|
||||||
|
YuqueNetworkError,
|
||||||
|
YuqueDataError,
|
||||||
|
YuqueAPIError,
|
||||||
|
)
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
class RetryStrategy:
|
||||||
|
"""Retry strategy for API calls."""
|
||||||
|
|
||||||
|
# Retryable error types
|
||||||
|
RETRYABLE_ERRORS = (
|
||||||
|
YuqueNetworkError,
|
||||||
|
YuqueRateLimitError,
|
||||||
|
httpx.TimeoutException,
|
||||||
|
httpx.ConnectError,
|
||||||
|
httpx.ReadError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-retryable error types
|
||||||
|
NON_RETRYABLE_ERRORS = (
|
||||||
|
YuqueAuthError,
|
||||||
|
YuquePermissionError,
|
||||||
|
YuqueNotFoundError,
|
||||||
|
YuqueDataError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retry configuration
|
||||||
|
MAX_RETRIES = 3
|
||||||
|
BACKOFF_DELAYS = [1, 2, 4] # seconds
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_retryable(cls, error: Exception) -> bool:
|
||||||
|
"""Check if an error is retryable."""
|
||||||
|
# Check for specific retryable errors
|
||||||
|
if isinstance(error, cls.RETRYABLE_ERRORS):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for non-retryable errors
|
||||||
|
if isinstance(error, cls.NON_RETRYABLE_ERRORS):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for HTTP status codes
|
||||||
|
if isinstance(error, httpx.HTTPStatusError):
|
||||||
|
status_code = error.response.status_code
|
||||||
|
# Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway)
|
||||||
|
if status_code in [429, 502, 503]:
|
||||||
|
return True
|
||||||
|
# Don't retry on 4xx errors (except 429)
|
||||||
|
if 400 <= status_code < 500:
|
||||||
|
return False
|
||||||
|
# Retry on 5xx errors
|
||||||
|
if 500 <= status_code < 600:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for YuqueRateLimitError
|
||||||
|
if isinstance(error, YuqueRateLimitError):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute_with_retry(
|
||||||
|
cls,
|
||||||
|
func: Callable[..., T],
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
Execute a function with retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Async function to execute
|
||||||
|
*args: Positional arguments for the function
|
||||||
|
**kwargs: Keyword arguments for the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: The last exception if all retries fail
|
||||||
|
"""
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
for attempt in range(cls.MAX_RETRIES + 1):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
# Don't retry if not retryable
|
||||||
|
if not cls.is_retryable(e):
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Don't retry if this was the last attempt
|
||||||
|
if attempt >= cls.MAX_RETRIES:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Wait before retrying
|
||||||
|
delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1]
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
# Should not reach here, but raise last exception if we do
|
||||||
|
if last_exception:
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
|
||||||
|
def with_retry(func: Callable[..., T]) -> Callable[..., T]:
|
||||||
|
"""
|
||||||
|
Decorator to add retry logic to async functions.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@with_retry
|
||||||
|
async def my_api_call():
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
return await RetryStrategy.execute_with_retry(func, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
@@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float
|
|||||||
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
||||||
from app.core.rag.llm.chat_model import Base
|
from app.core.rag.llm.chat_model import Base
|
||||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def knowledge_retrieval(
|
def knowledge_retrieval(
|
||||||
query: str,
|
query: str,
|
||||||
@@ -62,7 +64,15 @@ def knowledge_retrieval(
|
|||||||
merge_strategy = config.get("merge_strategy", "weight")
|
merge_strategy = config.get("merge_strategy", "weight")
|
||||||
reranker_id = config.get("reranker_id")
|
reranker_id = config.get("reranker_id")
|
||||||
reranker_top_k = config.get("reranker_top_k", 1024)
|
reranker_top_k = config.get("reranker_top_k", 1024)
|
||||||
use_graph = config.get("use_graph", "false").lower() == "true"
|
# use_graph = config.get("use_graph", "false").lower() == "true"
|
||||||
|
|
||||||
|
use_graph_value = config.get("use_graph", False)
|
||||||
|
if isinstance(use_graph_value, bool):
|
||||||
|
use_graph = use_graph_value
|
||||||
|
elif isinstance(use_graph_value, str):
|
||||||
|
use_graph = use_graph_value.lower() in ("true", "1", "yes")
|
||||||
|
else:
|
||||||
|
use_graph = False
|
||||||
|
|
||||||
file_names_filter = []
|
file_names_filter = []
|
||||||
if user_ids:
|
if user_ids:
|
||||||
@@ -159,13 +169,29 @@ def knowledge_retrieval(
|
|||||||
|
|
||||||
# Use the specified reranker for re-ranking
|
# Use the specified reranker for re-ranking
|
||||||
if reranker_id:
|
if reranker_id:
|
||||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
try:
|
||||||
# use graph
|
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||||
|
except Exception as rerank_error:
|
||||||
|
# If reranker fails, log warning and continue with original results
|
||||||
|
logger.warning(
|
||||||
|
"Reranker failed, falling back to original results",
|
||||||
|
extra={
|
||||||
|
"reranker_id": reranker_id,
|
||||||
|
"query": query,
|
||||||
|
"doc_count": len(all_results),
|
||||||
|
"error": str(rerank_error),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if use_graph:
|
if use_graph:
|
||||||
from app.core.rag.common.settings import kg_retriever
|
try:
|
||||||
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
from app.core.rag.common.settings import kg_retriever
|
||||||
if doc:
|
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||||
all_results.insert(0, doc)
|
if doc:
|
||||||
|
all_results.insert(0, doc)
|
||||||
|
except Exception as graph_error:
|
||||||
|
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
|
||||||
|
|
||||||
return all_results
|
return all_results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -25,6 +25,18 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||||
|
self.response_metadata = {}
|
||||||
|
|
||||||
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
|
if self.response_metadata:
|
||||||
|
usage = self.response_metadata.get('token_usage')
|
||||||
|
if usage:
|
||||||
|
return {
|
||||||
|
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||||
|
"completion_tokens": usage.get('completion_tokens', 0),
|
||||||
|
"total_tokens": usage.get('total_tokens', 0)
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
outputs = {}
|
outputs = {}
|
||||||
@@ -180,6 +192,7 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
])
|
])
|
||||||
|
|
||||||
model_resp = await llm.ainvoke(messages)
|
model_resp = await llm.ainvoke(messages)
|
||||||
|
self.response_metadata = model_resp.response_metadata
|
||||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||||
logger.info(f"node: {self.node_id} get params:{result}")
|
logger.info(f"node: {self.node_id} get params:{result}")
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,18 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||||
self.category_to_case_map = {}
|
self.category_to_case_map = {}
|
||||||
|
self.response_metadata = {}
|
||||||
|
|
||||||
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
|
if self.response_metadata:
|
||||||
|
usage = self.response_metadata.get('token_usage')
|
||||||
|
if usage:
|
||||||
|
return {
|
||||||
|
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||||
|
"completion_tokens": usage.get('completion_tokens', 0),
|
||||||
|
"total_tokens": usage.get('total_tokens', 0)
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {
|
return {
|
||||||
@@ -120,6 +132,7 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
|
|
||||||
response = await llm.ainvoke(messages)
|
response = await llm.ainvoke(messages)
|
||||||
result = response.content.strip()
|
result = response.content.strip()
|
||||||
|
self.response_metadata = response.response_metadata
|
||||||
|
|
||||||
if result in category_names:
|
if result in category_names:
|
||||||
category = result
|
category = result
|
||||||
|
|||||||
@@ -14,4 +14,5 @@ class File(Base):
|
|||||||
file_name = Column(String, index=True, nullable=False, comment="file name or folder name,default folder name is /")
|
file_name = Column(String, index=True, nullable=False, comment="file name or folder name,default folder name is /")
|
||||||
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
|
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
|
||||||
file_size = Column(Integer, default=0, comment="file size(byte)")
|
file_size = Column(Integer, default=0, comment="file size(byte)")
|
||||||
|
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||||
@@ -57,6 +57,17 @@ class Knowledge(Base):
|
|||||||
parser_id = Column(String, index=True, default="naive", comment="default parser ID")
|
parser_id = Column(String, index=True, default="naive", comment="default parser ID")
|
||||||
parser_config = Column(JSON, nullable=False,
|
parser_config = Column(JSON, nullable=False,
|
||||||
default={
|
default={
|
||||||
|
"entry_url": "https://ai.redbearai.com",
|
||||||
|
"max_pages": 20,
|
||||||
|
"delay_seconds": 1.0,
|
||||||
|
"timeout_seconds": 10,
|
||||||
|
"user_agent": "KnowledgeBaseCrawler/1.0",
|
||||||
|
"yuque_user_id": "User ID",
|
||||||
|
"yuque_token": "Token",
|
||||||
|
"feishu_app_id": "App ID",
|
||||||
|
"feishu_app_secret": "App Secret",
|
||||||
|
"feishu_folder_token": "Folder Token",
|
||||||
|
"sync_cron": "30 7 * * 1-5",
|
||||||
"layout_recognize": "DeepDOC",
|
"layout_recognize": "DeepDOC",
|
||||||
"chunk_token_num": 128,
|
"chunk_token_num": 128,
|
||||||
"delimiter": "\n",
|
"delimiter": "\n",
|
||||||
|
|||||||
@@ -86,7 +86,8 @@ class MemoryConfigRepository:
|
|||||||
n.description AS description,
|
n.description AS description,
|
||||||
n.entity_type AS entity_type,
|
n.entity_type AS entity_type,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
COALESCE(n.fact_summary, '') AS fact_summary,
|
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
// COALESCE(n.fact_summary, '') AS fact_summary,
|
||||||
n.end_user_id AS end_user_id,
|
n.end_user_id AS end_user_id,
|
||||||
n.apply_id AS apply_id,
|
n.apply_id AS apply_id,
|
||||||
n.user_id AS user_id,
|
n.user_id AS user_id,
|
||||||
@@ -279,6 +280,9 @@ class MemoryConfigRepository:
|
|||||||
if update.config_desc is not None:
|
if update.config_desc is not None:
|
||||||
db_config.config_desc = update.config_desc
|
db_config.config_desc = update.config_desc
|
||||||
has_update = True
|
has_update = True
|
||||||
|
if update.scene_id is not None:
|
||||||
|
db_config.scene_id = update.scene_id
|
||||||
|
has_update = True
|
||||||
|
|
||||||
if not has_update:
|
if not has_update:
|
||||||
raise ValueError("No fields to update")
|
raise ValueError("No fields to update")
|
||||||
@@ -650,28 +654,32 @@ class MemoryConfigRepository:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
|
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]:
|
||||||
"""获取所有配置参数
|
"""获取所有配置参数,包含关联的场景名称
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
workspace_id: 工作空间ID,用于过滤查询结果
|
workspace_id: 工作空间ID,用于过滤查询结果
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[MemoryConfig]: 配置列表
|
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
|
||||||
"""
|
"""
|
||||||
|
from app.models.ontology_scene import OntologyScene
|
||||||
|
|
||||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = db.query(MemoryConfig)
|
query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin(
|
||||||
|
OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id
|
||||||
|
)
|
||||||
|
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
query = query.filter(MemoryConfig.workspace_id == workspace_id)
|
query = query.filter(MemoryConfig.workspace_id == workspace_id)
|
||||||
|
|
||||||
configs = query.order_by(desc(MemoryConfig.updated_at)).all()
|
results = query.order_by(desc(MemoryConfig.updated_at)).all()
|
||||||
|
|
||||||
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
|
db_logger.debug(f"配置列表查询成功: 数量={len(results)}")
|
||||||
return configs
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
|
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||||
|
|||||||
@@ -79,7 +79,8 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
|
|||||||
try:
|
try:
|
||||||
edges: List[dict] = []
|
edges: List[dict] = []
|
||||||
for s in summaries:
|
for s in summaries:
|
||||||
for chunk_id in getattr(s, "chunk_ids", []) or []:
|
chunk_ids = getattr(s, "chunk_ids", []) or []
|
||||||
|
for chunk_id in chunk_ids:
|
||||||
edges.append({
|
edges.append({
|
||||||
"summary_id": s.id,
|
"summary_id": s.id,
|
||||||
"chunk_id": chunk_id,
|
"chunk_id": chunk_id,
|
||||||
@@ -91,12 +92,11 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
|
|||||||
|
|
||||||
if not edges:
|
if not edges:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
result = await connector.execute_query(
|
result = await connector.execute_query(
|
||||||
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
|
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
|
||||||
edges=edges
|
edges=edges
|
||||||
)
|
)
|
||||||
created = [record.get("uuid") for record in result] if result else []
|
created = [record.get("uuid") for record in result] if result else []
|
||||||
return created
|
return created
|
||||||
except Exception:
|
except Exception as e:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -217,8 +217,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
summaries=flattened
|
summaries=flattened
|
||||||
)
|
)
|
||||||
created_ids = [record.get("uuid") for record in result]
|
created_ids = [record.get("uuid") for record in result]
|
||||||
|
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||||
return created_ids
|
return created_ids
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -101,10 +101,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
|||||||
e.name_embedding = CASE
|
e.name_embedding = CASE
|
||||||
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
|
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
|
||||||
ELSE e.name_embedding END,
|
ELSE e.name_embedding END,
|
||||||
e.fact_summary = CASE
|
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
|
// e.fact_summary = CASE
|
||||||
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
|
// WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
|
||||||
THEN entity.fact_summary ELSE e.fact_summary END,
|
// AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
|
||||||
|
// THEN entity.fact_summary ELSE e.fact_summary END,
|
||||||
e.connect_strength = CASE
|
e.connect_strength = CASE
|
||||||
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
|
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
|
||||||
ELSE CASE
|
ELSE CASE
|
||||||
@@ -321,7 +322,8 @@ RETURN e.id AS id,
|
|||||||
e.description AS description,
|
e.description AS description,
|
||||||
e.aliases AS aliases,
|
e.aliases AS aliases,
|
||||||
e.name_embedding AS name_embedding,
|
e.name_embedding AS name_embedding,
|
||||||
COALESCE(e.fact_summary, '') AS fact_summary,
|
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
// COALESCE(e.fact_summary, '') AS fact_summary,
|
||||||
e.connect_strength AS connect_strength,
|
e.connect_strength AS connect_strength,
|
||||||
collect(DISTINCT s.id) AS statement_ids,
|
collect(DISTINCT s.id) AS statement_ids,
|
||||||
collect(DISTINCT c.id) AS chunk_ids,
|
collect(DISTINCT c.id) AS chunk_ids,
|
||||||
@@ -1002,3 +1004,58 @@ RETURN DISTINCT
|
|||||||
x.statement as statement,x.created_at as created_at
|
x.statement as statement,x.created_at as created_at
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Graph_Node_query = """
|
||||||
|
MATCH (n:MemorySummary)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
0 AS priority
|
||||||
|
LIMIT $limit
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
MATCH (n:Dialogue)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
1 AS priority
|
||||||
|
LIMIT 1
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
MATCH (n:Statement)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
1 AS priority
|
||||||
|
LIMIT $limit
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
MATCH (n:ExtractedEntity)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
2 AS priority
|
||||||
|
LIMIT $limit
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
MATCH (n:Chunk)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
3 AS priority
|
||||||
|
LIMIT $limit
|
||||||
|
|
||||||
|
"""
|
||||||
@@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import (
|
|||||||
ExtractedEntityNode,
|
ExtractedEntityNode,
|
||||||
EntityEntityEdge,
|
EntityEntityEdge,
|
||||||
)
|
)
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
async def save_entities_and_relationships(
|
async def save_entities_and_relationships(
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
@@ -41,8 +42,8 @@ async def save_entities_and_relationships(
|
|||||||
'statement': edge.statement,
|
'statement': edge.statement,
|
||||||
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
||||||
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
||||||
'created_at': edge.created_at.isoformat(),
|
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
||||||
'expired_at': edge.expired_at.isoformat(),
|
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
|
||||||
'run_id': edge.run_id,
|
'run_id': edge.run_id,
|
||||||
'end_user_id': edge.end_user_id,
|
'end_user_id': edge.end_user_id,
|
||||||
}
|
}
|
||||||
@@ -147,14 +148,14 @@ async def save_statement_entity_edges(
|
|||||||
|
|
||||||
|
|
||||||
async def save_dialog_and_statements_to_neo4j(
|
async def save_dialog_and_statements_to_neo4j(
|
||||||
dialogue_nodes: List[DialogueNode],
|
dialogue_nodes: List[DialogueNode],
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
entity_edges: List[EntityEntityEdge],
|
entity_edges: List[EntityEntityEdge],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||||
|
|
||||||
@@ -171,40 +172,126 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if successful, False otherwise
|
bool: True if successful, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
# Save all dialogue nodes in batch
|
# 定义事务函数,将所有写操作放在一个事务中
|
||||||
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector)
|
async def _save_all_in_transaction(tx):
|
||||||
if dialogue_uuids:
|
"""在单个事务中执行所有保存操作,避免死锁"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# 1. Save all dialogue nodes in batch
|
||||||
|
if dialogue_nodes:
|
||||||
|
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE
|
||||||
|
dialogue_data = [node.model_dump() for node in dialogue_nodes]
|
||||||
|
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
|
||||||
|
dialogue_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['dialogues'] = dialogue_uuids
|
||||||
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||||
else:
|
|
||||||
print("Failed to save dialogues to Neo4j")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Save all chunk nodes in batch
|
# 2. Save all chunk nodes in batch
|
||||||
await save_chunk_nodes(chunk_nodes, connector)
|
if chunk_nodes:
|
||||||
|
from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE
|
||||||
|
chunk_data = [node.model_dump() for node in chunk_nodes]
|
||||||
|
result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data)
|
||||||
|
chunk_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['chunks'] = chunk_uuids
|
||||||
|
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
||||||
|
|
||||||
# Save all statement nodes in batch
|
# 3. Save all statement nodes in batch
|
||||||
if statement_nodes:
|
if statement_nodes:
|
||||||
statement_uuids = await add_statement_nodes(statement_nodes, connector)
|
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
|
||||||
if statement_uuids:
|
statement_data = [node.model_dump() for node in statement_nodes]
|
||||||
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
|
result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data)
|
||||||
else:
|
statement_uuids = [record["uuid"] async for record in result]
|
||||||
print("Failed to save statement nodes to Neo4j")
|
results['statements'] = statement_uuids
|
||||||
return False
|
logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
|
||||||
else:
|
|
||||||
print("No statement nodes to save")
|
|
||||||
|
|
||||||
# Save entities and relationships
|
# 4. Save entities
|
||||||
await save_entities_and_relationships(entity_nodes, entity_edges, connector)
|
if entity_nodes:
|
||||||
print("Successfully saved entities and relationships to Neo4j")
|
from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
|
||||||
|
entity_data = [entity.model_dump() for entity in entity_nodes]
|
||||||
|
result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data)
|
||||||
|
entity_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['entities'] = entity_uuids
|
||||||
|
logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
|
||||||
|
|
||||||
# Save new edges
|
# 5. Create entity relationships
|
||||||
await save_statement_chunk_edges(statement_chunk_edges, connector)
|
if entity_edges:
|
||||||
await save_statement_entity_edges(statement_entity_edges, connector)
|
from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE
|
||||||
|
relationship_data = []
|
||||||
|
for edge in entity_edges:
|
||||||
|
relationship_data.append({
|
||||||
|
'source_id': edge.source,
|
||||||
|
'target_id': edge.target,
|
||||||
|
'predicate': edge.relation_type,
|
||||||
|
'statement_id': edge.source_statement_id,
|
||||||
|
'value': edge.relation_value,
|
||||||
|
'statement': edge.statement,
|
||||||
|
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
||||||
|
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
||||||
|
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
||||||
|
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
|
||||||
|
'run_id': edge.run_id,
|
||||||
|
'end_user_id': edge.end_user_id,
|
||||||
|
})
|
||||||
|
result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data)
|
||||||
|
rel_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['entity_relationships'] = rel_uuids
|
||||||
|
logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j")
|
||||||
|
|
||||||
|
# 6. Save statement-chunk edges
|
||||||
|
if statement_chunk_edges:
|
||||||
|
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE
|
||||||
|
sc_edge_data = []
|
||||||
|
for edge in statement_chunk_edges:
|
||||||
|
sc_edge_data.append({
|
||||||
|
"id": edge.id,
|
||||||
|
"source": edge.source,
|
||||||
|
"target": edge.target,
|
||||||
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
|
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||||
|
"run_id": edge.run_id,
|
||||||
|
"end_user_id": edge.end_user_id,
|
||||||
|
})
|
||||||
|
result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data)
|
||||||
|
sc_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['statement_chunk_edges'] = sc_uuids
|
||||||
|
logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j")
|
||||||
|
|
||||||
|
# 7. Save statement-entity edges
|
||||||
|
if statement_entity_edges:
|
||||||
|
from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE
|
||||||
|
se_edge_data = []
|
||||||
|
for edge in statement_entity_edges:
|
||||||
|
se_edge_data.append({
|
||||||
|
"source": edge.source,
|
||||||
|
"target": edge.target,
|
||||||
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
|
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||||
|
"run_id": edge.run_id,
|
||||||
|
"end_user_id": edge.end_user_id,
|
||||||
|
"connect_strength": getattr(edge, "connect_strength", "strong"),
|
||||||
|
})
|
||||||
|
result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data)
|
||||||
|
se_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['statement_entity_edges'] = se_uuids
|
||||||
|
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用显式写事务执行所有操作,避免死锁
|
||||||
|
results = await connector.execute_write_transaction(_save_all_in_transaction)
|
||||||
|
summary = {
|
||||||
|
key: len(value)
|
||||||
|
for key, value in results.items()
|
||||||
|
if isinstance(value, (list, tuple, set))
|
||||||
|
}
|
||||||
|
logger.info("Transaction completed. Summary: %s", summary)
|
||||||
|
logger.debug("Full transaction results: %r", results)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"Neo4j integration error: {e}", exc_info=True)
|
||||||
print(f"Neo4j integration error: {e}")
|
print(f"Neo4j integration error: {e}")
|
||||||
print("Continuing without database storage...")
|
print("Continuing without database storage...")
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -392,3 +392,48 @@ class OntologySceneRepository:
|
|||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_simple_list(self, workspace_id: UUID) -> List[dict]:
|
||||||
|
"""获取场景简单列表(仅包含scene_id和scene_name,用于下拉选择)
|
||||||
|
|
||||||
|
这是一个轻量级查询,不加载关联的classes,响应速度快。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[dict]: 场景简单列表,每项包含scene_id和scene_name
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologySceneRepository(db)
|
||||||
|
>>> scenes = repo.get_simple_list(workspace_id)
|
||||||
|
>>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Getting simple scene list for workspace: {workspace_id}")
|
||||||
|
|
||||||
|
# 只查询需要的字段,不加载关联数据
|
||||||
|
results = self.db.query(
|
||||||
|
OntologyScene.scene_id,
|
||||||
|
OntologyScene.scene_name
|
||||||
|
).filter(
|
||||||
|
OntologyScene.workspace_id == workspace_id
|
||||||
|
).order_by(
|
||||||
|
OntologyScene.updated_at.desc()
|
||||||
|
).all()
|
||||||
|
|
||||||
|
scenes = [
|
||||||
|
{"scene_id": str(r.scene_id), "scene_name": r.scene_name}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}")
|
||||||
|
|
||||||
|
return scenes
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get simple scene list: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ class FileBase(BaseModel):
|
|||||||
file_name: str
|
file_name: str
|
||||||
file_ext: str
|
file_ext: str
|
||||||
file_size: int
|
file_size: int
|
||||||
|
file_url: str | None = None
|
||||||
|
created_at: datetime.datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class FileCreate(FileBase):
|
class FileCreate(FileBase):
|
||||||
@@ -26,6 +28,7 @@ class FileUpdate(BaseModel):
|
|||||||
file_name: str | None = Field(None)
|
file_name: str | None = Field(None)
|
||||||
file_ext: str | None = Field(None)
|
file_ext: str | None = Field(None)
|
||||||
file_size: str | None = Field(None)
|
file_size: str | None = Field(None)
|
||||||
|
file_url: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class File(FileBase):
|
class File(FileBase):
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from abc import ABC
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -14,4 +15,15 @@ class UserInput(BaseModel):
|
|||||||
class Write_UserInput(BaseModel):
|
class Write_UserInput(BaseModel):
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
end_user_id: str
|
end_user_id: str
|
||||||
config_id: Optional[str] = None
|
config_id: Optional[str] = None
|
||||||
|
|
||||||
|
class AgentMemory_Long_Term(ABC):
|
||||||
|
"""长期记忆配置常量"""
|
||||||
|
STORAGE_NEO4J = "neo4j"
|
||||||
|
STORAGE_RAG = "rag"
|
||||||
|
STRATEGY_AGGREGATE = "aggregate"
|
||||||
|
STRATEGY_CHUNK = "chunk"
|
||||||
|
STRATEGY_TIME = "time"
|
||||||
|
DEFAULT_SCOPE = 6
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -248,8 +248,9 @@ class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
|||||||
|
|
||||||
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||||
config_id: Union[uuid.UUID, int, str] = None
|
config_id: Union[uuid.UUID, int, str] = None
|
||||||
config_name: str = Field("配置名称", description="配置名称(字符串)")
|
config_name: Optional[str] = Field(None, description="配置名称(字符串)")
|
||||||
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
config_desc: Optional[str] = Field(None, description="配置描述(字符串)")
|
||||||
|
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID")
|
||||||
|
|
||||||
|
|
||||||
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||||
|
|||||||
@@ -964,8 +964,15 @@ class AppService:
|
|||||||
).order_by(
|
).order_by(
|
||||||
AgentConfig.updated_at.desc()
|
AgentConfig.updated_at.desc()
|
||||||
)
|
)
|
||||||
|
|
||||||
config = self.db.scalars(stmt).first()
|
config = self.db.scalars(stmt).first()
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_memory=config.memory
|
||||||
|
if 'memory_content' in config_memory:
|
||||||
|
config.memory['memory_config_id'] = config.memory.pop('memory_content')
|
||||||
|
except:
|
||||||
|
logger.debug("记忆配置不存在")
|
||||||
if config:
|
if config:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|||||||
@@ -114,6 +114,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
result = task_service.get_task_memory_read_result(task.id)
|
result = task_service.get_task_memory_read_result(task.id)
|
||||||
status = result.get("status")
|
status = result.get("status")
|
||||||
logger.info(f"读取任务状态:{status}")
|
logger.info(f"读取任务状态:{status}")
|
||||||
|
if memory_content:
|
||||||
|
memory_content = memory_content['answer']
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -127,11 +129,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
"content_length": len(str(memory_content))
|
"content_length": len(str(memory_content))
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否有有效内容
|
|
||||||
if not memory_content or str(memory_content).strip() == "" or "answer" in str(memory_content) and str(memory_content).count("''") > 0:
|
|
||||||
return "未找到相关的历史记忆。请直接回答用户的问题,不要再次调用此工具。"
|
|
||||||
|
|
||||||
return f"检索到以下历史记忆:\n\n{memory_content}"
|
return f"检索到以下历史记忆:\n\n{memory_content}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||||||
|
|||||||
@@ -183,11 +183,11 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
|
|
||||||
# --- Read All ---
|
# --- Read All ---
|
||||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||||
configs = MemoryConfigRepository.get_all(self.db, workspace_id)
|
results = MemoryConfigRepository.get_all(self.db, workspace_id)
|
||||||
|
|
||||||
# 将 ORM 对象转换为字典列表
|
# 将 ORM 对象转换为字典列表
|
||||||
data_list = []
|
data_list = []
|
||||||
for config in configs:
|
for config, scene_name in results:
|
||||||
# 安全地转换 user_id 为 int
|
# 安全地转换 user_id 为 int
|
||||||
config_id_old = None
|
config_id_old = None
|
||||||
if config.config_id_old:
|
if config.config_id_old:
|
||||||
@@ -209,7 +209,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
"end_user_id": config.end_user_id,
|
"end_user_id": config.end_user_id,
|
||||||
"config_id_old": config_id_old,
|
"config_id_old": config_id_old,
|
||||||
"apply_id": config.apply_id,
|
"apply_id": config.apply_id,
|
||||||
"scene_id": config.scene_id,
|
"scene_id": str(config.scene_id) if config.scene_id else None,
|
||||||
|
"scene_name": scene_name, # 新增:场景名称
|
||||||
"llm_id": config.llm_id,
|
"llm_id": config.llm_id,
|
||||||
"embedding_id": config.embedding_id,
|
"embedding_id": config.embedding_id,
|
||||||
"rerank_id": config.rerank_id,
|
"rerank_id": config.rerank_id,
|
||||||
@@ -637,10 +638,9 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
|||||||
if m < 1:
|
if m < 1:
|
||||||
latest_relative = "刚刚"
|
latest_relative = "刚刚"
|
||||||
elif m < 60:
|
elif m < 60:
|
||||||
latest_relative = f"{m}分钟前"
|
latest_relative = "一会前"
|
||||||
else:
|
else:
|
||||||
h = int(m // 60)
|
latest_relative = "较早前"
|
||||||
latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前"
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.conversation_repository import ConversationRepository
|
from app.repositories.conversation_repository import ConversationRepository
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
||||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||||
@@ -1525,7 +1526,6 @@ async def analytics_graph_data(
|
|||||||
user_uuid = uuid.UUID(end_user_id)
|
user_uuid = uuid.UUID(end_user_id)
|
||||||
repo = EndUserRepository(db)
|
repo = EndUserRepository(db)
|
||||||
end_user = repo.get_by_id(user_uuid)
|
end_user = repo.get_by_id(user_uuid)
|
||||||
|
|
||||||
if not end_user:
|
if not end_user:
|
||||||
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
|
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
|
||||||
return {
|
return {
|
||||||
@@ -1579,21 +1579,11 @@ async def analytics_graph_data(
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# 查询所有节点
|
# 查询所有节点
|
||||||
node_query = """
|
node_query=Graph_Node_query
|
||||||
MATCH (n)
|
|
||||||
WHERE n.end_user_id = $end_user_id
|
|
||||||
RETURN
|
|
||||||
elementId(n) as id,
|
|
||||||
labels(n)[0] as label,
|
|
||||||
properties(n) as properties
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
node_params = {
|
node_params = {
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 执行节点查询
|
# 执行节点查询
|
||||||
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
|
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
|
||||||
|
|
||||||
@@ -1604,9 +1594,9 @@ async def analytics_graph_data(
|
|||||||
|
|
||||||
for record in node_results:
|
for record in node_results:
|
||||||
node_id = record["id"]
|
node_id = record["id"]
|
||||||
node_label = record["label"]
|
node_labels = record.get("labels", [])
|
||||||
|
node_label = node_labels[0] if node_labels else "Unknown"
|
||||||
node_props = record["properties"]
|
node_props = record["properties"]
|
||||||
|
|
||||||
# 根据节点类型提取需要的属性字段
|
# 根据节点类型提取需要的属性字段
|
||||||
filtered_props = await _extract_node_properties(node_label, node_props,node_id)
|
filtered_props = await _extract_node_properties(node_label, node_props,node_id)
|
||||||
|
|
||||||
|
|||||||
483
api/app/tasks.py
483
api/app/tasks.py
@@ -7,6 +7,8 @@ import uuid
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
@@ -16,8 +18,13 @@ import trio
|
|||||||
# Import a unified Celery instance
|
# Import a unified Celery instance
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.rag.crawler.web_crawler import WebCrawler
|
||||||
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
|
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
|
||||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||||
|
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||||
|
from app.core.rag.integrations.feishu.models import FileInfo
|
||||||
|
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||||
|
from app.core.rag.integrations.yuque.models import YuqueDocInfo
|
||||||
from app.core.rag.llm.chat_model import Base
|
from app.core.rag.llm.chat_model import Base
|
||||||
from app.core.rag.llm.cv_model import QWenCV
|
from app.core.rag.llm.cv_model import QWenCV
|
||||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||||
@@ -29,7 +36,9 @@ from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
|||||||
)
|
)
|
||||||
from app.db import get_db, get_db_context
|
from app.db import get_db, get_db_context
|
||||||
from app.models.document_model import Document
|
from app.models.document_model import Document
|
||||||
|
from app.models.file_model import File
|
||||||
from app.models.knowledge_model import Knowledge
|
from app.models.knowledge_model import Knowledge
|
||||||
|
from app.schemas import file_schema, document_schema
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
|
|
||||||
|
|
||||||
@@ -382,6 +391,480 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
||||||
|
def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||||
|
"""
|
||||||
|
sync knowledge document and Document parsing, vectorization, and storage
|
||||||
|
"""
|
||||||
|
db = next(get_db()) # Manually call the generator
|
||||||
|
db_knowledge = None
|
||||||
|
try:
|
||||||
|
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
||||||
|
# 1. get vector_service
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
|
||||||
|
# 2. sync data
|
||||||
|
match db_knowledge.type:
|
||||||
|
case "Web": # Crawl webpages in batches through a web crawler
|
||||||
|
entry_url = db_knowledge.parser_config.get("entry_url", "")
|
||||||
|
max_pages = db_knowledge.parser_config.get("max_pages", 20)
|
||||||
|
delay_seconds = db_knowledge.parser_config.get("delay_seconds", 1.0)
|
||||||
|
timeout_seconds = db_knowledge.parser_config.get("timeout_seconds", 10)
|
||||||
|
user_agent = db_knowledge.parser_config.get("user_agent", "KnowledgeBaseCrawler/1.0")
|
||||||
|
# Create crawler
|
||||||
|
crawler = WebCrawler(
|
||||||
|
entry_url=entry_url,
|
||||||
|
max_pages=max_pages,
|
||||||
|
delay_seconds=delay_seconds,
|
||||||
|
timeout_seconds=timeout_seconds,
|
||||||
|
user_agent=user_agent
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# 初始化存储已爬取 URLs 的集合
|
||||||
|
file_urls = set()
|
||||||
|
# crawl entry_url by yield
|
||||||
|
for crawled_document in crawler.crawl():
|
||||||
|
file_urls.add(crawled_document.url)
|
||||||
|
db_file = db.query(File).filter(File.kb_id == db_knowledge.id,
|
||||||
|
File.file_url == crawled_document.url).first()
|
||||||
|
if db_file:
|
||||||
|
if db_file.file_size == crawled_document.content_length: # same
|
||||||
|
continue
|
||||||
|
else: # --update
|
||||||
|
if crawled_document.content_length:
|
||||||
|
# 1. update file
|
||||||
|
db_file.file_name = f"{crawled_document.title}.txt"
|
||||||
|
db_file.file_ext=".txt"
|
||||||
|
db_file.file_size=crawled_document.content_length
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_file)
|
||||||
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
|
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id))
|
||||||
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
|
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
||||||
|
# update file
|
||||||
|
if os.path.exists(save_path):
|
||||||
|
os.remove(save_path) # Delete a single file
|
||||||
|
content_bytes = crawled_document.content.encode('utf-8')
|
||||||
|
with open(save_path, "wb") as f:
|
||||||
|
f.write(content_bytes)
|
||||||
|
# 2. update a document
|
||||||
|
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
|
||||||
|
Document.file_id == db_file.id).first()
|
||||||
|
if db_document:
|
||||||
|
db_document.file_name = db_file.file_name
|
||||||
|
db_document.file_ext = db_file.file_ext
|
||||||
|
db_document.file_size = db_file.file_size
|
||||||
|
db_document.updated_at = datetime.now()
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_document)
|
||||||
|
# 3. Document parsing, vectorization, and storage
|
||||||
|
parse_document(file_path=save_path, document_id=db_document.id)
|
||||||
|
else: # --add
|
||||||
|
if crawled_document.content_length:
|
||||||
|
# 1. upload file
|
||||||
|
upload_file = file_schema.FileCreate(
|
||||||
|
kb_id=db_knowledge.id,
|
||||||
|
created_by=db_knowledge.created_by,
|
||||||
|
parent_id=db_knowledge.id,
|
||||||
|
file_name=f"{crawled_document.title}.txt",
|
||||||
|
file_ext=".txt",
|
||||||
|
file_size=crawled_document.content_length,
|
||||||
|
file_url=crawled_document.url,
|
||||||
|
)
|
||||||
|
db_file = File(**upload_file.model_dump())
|
||||||
|
db.add(db_file)
|
||||||
|
db.commit()
|
||||||
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
|
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.id))
|
||||||
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
|
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
||||||
|
# Save file
|
||||||
|
content_bytes = crawled_document.content.encode('utf-8')
|
||||||
|
with open(save_path, "wb") as f:
|
||||||
|
f.write(content_bytes)
|
||||||
|
# 2. Create a document
|
||||||
|
create_document_data = document_schema.DocumentCreate(
|
||||||
|
kb_id=db_knowledge.id,
|
||||||
|
created_by=db_knowledge.created_by,
|
||||||
|
file_id=db_file.id,
|
||||||
|
file_name=db_file.file_name,
|
||||||
|
file_ext=db_file.file_ext,
|
||||||
|
file_size=db_file.file_size,
|
||||||
|
file_meta={},
|
||||||
|
parser_id="naive",
|
||||||
|
parser_config={
|
||||||
|
"layout_recognize": "DeepDOC",
|
||||||
|
"chunk_token_num": 128,
|
||||||
|
"delimiter": "\n",
|
||||||
|
"auto_keywords": 0,
|
||||||
|
"auto_questions": 0,
|
||||||
|
"html4excel": "false"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db_document = Document(**create_document_data.model_dump())
|
||||||
|
db.add(db_document)
|
||||||
|
db.commit()
|
||||||
|
# 3. Document parsing, vectorization, and storage
|
||||||
|
parse_document(file_path=save_path, document_id=db_document.id)
|
||||||
|
db_files = db.query(File).filter(File.kb_id == db_knowledge.id, File.file_url.notin_(file_urls)).all()
|
||||||
|
if db_files: # --delete
|
||||||
|
for db_file in db_files:
|
||||||
|
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
|
||||||
|
Document.file_id == db_file.id).first()
|
||||||
|
if db_document:
|
||||||
|
# 1. Delete vector index
|
||||||
|
vector_service.delete_by_metadata_field(key="document_id", value=str(db_document.id))
|
||||||
|
# 2. Delete document
|
||||||
|
db.delete(db_document)
|
||||||
|
# 3. Delete file
|
||||||
|
file_path = Path(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
if file_path.exists():
|
||||||
|
file_path.unlink() # Delete a single file
|
||||||
|
db.delete(db_file)
|
||||||
|
# commit transaction
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\nError during crawl: {e}")
|
||||||
|
case "Third-party": # Integration of knowledge bases from three parties
|
||||||
|
yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "")
|
||||||
|
feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "")
|
||||||
|
if yuque_user_id: # Yuque Knowledge Base
|
||||||
|
yuque_token = db_knowledge.parser_config.get("yuque_token", "")
|
||||||
|
# Create yuqueAPIClient
|
||||||
|
api_client = YuqueAPIClient(
|
||||||
|
user_id=yuque_user_id,
|
||||||
|
token=yuque_token
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# 初始化存储获取语雀 URLs 的集合
|
||||||
|
file_urls = set()
|
||||||
|
|
||||||
|
# Get all files from all repos
|
||||||
|
async def async_get_files(api_client: YuqueAPIClient):
|
||||||
|
async with api_client as client:
|
||||||
|
print("\n=== Fetching repositories ===")
|
||||||
|
repos = await client.get_user_repos()
|
||||||
|
print(f"Found {len(repos)} repositories:")
|
||||||
|
all_files = []
|
||||||
|
for repo in repos:
|
||||||
|
# Get documents from repository
|
||||||
|
print(f"\n=== Fetching documents from '{repo.name}' ===")
|
||||||
|
docs = await client.get_repo_docs(repo.id)
|
||||||
|
all_files.extend(docs)
|
||||||
|
return all_files
|
||||||
|
|
||||||
|
files = asyncio.run(async_get_files(api_client))
|
||||||
|
for doc in files:
|
||||||
|
file_urls.add(doc.slug)
|
||||||
|
db_file = db.query(File).filter(File.kb_id == db_knowledge.id,
|
||||||
|
File.file_url == doc.slug).first()
|
||||||
|
if db_file:
|
||||||
|
if db_file.created_at == doc.updated_at: # same
|
||||||
|
continue
|
||||||
|
else: # --update
|
||||||
|
# 1. update file
|
||||||
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
|
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id))
|
||||||
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
|
|
||||||
|
# download document from Feishu FileInfo
|
||||||
|
async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str):
|
||||||
|
async with api_client as client:
|
||||||
|
file_path = await client.download_document(doc, save_dir)
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
|
||||||
|
|
||||||
|
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
||||||
|
# update file
|
||||||
|
if os.path.exists(save_path):
|
||||||
|
os.remove(save_path) # Delete a single file
|
||||||
|
shutil.copyfile(file_path, save_path)
|
||||||
|
# update db_file
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
_, file_extension = os.path.splitext(file_name)
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
db_file.file_name = file_name
|
||||||
|
db_file.file_ext = file_extension.lower()
|
||||||
|
db_file.file_size = file_size
|
||||||
|
db_file.created_at = doc.updated_at
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_file)
|
||||||
|
# 2. update a document
|
||||||
|
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
|
||||||
|
Document.file_id == db_file.id).first()
|
||||||
|
if db_document:
|
||||||
|
db_document.file_name = db_file.file_name
|
||||||
|
db_document.file_ext = db_file.file_ext
|
||||||
|
db_document.file_size = db_file.file_size
|
||||||
|
db_document.created_at = db_file.created_at
|
||||||
|
db_document.updated_at = datetime.now()
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_document)
|
||||||
|
# 3. Document parsing, vectorization, and storage
|
||||||
|
parse_document(file_path=save_path, document_id=db_document.id)
|
||||||
|
else: # --add
|
||||||
|
# 1. update file
|
||||||
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
|
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id))
|
||||||
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
|
|
||||||
|
# download document from Feishu FileInfo
|
||||||
|
async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str):
|
||||||
|
async with api_client as client:
|
||||||
|
file_path = await client.download_document(doc, save_dir)
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
|
||||||
|
# add db_file
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
_, file_extension = os.path.splitext(file_name)
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
upload_file = file_schema.FileCreate(
|
||||||
|
kb_id=db_knowledge.id,
|
||||||
|
created_by=db_knowledge.created_by,
|
||||||
|
parent_id=db_knowledge.id,
|
||||||
|
file_name=file_name,
|
||||||
|
file_ext=file_extension.lower(),
|
||||||
|
file_size=file_size,
|
||||||
|
file_url=doc.slug,
|
||||||
|
created_at=doc.updated_at
|
||||||
|
)
|
||||||
|
db_file = File(**upload_file.model_dump())
|
||||||
|
db.add(db_file)
|
||||||
|
db.commit()
|
||||||
|
# Save file
|
||||||
|
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
||||||
|
# update file
|
||||||
|
if os.path.exists(save_path):
|
||||||
|
os.remove(save_path) # Delete a single file
|
||||||
|
shutil.copyfile(file_path, save_path)
|
||||||
|
# 2. Create a document
|
||||||
|
create_document_data = document_schema.DocumentCreate(
|
||||||
|
kb_id=db_knowledge.id,
|
||||||
|
created_by=db_knowledge.created_by,
|
||||||
|
file_id=db_file.id,
|
||||||
|
file_name=db_file.file_name,
|
||||||
|
file_ext=db_file.file_ext,
|
||||||
|
file_size=db_file.file_size,
|
||||||
|
file_meta={},
|
||||||
|
parser_id="naive",
|
||||||
|
parser_config={
|
||||||
|
"layout_recognize": "DeepDOC",
|
||||||
|
"chunk_token_num": 128,
|
||||||
|
"delimiter": "\n",
|
||||||
|
"auto_keywords": 0,
|
||||||
|
"auto_questions": 0,
|
||||||
|
"html4excel": "false"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db_document = Document(**create_document_data.model_dump())
|
||||||
|
db.add(db_document)
|
||||||
|
db.commit()
|
||||||
|
# 3. Document parsing, vectorization, and storage
|
||||||
|
parse_document(file_path=save_path, document_id=db_document.id)
|
||||||
|
db_files = db.query(File).filter(File.kb_id == db_knowledge.id,
|
||||||
|
File.file_url.notin_(file_urls)).all()
|
||||||
|
if db_files: # --delete
|
||||||
|
for db_file in db_files:
|
||||||
|
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
|
||||||
|
Document.file_id == db_file.id).first()
|
||||||
|
if db_document:
|
||||||
|
# 1. Delete vector index
|
||||||
|
vector_service.delete_by_metadata_field(key="document_id",
|
||||||
|
value=str(db_document.id))
|
||||||
|
# 2. Delete document
|
||||||
|
db.delete(db_document)
|
||||||
|
# 3. Delete file
|
||||||
|
file_path = Path(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
if file_path.exists():
|
||||||
|
file_path.unlink() # Delete a single file
|
||||||
|
db.delete(db_file)
|
||||||
|
# commit transaction
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\nError during fetch feishu: {e}")
|
||||||
|
if feishu_app_id: # Feishu Knowledge Base
|
||||||
|
feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "")
|
||||||
|
feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "")
|
||||||
|
# Create feishuAPIClient
|
||||||
|
api_client = FeishuAPIClient(
|
||||||
|
app_id=feishu_app_id,
|
||||||
|
app_secret=feishu_app_secret
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# 初始化存储获取飞书 URLs 的集合
|
||||||
|
file_urls = set()
|
||||||
|
# Get all files from folder
|
||||||
|
async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str):
|
||||||
|
async with api_client as client:
|
||||||
|
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
|
||||||
|
return files
|
||||||
|
files = asyncio.run(async_get_files(api_client, feishu_folder_token))
|
||||||
|
# Filter out folders, only sync documents
|
||||||
|
documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]]
|
||||||
|
for doc in documents:
|
||||||
|
file_urls.add(doc.url)
|
||||||
|
db_file = db.query(File).filter(File.kb_id == db_knowledge.id,
|
||||||
|
File.file_url == doc.url).first()
|
||||||
|
if db_file:
|
||||||
|
if db_file.created_at == doc.modified_time: # same
|
||||||
|
continue
|
||||||
|
else: # --update
|
||||||
|
# 1. update file
|
||||||
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
|
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id),
|
||||||
|
str(db_knowledge.parent_id))
|
||||||
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
|
# download document from Feishu FileInfo
|
||||||
|
async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str):
|
||||||
|
async with api_client as client:
|
||||||
|
file_path = await client.download_document(document=doc, save_dir=save_dir)
|
||||||
|
return file_path
|
||||||
|
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
|
||||||
|
|
||||||
|
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
||||||
|
# update file
|
||||||
|
if os.path.exists(save_path):
|
||||||
|
os.remove(save_path) # Delete a single file
|
||||||
|
shutil.copyfile(file_path, save_path)
|
||||||
|
# update db_file
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
_, file_extension = os.path.splitext(file_name)
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
db_file.file_name = file_name
|
||||||
|
db_file.file_ext = file_extension.lower()
|
||||||
|
db_file.file_size = file_size
|
||||||
|
db_file.created_at = doc.modified_time
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_file)
|
||||||
|
# 2. update a document
|
||||||
|
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
|
||||||
|
Document.file_id == db_file.id).first()
|
||||||
|
if db_document:
|
||||||
|
db_document.file_name = db_file.file_name
|
||||||
|
db_document.file_ext = db_file.file_ext
|
||||||
|
db_document.file_size = db_file.file_size
|
||||||
|
db_document.created_at = db_file.created_at
|
||||||
|
db_document.updated_at = datetime.now()
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_document)
|
||||||
|
# 3. Document parsing, vectorization, and storage
|
||||||
|
parse_document(file_path=save_path, document_id=db_document.id)
|
||||||
|
else: # --add
|
||||||
|
# 1. update file
|
||||||
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
|
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id),
|
||||||
|
str(db_knowledge.parent_id))
|
||||||
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
|
# download document from Feishu FileInfo
|
||||||
|
async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str):
|
||||||
|
async with api_client as client:
|
||||||
|
file_path = await client.download_document(document=doc, save_dir=save_dir)
|
||||||
|
return file_path
|
||||||
|
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
|
||||||
|
# add db_file
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
_, file_extension = os.path.splitext(file_name)
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
upload_file = file_schema.FileCreate(
|
||||||
|
kb_id=db_knowledge.id,
|
||||||
|
created_by=db_knowledge.created_by,
|
||||||
|
parent_id=db_knowledge.id,
|
||||||
|
file_name=file_name,
|
||||||
|
file_ext=file_extension.lower(),
|
||||||
|
file_size=file_size,
|
||||||
|
file_url=doc.url,
|
||||||
|
created_at = doc.modified_time
|
||||||
|
)
|
||||||
|
db_file = File(**upload_file.model_dump())
|
||||||
|
db.add(db_file)
|
||||||
|
db.commit()
|
||||||
|
# Save file
|
||||||
|
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
||||||
|
# update file
|
||||||
|
if os.path.exists(save_path):
|
||||||
|
os.remove(save_path) # Delete a single file
|
||||||
|
shutil.copyfile(file_path, save_path)
|
||||||
|
# 2. Create a document
|
||||||
|
create_document_data = document_schema.DocumentCreate(
|
||||||
|
kb_id=db_knowledge.id,
|
||||||
|
created_by=db_knowledge.created_by,
|
||||||
|
file_id=db_file.id,
|
||||||
|
file_name=db_file.file_name,
|
||||||
|
file_ext=db_file.file_ext,
|
||||||
|
file_size=db_file.file_size,
|
||||||
|
file_meta={},
|
||||||
|
parser_id="naive",
|
||||||
|
parser_config={
|
||||||
|
"layout_recognize": "DeepDOC",
|
||||||
|
"chunk_token_num": 128,
|
||||||
|
"delimiter": "\n",
|
||||||
|
"auto_keywords": 0,
|
||||||
|
"auto_questions": 0,
|
||||||
|
"html4excel": "false"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db_document = Document(**create_document_data.model_dump())
|
||||||
|
db.add(db_document)
|
||||||
|
db.commit()
|
||||||
|
# 3. Document parsing, vectorization, and storage
|
||||||
|
parse_document(file_path=save_path, document_id=db_document.id)
|
||||||
|
db_files = db.query(File).filter(File.kb_id == db_knowledge.id,
|
||||||
|
File.file_url.notin_(file_urls)).all()
|
||||||
|
if db_files: # --delete
|
||||||
|
for db_file in db_files:
|
||||||
|
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
|
||||||
|
Document.file_id == db_file.id).first()
|
||||||
|
if db_document:
|
||||||
|
# 1. Delete vector index
|
||||||
|
vector_service.delete_by_metadata_field(key="document_id",
|
||||||
|
value=str(db_document.id))
|
||||||
|
# 2. Delete document
|
||||||
|
db.delete(db_document)
|
||||||
|
# 3. Delete file
|
||||||
|
file_path = Path(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
if file_path.exists():
|
||||||
|
file_path.unlink() # Delete a single file
|
||||||
|
db.delete(db_file)
|
||||||
|
# commit transaction
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\nError during fetch feishu: {e}")
|
||||||
|
case _: # General
|
||||||
|
print(f"General: No synchronization needed\n")
|
||||||
|
|
||||||
|
|
||||||
|
result = f"sync knowledge '{db_knowledge.name}' processed successfully."
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
if 'db_knowledge' in locals():
|
||||||
|
print(f"Failed to sync knowledge:{str(e)}\n")
|
||||||
|
result = f"sync knowledge '{db_knowledge.name}' failed."
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
|
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
|
||||||
def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]:
|
def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
|||||||
@@ -5,42 +5,68 @@ Shared utilities for configuration handling to avoid circular imports.
|
|||||||
"""
|
"""
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
import uuid as uuid_module
|
||||||
|
|
||||||
|
|
||||||
def resolve_config_id(config_id: UUID | int|str, db: Session) -> UUID:
|
def resolve_config_id(config_id: UUID | int | str, db: Session) -> UUID:
|
||||||
"""
|
"""
|
||||||
解析 config_id,如果是整数则通过 config_id_old 查找对应的 UUID
|
解析 config_id,支持 UUID、UUID字符串、整数等多种格式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_id: 配置ID(UUID 或整数)
|
config_id: 配置ID(UUID、UUID字符串 或 整数)
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UUID: 解析后的配置ID
|
UUID: 解析后的配置ID
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 当找不到对应的配置时
|
ValueError: 当找不到对应的配置时或格式无效时
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.models.memory_config_model import MemoryConfig
|
from app.models.memory_config_model import MemoryConfig
|
||||||
if isinstance(config_id, UUID):
|
|
||||||
|
# 1. 如果已经是 UUID 类型,直接返回
|
||||||
|
if isinstance(config_id, UUID):
|
||||||
return config_id
|
return config_id
|
||||||
if isinstance(config_id, str) and len(config_id)<=6:
|
|
||||||
memory_config = db.query(MemoryConfig).filter(
|
# 2. 如果是字符串类型
|
||||||
MemoryConfig.config_id_old == int(config_id)
|
if isinstance(config_id, str):
|
||||||
).first()
|
config_id_stripped = config_id.strip()
|
||||||
print(memory_config)
|
|
||||||
if not memory_config:
|
# 2.1 尝试解析为 UUID(标准 UUID 字符串长度为 36)
|
||||||
raise ValueError(f"STR 未找到 config_id_old={config_id} 对应的配置")
|
try:
|
||||||
return memory_config.config_id
|
return uuid_module.UUID(config_id_stripped)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2.2 尝试解析为整数(用于查询 config_id_old)
|
||||||
|
try:
|
||||||
|
old_id = int(config_id_stripped)
|
||||||
|
if old_id > 0:
|
||||||
|
memory_config = db.query(MemoryConfig).filter(
|
||||||
|
MemoryConfig.config_id_old == old_id
|
||||||
|
).first()
|
||||||
|
if not memory_config:
|
||||||
|
raise ValueError(f"未找到 config_id_old={old_id} 对应的配置")
|
||||||
|
return memory_config.config_id
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2.3 无法解析的字符串格式
|
||||||
|
raise ValueError(f"无效的 config_id 格式: '{config_id}'(必须是 UUID 或正整数)")
|
||||||
|
|
||||||
|
# 3. 如果是整数类型,通过 config_id_old 查找
|
||||||
if isinstance(config_id, int):
|
if isinstance(config_id, int):
|
||||||
|
if config_id <= 0:
|
||||||
|
raise ValueError(f"config_id 必须是正整数: {config_id}")
|
||||||
|
|
||||||
memory_config = db.query(MemoryConfig).filter(
|
memory_config = db.query(MemoryConfig).filter(
|
||||||
MemoryConfig.config_id_old == config_id
|
MemoryConfig.config_id_old == config_id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not memory_config:
|
if not memory_config:
|
||||||
raise ValueError(f"INT 未找到 config_id_old={config_id} 对应的配置")
|
raise ValueError(f"未找到 config_id_old={config_id} 对应的配置")
|
||||||
|
|
||||||
return memory_config.config_id
|
return memory_config.config_id
|
||||||
|
|
||||||
return config_id
|
# 4. 不支持的类型
|
||||||
|
raise ValueError(f"不支持的 config_id 类型: {type(config_id).__name__}")
|
||||||
|
|||||||
32
api/migrations/versions/ef0787b85c35_202602061233.py
Normal file
32
api/migrations/versions/ef0787b85c35_202602061233.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""202602061233
|
||||||
|
|
||||||
|
Revision ID: ef0787b85c35
|
||||||
|
Revises: 9b28b66cf8e8
|
||||||
|
Create Date: 2026-02-06 12:33:26.114673
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'ef0787b85c35'
|
||||||
|
down_revision: Union[str, None] = '9b28b66cf8e8'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('files', sa.Column('file_url', sa.String(), nullable=True, comment='file comes from a website url'))
|
||||||
|
op.create_index(op.f('ix_files_file_url'), 'files', ['file_url'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f('ix_files_file_url'), table_name='files')
|
||||||
|
op.drop_column('files', 'file_url')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -141,6 +141,8 @@ dependencies = [
|
|||||||
"flower>=2.0.1",
|
"flower>=2.0.1",
|
||||||
"aiofiles>=23.0.0",
|
"aiofiles>=23.0.0",
|
||||||
"owlready2>=0.46",
|
"owlready2>=0.46",
|
||||||
|
"lxml>=4.9.0",
|
||||||
|
"httpx>=0.28.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -134,3 +134,5 @@ xlrd==2.0.2
|
|||||||
oss2>=2.18.0
|
oss2>=2.18.0
|
||||||
boto3>=1.28.0
|
boto3>=1.28.0
|
||||||
aiofiles>=23.0.0
|
aiofiles>=23.0.0
|
||||||
|
lxml>=4.9.0
|
||||||
|
httpx>=0.28.0
|
||||||
|
|||||||
4
api/uv.lock
generated
4
api/uv.lock
generated
@@ -3224,6 +3224,7 @@ dependencies = [
|
|||||||
{ name = "hanziconv" },
|
{ name = "hanziconv" },
|
||||||
{ name = "html5lib" },
|
{ name = "html5lib" },
|
||||||
{ name = "httptools" },
|
{ name = "httptools" },
|
||||||
|
{ name = "httpx" },
|
||||||
{ name = "huggingface-hub" },
|
{ name = "huggingface-hub" },
|
||||||
{ name = "idna" },
|
{ name = "idna" },
|
||||||
{ name = "jieba" },
|
{ name = "jieba" },
|
||||||
@@ -3237,6 +3238,7 @@ dependencies = [
|
|||||||
{ name = "langchain-ollama" },
|
{ name = "langchain-ollama" },
|
||||||
{ name = "langchain-openai" },
|
{ name = "langchain-openai" },
|
||||||
{ name = "langfuse" },
|
{ name = "langfuse" },
|
||||||
|
{ name = "lxml" },
|
||||||
{ name = "mako" },
|
{ name = "mako" },
|
||||||
{ name = "mammoth" },
|
{ name = "mammoth" },
|
||||||
{ name = "markdown" },
|
{ name = "markdown" },
|
||||||
@@ -3361,6 +3363,7 @@ requires-dist = [
|
|||||||
{ name = "hanziconv", specifier = "==0.3.2" },
|
{ name = "hanziconv", specifier = "==0.3.2" },
|
||||||
{ name = "html5lib", specifier = "==1.1" },
|
{ name = "html5lib", specifier = "==1.1" },
|
||||||
{ name = "httptools", specifier = "==0.7.1" },
|
{ name = "httptools", specifier = "==0.7.1" },
|
||||||
|
{ name = "httpx", specifier = ">=0.28.0" },
|
||||||
{ name = "huggingface-hub", specifier = "==0.25.2" },
|
{ name = "huggingface-hub", specifier = "==0.25.2" },
|
||||||
{ name = "idna", specifier = "==3.11" },
|
{ name = "idna", specifier = "==3.11" },
|
||||||
{ name = "jieba", specifier = ">=0.42.1" },
|
{ name = "jieba", specifier = ">=0.42.1" },
|
||||||
@@ -3375,6 +3378,7 @@ requires-dist = [
|
|||||||
{ name = "langchain-ollama" },
|
{ name = "langchain-ollama" },
|
||||||
{ name = "langchain-openai", specifier = ">=1.0.2" },
|
{ name = "langchain-openai", specifier = ">=1.0.2" },
|
||||||
{ name = "langfuse", specifier = ">=3.10.0" },
|
{ name = "langfuse", specifier = ">=3.10.0" },
|
||||||
|
{ name = "lxml", specifier = ">=4.9.0" },
|
||||||
{ name = "mako", specifier = "==1.3.10" },
|
{ name = "mako", specifier = "==1.3.10" },
|
||||||
{ name = "mammoth", specifier = "==1.11.0" },
|
{ name = "mammoth", specifier = "==1.11.0" },
|
||||||
{ name = "markdown", specifier = "==3.8" },
|
{ name = "markdown", specifier = "==3.8" },
|
||||||
|
|||||||
@@ -13,6 +13,14 @@
|
|||||||
"@antv/layout": "^1.2.14-beta.8",
|
"@antv/layout": "^1.2.14-beta.8",
|
||||||
"@antv/x6": "^3.0.1",
|
"@antv/x6": "^3.0.1",
|
||||||
"@antv/x6-react-shape": "^3.0.1",
|
"@antv/x6-react-shape": "^3.0.1",
|
||||||
|
"@codemirror/lang-cpp": "^6.0.3",
|
||||||
|
"@codemirror/lang-java": "^6.0.2",
|
||||||
|
"@codemirror/lang-javascript": "^6.2.4",
|
||||||
|
"@codemirror/lang-python": "^6.2.1",
|
||||||
|
"@codemirror/lang-rust": "^6.0.2",
|
||||||
|
"@codemirror/state": "^6.5.4",
|
||||||
|
"@codemirror/theme-one-dark": "^6.1.3",
|
||||||
|
"@codemirror/view": "^6.39.12",
|
||||||
"@dnd-kit/core": "^6.3.1",
|
"@dnd-kit/core": "^6.3.1",
|
||||||
"@dnd-kit/modifiers": "^9.0.0",
|
"@dnd-kit/modifiers": "^9.0.0",
|
||||||
"@dnd-kit/sortable": "^10.0.0",
|
"@dnd-kit/sortable": "^10.0.0",
|
||||||
@@ -25,6 +33,7 @@
|
|||||||
"antd": "^5.27.4",
|
"antd": "^5.27.4",
|
||||||
"axios": "^1.12.2",
|
"axios": "^1.12.2",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
|
"codemirror": "^6.0.2",
|
||||||
"copy-to-clipboard": "^3.3.3",
|
"copy-to-clipboard": "^3.3.3",
|
||||||
"crypto-js": "^4.2.0",
|
"crypto-js": "^4.2.0",
|
||||||
"dayjs": "^1.11.18",
|
"dayjs": "^1.11.18",
|
||||||
@@ -55,6 +64,7 @@
|
|||||||
"@tailwindcss/postcss": "^4.1.14",
|
"@tailwindcss/postcss": "^4.1.14",
|
||||||
"@tailwindcss/typography": "^0.5.19",
|
"@tailwindcss/typography": "^0.5.19",
|
||||||
"@tailwindcss/vite": "^4.1.14",
|
"@tailwindcss/vite": "^4.1.14",
|
||||||
|
"@types/codemirror": "^5.60.17",
|
||||||
"@types/crypto-js": "^4.2.2",
|
"@types/crypto-js": "^4.2.2",
|
||||||
"@types/js-yaml": "^4.0.9",
|
"@types/js-yaml": "^4.0.9",
|
||||||
"@types/node": "^24.6.0",
|
"@types/node": "^24.6.0",
|
||||||
|
|||||||
@@ -256,7 +256,7 @@ export const updateMemoryExtractionConfig = (values: ExtractionConfigForm) => {
|
|||||||
return request.post('/memory-storage/update_config_extracted', values)
|
return request.post('/memory-storage/update_config_extracted', values)
|
||||||
}
|
}
|
||||||
// Memory Extraction Engine - Pilot run
|
// Memory Extraction Engine - Pilot run
|
||||||
export const pilotRunMemoryExtractionConfig = (values: { config_id: number | string; dialogue_text: string; }, onMessage?: (data: SSEMessage[]) => void) => {
|
export const pilotRunMemoryExtractionConfig = (values: { config_id: number | string; dialogue_text: string; custom_text?: string; }, onMessage?: (data: SSEMessage[]) => void) => {
|
||||||
return handleSSE('/memory-storage/pilot_run', values, onMessage)
|
return handleSSE('/memory-storage/pilot_run', values, onMessage)
|
||||||
}
|
}
|
||||||
// Emotion Engine - Get configuration
|
// Emotion Engine - Get configuration
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import { request } from '@/utils/request'
|
|||||||
import type { Query, OntologyModalData, OntologyClassModalData, OntologyClassExtractModalData, OntologyExportModalData } from '@/views/Ontology/types'
|
import type { Query, OntologyModalData, OntologyClassModalData, OntologyClassExtractModalData, OntologyExportModalData } from '@/views/Ontology/types'
|
||||||
|
|
||||||
// Scene list
|
// Scene list
|
||||||
|
export const getOntologyScenesSimpleUrl = '/memory/ontology/scenes/simple'
|
||||||
export const getOntologyScenesUrl = '/memory/ontology/scenes'
|
export const getOntologyScenesUrl = '/memory/ontology/scenes'
|
||||||
export const getOntologyScenesList = (data: Query) => {
|
export const getOntologyScenesList = (data: Query) => {
|
||||||
return request.get(getOntologyScenesUrl, data)
|
return request.get(getOntologyScenesUrl, data)
|
||||||
|
|||||||
150
web/src/components/CodeMirrorEditor/index.tsx
Normal file
150
web/src/components/CodeMirrorEditor/index.tsx
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
/*
|
||||||
|
* @Author: ZhaoYing
|
||||||
|
* @Date: 2026-02-04 17:20:52
|
||||||
|
* @Last Modified by: ZhaoYing
|
||||||
|
* @Last Modified time: 2026-02-04 17:20:52
|
||||||
|
*/
|
||||||
|
import { useEffect, useRef, useMemo } from 'react';
|
||||||
|
import { EditorView, basicSetup } from 'codemirror';
|
||||||
|
import { EditorState } from '@codemirror/state';
|
||||||
|
import { python } from '@codemirror/lang-python';
|
||||||
|
import { javascript } from '@codemirror/lang-javascript';
|
||||||
|
import { java } from '@codemirror/lang-java';
|
||||||
|
import { cpp } from '@codemirror/lang-cpp';
|
||||||
|
import { rust } from '@codemirror/lang-rust';
|
||||||
|
import { oneDark } from '@codemirror/theme-one-dark';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Props for the CodeMirrorEditor component
|
||||||
|
* @property {string} value - The initial code content to display in the editor
|
||||||
|
* @property {string} language - Programming language for syntax highlighting (python, python3, javascript, typescript, java, cpp, c, rust)
|
||||||
|
* @property {function} onChange - Callback function triggered when editor content changes, receives the new code value
|
||||||
|
* @property {string} theme - Editor theme, either 'light' or 'dark'
|
||||||
|
* @property {boolean} readOnly - Whether the editor is read-only
|
||||||
|
* @property {string} height - Custom height for the editor
|
||||||
|
* @property {string} size - Predefined size preset: 'default' (120px min-height, 14px font) or 'small' (60px min-height, 12px font)
|
||||||
|
*/
|
||||||
|
interface CodeMirrorEditorProps {
|
||||||
|
value?: string;
|
||||||
|
language?: 'python' | 'python3' | 'javascript' | 'typescript' | 'java' | 'cpp' | 'c' | 'rust';
|
||||||
|
onChange?: (value: string) => void;
|
||||||
|
theme?: 'light' | 'dark';
|
||||||
|
readOnly?: boolean;
|
||||||
|
height?: string;
|
||||||
|
size?: 'default' | 'small';
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map of language identifiers to their corresponding CodeMirror language extensions
|
||||||
|
* Supports multiple programming languages with syntax highlighting
|
||||||
|
*/
|
||||||
|
const languageExtensions: Record<string, any> = {
|
||||||
|
python: python(),
|
||||||
|
python3: python(),
|
||||||
|
javascript: javascript(),
|
||||||
|
typescript: javascript({ typescript: true }),
|
||||||
|
java: java(),
|
||||||
|
cpp: cpp(),
|
||||||
|
c: cpp(),
|
||||||
|
rust: rust(),
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CodeMirrorEditor - A React wrapper component for CodeMirror 6 editor
|
||||||
|
* Provides a code editor with syntax highlighting, theme support, and customizable sizing
|
||||||
|
* Used in workflow code execution nodes for editing Python and JavaScript code
|
||||||
|
*/
|
||||||
|
const CodeMirrorEditor = ({
|
||||||
|
value = '',
|
||||||
|
language = 'javascript',
|
||||||
|
onChange,
|
||||||
|
theme = 'light',
|
||||||
|
readOnly = false,
|
||||||
|
size,
|
||||||
|
}: CodeMirrorEditorProps) => {
|
||||||
|
// Reference to the DOM element that will contain the editor
|
||||||
|
const editorRef = useRef<HTMLDivElement>(null);
|
||||||
|
// Reference to the CodeMirror EditorView instance
|
||||||
|
const viewRef = useRef<EditorView | null>(null);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize CodeMirror editor when component mounts or when language/theme/readOnly changes
|
||||||
|
* Sets up extensions for syntax highlighting, change listeners, and theme
|
||||||
|
*/
|
||||||
|
useEffect(() => {
|
||||||
|
if (!editorRef.current) return;
|
||||||
|
|
||||||
|
// Get the appropriate language extension, fallback to JavaScript if not found
|
||||||
|
const langExtension = languageExtensions[language] || languageExtensions.javascript;
|
||||||
|
|
||||||
|
// Configure editor extensions
|
||||||
|
const extensions = [
|
||||||
|
basicSetup, // Basic editor features (line numbers, bracket matching, etc.)
|
||||||
|
langExtension, // Language-specific syntax highlighting
|
||||||
|
// Listen for document changes and trigger onChange callback
|
||||||
|
EditorView.updateListener.of((update) => {
|
||||||
|
if (update.docChanged && onChange) {
|
||||||
|
onChange(update.state.doc.toString());
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
EditorState.readOnly.of(readOnly), // Set read-only mode
|
||||||
|
];
|
||||||
|
|
||||||
|
// Apply dark theme if specified
|
||||||
|
if (theme === 'dark') {
|
||||||
|
extensions.push(oneDark);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create editor state with initial value and extensions
|
||||||
|
const state = EditorState.create({
|
||||||
|
doc: value,
|
||||||
|
extensions,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create and mount the editor view
|
||||||
|
viewRef.current = new EditorView({
|
||||||
|
state,
|
||||||
|
parent: editorRef.current,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Cleanup: destroy editor instance when component unmounts or dependencies change
|
||||||
|
return () => {
|
||||||
|
viewRef.current?.destroy();
|
||||||
|
};
|
||||||
|
}, [language, theme, readOnly]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update editor content when the value prop changes externally
|
||||||
|
* Only updates if the new value differs from current editor content
|
||||||
|
*/
|
||||||
|
useEffect(() => {
|
||||||
|
if (viewRef.current && value !== viewRef.current.state.doc.toString()) {
|
||||||
|
viewRef.current.dispatch({
|
||||||
|
changes: {
|
||||||
|
from: 0,
|
||||||
|
to: viewRef.current.state.doc.length,
|
||||||
|
insert: value,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [value]);
|
||||||
|
|
||||||
|
// Calculate minimum height based on size prop: small (60px) or default (120px)
|
||||||
|
const minHeight = useMemo(() => {
|
||||||
|
return `${size === 'small' ? 60 : 120}px`
|
||||||
|
}, [size])
|
||||||
|
|
||||||
|
// Calculate font size based on size prop: small (12px) or default (14px)
|
||||||
|
const fontSize = useMemo(() => {
|
||||||
|
return `${size === 'small' ? 12 : 14}px`
|
||||||
|
}, [size])
|
||||||
|
|
||||||
|
// Calculate line height based on size prop: small (16px) or default (20px)
|
||||||
|
const lineHeight = useMemo(() => {
|
||||||
|
return `${size === 'small' ? 16 : 20}px`
|
||||||
|
}, [size])
|
||||||
|
|
||||||
|
return <div ref={editorRef} style={{ minHeight, fontSize, lineHeight }} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default CodeMirrorEditor;
|
||||||
@@ -81,7 +81,7 @@ const components = {
|
|||||||
audio: ({ src, ...props }: any) => <AudioBlock node={{ children: [{ properties: { src: src || '' } }] }} {...props} />,
|
audio: ({ src, ...props }: any) => <AudioBlock node={{ children: [{ properties: { src: src || '' } }] }} {...props} />,
|
||||||
a: ({ href, children, ...props }: any) => <Link href={href || '#'} {...props}>{children}</Link>,
|
a: ({ href, children, ...props }: any) => <Link href={href || '#'} {...props}>{children}</Link>,
|
||||||
button: ({ children }: any) => <RbButton node={{ children }}>{[children]}</RbButton>,
|
button: ({ children }: any) => <RbButton node={{ children }}>{[children]}</RbButton>,
|
||||||
table: ({ children, ...props }: any) => <table className="rb:border rb:border-[#D9D9D9] rb:mb-2" {...props}>{children}</table>,
|
table: ({ children, ...props }: any) => <div className="rb:overflow-x-auto rb:max-w-full"><table className="rb:border rb:border-[#D9D9D9] rb:mb-2" {...props}>{children}</table></div>,
|
||||||
tr: ({ children, ...props }: any) => <tr className="rb:border rb:border-[#D9D9D9]" {...props}>{children}</tr>,
|
tr: ({ children, ...props }: any) => <tr className="rb:border rb:border-[#D9D9D9]" {...props}>{children}</tr>,
|
||||||
th: ({ children, ...props }: any) => <th className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left rb:font-bold" {...props}>{children}</th>,
|
th: ({ children, ...props }: any) => <th className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left rb:font-bold" {...props}>{children}</th>,
|
||||||
td: ({ children, ...props }: any) => <td className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left" {...props}>{children}</td>,
|
td: ({ children, ...props }: any) => <td className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left" {...props}>{children}</td>,
|
||||||
|
|||||||
@@ -1543,7 +1543,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
text_preprocessing_desc: 'Text split into {{count}} semantic fragments',
|
text_preprocessing_desc: 'Text split into {{count}} semantic fragments',
|
||||||
knowledge_extraction_desc: 'Knowledge extraction completed, identified {{entities}} entities, {{statements}} statements, {{temporal_ranges_count}} temporal extractions, {{triplets}} triplets',
|
knowledge_extraction_desc: 'Knowledge extraction completed, identified {{entities}} entities, {{statements}} statements, {{temporal_ranges_count}} temporal extractions, {{triplets}} triplets',
|
||||||
creating_nodes_edges_desc: 'Entity relationship creation completed, {{num}} relationships in total',
|
creating_nodes_edges_desc: 'Entity relationship creation completed, {{num}} relationships in total',
|
||||||
deduplication_desc: 'Deduplication and disambiguation completed, {{count}} unique entities in total'
|
deduplication_desc: 'Deduplication and disambiguation completed, {{count}} unique entities in total',
|
||||||
|
custom_text: 'Debug Text',
|
||||||
},
|
},
|
||||||
memoryConversation: {
|
memoryConversation: {
|
||||||
searchPlaceholder: 'Enter user ID...',
|
searchPlaceholder: 'Enter user ID...',
|
||||||
|
|||||||
@@ -1617,7 +1617,8 @@ export const zh = {
|
|||||||
text_preprocessing_desc: '文本切分为{{count}}个语义片段',
|
text_preprocessing_desc: '文本切分为{{count}}个语义片段',
|
||||||
knowledge_extraction_desc: '知识抽取完成,共识别{{entities}}个实体,{{statements}}个句子, {{temporal_ranges_count}}个时间提取, {{triplets}}个三元组',
|
knowledge_extraction_desc: '知识抽取完成,共识别{{entities}}个实体,{{statements}}个句子, {{temporal_ranges_count}}个时间提取, {{triplets}}个三元组',
|
||||||
creating_nodes_edges_desc: '实体关系创建完成,共{{num}}条关系',
|
creating_nodes_edges_desc: '实体关系创建完成,共{{num}}条关系',
|
||||||
deduplication_desc: '去重消歧完成,最终{{count}}个唯一实体'
|
deduplication_desc: '去重消歧完成,最终{{count}}个唯一实体',
|
||||||
|
custom_text: '调试文本',
|
||||||
},
|
},
|
||||||
memoryConversation: {
|
memoryConversation: {
|
||||||
chatEmpty:'有什么我可以帮您的吗?',
|
chatEmpty:'有什么我可以帮您的吗?',
|
||||||
|
|||||||
@@ -180,4 +180,9 @@ body {
|
|||||||
.x6-node foreignObject > body {
|
.x6-node foreignObject > body {
|
||||||
min-height: 100%;
|
min-height: 100%;
|
||||||
max-height: 100%;
|
max-height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.ͼ2 .cm-gutters {
|
||||||
|
background-color: #FFFFFF;
|
||||||
|
border: none;
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 16:29:21
|
* @Date: 2026-02-03 16:29:21
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-04 20:16:45
|
* @Last Modified time: 2026-02-06 11:20:14
|
||||||
*/
|
*/
|
||||||
import { type FC, type ReactNode, useEffect, useRef, useState, forwardRef, useImperativeHandle } from 'react';
|
import { type FC, type ReactNode, useEffect, useRef, useState, forwardRef, useImperativeHandle } from 'react';
|
||||||
import clsx from 'clsx'
|
import clsx from 'clsx'
|
||||||
@@ -38,8 +38,8 @@ import CustomSelect from '@/components/CustomSelect'
|
|||||||
import aiPrompt from '@/assets/images/application/aiPrompt.png'
|
import aiPrompt from '@/assets/images/application/aiPrompt.png'
|
||||||
import AiPromptModal from './components/AiPromptModal'
|
import AiPromptModal from './components/AiPromptModal'
|
||||||
import ToolList from './components/ToolList/ToolList'
|
import ToolList from './components/ToolList/ToolList'
|
||||||
import ChatVariableConfigModal from './components/ChatVariableConfigModal';
|
|
||||||
import SkillList from './components/Skill'
|
import SkillList from './components/Skill'
|
||||||
|
import ChatVariableConfigModal from './components/ChatVariableConfigModal';
|
||||||
import type { Skill } from '@/views/Skills/types'
|
import type { Skill } from '@/views/Skills/types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -169,7 +169,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
|||||||
const { skills } = response
|
const { skills } = response
|
||||||
let allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : []
|
let allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : []
|
||||||
let allTools = Array.isArray(response.tools) ? response.tools : []
|
let allTools = Array.isArray(response.tools) ? response.tools : []
|
||||||
const memoryContent = response.memory?.memory_content
|
const memoryContent = response.memory?.memory_config_id
|
||||||
const parsedMemoryContent = memoryContent === null || memoryContent === ''
|
const parsedMemoryContent = memoryContent === null || memoryContent === ''
|
||||||
? undefined
|
? undefined
|
||||||
: !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent
|
: !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent
|
||||||
@@ -178,7 +178,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
|||||||
tools: allTools,
|
tools: allTools,
|
||||||
memory: {
|
memory: {
|
||||||
...response.memory,
|
...response.memory,
|
||||||
memory_content: parsedMemoryContent
|
memory_config_id: parsedMemoryContent
|
||||||
},
|
},
|
||||||
skills: {
|
skills: {
|
||||||
...skills,
|
...skills,
|
||||||
@@ -262,7 +262,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
|||||||
if (!isSave || !data) return Promise.resolve()
|
if (!isSave || !data) return Promise.resolve()
|
||||||
const { memory, knowledge_retrieval, tools, skills, ...rest } = values
|
const { memory, knowledge_retrieval, tools, skills, ...rest } = values
|
||||||
const { knowledge_bases = [], ...knowledgeRest } = knowledge_retrieval || {}
|
const { knowledge_bases = [], ...knowledgeRest } = knowledge_retrieval || {}
|
||||||
const { memory_content } = memory || {}
|
const { memory_config_id } = memory || {}
|
||||||
// Get other necessary properties of memory from original data
|
// Get other necessary properties of memory from original data
|
||||||
const originalMemory = data.memory || ({} as MemoryConfig)
|
const originalMemory = data.memory || ({} as MemoryConfig)
|
||||||
|
|
||||||
@@ -272,7 +272,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
|||||||
memory: {
|
memory: {
|
||||||
...originalMemory,
|
...originalMemory,
|
||||||
...memory,
|
...memory,
|
||||||
memory_content: memory_content ? String(memory_content) : '',
|
memory_config_id: memory_config_id ? String(memory_config_id) : '',
|
||||||
},
|
},
|
||||||
knowledge_retrieval: knowledge_bases.length > 0 ? {
|
knowledge_retrieval: knowledge_bases.length > 0 ? {
|
||||||
...data.knowledge_retrieval,
|
...data.knowledge_retrieval,
|
||||||
@@ -444,7 +444,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
|||||||
<SelectWrapper
|
<SelectWrapper
|
||||||
title="selectMemoryContent"
|
title="selectMemoryContent"
|
||||||
desc="selectMemoryContentDesc"
|
desc="selectMemoryContentDesc"
|
||||||
name={['memory', 'memory_content']}
|
name={['memory', 'memory_config_id']}
|
||||||
url={memoryConfigListUrl}
|
url={memoryConfigListUrl}
|
||||||
/>
|
/>
|
||||||
</Space>
|
</Space>
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
|
|||||||
title={t('application.knowledgeBaseAssociation')}
|
title={t('application.knowledgeBaseAssociation')}
|
||||||
extra={
|
extra={
|
||||||
<Space>
|
<Space>
|
||||||
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleKnowledgeConfig}>{t('workflow.config.knowledge-retrieval.recallConfig')}</Button>
|
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleKnowledgeConfig}>{t('application.globalConfig')}</Button>
|
||||||
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleAddKnowledge}>+</Button>
|
<Button style={{ padding: '0 8px', height: '24px' }} onClick={handleAddKnowledge}>+</Button>
|
||||||
</Space>
|
</Space>
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ const processObj = [
|
|||||||
* @param value - Current skill configuration values
|
* @param value - Current skill configuration values
|
||||||
* @param onChange - Callback function when configuration changes
|
* @param onChange - Callback function when configuration changes
|
||||||
*/
|
*/
|
||||||
const Skill: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) => void}> = () => {
|
const SkillList: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) => void}> = () => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const form = Form.useFormInstance()
|
const form = Form.useFormInstance()
|
||||||
const skillConfig = Form.useWatch(['skills'], form)
|
const skillConfig = Form.useWatch(['skills'], form)
|
||||||
@@ -148,4 +148,4 @@ const Skill: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) =
|
|||||||
</Card>
|
</Card>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
export default Skill
|
export default SkillList
|
||||||
@@ -43,7 +43,7 @@ export interface MemoryConfig {
|
|||||||
/** Whether memory is enabled */
|
/** Whether memory is enabled */
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
/** Memory content */
|
/** Memory content */
|
||||||
memory_content?: string;
|
memory_config_id?: string;
|
||||||
/** Maximum history length */
|
/** Maximum history length */
|
||||||
max_history?: number | string;
|
max_history?: number | string;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
import { type FC, useState } from 'react'
|
import { type FC, useState } from 'react'
|
||||||
import { useParams } from 'react-router-dom'
|
import { useParams } from 'react-router-dom'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { Space, Button, Progress } from 'antd'
|
import { Space, Button, Progress, Form, Input } from 'antd'
|
||||||
import { ExclamationCircleFilled, CheckCircleFilled, ClockCircleOutlined, LoadingOutlined } from '@ant-design/icons'
|
import { ExclamationCircleFilled, CheckCircleFilled, ClockCircleOutlined, LoadingOutlined } from '@ant-design/icons'
|
||||||
import clsx from 'clsx'
|
import clsx from 'clsx'
|
||||||
import type { AnyObject } from 'antd/es/_util/type';
|
import type { AnyObject } from 'antd/es/_util/type';
|
||||||
@@ -79,6 +79,8 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
|
|||||||
const [creatingNodesEdges, setCreatingNodesEdges] = useState<ModuleItem>(initObj as ModuleItem)
|
const [creatingNodesEdges, setCreatingNodesEdges] = useState<ModuleItem>(initObj as ModuleItem)
|
||||||
const [deduplication, setDeduplication] = useState<ModuleItem>(initObj as ModuleItem)
|
const [deduplication, setDeduplication] = useState<ModuleItem>(initObj as ModuleItem)
|
||||||
|
|
||||||
|
const [runForm] = Form.useForm()
|
||||||
|
|
||||||
/** Run pilot test */
|
/** Run pilot test */
|
||||||
const handleRun = () => {
|
const handleRun = () => {
|
||||||
if(!id) return
|
if(!id) return
|
||||||
@@ -187,6 +189,7 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
|
|||||||
pilotRunMemoryExtractionConfig({
|
pilotRunMemoryExtractionConfig({
|
||||||
config_id: id,
|
config_id: id,
|
||||||
dialogue_text: t('memoryExtractionEngine.exampleText'),
|
dialogue_text: t('memoryExtractionEngine.exampleText'),
|
||||||
|
custom_text: runForm.getFieldValue('custom_text')
|
||||||
}, handleStreamMessage)
|
}, handleStreamMessage)
|
||||||
.finally(() => {
|
.finally(() => {
|
||||||
setRunLoading(false)
|
setRunLoading(false)
|
||||||
@@ -222,6 +225,14 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
|
|||||||
headerClassName="rb:pb-0! rb:pt-4!"
|
headerClassName="rb:pb-0! rb:pt-4!"
|
||||||
bodyClassName="rb:min-h-[calc(100vh-388px)] rb:p-[16px_20px]!"
|
bodyClassName="rb:min-h-[calc(100vh-388px)] rb:p-[16px_20px]!"
|
||||||
>
|
>
|
||||||
|
<Form form={runForm} layout="vertical">
|
||||||
|
<Form.Item
|
||||||
|
name="custom_text"
|
||||||
|
label={t('memoryExtractionEngine.custom_text')}
|
||||||
|
>
|
||||||
|
<Input.TextArea placeholder={t('common.pleaseEnter')} />
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
<div className="rb:min-h-[calc(100vh-480px)] rb:overflow-y-auto">
|
<div className="rb:min-h-[calc(100vh-480px)] rb:overflow-y-auto">
|
||||||
{runLoading
|
{runLoading
|
||||||
? <>
|
? <>
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import type { MemoryFormData, Memory, MemoryFormRef } from '../types';
|
import type { MemoryFormData, Memory, MemoryFormRef } from '../types';
|
||||||
import RbModal from '@/components/RbModal'
|
import RbModal from '@/components/RbModal'
|
||||||
import { createMemoryConfig, updateMemoryConfig } from '@/api/memory'
|
import { createMemoryConfig, updateMemoryConfig } from '@/api/memory'
|
||||||
import { getOntologyScenesUrl } from '@/api/ontology'
|
import { getOntologyScenesSimpleUrl } from '@/api/ontology'
|
||||||
import CustomSelect from '@/components/CustomSelect';
|
import CustomSelect from '@/components/CustomSelect';
|
||||||
|
|
||||||
const FormItem = Form.Item;
|
const FormItem = Form.Item;
|
||||||
@@ -129,8 +129,7 @@ const MemoryForm = forwardRef<MemoryFormRef, MemoryFormProps>(({
|
|||||||
>
|
>
|
||||||
<CustomSelect
|
<CustomSelect
|
||||||
placeholder={t('common.pleaseSelect')}
|
placeholder={t('common.pleaseSelect')}
|
||||||
url={getOntologyScenesUrl}
|
url={getOntologyScenesSimpleUrl}
|
||||||
params={{ pagesize: 100, page: 1 }}
|
|
||||||
hasAll={false}
|
hasAll={false}
|
||||||
valueKey='scene_id'
|
valueKey='scene_id'
|
||||||
labelKey="scene_name"
|
labelKey="scene_name"
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ const MemoryManagement: React.FC = () => {
|
|||||||
title={item.config_name}
|
title={item.config_name}
|
||||||
>
|
>
|
||||||
<Tooltip title={item.config_desc}>
|
<Tooltip title={item.config_desc}>
|
||||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1">{item.config_desc}</div>
|
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-[17px]">{item.config_desc}</div>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
<RbAlert className="rb:mt-3 ">
|
<RbAlert className="rb:mt-3 ">
|
||||||
<div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}>
|
<div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}>
|
||||||
|
|||||||
@@ -103,9 +103,9 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
|
|||||||
{model.api_keys && model.api_keys.length > 0 && (
|
{model.api_keys && model.api_keys.length > 0 && (
|
||||||
<div className="rb:mb-4">
|
<div className="rb:mb-4">
|
||||||
{model.api_keys.map((key) => (
|
{model.api_keys.map((key) => (
|
||||||
<div key={key.id} className="rb:flex rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2">
|
<div key={key.id} className="rb:flex rb:gap-3 rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2">
|
||||||
<div>
|
<div className="rb:flex-1">
|
||||||
<div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium">{key.api_key}</div>
|
<div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium rb:break-all">{key.api_key}</div>
|
||||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:mt-1">{key.api_base}</div>
|
<div className="rb:text-[#5B6167] rb:text-[12px] rb:mt-1">{key.api_base}</div>
|
||||||
</div>
|
</div>
|
||||||
<Button type="primary" danger ghost onClick={() => handleDelete(key.id)}>{t('common.remove')}</Button>
|
<Button type="primary" danger ghost onClick={() => handleDelete(key.id)}>{t('common.remove')}</Button>
|
||||||
|
|||||||
@@ -15,8 +15,6 @@ import CharacterCountPlugin from './plugin/CharacterCountPlugin'
|
|||||||
import InitialValuePlugin from './plugin/InitialValuePlugin';
|
import InitialValuePlugin from './plugin/InitialValuePlugin';
|
||||||
import CommandPlugin from './plugin/CommandPlugin';
|
import CommandPlugin from './plugin/CommandPlugin';
|
||||||
import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin';
|
import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin';
|
||||||
import Python3HighlightPlugin from './plugin/Python3HighlightPlugin';
|
|
||||||
import JavaScriptHighlightPlugin from './plugin/JavaScriptHighlightPlugin';
|
|
||||||
import LineNumberPlugin from './plugin/LineNumberPlugin';
|
import LineNumberPlugin from './plugin/LineNumberPlugin';
|
||||||
import BlurPlugin from './plugin/BlurPlugin';
|
import BlurPlugin from './plugin/BlurPlugin';
|
||||||
import { VariableNode } from './nodes/VariableNode'
|
import { VariableNode } from './nodes/VariableNode'
|
||||||
@@ -32,7 +30,7 @@ export interface LexicalEditorProps {
|
|||||||
lineHeight?: number;
|
lineHeight?: number;
|
||||||
size?: 'default' | 'small';
|
size?: 'default' | 'small';
|
||||||
type?: 'input' | 'textarea',
|
type?: 'input' | 'textarea',
|
||||||
language?: 'string' | 'jinja2' | 'python3' | 'javascript'
|
language?: 'string' | 'jinja2'
|
||||||
}
|
}
|
||||||
|
|
||||||
const theme = {
|
const theme = {
|
||||||
@@ -67,7 +65,7 @@ const Editor: FC<LexicalEditorProps> =({
|
|||||||
const [enableLineNumbers, setEnableLineNumbers] = useState(false)
|
const [enableLineNumbers, setEnableLineNumbers] = useState(false)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const needsLineNumbers = language === 'jinja2' || language === 'python3' || language === 'javascript';
|
const needsLineNumbers = language === 'jinja2';
|
||||||
setEnableJinja2(language === 'jinja2');
|
setEnableJinja2(language === 'jinja2');
|
||||||
setEnableLineNumbers(needsLineNumbers);
|
setEnableLineNumbers(needsLineNumbers);
|
||||||
|
|
||||||
@@ -237,13 +235,11 @@ const Editor: FC<LexicalEditorProps> =({
|
|||||||
<HistoryPlugin />
|
<HistoryPlugin />
|
||||||
<CommandPlugin />
|
<CommandPlugin />
|
||||||
{language === 'jinja2' && <Jinja2HighlightPlugin />}
|
{language === 'jinja2' && <Jinja2HighlightPlugin />}
|
||||||
{language === 'python3' && <Python3HighlightPlugin />}
|
|
||||||
{language === 'javascript' && <JavaScriptHighlightPlugin />}
|
|
||||||
{enableLineNumbers && <LineNumberPlugin />}
|
{enableLineNumbers && <LineNumberPlugin />}
|
||||||
<AutocompletePlugin options={options} enableJinja2={enableJinja2} />
|
<AutocompletePlugin options={options} enableJinja2={enableJinja2} />
|
||||||
<CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} />
|
<CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} />
|
||||||
<InitialValuePlugin key={language} value={value} options={options} enableLineNumbers={enableLineNumbers} />
|
<InitialValuePlugin value={value} options={options} enableLineNumbers={enableLineNumbers} />
|
||||||
{enableLineNumbers && <BlurPlugin />}
|
{enableJinja2 && <BlurPlugin />}
|
||||||
</div>
|
</div>
|
||||||
</LexicalComposer>
|
</LexicalComposer>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,182 +0,0 @@
|
|||||||
import { useEffect, useRef } from 'react';
|
|
||||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
|
||||||
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
|
|
||||||
|
|
||||||
const JS_KEYWORDS = new Set([
|
|
||||||
'async', 'await', 'break', 'case', 'catch', 'class', 'const', 'continue', 'debugger', 'default',
|
|
||||||
'delete', 'do', 'else', 'export', 'extends', 'finally', 'for', 'function', 'if', 'import',
|
|
||||||
'in', 'instanceof', 'let', 'new', 'return', 'super', 'switch', 'this', 'throw', 'try',
|
|
||||||
'typeof', 'var', 'void', 'while', 'with', 'yield', 'true', 'false', 'null', 'undefined'
|
|
||||||
]);
|
|
||||||
|
|
||||||
const JavaScriptHighlightPlugin = () => {
|
|
||||||
const [editor] = useLexicalComposerContext();
|
|
||||||
const isPastingRef = useRef(false);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
return editor.registerCommand(
|
|
||||||
PASTE_COMMAND,
|
|
||||||
() => {
|
|
||||||
isPastingRef.current = true;
|
|
||||||
setTimeout(() => {
|
|
||||||
isPastingRef.current = false;
|
|
||||||
}, 100);
|
|
||||||
return false;
|
|
||||||
},
|
|
||||||
COMMAND_PRIORITY_LOW
|
|
||||||
);
|
|
||||||
}, [editor]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
|
|
||||||
if (isPastingRef.current) return;
|
|
||||||
|
|
||||||
const text = textNode.getTextContent();
|
|
||||||
|
|
||||||
if (textNode.hasFormat('code')) return;
|
|
||||||
if (!needsHighlight(text)) return;
|
|
||||||
if (textNode.getStyle()) return;
|
|
||||||
|
|
||||||
const parent = textNode.getParent();
|
|
||||||
if (!parent) return;
|
|
||||||
|
|
||||||
const selection = $getSelection();
|
|
||||||
let selectionOffset = null;
|
|
||||||
if ($isRangeSelection(selection)) {
|
|
||||||
const anchor = selection.anchor;
|
|
||||||
if (anchor.getNode() === textNode) {
|
|
||||||
selectionOffset = anchor.offset;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const tokens = tokenizeJavaScript(text);
|
|
||||||
if (tokens.length <= 1) return;
|
|
||||||
|
|
||||||
const newNodes = tokens.map(token => {
|
|
||||||
const newNode = $createTextNode(token.text);
|
|
||||||
newNode.toggleFormat('code');
|
|
||||||
|
|
||||||
switch (token.type) {
|
|
||||||
case 'keyword':
|
|
||||||
newNode.setStyle('color: #d73a49; font-weight: 600;');
|
|
||||||
break;
|
|
||||||
case 'string':
|
|
||||||
newNode.setStyle('color: #032f62;');
|
|
||||||
break;
|
|
||||||
case 'comment':
|
|
||||||
newNode.setStyle('color: #6a737d; font-style: italic;');
|
|
||||||
break;
|
|
||||||
case 'number':
|
|
||||||
newNode.setStyle('color: #005cc5; font-weight: 500;');
|
|
||||||
break;
|
|
||||||
case 'function':
|
|
||||||
newNode.setStyle('color: #6f42c1; font-weight: 500;');
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return newNode;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (newNodes.length > 1) {
|
|
||||||
textNode.replace(newNodes[0]);
|
|
||||||
for (let i = 1; i < newNodes.length; i++) {
|
|
||||||
newNodes[i - 1].insertAfter(newNodes[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (selectionOffset !== null && $isRangeSelection(selection)) {
|
|
||||||
let currentOffset = 0;
|
|
||||||
for (const node of newNodes) {
|
|
||||||
const nodeLength = node.getTextContent().length;
|
|
||||||
if (currentOffset + nodeLength >= selectionOffset) {
|
|
||||||
node.select(selectionOffset - currentOffset, selectionOffset - currentOffset);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
currentOffset += nodeLength;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}, [editor]);
|
|
||||||
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
|
|
||||||
function needsHighlight(text: string): boolean {
|
|
||||||
return /[a-zA-Z0-9_/"'`]/.test(text);
|
|
||||||
}
|
|
||||||
|
|
||||||
function tokenizeJavaScript(text: string): Array<{text: string, type: string}> {
|
|
||||||
const tokens: Array<{text: string, type: string}> = [];
|
|
||||||
let i = 0;
|
|
||||||
|
|
||||||
while (i < text.length) {
|
|
||||||
// Single-line comments
|
|
||||||
if (text.slice(i, i + 2) === '//') {
|
|
||||||
let start = i;
|
|
||||||
while (i < text.length && text[i] !== '\n') i++;
|
|
||||||
tokens.push({ text: text.slice(start, i), type: 'comment' });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multi-line comments
|
|
||||||
if (text.slice(i, i + 2) === '/*') {
|
|
||||||
let start = i;
|
|
||||||
i += 2;
|
|
||||||
while (i < text.length && text.slice(i, i + 2) !== '*/') i++;
|
|
||||||
if (i < text.length) i += 2;
|
|
||||||
tokens.push({ text: text.slice(start, i), type: 'comment' });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Strings
|
|
||||||
if (text[i] === '"' || text[i] === "'" || text[i] === '`') {
|
|
||||||
const quote = text[i];
|
|
||||||
let start = i++;
|
|
||||||
|
|
||||||
while (i < text.length) {
|
|
||||||
if (text[i] === quote && text[i - 1] !== '\\') {
|
|
||||||
i++;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
tokens.push({ text: text.slice(start, i), type: 'string' });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Numbers
|
|
||||||
if (/\d/.test(text[i])) {
|
|
||||||
let start = i;
|
|
||||||
while (i < text.length && /[\d.]/.test(text[i])) i++;
|
|
||||||
tokens.push({ text: text.slice(start, i), type: 'number' });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keywords and identifiers
|
|
||||||
if (/[a-zA-Z_$]/.test(text[i])) {
|
|
||||||
let start = i;
|
|
||||||
while (i < text.length && /[a-zA-Z0-9_$]/.test(text[i])) i++;
|
|
||||||
const word = text.slice(start, i);
|
|
||||||
|
|
||||||
if (JS_KEYWORDS.has(word)) {
|
|
||||||
tokens.push({ text: word, type: 'keyword' });
|
|
||||||
} else if (i < text.length && text[i] === '(') {
|
|
||||||
tokens.push({ text: word, type: 'function' });
|
|
||||||
} else {
|
|
||||||
tokens.push({ text: word, type: 'text' });
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Other characters
|
|
||||||
let start = i;
|
|
||||||
while (i < text.length && !/[a-zA-Z0-9_$/"'`]/.test(text[i])) i++;
|
|
||||||
if (start < i) {
|
|
||||||
tokens.push({ text: text.slice(start, i), type: 'text' });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
export default JavaScriptHighlightPlugin;
|
|
||||||
@@ -1,177 +0,0 @@
|
|||||||
import { useEffect, useRef } from 'react';
|
|
||||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
|
||||||
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
|
|
||||||
|
|
||||||
const PYTHON_KEYWORDS = new Set([
|
|
||||||
'False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue',
|
|
||||||
'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import',
|
|
||||||
'in', 'is', 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while',
|
|
||||||
'with', 'yield'
|
|
||||||
]);
|
|
||||||
|
|
||||||
const Python3HighlightPlugin = () => {
|
|
||||||
const [editor] = useLexicalComposerContext();
|
|
||||||
const isPastingRef = useRef(false);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
return editor.registerCommand(
|
|
||||||
PASTE_COMMAND,
|
|
||||||
() => {
|
|
||||||
isPastingRef.current = true;
|
|
||||||
setTimeout(() => {
|
|
||||||
isPastingRef.current = false;
|
|
||||||
}, 100);
|
|
||||||
return false;
|
|
||||||
},
|
|
||||||
COMMAND_PRIORITY_LOW
|
|
||||||
);
|
|
||||||
}, [editor]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
|
|
||||||
if (isPastingRef.current) return;
|
|
||||||
|
|
||||||
const text = textNode.getTextContent();
|
|
||||||
|
|
||||||
if (textNode.hasFormat('code')) return;
|
|
||||||
if (textNode.getStyle()) return;
|
|
||||||
if (!needsHighlight(text)) return;
|
|
||||||
|
|
||||||
const parent = textNode.getParent();
|
|
||||||
if (!parent) return;
|
|
||||||
|
|
||||||
const selection = $getSelection();
|
|
||||||
let selectionOffset = null;
|
|
||||||
if ($isRangeSelection(selection)) {
|
|
||||||
const anchor = selection.anchor;
|
|
||||||
if (anchor.getNode() === textNode) {
|
|
||||||
selectionOffset = anchor.offset;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const tokens = tokenizePython(text);
|
|
||||||
if (tokens.length <= 1) return;
|
|
||||||
|
|
||||||
const newNodes = tokens.map(token => {
|
|
||||||
const newNode = $createTextNode(token.text);
|
|
||||||
newNode.toggleFormat('code');
|
|
||||||
|
|
||||||
switch (token.type) {
|
|
||||||
case 'keyword':
|
|
||||||
newNode.setStyle('color: #d73a49; font-weight: 600;');
|
|
||||||
break;
|
|
||||||
case 'string':
|
|
||||||
newNode.setStyle('color: #032f62;');
|
|
||||||
break;
|
|
||||||
case 'comment':
|
|
||||||
newNode.setStyle('color: #6a737d; font-style: italic;');
|
|
||||||
break;
|
|
||||||
case 'number':
|
|
||||||
newNode.setStyle('color: #005cc5; font-weight: 500;');
|
|
||||||
break;
|
|
||||||
case 'function':
|
|
||||||
newNode.setStyle('color: #6f42c1; font-weight: 500;');
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return newNode;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (newNodes.length > 1) {
|
|
||||||
textNode.replace(newNodes[0]);
|
|
||||||
for (let i = 1; i < newNodes.length; i++) {
|
|
||||||
newNodes[i - 1].insertAfter(newNodes[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (selectionOffset !== null && $isRangeSelection(selection)) {
|
|
||||||
let currentOffset = 0;
|
|
||||||
for (const node of newNodes) {
|
|
||||||
const nodeLength = node.getTextContent().length;
|
|
||||||
if (currentOffset + nodeLength >= selectionOffset) {
|
|
||||||
node.select(selectionOffset - currentOffset, selectionOffset - currentOffset);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
currentOffset += nodeLength;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}, [editor]);
|
|
||||||
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
|
|
||||||
function needsHighlight(text: string): boolean {
|
|
||||||
return /[a-zA-Z0-9_#"']/.test(text);
|
|
||||||
}
|
|
||||||
|
|
||||||
function tokenizePython(text: string): Array<{text: string, type: string}> {
|
|
||||||
const tokens: Array<{text: string, type: string}> = [];
|
|
||||||
let i = 0;
|
|
||||||
|
|
||||||
while (i < text.length) {
|
|
||||||
// Comments
|
|
||||||
if (text[i] === '#') {
|
|
||||||
let start = i;
|
|
||||||
while (i < text.length && text[i] !== '\n') i++;
|
|
||||||
tokens.push({ text: text.slice(start, i), type: 'comment' });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Strings
|
|
||||||
if (text[i] === '"' || text[i] === "'") {
|
|
||||||
const quote = text[i];
|
|
||||||
let start = i++;
|
|
||||||
const isTriple = text.slice(start, start + 3) === quote.repeat(3);
|
|
||||||
if (isTriple) i += 2;
|
|
||||||
|
|
||||||
while (i < text.length) {
|
|
||||||
if (isTriple && text.slice(i, i + 3) === quote.repeat(3)) {
|
|
||||||
i += 3;
|
|
||||||
break;
|
|
||||||
} else if (!isTriple && text[i] === quote && text[i - 1] !== '\\') {
|
|
||||||
i++;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
tokens.push({ text: text.slice(start, i), type: 'string' });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Numbers
|
|
||||||
if (/\d/.test(text[i])) {
|
|
||||||
let start = i;
|
|
||||||
while (i < text.length && /[\d.]/.test(text[i])) i++;
|
|
||||||
tokens.push({ text: text.slice(start, i), type: 'number' });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keywords and identifiers
|
|
||||||
if (/[a-zA-Z_]/.test(text[i])) {
|
|
||||||
let start = i;
|
|
||||||
while (i < text.length && /[a-zA-Z0-9_]/.test(text[i])) i++;
|
|
||||||
const word = text.slice(start, i);
|
|
||||||
|
|
||||||
if (PYTHON_KEYWORDS.has(word)) {
|
|
||||||
tokens.push({ text: word, type: 'keyword' });
|
|
||||||
} else if (i < text.length && text[i] === '(') {
|
|
||||||
tokens.push({ text: word, type: 'function' });
|
|
||||||
} else {
|
|
||||||
tokens.push({ text: word, type: 'text' });
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Other characters
|
|
||||||
let start = i;
|
|
||||||
while (i < text.length && !/[a-zA-Z0-9_#"']/.test(text[i])) i++;
|
|
||||||
if (start < i) {
|
|
||||||
tokens.push({ text: text.slice(start, i), type: 'text' });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
export default Python3HighlightPlugin;
|
|
||||||
@@ -5,8 +5,8 @@ import { Node } from '@antv/x6'
|
|||||||
|
|
||||||
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||||
import MappingList from '../MappingList'
|
import MappingList from '../MappingList'
|
||||||
import Editor from '../../Editor'
|
|
||||||
import OutputList from './OutputList'
|
import OutputList from './OutputList'
|
||||||
|
import CodeMirrorEditor from '@/components/CodeMirrorEditor';
|
||||||
|
|
||||||
interface MappingItem {
|
interface MappingItem {
|
||||||
name?: string
|
name?: string
|
||||||
@@ -110,7 +110,10 @@ const CodeExecution: FC<CodeExecutionProps> = ({ options }) => {
|
|||||||
<Form.Item noStyle shouldUpdate={(prev, curr) => prev.language !== curr.language}>
|
<Form.Item noStyle shouldUpdate={(prev, curr) => prev.language !== curr.language}>
|
||||||
{() => (
|
{() => (
|
||||||
<Form.Item name="code" noStyle>
|
<Form.Item name="code" noStyle>
|
||||||
<Editor size="small" language={form.getFieldValue('language')} />
|
<CodeMirrorEditor
|
||||||
|
language={form.getFieldValue('language')}
|
||||||
|
size="small"
|
||||||
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
)}
|
)}
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
|
|||||||
<div
|
<div
|
||||||
className="rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/workflow/recall.svg')] rb:group-hover:bg-[url('@/assets/images/workflow/recall_hover.svg')]"
|
className="rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/workflow/recall.svg')] rb:group-hover:bg-[url('@/assets/images/workflow/recall_hover.svg')]"
|
||||||
></div>
|
></div>
|
||||||
{t('workflow.config.knowledge-retrieval.recallConfig')}
|
{t('application.globalConfig')}
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user