diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 002547f6..db78a368 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -76,6 +76,7 @@ celery_app.conf.update( # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, + 'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'}, # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, diff --git a/api/app/controllers/knowledge_controller.py b/api/app/controllers/knowledge_controller.py index 901208ba..01f89a3d 100644 --- a/api/app/controllers/knowledge_controller.py +++ b/api/app/controllers/knowledge_controller.py @@ -9,13 +9,16 @@ from sqlalchemy import or_ from sqlalchemy.orm import Session from app.celery_app import celery_app +from app.core.error_codes import BizCode from app.core.logging_config import get_api_logger from app.core.rag.common import settings +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.yuque.client import YuqueAPIClient from app.core.rag.llm.chat_model import Base from app.core.rag.nlp import rag_tokenizer, search from app.core.rag.prompts.generator import graph_entity_types from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory -from app.core.response_utils import success +from app.core.response_utils import success, fail from app.db import get_db from app.dependencies import get_current_user from app.models import knowledge_model @@ -484,3 +487,99 @@ async def rebuild_knowledge_graph( except Exception as e: api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}") raise + + +@router.get("/check/yuque/auth", response_model=ApiResponse) +async def check_yuque_auth( + yuque_user_id: str, + yuque_token: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + check yuque auth info + """ + api_logger.info(f"check yuque auth info, username: {current_user.username}") + + try: + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + async with api_client as client: + repos = await client.get_user_repos() + if repos: + return success(data=repos, msg="Successfully auth yuque info") + return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"auth yuque info failed: {str(e)}") + raise + + +@router.get("/check/feishu/auth", response_model=ApiResponse) +async def check_yuque_auth( + feishu_app_id: str, + feishu_app_secret: str, + feishu_folder_token: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + check feishu auth info + """ + api_logger.info(f"check feishu auth info, username: {current_user.username}") + + try: + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret + ) + async with api_client as client: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + if files: + return success(data=files, msg="Successfully auth feishu info") + return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"auth feishu info failed: {str(e)}") + raise + + +@router.post("/{knowledge_id}/sync", response_model=ApiResponse) +async def sync_knowledge( + knowledge_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + sync knowledge base information based on knowledge_id + """ + api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}") + + try: + # 1. Query knowledge base information from the database + api_logger.debug(f"Query knowledge base: {knowledge_id}") + db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user) + if not db_knowledge: + api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The knowledge base does not exist or access is denied" + ) + + # 2. sync knowledge + # from app.tasks import sync_knowledge_for_kb + # sync_knowledge_for_kb(kb_id) + task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id]) + result = { + "task_id": task.id + } + return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}") + raise diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index e831cc01..7d74b85f 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -191,6 +191,11 @@ def update_config( api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + # 校验至少有一个字段需要更新 + if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") + return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") try: svc = DataConfigService(db) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index f36aa6c5..588c913c 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -52,6 +52,7 @@ from app.services.ontology_service import OntologyService from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.utils.validation.owl_validator import OWLValidator from app.services.model_service import ModelConfigService +from app.repositories.ontology_scene_repository import OntologySceneRepository api_logger = get_api_logger() @@ -116,27 +117,35 @@ def _get_ontology_service( detail=f"找不到指定的LLM模型: {llm_id}" ) - # 验证模型配置了API密钥 - if not model_config.api_keys: - logger.error(f"Model {llm_id} has no API key configuration") + # 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理) + from app.repositories.model_repository import ModelApiKeyRepository + api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id) + if not api_keys: + logger.error(f"Model {llm_id} has no active API key") raise HTTPException( status_code=400, - detail="指定的LLM模型没有配置API密钥" + detail="指定的LLM模型没有可用的API密钥" ) + api_key_config = api_keys[0] - api_key_config = model_config.api_keys[0] - + is_composite = getattr(model_config, 'is_composite', False) logger.info( f"Using specified model - user: {current_user.id}, " - f"model_id: {llm_id}, model_name: {api_key_config.model_name}" + f"model_id: {llm_id}, model_name: {api_key_config.model_name}, " + f"is_composite: {is_composite}, api_key_id: {api_key_config.id}" ) # 创建模型配置对象 from app.core.models.base import RedBearModelConfig + # 对于组合模型,使用 API Key 的 provider;否则使用 model_config 的 provider + actual_provider = api_key_config.provider if is_composite else ( + getattr(model_config, 'provider', None) or "openai" + ) + llm_model_config = RedBearModelConfig( model_name=api_key_config.model_name, - provider=model_config.provider if hasattr(model_config, 'provider') else "openai", + provider=actual_provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, max_retries=3, @@ -648,6 +657,46 @@ async def delete_scene( return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e)) +@router.get("/scenes/simple", response_model=ApiResponse) +async def get_scenes_simple( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取场景简单列表(轻量级,用于下拉选择) + + 仅返回 scene_id 和 scene_name,不加载关联数据,响应速度快。 + 适用于前端下拉选择场景的场景。 + + Args: + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含场景简单列表 + + Examples: + GET /scenes/simple + 返回: {"data": [{"scene_id": "xxx", "scene_name": "场景1"}, ...]} + """ + api_logger.info(f"Simple scene list requested by user {current_user.id}") + + try: + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + repo = OntologySceneRepository(db) + scenes = repo.get_simple_list(workspace_id) + + api_logger.info(f"Simple scene list retrieved: {len(scenes)} scenes") + return success(data=scenes, msg="查询成功") + + except Exception as e: + api_logger.error(f"Failed to get simple scene list: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + @router.get("/scenes", response_model=ApiResponse) async def get_scenes( workspace_id: Optional[str] = None, diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 40cf068e..fae20ea2 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -7,30 +7,21 @@ LangChain Agent 封装 - 支持流式输出 - 使用 RedBearLLM 支持多提供商 """ -import os + import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse -from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.core.logging_config import get_business_logger -from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType -from app.repositories.memory_short_repository import LongTermMemoryRepository from app.services.memory_agent_service import ( get_end_user_connected_config, ) -from app.services.memory_konwledges_server import write_rag -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool - -from app.utils.config_utils import resolve_config_id - logger = get_business_logger() @@ -289,105 +280,6 @@ class LangChainAgent: return content_parts - async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): - db = next(get_db()) - #TODO: 魔法数字 - scope=6 - - try: - repo = LongTermMemoryRepository(db) - await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) - - from app.core.memory.agent.utils.redis_tool import write_store - result = write_store.get_session_by_userid(end_user_id) - - # Handle case where no session exists in Redis (returns False) - if not result or result is False: - logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update") - return - - if type=="chunk" or type=="aggregate": - data = await format_parsing(result, "dict") - chunk_data = data[:scope] - if len(chunk_data)==scope: - repo.upsert(end_user_id, chunk_data) - logger.info(f'写入短长期:') - else: - # TODO: This branch handles type="time" strategy, currently unused. - # Will be activated when time-based long-term storage is implemented. - # TODO: 魔法数字 - extract 5 to a constant - long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - logger.debug(f"No recent sessions in Redis for user {end_user_id}") - return - long_messages = await messages_parse(long_time_data) - repo.upsert(end_user_id, long_messages) - logger.info(f'写入短长期:') - finally: - db.close() - - async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): - """ - 写入记忆(支持结构化消息) - - Args: - storage_type: 存储类型 (neo4j/rag) - end_user_id: 终端用户ID - user_message: 用户消息内容 - ai_message: AI 回复内容 - user_rag_memory_id: RAG 记忆ID - actual_end_user_id: 实际用户ID - actual_config_id: 配置ID - - 逻辑说明: - - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - - Neo4j 模式:使用结构化消息列表 - 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] - 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) - 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 - """ - - db = next(get_db()) - try: - actual_config_id=resolve_config_id(actual_config_id, db) - - if storage_type == "rag": - # RAG 模式:组合消息为字符串格式(保持原有逻辑) - combined_message = f"user: {user_message}\nassistant: {ai_message}" - await write_rag(end_user_id, combined_message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - else: - # Neo4j 模式:使用结构化消息列表 - structured_messages = [] - - # 始终添加用户消息(如果不为空) - if user_message: - structured_messages.append({"role": "user", "content": user_message}) - - # 只有当 AI 回复不为空时才添加 assistant 消息 - if ai_message: - structured_messages.append({"role": "assistant", "content": ai_message}) - - # 如果没有消息,直接返回 - if not structured_messages: - logger.warning(f"No messages to write for user {actual_end_user_id}") - return - - logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") - write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: 用户ID - structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] - actual_config_id, # config_id: 配置ID - storage_type, # storage_type: "neo4j" - user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) - ) - logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') - finally: - db.close() async def chat( self, message: str, @@ -520,14 +412,7 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: - long_term_messages=await agent_chat_messages(message_chat,content) - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") + await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id) response = { "content": content, "model": self.model_name, @@ -710,15 +595,7 @@ class LangChainAgent: yield total_tokens break if memory_flag: - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - long_term_messages = await agent_chat_messages(message_chat, full_content) - await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") - + await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index e9de02b6..895f61ac 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -1,8 +1,9 @@ +import json import os from app.core.logging_config import get_agent_logger -from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse +from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ @@ -10,46 +11,108 @@ from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import count_store from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context +from app.db import get_db_context, get_db +from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.schemas.memory_agent_schema import AgentMemory_Long_Term +from app.services.memory_konwledges_server import write_rag +from app.services.task_service import get_task_memory_write_result +from app.tasks import write_message_task +from app.utils.config_utils import resolve_config_id logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') +async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): + # RAG 模式:组合消息为字符串格式(保持原有逻辑) + combined_message = f"user: {user_message}\nassistant: {ai_message}" + await write_rag(end_user_id, combined_message, user_rag_memory_id) + logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') +async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, + actual_config_id, long_term_messages=[]): + """ + 写入记忆(支持结构化消息) -async def write_messages(end_user_id,langchain_messages,memory_config): - ''' - 写入数据到neo4j: - Args: + Args: + storage_type: 存储类型 (neo4j/rag) end_user_id: 终端用户ID - memory_config: 内存配置对象 - langchain_messages:原始数据LIST - ''' + user_message: 用户消息内容 + ai_message: AI 回复内容 + user_rag_memory_id: RAG 记忆ID + actual_end_user_id: 实际用户ID + actual_config_id: 配置ID + + 逻辑说明: + - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 + - Neo4j 模式:使用结构化消息列表 + 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] + 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) + 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + """ + + db = next(get_db()) try: + actual_config_id = resolve_config_id(actual_config_id, db) + # Neo4j 模式:使用结构化消息列表 + structured_messages = [] + + # 始终添加用户消息(如果不为空) + if isinstance(user_message, str) and user_message.strip() != "": + structured_messages.append({"role": "user", "content": user_message}) + + # 只有当 AI 回复不为空时才添加 assistant 消息 + if isinstance(ai_message, str) and ai_message.strip() != "": + structured_messages.append({"role": "assistant", "content": ai_message}) + + # 如果提供了 long_term_messages,使用它替代 structured_messages + if long_term_messages and isinstance(long_term_messages, list): + structured_messages = long_term_messages + elif long_term_messages and isinstance(long_term_messages, str): + # 如果是 JSON 字符串,先解析 + try: + structured_messages = json.loads(long_term_messages) + except json.JSONDecodeError: + logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}") + + # 如果没有消息,直接返回 + if not structured_messages: + logger.warning(f"No messages to write for user {actual_end_user_id}") + return + + logger.info( + f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") + write_id = write_message_task.delay( + actual_end_user_id, # end_user_id: 用户ID + structured_messages, # message: JSON 字符串格式的消息列表 + str(actual_config_id), # config_id: 配置ID字符串 + storage_type, # storage_type: "neo4j" + user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + ) + logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") + write_status = get_task_memory_write_result(str(write_id)) + logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') + finally: + db.close() + +async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): + with get_db_context() as db_session: + repo = LongTermMemoryRepository(db_session) + + + from app.core.memory.agent.utils.redis_tool import write_store + result = write_store.get_session_by_userid(end_user_id) + if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + data = await format_parsing(result, "dict") + chunk_data = data[:scope] + if len(chunk_data)==scope: + repo.upsert(end_user_id, chunk_data) + logger.info(f'---------写入短长期-----------') + else: + long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + long_messages = await messages_parse(long_time_data) + repo.upsert(end_user_id, long_messages) + logger.info(f'写入短长期:') + - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config - } - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - # TODO:删除 - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - print(contents) - except Exception as e: - import traceback - traceback.print_exc() '''根据窗口''' async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): ''' @@ -61,25 +124,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): scope:窗口大小 ''' scope=scope - redis_messages = [] is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_id is not False: is_end_user_id = count_store.get_sessions_count(end_user_id)[0] redis_messages = count_store.get_sessions_count(end_user_id)[1] if is_end_user_id and int(is_end_user_id) != int(scope): - print(is_end_user_id) is_end_user_id += 1 langchain_messages += redis_messages count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) elif int(is_end_user_id) == int(scope): - print('写入长期记忆,并且设置为0') - print(is_end_user_id) - formatted_messages = await chat_data_format(redis_messages) - print(100*'-') - print(formatted_messages) - print(100*'-') - await write_messages(end_user_id, formatted_messages, memory_config) - count_store.update_sessions_count(end_user_id, 0, '') + logger.info('写入长期记忆NEO4J') + formatted_messages = (redis_messages) + # 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用) + if hasattr(memory_config, 'config_id'): + config_id = memory_config.config_id + else: + config_id = memory_config + + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + config_id, formatted_messages) + count_store.update_sessions_count(end_user_id, 1, langchain_messages) else: count_store.save_sessions_count(end_user_id, 1, langchain_messages) @@ -93,12 +157,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time): memory_config: 内存配置对象 ''' long_time_data = write_store.find_user_recent_sessions(end_user_id, time) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - return - format_messages = await chat_data_format(long_time_data) + format_messages = (long_time_data) + messages=[] + memory_config=memory_config.config_id + for i in format_messages: + message=json.loads(i['Query']) + messages+= message if format_messages!=[]: - await write_messages(end_user_id, format_messages, memory_config) + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + memory_config, messages) '''聚合判断''' async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: """ @@ -109,13 +176,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] memory_config: 内存配置对象 """ - + try: # 1. 获取历史会话数据(使用新方法) result = write_store.get_all_sessions_by_end_user_id(end_user_id) - - # Handle case where no session exists in Redis (returns False or empty) - if not result or result is False: + history = await format_parsing(result) + if not result: history = [] else: history = await format_parsing(result) @@ -154,7 +220,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config } if not structured.is_same_event: logger.info(result_dict) - await write_messages(end_user_id, output_value, memory_config) + await write("neo4j", end_user_id, "", "", None, end_user_id, + memory_config.config_id, output_value) return result_dict except Exception as e: diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py index c4814de1..fcbb18e3 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): 清理后的数据 """ # 需要过滤的字段列表 + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 fields_to_remove = { 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', - 'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary" + 'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary" } if isinstance(data, dict): diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py index a1fb8226..9ce581ee 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -1,8 +1,6 @@ import json from langchain_core.messages import HumanMessage, AIMessage - - async def format_parsing(messages: list,type:str='string'): """ 格式化解析消息列表 @@ -26,13 +24,13 @@ async def format_parsing(messages: list,type:str='string'): role = content['role'] content = content['content'] if type == "string": - if role == 'human': + if role == 'human' or role=="user": content = '用户:' + content else: content = 'AI:' + content result.append(content) - if type == "dict": - if role == 'human': + if type == "dict" : + if role == 'human' or role=="user": user.append( content) else: ai.append(content) @@ -57,33 +55,7 @@ async def messages_parse(messages: list | dict): for key, values in zip(user, ai): database.append({key, values}) return database -async def chat_data_format(messages: list | dict): - """ - 将消息格式化为 LangChain 消息格式 - - Args: - messages: 消息列表或字典 - - Returns: - LangChain 消息列表 - """ - langchain_messages = [] - if isinstance(messages, list): - for msg in messages: - if 'role' in msg.keys(): - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - langchain_messages.append(AIMessage(content=msg['content'])) - if "Query" in msg.keys(): - langchain_messages.append(HumanMessage(content=msg['Query'])) - langchain_messages.append(AIMessage(content=msg['Answer'])) - if isinstance(messages, dict): - if messages['type'] == 'human': - langchain_messages.append(HumanMessage(content=messages['content'])) - elif messages['type'] == 'ai': - langchain_messages.append(AIMessage(content=messages['content'])) - return langchain_messages + async def agent_chat_messages(user_content,ai_content): messages = [ diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 9b858f47..1134acc7 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,13 +1,18 @@ import asyncio +import json import sys import warnings from contextlib import asynccontextmanager from langgraph.constants import END, START from langgraph.graph import StateGraph +from app.db import get_db, get_db_context from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.schemas.memory_agent_schema import AgentMemory_Long_Term +from app.services.memory_config_service import MemoryConfigService + warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) @@ -37,76 +42,61 @@ async def make_write_graph(): yield graph - -async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '', - end_user_id: str = '', scope: int = 6): - """Dispatch long-term memory storage to Celery background tasks. - - Args: - long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate' - langchain_messages: List of messages to store - memory_config: Memory configuration ID (string) - end_user_id: End user identifier - scope: Window size for 'chunk' strategy (default: 6) - """ - from app.tasks import ( - long_term_storage_window_task, - # TODO: Uncomment when implemented - # long_term_storage_time_task, - # long_term_storage_aggregate_task, - ) - from app.core.logging_config import get_logger - - logger = get_logger(__name__) - - # Convert config to string if needed - config_id = str(memory_config) if memory_config else '' - - if long_term_type == 'chunk': - # Strategy 1: Window-based batching (6 rounds of dialogue) - logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}") - long_term_storage_window_task.delay( - end_user_id=end_user_id, - langchain_messages=langchain_messages, - config_id=config_id, - scope=scope +async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): + from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment + from app.core.memory.agent.utils.redis_tool import write_store + write_store.save_session_write(end_user_id, (langchain_messages)) + # 获取数据库会话 + with get_db_context() as db_session: + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=memory_config, # 改为整数 + service_name="MemoryAgentService" ) - # TODO: Uncomment when time-based strategy is fully implemented - # elif long_term_type == 'time': - # # Strategy 2: Time-based retrieval - # logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}") - # long_term_storage_time_task.delay( - # end_user_id=end_user_id, - # config_id=config_id, - # time_window=5 - # ) - # TODO: Uncomment when aggregate strategy is fully implemented - # elif long_term_type == 'aggregate': - # # Strategy 3: Aggregate judgment (deduplication) - # logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}") - # long_term_storage_aggregate_task.delay( - # end_user_id=end_user_id, - # langchain_messages=langchain_messages, - # config_id=config_id - # ) + if long_term_type=='chunk': + '''方案一:对话窗口6轮对话''' + await window_dialogue(end_user_id,langchain_messages,memory_config,scope) + if long_term_type=='time': + """时间""" + await memory_long_term_storage(end_user_id, memory_config,5) + if long_term_type=='aggregate': + """方案三:聚合判断""" + await aggregate_judgment(end_user_id, langchain_messages, memory_config) + + + +async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): + from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent + from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save + from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages + if storage_type == AgentMemory_Long_Term.STORAGE_RAG: + await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) + else: + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK + SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE + long_term_messages = await agent_chat_messages(message_chat, aimessages) + await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) + await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) # async def main(): # """主函数 - 运行工作流""" # langchain_messages = [ # { # "role": "user", -# "content": "今天周五好开心啊" +# "content": "今天周五去爬山" # }, # { # "role": "assistant", -# "content": "你也这么觉得,我也是耶" +# "content": "好耶" # } # # ] # end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID # memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" -# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) -# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# # # # if __name__ == "__main__": diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index b61319e5..c5729628 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -294,6 +294,7 @@ class RedisCountStore: """ session_id = str(uuid.uuid4()) key = generate_session_key(session_id, key_type="count") + index_key = f'session:count:index:{end_user_id}' # 索引键 pipe = self.r.pipeline() pipe.hset(key, mapping={ @@ -304,6 +305,10 @@ class RedisCountStore: "starttime": get_current_timestamp() }) pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 + + # 创建索引:end_user_id -> session_id 映射 + pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60) + result = pipe.execute() print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") @@ -320,31 +325,47 @@ class RedisCountStore: list 或 False: 如果找到返回 [count, messages],否则返回 False """ try: - search_pattern = 'session:count:*' + # 使用索引键快速查找 + index_key = f'session:count:index:{end_user_id}' - for key in self.r.keys(search_pattern): - data = self.r.hgetall(key) - - if not data: - continue - - if data.get('end_user_id') == end_user_id: - count = data.get('count') - messages_str = data.get('messages') - - if count is not None: - messages = deserialize_messages(messages_str) - return [int(count), messages] + # 检查索引键类型,避免 WRONGTYPE 错误 + try: + key_type = self.r.type(index_key) + if key_type != 'string' and key_type != 'none': + self.r.delete(index_key) + return False + except Exception as type_error: + print(f"[get_sessions_count] 检查键类型失败: {type_error}") + + session_id = self.r.get(index_key) + + if not session_id: + return False + + # 直接获取数据 + key = generate_session_key(session_id, key_type="count") + data = self.r.hgetall(key) + + if not data: + # 索引存在但数据不存在,清理索引 + self.r.delete(index_key) + return False + + count = data.get('count') + messages_str = data.get('messages') + + if count is not None: + messages = deserialize_messages(messages_str) + return [int(count), messages] return False except Exception as e: print(f"[get_sessions_count] 查询失败: {e}") return False - def update_sessions_count(self, end_user_id: str, new_count: int, messages: Any) -> bool: """ - 通过 end_user_id 修改访问次数统计 + 通过 end_user_id 修改访问次数统计(优化版:使用索引) Args: end_user_id: 终端用户ID @@ -355,23 +376,39 @@ class RedisCountStore: bool: 更新成功返回 True,未找到记录返回 False """ try: + # 使用索引键快速查找 + index_key = f'session:count:index:{end_user_id}' + + # 检查索引键类型,避免 WRONGTYPE 错误 + try: + key_type = self.r.type(index_key) + if key_type != 'string' and key_type != 'none': + # 索引键类型错误,删除并返回 False + print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") + self.r.delete(index_key) + print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + return False + except Exception as type_error: + print(f"[update_sessions_count] 检查键类型失败: {type_error}") + + session_id = self.r.get(index_key) + + if not session_id: + print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + return False + + # 直接更新数据 + key = generate_session_key(session_id, key_type="count") messages_str = serialize_messages(messages) - search_pattern = 'session:count:*' - for key in self.r.keys(search_pattern): - data = self.r.hgetall(key) - - if not data: - continue - - if data.get('end_user_id') == end_user_id: - self.r.hset(key, 'count', int(new_count)) - self.r.hset(key, 'messages', messages_str) - print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") - return True + pipe = self.r.pipeline() + pipe.hset(key, 'count', int(new_count)) + pipe.hset(key, 'messages', messages_str) + result = pipe.execute() + + print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") + return True - print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") - return False except Exception as e: print(f"[update_sessions_count] 更新失败: {e}") return False diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 76a28156..fadc7669 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline This module provides the main write function for executing the knowledge extraction pipeline. Only MemoryConfig is needed - clients are constructed internally. """ +import asyncio import time from datetime import datetime @@ -124,23 +125,48 @@ async def write( except Exception as e: logger.error(f"Error creating indexes: {e}", exc_info=True) + # 添加死锁重试机制 + max_retries = 3 + retry_delay = 1 # 秒 + + for attempt in range(max_retries): + try: + success = await save_dialog_and_statements_to_neo4j( + dialogue_nodes=all_dialogue_nodes, + chunk_nodes=all_chunk_nodes, + statement_nodes=all_statement_nodes, + entity_nodes=all_entity_nodes, + statement_chunk_edges=all_statement_chunk_edges, + statement_entity_edges=all_statement_entity_edges, + entity_edges=all_entity_entity_edges, + connector=neo4j_connector + ) + if success: + logger.info("Successfully saved all data to Neo4j") + break + else: + logger.warning("Failed to save some data to Neo4j") + if attempt < max_retries - 1: + logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})") + await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 + except Exception as e: + error_msg = str(e) + # 检查是否是死锁错误 + if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower(): + if attempt < max_retries - 1: + logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})") + await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 + else: + logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}") + raise + else: + # 非死锁错误,直接抛出 + raise + try: - success = await save_dialog_and_statements_to_neo4j( - dialogue_nodes=all_dialogue_nodes, - chunk_nodes=all_chunk_nodes, - statement_nodes=all_statement_nodes, - entity_nodes=all_entity_nodes, - statement_chunk_edges=all_statement_chunk_edges, - statement_entity_edges=all_statement_entity_edges, - entity_edges=all_entity_entity_edges, - connector=neo4j_connector - ) - if success: - logger.info("Successfully saved all data to Neo4j") - else: - logger.warning("Failed to save some data to Neo4j") - finally: await neo4j_connector.close() + except Exception as e: + logger.error(f"Error closing Neo4j connector: {e}") log_time("Neo4j Database Save", time.time() - step_start, log_file) diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 79b88fdc..1880b9ab 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -413,7 +413,8 @@ class ExtractedEntityNode(Node): description="Entity aliases - alternative names for this entity" ) name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") - fact_summary: str = Field(default="", description="Summary of the fact about this entity") + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary: str = Field(default="", description="Summary of the fact about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index a425e0ed..f2f14d9e 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode): if len(desc_b) > len(desc_a): canonical.description = desc_b # 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序 - fact_a = getattr(canonical, "fact_summary", "") or "" - fact_b = getattr(ent, "fact_summary", "") or "" - def _extract_sources(txt: str) -> List[str]: - sources: List[str] = [] - if not txt: - return sources - for line in str(txt).splitlines(): - ln = line.strip() + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_a = getattr(canonical, "fact_summary", "") or "" + # fact_b = getattr(ent, "fact_summary", "") or "" + # def _extract_sources(txt: str) -> List[str]: + # sources: List[str] = [] + # if not txt: + # return sources + # for line in str(txt).splitlines(): + # ln = line.strip() # 支持“来源:”或“来源:”前缀 - m = re.match(r"^来源[::]\s*(.+)$", ln) - if m: - content = m.group(1).strip() - if content: - sources.append(content) + # m = re.match(r"^来源[::]\s*(.+)$", ln) + # if m: + # content = m.group(1).strip() + # if content: + # sources.append(content) # 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失 - if not sources and txt.strip(): - sources.append(txt.strip()) - return sources + # if not sources and txt.strip(): + # sources.append(txt.strip()) + # return sources try: - src_a = _extract_sources(fact_a) - src_b = _extract_sources(fact_b) - seen = set() - merged_sources: List[str] = [] - for s in src_a + src_b: - if s and s not in seen: - seen.add(s) - merged_sources.append(s) - if merged_sources: - name_line = f"实体: {getattr(canonical, 'name', '')}".strip() - canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources]) - elif fact_b and not fact_a: - canonical.fact_summary = fact_b + # src_a = _extract_sources(fact_a) + # src_b = _extract_sources(fact_b) + # seen = set() + # merged_sources: List[str] = [] + # for s in src_a + src_b: + # if s and s not in seen: + # seen.add(s) + # merged_sources.append(s) + # if merged_sources: + # name_line = f"实体: {getattr(canonical, 'name', '')}".strip() + # canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources]) + # elif fact_b and not fact_a: + # canonical.fact_summary = fact_b + pass except Exception: # 兜底:若解析失败,保留较长文本 - if len(fact_b) > len(fact_a): - canonical.fact_summary = fact_b + # if len(fact_b) > len(fact_a): + # canonical.fact_summary = fact_b + pass except Exception: pass diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py index 0249ac1f..a028e916 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py @@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: # # 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整) desc_a = (getattr(a, "description", "") or "") desc_b = (getattr(b, "description", "") or "") - fact_a = (getattr(a, "fact_summary", "") or "") - fact_b = (getattr(b, "fact_summary", "") or "") - score_a = len(desc_a) + len(fact_a) - score_b = len(desc_b) + len(fact_b) + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_a = (getattr(a, "fact_summary", "") or "") + # fact_b = (getattr(b, "fact_summary", "") or "") + # score_a = len(desc_a) + len(fact_a) + # score_b = len(desc_b) + len(fact_b) + score_a = len(desc_a) + score_b = len(desc_b) if score_a != score_b: return 0 if score_a >= score_b else 1 return 0 @@ -189,7 +192,8 @@ async def _judge_pair( "entity_type": getattr(a, "entity_type", None), "description": getattr(a, "description", None), "aliases": getattr(a, "aliases", None) or [], - "fact_summary": getattr(a, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(a, "fact_summary", None), "connect_strength": getattr(a, "connect_strength", None), } entity_b = { @@ -197,7 +201,8 @@ async def _judge_pair( "entity_type": getattr(b, "entity_type", None), "description": getattr(b, "description", None), "aliases": getattr(b, "aliases", None) or [], - "fact_summary": getattr(b, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(b, "fact_summary", None), "connect_strength": getattr(b, "connect_strength", None), } # 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式) @@ -248,7 +253,8 @@ async def _judge_pair_disamb( "entity_type": getattr(a, "entity_type", None), "description": getattr(a, "description", None), "aliases": getattr(a, "aliases", None) or [], - "fact_summary": getattr(a, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(a, "fact_summary", None), "connect_strength": getattr(a, "connect_strength", None), } entity_b = { @@ -256,7 +262,8 @@ async def _judge_pair_disamb( "entity_type": getattr(b, "entity_type", None), "description": getattr(b, "description", None), "aliases": getattr(b, "aliases", None) or [], - "fact_summary": getattr(b, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(b, "fact_summary", None), "connect_strength": getattr(b, "connect_strength", None), } prompt = render_entity_dedup_prompt( diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py index dbc697d9..028a926f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py @@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode: description=row.get("description") or "", aliases=row.get("aliases") or [], name_embedding=row.get("name_embedding") or [], - fact_summary=row.get("fact_summary") or "", + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary=row.get("fact_summary") or "", connect_strength=row.get("connect_strength") or "", ) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 98bec522..08be0aeb 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1088,7 +1088,8 @@ class ExtractionOrchestrator: entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type description=getattr(entity, 'description', ''), # 添加必需的 description 字段 example=getattr(entity, 'example', ''), # 新增:传递示例字段 - fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases name_embedding=getattr(entity, 'name_embedding', None), diff --git a/api/app/core/memory/utils/alias_utils.py b/api/app/core/memory/utils/alias_utils.py index df75752a..ff139128 100644 --- a/api/app/core/memory/utils/alias_utils.py +++ b/api/app/core/memory/utils/alias_utils.py @@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li key=lambda eid: ( _strength_rank(eid), len(getattr(entity_by_id.get(eid), 'description', '') or ''), - len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '') + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '') + 0 # 临时占位 ), reverse=True ) diff --git a/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 b/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 index be53c9d4..7fb465a2 100644 --- a/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 @@ -9,7 +9,8 @@ - 类型: "{{ entity_a.entity_type | default('') }}" - 描述: "{{ entity_a.description | default('') }}" - 别名: {{ entity_a.aliases | default([]) }} -- 摘要: "{{ entity_a.fact_summary | default('') }}" +{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #} +{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #} - 连接强弱: "{{ entity_a.connect_strength | default('') }}" 实体B: @@ -17,7 +18,8 @@ - 类型: "{{ entity_b.entity_type | default('') }}" - 描述: "{{ entity_b.description | default('') }}" - 别名: {{ entity_b.aliases | default([]) }} -- 摘要: "{{ entity_b.fact_summary | default('') }}" +{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #} +{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #} - 连接强弱: "{{ entity_b.connect_strength | default('') }}" 上下文: diff --git a/api/app/core/rag/crawler/__init__.py b/api/app/core/rag/crawler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/rag/crawler/__main__.py b/api/app/core/rag/crawler/__main__.py new file mode 100644 index 00000000..51a6870f --- /dev/null +++ b/api/app/core/rag/crawler/__main__.py @@ -0,0 +1,89 @@ +"""Command-line interface for web crawler.""" + +import argparse +import logging +import sys +from app.core.rag.crawler.web_crawler import WebCrawler + + +def setup_logging(verbose: bool = False): + """Set up logging configuration.""" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout) + ] + ) + + +def main(entry_url: str, + max_pages: int = 200, + delay_seconds: float = 1.0, + timeout_seconds: int = 10, + user_agent: str = "KnowledgeBaseCrawler/1.0"): + """Main entry point for the crawler.""" + # Create crawler + crawler = WebCrawler( + entry_url=entry_url, + max_pages=max_pages, + delay_seconds=delay_seconds, + timeout_seconds=timeout_seconds, + user_agent=user_agent + ) + + # Crawl and collect documents + documents = [] + try: + for doc in crawler.crawl(): + print(f"\n{'=' * 80}") + print(f"URL: {doc.url}") + print(f"Title: {doc.title}") + print(f"Content Length: {doc.content_length} characters") + print(f"Word Count: {doc.metadata.get('word_count', 0)} words") + print(f"{'=' * 80}\n") + + documents.append({ + 'url': doc.url, + 'title': doc.title, + 'content': doc.content, + 'content_length': doc.content_length, + 'crawl_timestamp': doc.crawl_timestamp.isoformat(), + 'http_status': doc.http_status, + 'metadata': doc.metadata + }) + + except KeyboardInterrupt: + print("\n\nCrawl interrupted by user.") + + except Exception as e: + print(f"\n\nError during crawl: {e}") + sys.exit(1) + + # Get summary + summary = crawler.get_summary() + print(f"\n{'=' * 80}") + print("CRAWL SUMMARY") + print(f"{'=' * 80}") + print(f"Total Pages Processed: {summary.total_pages_processed}") + print(f"Total Errors: {summary.total_errors}") + print(f"Total Skipped: {summary.total_skipped}") + print(f"Total URLs Discovered: {summary.total_urls_discovered}") + print(f"Duration: {summary.duration_seconds:.2f} seconds") + print(f"documents: {documents}") + + if summary.error_breakdown: + print(f"\nError Breakdown:") + for error_type, count in summary.error_breakdown.items(): + print(f" {error_type}: {count}") + + +if __name__ == '__main__': + entry_url = "https://www.xxx.com" + max_pages = 20 + delay_seconds = 1.0 + timeout_seconds = 10 + user_agent = "KnowledgeBaseCrawler/1.0" + + main(entry_url, max_pages, delay_seconds, timeout_seconds, user_agent) diff --git a/api/app/core/rag/crawler/content_extractor.py b/api/app/core/rag/crawler/content_extractor.py new file mode 100644 index 00000000..69dca53c --- /dev/null +++ b/api/app/core/rag/crawler/content_extractor.py @@ -0,0 +1,233 @@ +"""Content extractor for web crawler.""" + +from bs4 import BeautifulSoup +import re +import logging + +from app.core.rag.crawler.models import ExtractedContent + +logger = logging.getLogger(__name__) + + +class ContentExtractor: + """Extract clean, readable text from HTML pages.""" + + # Tags to remove completely + REMOVE_TAGS = ['script', 'style', 'nav', 'header', 'footer', 'aside'] + + # Tags that typically contain main content + MAIN_CONTENT_TAGS = ['article', 'main'] + + # Content extraction tags + CONTENT_TAGS = ['p', 'div', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'td', 'th', 'section'] + + def is_static_content(self, html: str) -> bool: + """ + Determine if the HTML represents static content. + + Detects JavaScript-rendered content by checking for minimal body + with heavy script tag presence. + + Args: + html: Raw HTML string + + Returns: + bool: True if static, False if JavaScript-rendered + """ + try: + soup = BeautifulSoup(html, 'lxml') + + # Count script tags + script_tags = soup.find_all('script') + script_count = len(script_tags) + + # Get body content (excluding scripts and styles) + body = soup.find('body') + if not body: + return False + + # Remove scripts and styles temporarily for text check + for tag in body.find_all(['script', 'style']): + tag.decompose() + + # Get text content + text = body.get_text(strip=True) + text_length = len(text) + + # If there's very little text but many scripts, likely JS-rendered + if script_count > 5 and text_length < 200: + logger.warning("Detected JavaScript-rendered content (many scripts, little text)") + return False + + # If there's no meaningful text, likely JS-rendered + if text_length < 50: + logger.warning("Detected JavaScript-rendered content (minimal text)") + return False + + return True + + except Exception as e: + logger.error(f"Error checking if content is static: {e}") + return True # Assume static on error + + def extract(self, html: str, url: str) -> ExtractedContent: + """ + Extract clean text content from HTML. + + Args: + html: Raw HTML string + url: Source URL (for context) + + Returns: + ExtractedContent: Contains title, text, metadata + """ + try: + soup = BeautifulSoup(html, 'lxml') + + # Check if content is static + is_static = self.is_static_content(html) + + # Extract title + title = self._extract_title(soup) + + # Remove unwanted tags + for tag_name in self.REMOVE_TAGS: + for tag in soup.find_all(tag_name): + tag.decompose() + + # Extract main content + text = self._extract_main_content(soup) + + # Normalize whitespace + text = self._normalize_whitespace(text) + + # Count words + word_count = len(text.split()) + + logger.info(f"Extracted {word_count} words from {url}") + + return ExtractedContent( + title=title, + text=text, + is_static=is_static, + word_count=word_count, + metadata={'url': url} + ) + + except Exception as e: + logger.error(f"Error extracting content from {url}: {e}") + return ExtractedContent( + title=url, + text="", + is_static=False, + word_count=0, + metadata={'url': url, 'error': str(e)} + ) + + def _extract_title(self, soup: BeautifulSoup) -> str: + """ + Extract title from HTML. + + Tries tag first, then first <h1>. + + Args: + soup: BeautifulSoup object + + Returns: + str: Page title + """ + # Try <title> tag + title_tag = soup.find('title') + if title_tag and title_tag.string: + return title_tag.string.strip() + + # Try first <h1> + h1_tag = soup.find('h1') + if h1_tag: + return h1_tag.get_text(strip=True) + + # Default to empty string + return "" + + def _extract_main_content(self, soup: BeautifulSoup) -> str: + """ + Extract main content from HTML. + + Prioritizes semantic HTML5 elements like <article> and <main>. + + Args: + soup: BeautifulSoup object + + Returns: + str: Extracted text content + """ + # Try to find main content area + main_content = None + + # Priority 1: <article> or <main> tags + for tag_name in self.MAIN_CONTENT_TAGS: + main_content = soup.find(tag_name) + if main_content: + logger.debug(f"Found main content in <{tag_name}> tag") + break + + # Priority 2: div with role="main" + if not main_content: + main_content = soup.find('div', role='main') + if main_content: + logger.debug("Found main content in div[role='main']") + + # Priority 3: Common class/id patterns + if not main_content: + for pattern in ['content', 'main', 'article', 'post']: + main_content = soup.find(['div', 'section'], class_=re.compile(pattern, re.I)) + if main_content: + logger.debug(f"Found main content with class pattern '{pattern}'") + break + + main_content = soup.find(['div', 'section'], id=re.compile(pattern, re.I)) + if main_content: + logger.debug(f"Found main content with id pattern '{pattern}'") + break + + # Fallback: use body + if not main_content: + main_content = soup.find('body') + logger.debug("Using <body> as main content (no specific content area found)") + + # Extract text from content tags + if main_content: + text_parts = [] + for tag in main_content.find_all(self.CONTENT_TAGS): + text = tag.get_text(strip=True) + if text: + text_parts.append(text) + + return '\n'.join(text_parts) + + return "" + + def _normalize_whitespace(self, text: str) -> str: + """ + Normalize whitespace in text. + + - Collapse multiple spaces to single space + - Reduce excessive newlines to maximum 2 + - Strip leading/trailing whitespace + + Args: + text: Text to normalize + + Returns: + str: Normalized text + """ + # Collapse multiple spaces to single space + text = re.sub(r' +', ' ', text) + + # Reduce excessive newlines to maximum 2 + text = re.sub(r'\n{3,}', '\n\n', text) + + # Strip leading/trailing whitespace + text = text.strip() + + return text diff --git a/api/app/core/rag/crawler/http_fetcher.py b/api/app/core/rag/crawler/http_fetcher.py new file mode 100644 index 00000000..b3a08098 --- /dev/null +++ b/api/app/core/rag/crawler/http_fetcher.py @@ -0,0 +1,302 @@ +"""HTTP fetcher for web crawler.""" + +import requests +import time +import logging +import re +from typing import Optional, Dict + + +from app.core.rag.crawler.models import FetchResult + +logger = logging.getLogger(__name__) + + +class HTTPFetcher: + """Handle HTTP requests with retries, error handling, and response validation.""" + + def __init__( + self, + timeout: int = 10, + max_retries: int = 3, + user_agent: str = "KnowledgeBaseCrawler/1.0" + ): + """ + Initialize HTTP fetcher. + + Args: + timeout: Request timeout in seconds + max_retries: Maximum number of retry attempts + user_agent: User-Agent header value + """ + self.timeout = timeout + self.max_retries = max_retries + self.user_agent = user_agent + + # Create session for connection pooling + self.session = requests.Session() + self.session.headers.update({ + 'User-Agent': user_agent + }) + + def fetch(self, url: str) -> FetchResult: + """ + Fetch a URL with retry logic and error handling. + + Args: + url: URL to fetch + + Returns: + FetchResult: Contains status_code, content, headers, error info + """ + last_error = None + + for attempt in range(self.max_retries): + try: + # Calculate backoff delay for retries + if attempt > 0: + backoff_delay = 2 ** (attempt - 1) # 1s, 2s, 4s + logger.info(f"Retry attempt {attempt + 1}/{self.max_retries} for {url} after {backoff_delay}s") + time.sleep(backoff_delay) + + # Make HTTP request + response = self.session.get( + url, + timeout=self.timeout, + allow_redirects=True + ) + + # Handle different status codes + if response.status_code == 429: + # Too Many Requests - backoff and retry + logger.warning(f"429 Too Many Requests for {url}, backing off") + if attempt < self.max_retries - 1: + continue + + if response.status_code == 503: + # Service Unavailable - pause and retry + logger.warning(f"503 Service Unavailable for {url}") + if attempt < self.max_retries - 1: + time.sleep(5) # Longer pause for 503 + continue + + # Success or client error (don't retry 4xx except 429) + if 200 <= response.status_code < 300: + logger.info(f"Successfully fetched {url} (status: {response.status_code})") + + # Get correctly encoded content + content = self._get_decoded_content(response) + + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=content, + headers=dict(response.headers), + error=None, + success=True + ) + elif response.status_code == 404: + logger.info(f"404 Not Found: {url}") + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=None, + headers=dict(response.headers), + error="Not Found", + success=False + ) + elif 400 <= response.status_code < 500: + logger.warning(f"Client error {response.status_code} for {url}") + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=None, + headers=dict(response.headers), + error=f"Client error: {response.status_code}", + success=False + ) + elif 500 <= response.status_code < 600: + logger.error(f"Server error {response.status_code} for {url}") + last_error = f"Server error: {response.status_code}" + if attempt < self.max_retries - 1: + continue + return FetchResult( + url=url, + final_url=url, + status_code=response.status_code, + content=None, + headers={}, + error=last_error, + success=False + ) + + except requests.exceptions.Timeout: + last_error = "Request timeout" + logger.warning(f"Timeout fetching {url} (attempt {attempt + 1}/{self.max_retries})") + if attempt >= self.max_retries - 1: + break + continue + + except requests.exceptions.SSLError as e: + last_error = f"SSL/TLS error: {str(e)}" + logger.error(f"SSL/TLS error for {url}: {e}") + return FetchResult( + url=url, + final_url=url, + status_code=0, + content=None, + headers={}, + error=last_error, + success=False + ) + + except requests.exceptions.ConnectionError as e: + last_error = f"Connection error: {str(e)}" + logger.warning(f"Connection error for {url} (attempt {attempt + 1}/{self.max_retries}): {e}") + if attempt >= self.max_retries - 1: + break + continue + + except requests.exceptions.RequestException as e: + last_error = f"Request error: {str(e)}" + logger.error(f"Request error for {url}: {e}") + if attempt >= self.max_retries - 1: + break + continue + + # All retries exhausted + logger.error(f"Failed to fetch {url} after {self.max_retries} attempts: {last_error}") + return FetchResult( + url=url, + final_url=url, + status_code=0, + content=None, + headers={}, + error=last_error or "Unknown error", + success=False + ) + + def _get_decoded_content(self, response) -> str: + """ + Get correctly decoded content from response. + + Handles encoding detection and fallback strategies: + 1. Try encoding from HTML meta tags + 2. Try response.encoding (from Content-Type header or detected) + 3. Try UTF-8 + 4. Try common encodings (GB2312, GBK for Chinese, etc.) + 5. Fall back to latin-1 with error replacement + + Args: + response: requests.Response object + + Returns: + str: Decoded content + """ + # Try to detect encoding from HTML meta tags + meta_encoding = self._detect_encoding_from_meta(response.content) + if meta_encoding: + try: + content = response.content.decode(meta_encoding) + logger.info(f"Successfully decoded with meta tag encoding: {meta_encoding}") + return content + except (UnicodeDecodeError, LookupError) as e: + logger.warning(f"Failed to decode with meta encoding {meta_encoding}: {e}") + + # Try response.encoding (from Content-Type header or detected by requests) + if response.encoding and response.encoding.lower() != 'iso-8859-1': + # Note: requests defaults to ISO-8859-1 if no charset in Content-Type, + # so we skip it here and try UTF-8 first + try: + return response.text + except (UnicodeDecodeError, LookupError) as e: + logger.warning(f"Failed to decode with detected encoding {response.encoding}: {e}") + + # Try UTF-8 first (most common) + try: + return response.content.decode('utf-8') + except UnicodeDecodeError: + logger.debug("UTF-8 decoding failed, trying other encodings") + + # Try common encodings for different languages + encodings_to_try = [ + 'gbk', # Chinese (Simplified) + 'gb2312', # Chinese (Simplified, older) + 'gb18030', # Chinese (Simplified, extended) + 'big5', # Chinese (Traditional) + 'shift_jis', # Japanese + 'euc-jp', # Japanese + 'euc-kr', # Korean + 'iso-8859-1', # Western European + 'windows-1252', # Windows Western European + 'windows-1251', # Cyrillic + ] + + for encoding in encodings_to_try: + try: + content = response.content.decode(encoding) + logger.info(f"Successfully decoded with {encoding}") + return content + except (UnicodeDecodeError, LookupError): + continue + + # Last resort: use latin-1 with error replacement + logger.warning("All encoding attempts failed, using latin-1 with error replacement") + return response.content.decode('latin-1', errors='replace') + + def _detect_encoding_from_meta(self, content: bytes) -> Optional[str]: + """ + Detect encoding from HTML meta tags. + + Looks for: + - <meta charset="..."> + - <meta http-equiv="Content-Type" content="...; charset=..."> + + Args: + content: Raw response content (bytes) + + Returns: + Optional[str]: Detected encoding or None + """ + try: + # Only check first 2KB for performance + head = content[:2048] + + # Try to decode as ASCII/Latin-1 to search for meta tags + try: + head_str = head.decode('ascii', errors='ignore') + except: + head_str = head.decode('latin-1', errors='ignore') + + # Look for <meta charset="..."> + charset_match = re.search( + r'<meta[^>]+charset=["\']?([a-zA-Z0-9_-]+)', + head_str, + re.IGNORECASE + ) + if charset_match: + encoding = charset_match.group(1).lower() + logger.debug(f"Found charset in meta tag: {encoding}") + return encoding + + # Look for <meta http-equiv="Content-Type" content="...; charset=..."> + content_type_match = re.search( + r'<meta[^>]+http-equiv=["\']?content-type["\']?[^>]+content=["\']([^"\']+)', + head_str, + re.IGNORECASE + ) + if content_type_match: + content_value = content_type_match.group(1) + charset_match = re.search(r'charset=([a-zA-Z0-9_-]+)', content_value, re.IGNORECASE) + if charset_match: + encoding = charset_match.group(1).lower() + logger.debug(f"Found charset in Content-Type meta: {encoding}") + return encoding + + except Exception as e: + logger.debug(f"Error detecting encoding from meta tags: {e}") + + return None diff --git a/api/app/core/rag/crawler/models.py b/api/app/core/rag/crawler/models.py new file mode 100644 index 00000000..5d10963c --- /dev/null +++ b/api/app/core/rag/crawler/models.py @@ -0,0 +1,52 @@ +"""Data models for web crawler.""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, Any, Optional + + +@dataclass +class CrawledDocument: + """Represents a successfully processed web page with extracted content.""" + url: str + title: str + content: str + content_length: int + crawl_timestamp: datetime + http_status: int + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class FetchResult: + """Represents the result of an HTTP fetch operation.""" + url: str + final_url: str + status_code: int + content: Optional[str] + headers: Dict[str, str] + error: Optional[str] + success: bool + + +@dataclass +class ExtractedContent: + """Represents content extracted from HTML.""" + title: str + text: str + is_static: bool + word_count: int + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class CrawlSummary: + """Represents statistics from a completed crawl.""" + total_pages_processed: int + total_errors: int + total_skipped: int + total_urls_discovered: int + start_time: datetime + end_time: datetime + duration_seconds: float + error_breakdown: Dict[str, int] = field(default_factory=dict) diff --git a/api/app/core/rag/crawler/rate_limiter.py b/api/app/core/rag/crawler/rate_limiter.py new file mode 100644 index 00000000..e00fad36 --- /dev/null +++ b/api/app/core/rag/crawler/rate_limiter.py @@ -0,0 +1,57 @@ +"""Rate limiter for web crawler.""" + +import time +import logging + +logger = logging.getLogger(__name__) + + +class RateLimiter: + """Enforce delays between requests to be polite to servers.""" + + def __init__(self, delay_seconds: float = 1.0): + """ + Initialize rate limiter. + + Args: + delay_seconds: Minimum delay between requests + """ + self.delay_seconds = delay_seconds + self.last_request_time = 0.0 + self.max_delay = 60.0 # Cap maximum delay at 60 seconds + + def wait(self): + """ + Block until enough time has passed since last request. + Respects the configured delay. + """ + current_time = time.time() + elapsed = current_time - self.last_request_time + + if elapsed < self.delay_seconds: + sleep_time = self.delay_seconds - elapsed + logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds") + time.sleep(sleep_time) + + self.last_request_time = time.time() + + def set_delay(self, delay_seconds: float): + """ + Update the delay (useful for respecting Crawl-delay from robots.txt). + + Args: + delay_seconds: New delay in seconds + """ + self.delay_seconds = min(delay_seconds, self.max_delay) + logger.info(f"Rate limiter delay updated to {self.delay_seconds} seconds") + + def backoff(self, multiplier: float = 2.0): + """ + Increase delay exponentially for backoff scenarios (429, 503 responses). + + Args: + multiplier: Factor to multiply current delay by + """ + old_delay = self.delay_seconds + self.delay_seconds = min(self.delay_seconds * multiplier, self.max_delay) + logger.warning(f"Rate limiter backing off: {old_delay:.2f}s -> {self.delay_seconds:.2f}s") diff --git a/api/app/core/rag/crawler/robots_parser.py b/api/app/core/rag/crawler/robots_parser.py new file mode 100644 index 00000000..882bc9c8 --- /dev/null +++ b/api/app/core/rag/crawler/robots_parser.py @@ -0,0 +1,118 @@ +"""Robots.txt parser for web crawler.""" + +from urllib.robotparser import RobotFileParser +from urllib.parse import urlparse, urljoin +from typing import Optional +import logging + +logger = logging.getLogger(__name__) + + +class RobotsParser: + """Parse and check robots.txt compliance for URLs.""" + + def __init__(self, user_agent: str, timeout: int = 10): + """ + Initialize robots.txt parser. + + Args: + user_agent: User agent string to check permissions for + timeout: Timeout for fetching robots.txt + """ + self.user_agent = user_agent + self.timeout = timeout + self._parsers = {} # Cache parsers by domain + + def _get_robots_url(self, url: str) -> str: + """ + Get the robots.txt URL for a given URL. + + Args: + url: URL to get robots.txt for + + Returns: + str: robots.txt URL + """ + parsed = urlparse(url) + robots_url = f"{parsed.scheme}://{parsed.netloc}/robots.txt" + return robots_url + + def _get_parser(self, url: str) -> RobotFileParser: + """ + Get or create a RobotFileParser for the domain. + + Args: + url: URL to get parser for + + Returns: + RobotFileParser: Parser for the domain + """ + robots_url = self._get_robots_url(url) + + # Return cached parser if available + if robots_url in self._parsers: + return self._parsers[robots_url] + + # Create new parser + parser = RobotFileParser() + parser.set_url(robots_url) + + try: + # Fetch and parse robots.txt + parser.read() + logger.info(f"Successfully fetched robots.txt from {robots_url}") + except Exception as e: + # If robots.txt cannot be fetched, assume all URLs are allowed + logger.warning(f"Could not fetch robots.txt from {robots_url}: {e}. Assuming all URLs allowed.") + # Create a permissive parser + parser = RobotFileParser() + parser.parse([]) # Empty robots.txt allows everything + + # Cache the parser + self._parsers[robots_url] = parser + return parser + + def can_fetch(self, url: str) -> bool: + """ + Check if the given URL can be fetched according to robots.txt. + + Args: + url: URL to check + + Returns: + bool: True if allowed, False if disallowed + """ + try: + parser = self._get_parser(url) + allowed = parser.can_fetch(self.user_agent, url) + + if not allowed: + logger.info(f"URL disallowed by robots.txt: {url}") + + return allowed + except Exception as e: + logger.error(f"Error checking robots.txt for {url}: {e}") + # On error, assume allowed + return True + + def get_crawl_delay(self, url: str) -> Optional[float]: + """ + Get the Crawl-delay directive from robots.txt if present. + + Args: + url: URL to get crawl delay for + + Returns: + Optional[float]: Delay in seconds, or None if not specified + """ + try: + parser = self._get_parser(url) + delay = parser.crawl_delay(self.user_agent) + + if delay is not None: + logger.info(f"Crawl-delay from robots.txt: {delay} seconds") + + return delay + except Exception as e: + logger.error(f"Error getting crawl delay for {url}: {e}") + return None diff --git a/api/app/core/rag/crawler/url_normalizer.py b/api/app/core/rag/crawler/url_normalizer.py new file mode 100644 index 00000000..7762a9d5 --- /dev/null +++ b/api/app/core/rag/crawler/url_normalizer.py @@ -0,0 +1,171 @@ +"""URL normalization and validation for web crawler.""" + +from typing import Optional, List +from urllib.parse import urlparse, urlunparse, parse_qs, urlencode, urljoin +from bs4 import BeautifulSoup + + +class URLNormalizer: + """Normalize and validate URLs for deduplication and domain checking.""" + + # Common tracking parameters to remove + TRACKING_PARAMS = { + 'utm_source', 'utm_medium', 'utm_campaign', 'utm_term', 'utm_content', + 'fbclid', 'gclid', 'msclkid', '_ga', 'mc_cid', 'mc_eid' + } + + def __init__(self, base_domain: str): + """ + Initialize URL normalizer with base domain. + + Args: + base_domain: The domain to use for same-domain checks + """ + parsed = urlparse(base_domain) + self.base_domain = parsed.netloc.lower() # example.com:8000 + self.base_scheme = parsed.scheme or 'https' # https + + def normalize(self, url: str) -> Optional[str]: + """ + Normalize a URL for deduplication. + + Normalization rules: + 1. Convert domain to lowercase + 2. Remove fragments (#section) + 3. Remove default ports (80 for http, 443 for https) + 4. Remove trailing slashes (except for root) + 5. Sort query parameters alphabetically + 6. Remove common tracking parameters + + Args: + url: URL to normalize + + Returns: + Optional[str]: Normalized URL, or None if invalid + """ + try: + parsed = urlparse(url) + + # Validate scheme + if parsed.scheme not in ('http', 'https'): + return None + + # Normalize domain to lowercase + netloc = parsed.netloc.lower() + + # Remove default ports + if ':' in netloc: + host, port = netloc.rsplit(':', 1) + if (parsed.scheme == 'http' and port == '80') or \ + (parsed.scheme == 'https' and port == '443'): + netloc = host + + # Normalize path + path = parsed.path + # Remove trailing slash except for root + if path != '/' and path.endswith('/'): + path = path.rstrip('/') + # Ensure path starts with / + if not path: + path = '/' + + # Process query parameters + query = '' + if parsed.query: + # Parse query parameters + params = parse_qs(parsed.query, keep_blank_values=True) + # Remove tracking parameters + filtered_params = { + k: v for k, v in params.items() + if k not in self.TRACKING_PARAMS + } + # Sort parameters alphabetically + if filtered_params: + sorted_params = sorted(filtered_params.items()) + query = urlencode(sorted_params, doseq=True) + + # Reconstruct URL without fragment + normalized = urlunparse(( + parsed.scheme, + netloc, + path, + parsed.params, + query, + '' # Remove fragment + )) + + return normalized + + except Exception: + return None + + def is_same_domain(self, url: str) -> bool: + """ + Check if URL belongs to the same domain as base_domain. + + Args: + url: URL to check + + Returns: + bool: True if same domain, False otherwise + """ + try: + parsed = urlparse(url) + domain = parsed.netloc.lower() + + # Remove port if present + if ':' in domain: + domain = domain.split(':')[0] + + # Check if domains match + return domain == self.base_domain or domain == self.base_domain.split(':')[0] + + except Exception: + return False + + def extract_links(self, html: str, base_url: str) -> List[str]: + """ + Extract and normalize all links from HTML. + + Args: + html: HTML content + base_url: Base URL for resolving relative links + + Returns: + List[str]: List of normalized absolute URLs + """ + links = [] + + try: + soup = BeautifulSoup(html, 'lxml') + + # Find all anchor tags + for anchor in soup.find_all('a', href=True): + href = anchor['href'] + + # Skip empty hrefs + if not href or href.strip() == '': + continue + + # Skip javascript: and mailto: links + if href.startswith(('javascript:', 'mailto:', 'tel:')): + continue + + normalized_url = None + # Check if href starts with http/https (absolute URL) + if href.startswith(('http://', 'https://')): + if self.is_same_domain(href): + normalized_url = self.normalize(href) + else: + # Convert relative URL to absolute + absolute_url = urljoin(base_url, href) + # Normalize the URL + normalized_url = self.normalize(absolute_url) + + if normalized_url: + links.append(normalized_url) + + except Exception: + pass + + return links diff --git a/api/app/core/rag/crawler/web_crawler.py b/api/app/core/rag/crawler/web_crawler.py new file mode 100644 index 00000000..3afa09b2 --- /dev/null +++ b/api/app/core/rag/crawler/web_crawler.py @@ -0,0 +1,215 @@ +"""Main web crawler orchestrator.""" + +from collections import deque +from datetime import datetime +from typing import Iterator, Optional, List, Set +from urllib.parse import urlparse +import logging + +from app.core.rag.crawler.url_normalizer import URLNormalizer +from app.core.rag.crawler.robots_parser import RobotsParser +from app.core.rag.crawler.rate_limiter import RateLimiter +from app.core.rag.crawler.http_fetcher import HTTPFetcher +from app.core.rag.crawler.content_extractor import ContentExtractor +from app.core.rag.crawler.models import CrawledDocument, CrawlSummary + +logger = logging.getLogger(__name__) + + +class WebCrawler: + """Main orchestrator for web crawling.""" + + def __init__( + self, + entry_url: str, + max_pages: int = 200, + delay_seconds: float = 1.0, + timeout_seconds: int = 10, + user_agent: str = "KnowledgeBaseCrawler/1.0", + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None, + content_extractor: Optional[ContentExtractor] = None + ): + """ + Initialize the web crawler. + + Args: + entry_url: Starting URL for the crawl + max_pages: Maximum number of pages to crawl (default: 200) + delay_seconds: Delay between requests in seconds (default: 1.0) + timeout_seconds: HTTP request timeout (default: 10) + user_agent: User-Agent header string + include_patterns: List of regex patterns for URLs to include + exclude_patterns: List of regex patterns for URLs to exclude + content_extractor: Custom content extractor (optional) + """ + # Validate entry URL + parsed = urlparse(entry_url) + if not parsed.scheme or not parsed.netloc: + raise ValueError(f"Invalid entry URL: {entry_url}") + + self.entry_url = entry_url + self.max_pages = max_pages + self.user_agent = user_agent + + # Extract domain from entry URL + self.domain = parsed.netloc + + # Initialize components + self.url_normalizer = URLNormalizer(entry_url) + self.robots_parser = RobotsParser(user_agent, timeout_seconds) + self.rate_limiter = RateLimiter(delay_seconds) + self.http_fetcher = HTTPFetcher(timeout_seconds, max_retries=3, user_agent=user_agent) + self.content_extractor = content_extractor or ContentExtractor() + + # State management + self.url_queue: deque = deque() + self.visited_urls: Set[str] = set() + self.pages_processed = 0 + + # Statistics + self.stats = { + 'success': 0, + 'errors': 0, + 'skipped': 0, + 'urls_discovered': 0, + 'error_breakdown': {} + } + self.start_time: Optional[datetime] = None + self.end_time: Optional[datetime] = None + + def crawl(self) -> Iterator[CrawledDocument]: + """ + Execute the crawl and yield documents as they are processed. + + Yields: + CrawledDocument: Structured document with extracted content + """ + logger.info(f"Starting crawl from {self.entry_url} (max_pages: {self.max_pages})") + self.start_time = datetime.now() + + # Add entry URL to queue + normalized_entry = self.url_normalizer.normalize(self.entry_url) + if normalized_entry: + self.url_queue.append(normalized_entry) + self.stats['urls_discovered'] += 1 + + # Check robots.txt and update rate limiter if needed + crawl_delay = self.robots_parser.get_crawl_delay(self.entry_url) + if crawl_delay: + self.rate_limiter.set_delay(crawl_delay) + + # Main crawl loop + while self.url_queue and self.pages_processed < self.max_pages: + url = self.url_queue.popleft() + + # Skip if already visited + if url in self.visited_urls: + continue + + # Mark as visited + self.visited_urls.add(url) + + # Check robots.txt permission + if not self.robots_parser.can_fetch(url): + logger.info(f"Skipping {url} (disallowed by robots.txt)") + self.stats['skipped'] += 1 + continue + + # Apply rate limiting + self.rate_limiter.wait() + + # Fetch URL + logger.info(f"Fetching {url} ({self.pages_processed + 1}/{self.max_pages})") + fetch_result = self.http_fetcher.fetch(url) + + # Handle fetch errors + if not fetch_result.success: + self._record_error(fetch_result.error or "Unknown error") + continue + + # Check Content-Type + content_type = fetch_result.headers.get('Content-Type', '').lower() + if not any(substring in content_type for substring in ['text/html', 'application/xhtml+xml']): + logger.warning(f"Skipping {url} (Content-Type: {content_type})") + self.stats['skipped'] += 1 + continue + + # Extract content + try: + extracted = self.content_extractor.extract(fetch_result.content, url) + + # Check if static content + if not extracted.is_static: + logger.warning(f"Skipping {url} (JavaScript-rendered content)") + self.stats['skipped'] += 1 + continue + + # Create document + document = CrawledDocument( + url=url, + title=extracted.title, + content=extracted.text, + content_length=len(extracted.text), + crawl_timestamp=datetime.now(), + http_status=fetch_result.status_code, + metadata={ + 'word_count': extracted.word_count, + 'final_url': fetch_result.final_url + } + ) + + # Update statistics + self.pages_processed += 1 + self.stats['success'] += 1 + + # Extract and queue links + links = self.url_normalizer.extract_links(fetch_result.content, url) + for link in links: + if link not in self.visited_urls and self.url_normalizer.is_same_domain(link): + if link not in self.url_queue: + self.url_queue.append(link) + self.stats['urls_discovered'] += 1 + + # Yield document + yield document + + except Exception as e: + logger.error(f"Error processing {url}: {e}") + self._record_error(f"Processing error: {str(e)}") + continue + + self.end_time = datetime.now() + logger.info(f"Crawl completed. Processed {self.pages_processed} pages.") + + def get_summary(self) -> CrawlSummary: + """ + Get summary statistics after crawl completion. + + Returns: + CrawlSummary: Statistics including success/error/skip counts + """ + if not self.start_time: + self.start_time = datetime.now() + if not self.end_time: + self.end_time = datetime.now() + + duration = (self.end_time - self.start_time).total_seconds() + + return CrawlSummary( + total_pages_processed=self.stats['success'], + total_errors=self.stats['errors'], + total_skipped=self.stats['skipped'], + total_urls_discovered=self.stats['urls_discovered'], + start_time=self.start_time, + end_time=self.end_time, + duration_seconds=duration, + error_breakdown=self.stats['error_breakdown'] + ) + + def _record_error(self, error: str): + """Record an error in statistics.""" + self.stats['errors'] += 1 + error_type = error.split(':')[0] if ':' in error else error + self.stats['error_breakdown'][error_type] = \ + self.stats['error_breakdown'].get(error_type, 0) + 1 diff --git a/api/app/core/rag/integrations/__init__.py b/api/app/core/rag/integrations/__init__.py new file mode 100644 index 00000000..c1c43854 --- /dev/null +++ b/api/app/core/rag/integrations/__init__.py @@ -0,0 +1 @@ +"""Integrations package for external services.""" diff --git a/api/app/core/rag/integrations/feishu/__init__.py b/api/app/core/rag/integrations/feishu/__init__.py new file mode 100644 index 00000000..d989b816 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/__init__.py @@ -0,0 +1 @@ +"""Feishu integration module for document synchronization.""" diff --git a/api/app/core/rag/integrations/feishu/__main__.py b/api/app/core/rag/integrations/feishu/__main__.py new file mode 100644 index 00000000..79d5a48e --- /dev/null +++ b/api/app/core/rag/integrations/feishu/__main__.py @@ -0,0 +1,84 @@ +"""Command-line interface for feishu integration.""" + +import asyncio +import sys +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.feishu.models import FileInfo + + +def main(feishu_app_id: str, # Feishu application ID + feishu_app_secret: str, # Feishu application secret + feishu_folder_token: str, # Feishu Folder Token + save_dir: str, # save file directory + feishu_api_base_url: str = "https://open.feishu.cn/open-apis", # Feishu API base URL + timeout: int = 30, # Request timeout in seconds + max_retries: int = 3, # Maximum number of retries + recursive: bool = True # recursive: Whether to sync subfolders recursively, + ): + """Main entry point for the feishuAPIClient.""" + # Create feishuAPIClient + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret, + api_base_url=feishu_api_base_url, + timeout=timeout, + max_retries=max_retries + ) + + # Get all files from folder + async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str): + async with api_client as client: + if recursive: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + else: + all_files = [] + page_token = None + while True: + files_page, page_token = await client.list_folder_files( + feishu_folder_token, page_token + ) + all_files.extend(files_page) + if not page_token: + break + files = all_files + return files + files = asyncio.run(async_get_files(api_client,feishu_folder_token)) + + # Filter out folders, only sync documents + # documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file", "slides"]] + documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]] + + try: + for doc in documents: + print(f"\n{'=' * 80}") + print(f"token: {doc.token}") + print(f"name: {doc.name}") + print(f"type: {doc.type}") + print(f"created_time: {doc.created_time}") + print(f"modified_time: {doc.modified_time}") + print(f"owner_id: {doc.owner_id}") + print(f"url: {doc.url}") + print(f"{'=' * 80}\n") + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + print(file_path) + + except KeyboardInterrupt: + print("\n\nfeishu integration interrupted by user.") + + except Exception as e: + print(f"\n\nError during feishu integration: {e}") + sys.exit(1) + + +if __name__ == '__main__': + feishu_app_id = "" + feishu_app_secret = "" + feishu_folder_token = "" + save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/" + main(feishu_app_id, feishu_app_secret, feishu_folder_token, save_dir) diff --git a/api/app/core/rag/integrations/feishu/client.py b/api/app/core/rag/integrations/feishu/client.py new file mode 100644 index 00000000..0a3c4ea8 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/client.py @@ -0,0 +1,452 @@ +"""Feishu API client for document operations.""" + +import asyncio +import os +import re +from typing import Optional, Tuple, List +from datetime import datetime, timedelta +import httpx +from cachetools import TTLCache +import urllib.parse + +from app.core.rag.integrations.feishu.exceptions import ( + FeishuAuthError, + FeishuAPIError, + FeishuNotFoundError, + FeishuPermissionError, + FeishuRateLimitError, + FeishuNetworkError, +) +from app.core.rag.integrations.feishu.models import FileInfo +from app.core.rag.integrations.feishu.retry import with_retry + + +class FeishuAPIClient: + """Feishu API client for document synchronization.""" + + def __init__( + self, + app_id: str, + app_secret: str, + api_base_url: str = "https://open.feishu.cn/open-apis", + timeout: int = 30, + max_retries: int = 3 + ): + """ + Initialize Feishu API client. + + Args: + app_id: Feishu application ID + app_secret: Feishu application secret + api_base_url: Feishu API base URL + timeout: Request timeout in seconds + max_retries: Maximum number of retries + """ + self.app_id = app_id + self.app_secret = app_secret + self.api_base_url = api_base_url + self.timeout = timeout + self.max_retries = max_retries + self._http_client: Optional[httpx.AsyncClient] = None + self._token_cache: TTLCache = TTLCache(maxsize=1, ttl=7200 - 300) # 2 hours - 5 minutes + self._token_lock = asyncio.Lock() + + async def __aenter__(self): + """Async context manager entry.""" + self._http_client = httpx.AsyncClient( + base_url=self.api_base_url, + timeout=self.timeout, + headers={"Content-Type": "application/json"} + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._http_client: + await self._http_client.aclose() + + async def get_tenant_access_token(self) -> str: + """ + Get tenant access token with caching. + + Returns: + Access token string + + Raises: + FeishuAuthError: If authentication fails + """ + # Check cache first + cached_token = self._token_cache.get("access_token") + if cached_token: + return cached_token + + # Use lock to prevent concurrent token requests + async with self._token_lock: + # Double-check cache after acquiring lock + cached_token = self._token_cache.get("access_token") + if cached_token: + return cached_token + + # Request new token + try: + if not self._http_client: + raise FeishuAuthError("HTTP client not initialized") + + response = await self._http_client.post( + "/auth/v3/tenant_access_token/internal", + json={ + "app_id": self.app_id, + "app_secret": self.app_secret + } + ) + + data = response.json() + + if data.get("code") != 0: + error_msg = data.get("msg", "Unknown error") + raise FeishuAuthError( + f"Authentication failed: {error_msg}", + error_code=str(data.get("code")), + details=data + ) + + token = data.get("tenant_access_token") + if not token: + raise FeishuAuthError("No access token in response") + + # Cache the token + self._token_cache["access_token"] = token + + return token + + except httpx.HTTPError as e: + raise FeishuAuthError(f"HTTP error during authentication: {str(e)}") + except Exception as e: + if isinstance(e, FeishuAuthError): + raise + raise FeishuAuthError(f"Unexpected error during authentication: {str(e)}") + + @with_retry + async def list_folder_files( + self, + folder_token: str, + page_token: Optional[str] = None + ) -> Tuple[List[FileInfo], Optional[str]]: + """ + Get list of files in a folder with pagination support. + + Args: + folder_token: Folder token + page_token: Page token for pagination + + Returns: + Tuple of (list of FileInfo, next page token) + + Raises: + FeishuAPIError: If API call fails + FeishuNotFoundError: If folder not found + FeishuPermissionError: If permission denied + """ + try: + token = await self.get_tenant_access_token() + + if not self._http_client: + raise FeishuAPIError("HTTP client not initialized") + + # Build request parameters + params = {"page_size": 200, "folder_token": folder_token} + if page_token: + params["page_token"] = page_token + + # Make API request + response = await self._http_client.get( + f"/drive/v1/files", + params=params, + headers={"Authorization": f"Bearer {token}"} + ) + + data = response.json() + # print(f"get files: {data}") + + # Handle errors + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + + if error_code == 404 or error_code == 230005: + raise FeishuNotFoundError( + f"Folder not found: {error_msg}", + error_code=str(error_code), + details=data + ) + elif error_code == 403 or error_code == 230003: + raise FeishuPermissionError( + f"Permission denied: {error_msg}", + error_code=str(error_code), + details=data + ) + else: + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data + ) + + # Parse response + files_data = data.get("data", {}).get("files", []) + next_page_token = data.get("data", {}).get("next_page_token", None) + + # Convert to FileInfo objects + files = [] + for file_data in files_data: + try: + file_info = FileInfo( + token=file_data.get("token", ""), + name=file_data.get("name", ""), + type=file_data.get("type", ""), + created_time=datetime.fromtimestamp(int(file_data.get("created_time", 0))), + modified_time=datetime.fromtimestamp(int(file_data.get("modified_time", 0))), + owner_id=file_data.get("owner_id", ""), + url=file_data.get("url", "") + ) + files.append(file_info) + except (ValueError, TypeError) as e: + # Skip invalid file entries + continue + + return files, next_page_token + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)): + raise + raise FeishuAPIError(f"Unexpected error: {str(e)}") + + async def list_all_folder_files( + self, + folder_token: str, + recursive: bool = True + ) -> List[FileInfo]: + """ + Get all files in a folder, handling pagination automatically. + + Args: + folder_token: Folder token + recursive: Whether to recursively get files from subfolders + + Returns: + List of all FileInfo objects + + Raises: + FeishuAPIError: If API call fails + """ + all_files = [] + page_token = None + + # Get all files with pagination + while True: + files, page_token = await self.list_folder_files(folder_token, page_token) + all_files.extend(files) + + if not page_token: + break + + # Recursively get files from subfolders if requested + if recursive: + subfolders = [f for f in all_files if f.type == "folder"] + for subfolder in subfolders: + try: + subfolder_files = await self.list_all_folder_files( + subfolder.token, + recursive=True + ) + all_files.extend(subfolder_files) + except Exception: + # Continue with other folders if one fails + continue + + return all_files + + @with_retry + async def download_document( + self, + document: FileInfo, + save_dir: str + ) -> str: + """ + download document content. + + Args: + document: Document FileInfo + save_dir: save dir + + Returns: + file_full_path + + Raises: + FeishuAPIError: If API call fails + FeishuNotFoundError: If document not found + FeishuPermissionError: If permission denied + """ + try: + token = await self.get_tenant_access_token() + + if not self._http_client: + raise FeishuAPIError("HTTP client not initialized") + + # Different API endpoints for different document types + if document.type == "doc" or document.type == "docx" or document.type == "sheet" or document.type == "bitable": + return await self._export_file(document, token, save_dir) + elif document.type == "file" or document.type == "slides": + return await self._download_file(document, token, save_dir) + else: + raise FeishuAPIError(f"Unsupported document type: {document.type}") + + except Exception as e: + if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)): + raise + raise FeishuAPIError(f"Unexpected error: {str(e)}") + + async def _export_file(self, document: FileInfo, access_token: str, save_dir: str) -> str: + """export file for feishu online file type.""" + try: + # 1.创建导出任务 + file_extension = "pdf" + match document.type: + case "doc": + file_extension = "doc" + case "docx": + file_extension = "docx" + case "sheet": + file_extension = "xlsx" + case "bitable": + file_extension = "xlsx" + case _: + file_extension = "pdf" + response = await self._http_client.post( + "/drive/v1/export_tasks", + json={ + "file_extension": file_extension, + "token": document.token, + "type": document.type + }, + headers={"Authorization": f"Bearer {access_token}"} + ) + data = response.json() + print(f"1.创建导出任务: {data}") + + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data + ) + + ticket = data.get("data", {}).get("ticket", None) + if not ticket: + raise FeishuAuthError("No ticket in response") + + # 2.轮序查询导出任务结果 + max_retries = 10 # 最大轮询次数 + poll_interval = 2 # 每次轮询间隔时间(秒) + file_token = None + for attempt in range(max_retries): + # 查询导出任务 + response = await self._http_client.get( + f"/drive/v1/export_tasks/{ticket}", + params={"token": document.token}, + headers={"Authorization": f"Bearer {access_token}"} + ) + data = response.json() + print(f"2. 尝试查询导出任务结果 (第{attempt + 1}次): {data}") + + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data, + ) + + # 检查导出任务结果 + file_token = data.get("data", {}).get("result", {}).get("file_token", None) + if file_token: + # 如果导出任务成功生成 file_token,则退出轮询 + break + + # 如果结果还没准备好,等待一段时间再进行下一次轮询 + await asyncio.sleep(poll_interval) + + if not file_token: + raise FeishuAPIError("Export task did not complete within the allowed time") + + # 3.下载导出任务 + response = await self._http_client.get( + f"/drive/v1/export_tasks/file/{file_token}/download", + headers={"Authorization": f"Bearer {access_token}"} + ) + response.raise_for_status() + print(f'3.下载导出任务: {response.headers.get("Content-Disposition")}') + + file_full_path = os.path.join(save_dir, document.name + "." + file_extension) + if os.path.exists(file_full_path): + os.remove(file_full_path) # Delete a single file + with open(file_full_path, "wb") as file: + file.write(response.content) + + return file_full_path + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + raise FeishuAPIError(f"Unexpected error during file download: {str(e)}") + + async def _download_file(self, document: FileInfo, access_token: str, save_dir: str) -> str: + """download file for file type.""" + try: + response = await self._http_client.get( + f"/drive/v1/files/{document.token}/download", + headers={"Authorization": f"Bearer {access_token}"} + ) + response.raise_for_status() + + filename_header = response.headers.get("Content-Disposition") + + # 最终的文件名(初始化为 None) + filename = None + if filename_header: + # 优先解析 filename* 格式 + match = re.search(r"filename\*=([^']*)''([^;]+)", filename_header) + if match: + # 使用 `filename*` 提取(已编码) + encoding = match.group(1) # 编码部分(如 UTF-8) + encoded_filename = match.group(2) # 文件名部分 + filename = urllib.parse.unquote(encoded_filename) # 解码 URL 编码的文件名 + + # 如果 `filename*` 不存在,回退到解析 `filename` + if not filename: + match = re.search(r'filename="([^"]+)"', filename_header) + if match: + filename = match.group(1) + # 如果文件名仍为 None,则使用默认文件名 + if not filename: + filename = f"{document.name}.pdf" + # 确保文件名合法,替换非法字符 + filename = re.sub(r'[\/:*?"<>|]', '_', filename) + + file_full_path = os.path.join(save_dir, filename) + if os.path.exists(file_full_path): + os.remove(file_full_path) # Delete a single file + with open(file_full_path, "wb") as file: + file.write(response.content) + + return file_full_path + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + raise FeishuAPIError(f"Unexpected error during file download: {str(e)}") diff --git a/api/app/core/rag/integrations/feishu/exceptions.py b/api/app/core/rag/integrations/feishu/exceptions.py new file mode 100644 index 00000000..26e42a07 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/exceptions.py @@ -0,0 +1,46 @@ +"""Exception classes for Feishu integration.""" + + +class FeishuError(Exception): + """Base exception for all Feishu-related errors.""" + + def __init__(self, message: str, error_code: str = None, details: dict = None): + super().__init__(message) + self.message = message + self.error_code = error_code + self.details = details or {} + + +class FeishuAuthError(FeishuError): + """Authentication error with Feishu API.""" + pass + + +class FeishuAPIError(FeishuError): + """General API error from Feishu.""" + pass + + +class FeishuNotFoundError(FeishuError): + """Resource not found error (404).""" + pass + + +class FeishuPermissionError(FeishuError): + """Permission denied error (403).""" + pass + + +class FeishuRateLimitError(FeishuError): + """Rate limit exceeded error (429).""" + pass + + +class FeishuNetworkError(FeishuError): + """Network-related error (timeout, connection failure).""" + pass + + +class FeishuDataError(FeishuError): + """Data parsing or validation error.""" + pass diff --git a/api/app/core/rag/integrations/feishu/models.py b/api/app/core/rag/integrations/feishu/models.py new file mode 100644 index 00000000..b194afc1 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/models.py @@ -0,0 +1,17 @@ +"""Data models for Feishu integration.""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, Any, List, Optional + + +@dataclass +class FileInfo: + """File information from Feishu.""" + token: str + name: str + type: str # doc/docx/sheet/bitable/file/slides/folder + created_time: datetime + modified_time: datetime + owner_id: str + url: str diff --git a/api/app/core/rag/integrations/feishu/retry.py b/api/app/core/rag/integrations/feishu/retry.py new file mode 100644 index 00000000..c1d9aff1 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/retry.py @@ -0,0 +1,137 @@ +"""Retry strategy for Feishu API calls.""" + +import asyncio +import functools +from typing import Callable, TypeVar +import httpx + +from app.core.rag.integrations.feishu.exceptions import ( + FeishuAuthError, + FeishuPermissionError, + FeishuNotFoundError, + FeishuRateLimitError, + FeishuNetworkError, + FeishuDataError, + FeishuAPIError, +) + +T = TypeVar('T') + + +class RetryStrategy: + """Retry strategy for API calls.""" + + # Retryable error types + RETRYABLE_ERRORS = ( + FeishuNetworkError, + FeishuRateLimitError, + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError, + ) + + # Non-retryable error types + NON_RETRYABLE_ERRORS = ( + FeishuAuthError, + FeishuPermissionError, + FeishuNotFoundError, + FeishuDataError, + ) + + # Retry configuration + MAX_RETRIES = 3 + BACKOFF_DELAYS = [1, 2, 4] # seconds + + @classmethod + def is_retryable(cls, error: Exception) -> bool: + """Check if an error is retryable.""" + # Check for specific retryable errors + if isinstance(error, cls.RETRYABLE_ERRORS): + return True + + # Check for non-retryable errors + if isinstance(error, cls.NON_RETRYABLE_ERRORS): + return False + + # Check for HTTP status codes + if isinstance(error, httpx.HTTPStatusError): + status_code = error.response.status_code + # Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway) + if status_code in [429, 502, 503]: + return True + # Don't retry on 4xx errors (except 429) + if 400 <= status_code < 500: + return False + # Retry on 5xx errors + if 500 <= status_code < 600: + return True + + # Check for FeishuAPIError with specific codes + if isinstance(error, FeishuAPIError): + if error.error_code: + # Rate limit error codes + if error.error_code in ["99991400", "99991401"]: + return True + + return False + + @classmethod + async def execute_with_retry( + cls, + func: Callable[..., T], + *args, + **kwargs + ) -> T: + """ + Execute a function with retry logic. + + Args: + func: Async function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Function result + + Raises: + Exception: The last exception if all retries fail + """ + last_exception = None + + for attempt in range(cls.MAX_RETRIES + 1): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + + # Don't retry if not retryable + if not cls.is_retryable(e): + raise + + # Don't retry if this was the last attempt + if attempt >= cls.MAX_RETRIES: + raise + + # Wait before retrying + delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1] + await asyncio.sleep(delay) + + # Should not reach here, but raise last exception if we do + if last_exception: + raise last_exception + + +def with_retry(func: Callable[..., T]) -> Callable[..., T]: + """ + Decorator to add retry logic to async functions. + + Usage: + @with_retry + async def my_api_call(): + ... + """ + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await RetryStrategy.execute_with_retry(func, *args, **kwargs) + + return wrapper diff --git a/api/app/core/rag/integrations/yuque/__init__.py b/api/app/core/rag/integrations/yuque/__init__.py new file mode 100644 index 00000000..dc4f2a17 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/__init__.py @@ -0,0 +1 @@ +"""Yuque integration module for document synchronization.""" diff --git a/api/app/core/rag/integrations/yuque/__main__.py b/api/app/core/rag/integrations/yuque/__main__.py new file mode 100644 index 00000000..3b87bbcd --- /dev/null +++ b/api/app/core/rag/integrations/yuque/__main__.py @@ -0,0 +1,77 @@ +"""Main entry point for Yuque integration testing.""" + +import asyncio +import sys +from app.core.rag.integrations.yuque.client import YuqueAPIClient +from app.core.rag.integrations.yuque.models import YuqueDocInfo + + +def main(yuque_user_id: str, # yuque User ID + yuque_token: str, # yuque Token + save_dir: str, # save file directory + ): + """Main entry point for the YuqueAPIClient.""" + # Create feishuAPIClient + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + + # Get all files from all repos + async def async_get_files(api_client: YuqueAPIClient): + async with api_client as client: + print("\n=== Fetching repositories ===") + repos = await client.get_user_repos() + print(f"Found {len(repos)} repositories:") + all_files = [] + for repo in repos: + # Get documents from repository + print(f"\n=== Fetching documents from '{repo.name}' ===") + docs = await client.get_repo_docs(repo.id) + all_files.extend(docs) + return all_files + files = asyncio.run(async_get_files(api_client)) + + try: + for doc in files: + print(f"\n{'=' * 80}") + print(f"id: {doc.id}") + print(f"type: {doc.type}") + print(f"slug: {doc.slug}") + print(f"title: {doc.title}") + print(f"book_id: {doc.book_id}") + # print(f"format: {doc.format}") + # print(f"body: {doc.body}") + # print(f"body_draft: {doc.body_draft}") + # print(f"body_html: {doc.body_html}") + print(f"public: {doc.public}") + print(f"status: {doc.status}") + print(f"created_at: {doc.created_at}") + print(f"updated_at: {doc.updated_at}") + print(f"published_at: {doc.published_at}") + print(f"word_count: {doc.word_count}") + print(f"cover: {doc.cover}") + print(f"description: {doc.description}") + print(f"{'=' * 80}\n") + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + print(file_path) + + except KeyboardInterrupt: + print("\n\nfeishu integration interrupted by user.") + + except Exception as e: + print(f"\n\nError during feishu integration: {e}") + sys.exit(1) + + +if __name__ == "__main__": + yuque_user_id = "" + yuque_token = "" + save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/" + main(yuque_user_id, yuque_token, save_dir) diff --git a/api/app/core/rag/integrations/yuque/client.py b/api/app/core/rag/integrations/yuque/client.py new file mode 100644 index 00000000..444d9d31 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/client.py @@ -0,0 +1,544 @@ +"""Yuque API client for document operations.""" + +import os +import re +from typing import Optional, List +from datetime import datetime, timedelta +import httpx +import urllib.parse +import json +from openpyxl import Workbook +from openpyxl.styles import Font, Alignment, PatternFill +from openpyxl.utils import get_column_letter +import zlib + +from app.core.rag.integrations.yuque.exceptions import ( + YuqueAuthError, + YuqueAPIError, + YuqueNotFoundError, + YuquePermissionError, + YuqueRateLimitError, + YuqueNetworkError, +) +from app.core.rag.integrations.yuque.models import YuqueDocInfo, YuqueRepoInfo +from app.core.rag.integrations.yuque.retry import with_retry + + +class YuqueAPIClient: + """Yuque API client for document synchronization.""" + + def __init__( + self, + user_id: str, + token: str, + api_base_url: str = "https://www.yuque.com/api/v2", + timeout: int = 30, + max_retries: int = 3 + ): + """ + Initialize Yuque API client. + + Args: + user_id: Yuque user ID or login name + token: Yuque personal access token + api_base_url: Yuque API base URL + timeout: Request timeout in seconds + max_retries: Maximum number of retries + """ + self.user_id = user_id + self.token = token + self.api_base_url = api_base_url + self.timeout = timeout + self.max_retries = max_retries + self._http_client: Optional[httpx.AsyncClient] = None + + async def __aenter__(self): + """Async context manager entry.""" + self._http_client = httpx.AsyncClient( + base_url=self.api_base_url, + timeout=self.timeout, + headers={ + "Content-Type": "application/json", + "X-Auth-Token": self.token, + "User-Agent": "Yuque-Integration-Client" + } + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._http_client: + await self._http_client.aclose() + + def _handle_api_error(self, response: httpx.Response): + """Handle API error responses.""" + try: + data = response.json() + except Exception: + data = {} + + status_code = response.status_code + error_msg = data.get("message", "Unknown error") + + # Rate limit errors + if status_code == 429: + raise YuqueRateLimitError( + f"Rate limit exceeded: {error_msg}", + error_code=str(status_code), + details=data + ) + # Not found errors + elif status_code == 404: + raise YuqueNotFoundError( + f"Resource not found: {error_msg}", + error_code=str(status_code), + details=data + ) + # Permission errors + elif status_code == 403: + raise YuquePermissionError( + f"Permission denied: {error_msg}", + error_code=str(status_code), + details=data + ) + # Authentication errors + elif status_code == 401: + raise YuqueAuthError( + f"Authentication failed: {error_msg}", + error_code=str(status_code), + details=data + ) + # Generic API error + else: + raise YuqueAPIError( + f"API error: {error_msg}", + error_code=str(status_code), + details=data + ) + + @with_retry + async def get_user_repos(self) -> List[YuqueRepoInfo]: + """ + Get all repositories (知识库) for the user. + + Returns: + List of YuqueRepoInfo objects + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get(f"/users/{self.user_id}/repos") + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + repos_data = data.get("data", []) + + repos = [] + for repo_data in repos_data: + try: + repo = YuqueRepoInfo( + id=repo_data.get("id"), + type=repo_data.get("type", ""), + name=repo_data.get("name", ""), + namespace=repo_data.get("namespace", ""), + slug=repo_data.get("slug", ""), + description=repo_data.get("description"), + public=repo_data.get("public", 0), + items_count=repo_data.get("items_count", 0), + created_at=datetime.fromisoformat(repo_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(repo_data.get("updated_at", "").replace("Z", "+00:00")) + ) + repos.append(repo) + except (ValueError, TypeError, KeyError) as e: + # Skip invalid repo entries + continue + + return repos + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueAuthError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + @with_retry + async def get_repo_docs(self, book_id: int) -> List[YuqueDocInfo]: + """ + Get all documents in a repository. + + Args: + book_id: repository id + + Returns: + List of YuqueDocInfo objects (without body content) + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get(f"/repos/{book_id}/docs") + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + docs_data = data.get("data", []) + + docs = [] + for doc_data in docs_data: + try: + published_at = doc_data.get("published_at") + doc = YuqueDocInfo( + id=doc_data.get("id"), + type=doc_data.get("type", ""), + slug=doc_data.get("slug", ""), + title=doc_data.get("title", ""), + book_id=doc_data.get("book_id"), + format=doc_data.get("format", "markdown"), + body=None, # Body not included in list API + body_draft=None, + body_html=None, + public=doc_data.get("public", 0), + status=doc_data.get("status", 0), + created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")), + published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None, + word_count=doc_data.get("word_count", 0), + cover=doc_data.get("cover"), + description=doc_data.get("description") + ) + docs.append(doc) + except (ValueError, TypeError, KeyError) as e: + # Skip invalid doc entries + continue + + return docs + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueNotFoundError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + @with_retry + async def get_doc_detail(self, id: int) -> YuqueDocInfo: + """ + Get detailed document information including content. + + Args: + id: document ID + + Returns: + YuqueDocInfo object with full content + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get( + f"/repos/docs/{id}", + params={"raw": 1} # Get raw markdown content + ) + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + doc_data = data.get("data", {}) + + published_at = doc_data.get("published_at") + doc = YuqueDocInfo( + id=doc_data.get("id"), + type=doc_data.get("type", ""), + slug=doc_data.get("slug", ""), + title=doc_data.get("title", ""), + book_id=doc_data.get("book_id"), + format=doc_data.get("format", "markdown"), + body=doc_data.get("body", ""), + body_draft=doc_data.get("body_draft"), + body_html=doc_data.get("body_html"), + public=doc_data.get("public", 0), + status=doc_data.get("status", 0), + created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")), + published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None, + word_count=doc_data.get("word_count", 0), + cover=doc_data.get("cover"), + description=doc_data.get("description") + ) + + return doc + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueNotFoundError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + async def download_document( + self, + doc: YuqueDocInfo, + save_dir: str + ) -> str: + """ + Download document content to local file. + + Args: + doc: Document info (can be without body) + save_dir: Directory to save the file + + Returns: + Full path to the saved file + + Raises: + YuqueAPIError: If download fails + """ + try: + # Get full document content if not already loaded + if not doc.body: + doc = await self.get_doc_detail(doc.id) + + # Sanitize filename + filename = re.sub(r'[\/:*?"<>|]', '_', doc.title) + + # Determine file extension based on format + content = doc.body or "" + if doc.format == "markdown": + file_extension = "md" + elif doc.format == "lake": + file_extension = "md" # Save lake format as markdown + elif doc.format == "html": + file_extension = "html" + elif doc.format == "lakesheet": + file_extension = "xlsx" + + body_data = json.loads(doc.body) + sheet_data = body_data.get("sheet", "") + try: + sheet_raw = zlib.decompress(bytes(sheet_data, 'latin-1')) + except Exception as e: + print(f"Error decompressing sheet data: {e}") + raise ValueError("Invalid or unsupported sheet data format.") + try: + sheet_text = sheet_raw.decode("utf-8") # 假设是 UTF-8 编码 + except UnicodeDecodeError: + sheet_text = sheet_raw.decode("gbk") # 如果 UTF-8 解码失败,尝试 GBK + + file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}") + self.generate_excel_from_sheet(sheet_text, file_full_path) + return file_full_path + else: + file_extension = "txt" + + file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}") + # Remove existing file if it exists + if os.path.exists(file_full_path): + os.remove(file_full_path) + + # Write content to file + with open(file_full_path, "w", encoding="utf-8") as file: + file.write(content) + + return file_full_path + + except Exception as e: + if isinstance(e, YuqueAPIError): + raise + raise YuqueAPIError(f"Unexpected error during file download: {str(e)}") + + def generate_excel_from_sheet(self, sheet_text: str, save_path: str): + """ + 将解析的 sheet_text 数据转换为 Excel 文件。 + + Args: + sheet_text (str): JSON 格式的 sheet 数据。 + save_path (str): Excel 文件的保存路径。 + """ + try: + # 解析 JSON 数据 + sheets = json.loads(sheet_text) + + if not isinstance(sheets, list): + raise ValueError("sheet_text must be a JSON array of sheets.") + + # 创建一个新的 Excel 工作簿 + workbook = Workbook() + + for sheet_index, sheet_data in enumerate(sheets): + sheet_name = sheet_data.get("name", f"Sheet{sheet_index + 1}") + row_data = sheet_data.get("data", {}) + merge_cells = sheet_data.get("mergeCells", {}) + rows_styles = sheet_data.get("rows", []) + cols_styles = sheet_data.get("columns", []) + + # 创建 Sheet + if sheet_index == 0: + worksheet = workbook.active + worksheet.title = sheet_name + else: + worksheet = workbook.create_sheet(title=sheet_name) + + # 设置列宽 + for col_index, col_style in enumerate(cols_styles): + col_width = col_style.get("size", 82.125) / 7.0 + col_letter = get_column_letter(col_index + 1) # Excel 列从1开始 + worksheet.column_dimensions[col_letter].width = col_width + + # 设置行高 + for row_index, row_style in enumerate(rows_styles): + row_height = row_style.get("size", 24) / 1.5 + worksheet.row_dimensions[row_index + 1].height = row_height + + # 写入单元格数据 + for r_index, row in row_data.items(): + for c_index, cell in row.items(): + # 防御性检查:确保行号和列号都是有效的整数 + try: + row_number = int(r_index) + 1 + col_number = int(c_index) + 1 + except ValueError: + print(f"Invalid row or column index: r_index={r_index}, c_index={c_index}") + continue + + if col_number < 1 or col_number > 16384: # Excel 最大列数支持到 XFD,即 16384 列 + print(f"Invalid column index: c_index={c_index}") + continue + + cell_obj = worksheet.cell(row=row_number, column=col_number) + + # 处理值和公式 + cell_value = cell.get("value", "") + if isinstance(cell_value, dict): + # 检查是否为公式 + if cell_value.get("class") == "formula" and "formula" in cell_value: + cell_obj.value = f"={cell_value['formula']}" # 写入公式 + else: + cell_obj.value = cell_value.get("value", "") # 写入值 + else: + cell_obj.value = cell_value # 写入简单值 + + # 应用样式 + style = cell.get("style", {}) + self.apply_cell_style(cell_obj, style) + + # 合并单元格 + for key, merge_def in merge_cells.items(): + start_row = merge_def["row"] + 1 + start_col = merge_def["col"] + 1 + end_row = start_row + merge_def["rowCount"] - 1 + end_col = start_col + merge_def["colCount"] - 1 + worksheet.merge_cells( + start_row=start_row, start_column=start_col, end_row=end_row, end_column=end_col + ) + + # 保存 Excel 文件 + workbook.save(save_path) + print(f"Excel file successfully saved to: {save_path}") + + except Exception as e: + print(f"Error generating Excel file: {e}") + + + def apply_cell_style(self, cell, style): + """ + 应用单元格样式,包括字体、对齐、背景颜色等。 + + Args: + cell: openpyxl 的单元格对象。 + style: 字典格式的样式信息。 + """ + # 定义允许的对齐值 + allowed_horizontal_alignments = {"general", "left", "center", "centerContinuous", "right", "fill", "justify", + "distributed"} + allowed_vertical_alignments = {"top", "center", "justify", "distributed", "bottom"} + + # 处理字体 + font = Font( + size=style.get("fontSize", 11), + bold=style.get("fontWeight", False), + italic=style.get("fontStyle", "normal") == "italic", + underline="single" if style.get("underline", False) else None, + color=self.convert_color_to_hex(style.get("color", "#000000")), + ) + cell.font = font + + # 处理对齐方式 + horizontal_alignment = style.get("hAlign", "left") + vertical_alignment = style.get("vAlign", "top") + + # 如果对齐值无效,则使用默认值 + if horizontal_alignment not in allowed_horizontal_alignments: + horizontal_alignment = "left" + if vertical_alignment not in allowed_vertical_alignments: + vertical_alignment = "top" + + alignment = Alignment( + horizontal=horizontal_alignment, + vertical=vertical_alignment, + wrap_text=style.get("overflow") == "wrap", + ) + cell.alignment = alignment + + # 处理背景颜色 + background_color = style.get("backColor", None) + if background_color: + hex_color = self.convert_color_to_hex(background_color) + if hex_color: + cell.fill = PatternFill( + start_color=hex_color, + end_color=hex_color, + fill_type="solid" + ) + + def convert_color_to_hex(self, color): + """ + 将颜色从 `rgba(...)` 或 `rgb(...)` 转换为 aRGB 十六进制格式。 + + Args: + color (str): 原始颜色字符串,如 `rgba(255,255,0,1.00)` 或 `#FFFFFF`。 + + Returns: + str: 转换后的颜色字符串(符合 openpyxl 的格式),例如 `FFFF0000`。 + """ + try: + if not color: + return None + + # 如果是 `#RRGGBB` 或 `#AARRGGBB` 格式,直接返回 + if color.startswith("#"): + return color.lstrip("#").upper() + + # 如果是 `rgb(...)` 格式,例如 `rgb(255,255,0)` + if color.startswith("rgb("): + rgb_values = color.strip("rgb()").split(",") + red, green, blue = [int(v) for v in rgb_values] + return f"FF{red:02X}{green:02X}{blue:02X}" + + # 如果是 `rgba(...)` 格式,例如 `rgba(255,255,0,1.00)` + if color.startswith("rgba("): + rgba_values = color.strip("rgba()").split(",") + red, green, blue = [int(v) for v in rgba_values[:3]] + alpha = float(rgba_values[3]) + alpha_hex = int(alpha * 255) # 将透明度转换为 [00, FF] + return f"{alpha_hex:02X}{red:02X}{green:02X}{blue:02X}" + + # 返回默认颜色 + return None + except Exception as e: + print(f"Error parsing color '{color}': {e}") + return None diff --git a/api/app/core/rag/integrations/yuque/exceptions.py b/api/app/core/rag/integrations/yuque/exceptions.py new file mode 100644 index 00000000..e862323c --- /dev/null +++ b/api/app/core/rag/integrations/yuque/exceptions.py @@ -0,0 +1,46 @@ +"""Exception classes for Yuque integration.""" + + +class YuqueError(Exception): + """Base exception for all Yuque-related errors.""" + + def __init__(self, message: str, error_code: str = None, details: dict = None): + super().__init__(message) + self.message = message + self.error_code = error_code + self.details = details or {} + + +class YuqueAuthError(YuqueError): + """Authentication error with Yuque API.""" + pass + + +class YuqueAPIError(YuqueError): + """General API error from Yuque.""" + pass + + +class YuqueNotFoundError(YuqueError): + """Resource not found error (404).""" + pass + + +class YuquePermissionError(YuqueError): + """Permission denied error (403).""" + pass + + +class YuqueRateLimitError(YuqueError): + """Rate limit exceeded error (429).""" + pass + + +class YuqueNetworkError(YuqueError): + """Network-related error (timeout, connection failure).""" + pass + + +class YuqueDataError(YuqueError): + """Data parsing or validation error.""" + pass diff --git a/api/app/core/rag/integrations/yuque/models.py b/api/app/core/rag/integrations/yuque/models.py new file mode 100644 index 00000000..6230aa69 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/models.py @@ -0,0 +1,42 @@ +"""Data models for Yuque integration.""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + + +@dataclass +class YuqueRepoInfo: + """Repository (知识库) information from Yuque.""" + id: int # 知识库 ID + type: str # 类型 (Book:文档, Design:图集, Sheet:表格, Resource:资源) + name: str # 名称 + namespace: str # 完整路径: user/repo format + slug: str # 路径 + description: Optional[str] # 简介 + public: int # 公开性 (0:私密, 1:公开, 2:企业内公开) + items_count: int # 文档数量 + created_at: datetime # 创建时间 + updated_at: datetime # 更新时间 + + +@dataclass +class YuqueDocInfo: + """Document information from Yuque.""" + id: int # 文档 ID + type: str # 文档类型 (Doc:普通文档, Sheet:表格, Thread:话题, Board:图集, Table:数据表) + slug: str # 路径 + title: str # 标题 + book_id: int # 归属知识库 ID + format: str # 内容格式 (markdown:Markdown 格式, lake:语雀 Lake 格式, html:HTML 标准格式, lakesheet:语雀表格) + body: Optional[str] # 正文原始内容 + body_draft: Optional[str] # 正文草稿内容 + body_html: Optional[str] # 正文 HTML 标准格式内容 + public: int # 公开性 (0:私密, 1:公开, 2:企业内公开) + status: int # 状态 (0:草稿, 1:发布) + created_at: datetime # 创建时间 + updated_at: datetime # 更新时间 + published_at: Optional[datetime] # 发布时间 + word_count: int # 内容字数 + cover: Optional[str] # 封面 + description: Optional[str] # 摘要 diff --git a/api/app/core/rag/integrations/yuque/retry.py b/api/app/core/rag/integrations/yuque/retry.py new file mode 100644 index 00000000..a68d6b47 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/retry.py @@ -0,0 +1,134 @@ +"""Retry strategy for Yuque API calls.""" + +import asyncio +import functools +from typing import Callable, TypeVar +import httpx + +from app.core.rag.integrations.yuque.exceptions import ( + YuqueAuthError, + YuquePermissionError, + YuqueNotFoundError, + YuqueRateLimitError, + YuqueNetworkError, + YuqueDataError, + YuqueAPIError, +) + +T = TypeVar('T') + + +class RetryStrategy: + """Retry strategy for API calls.""" + + # Retryable error types + RETRYABLE_ERRORS = ( + YuqueNetworkError, + YuqueRateLimitError, + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError, + ) + + # Non-retryable error types + NON_RETRYABLE_ERRORS = ( + YuqueAuthError, + YuquePermissionError, + YuqueNotFoundError, + YuqueDataError, + ) + + # Retry configuration + MAX_RETRIES = 3 + BACKOFF_DELAYS = [1, 2, 4] # seconds + + @classmethod + def is_retryable(cls, error: Exception) -> bool: + """Check if an error is retryable.""" + # Check for specific retryable errors + if isinstance(error, cls.RETRYABLE_ERRORS): + return True + + # Check for non-retryable errors + if isinstance(error, cls.NON_RETRYABLE_ERRORS): + return False + + # Check for HTTP status codes + if isinstance(error, httpx.HTTPStatusError): + status_code = error.response.status_code + # Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway) + if status_code in [429, 502, 503]: + return True + # Don't retry on 4xx errors (except 429) + if 400 <= status_code < 500: + return False + # Retry on 5xx errors + if 500 <= status_code < 600: + return True + + # Check for YuqueRateLimitError + if isinstance(error, YuqueRateLimitError): + return True + + return False + + @classmethod + async def execute_with_retry( + cls, + func: Callable[..., T], + *args, + **kwargs + ) -> T: + """ + Execute a function with retry logic. + + Args: + func: Async function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Function result + + Raises: + Exception: The last exception if all retries fail + """ + last_exception = None + + for attempt in range(cls.MAX_RETRIES + 1): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + + # Don't retry if not retryable + if not cls.is_retryable(e): + raise + + # Don't retry if this was the last attempt + if attempt >= cls.MAX_RETRIES: + raise + + # Wait before retrying + delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1] + await asyncio.sleep(delay) + + # Should not reach here, but raise last exception if we do + if last_exception: + raise last_exception + + +def with_retry(func: Callable[..., T]) -> Callable[..., T]: + """ + Decorator to add retry logic to async functions. + + Usage: + @with_retry + async def my_api_call(): + ... + """ + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await RetryStrategy.execute_with_retry(func, *args, **kwargs) + + return wrapper diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index 1f696c98..65fbd9cb 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD from app.core.rag.llm.chat_model import Base from app.core.rag.llm.embedding_model import OpenAIEmbed +import logging +logger = logging.getLogger(__name__) def knowledge_retrieval( query: str, @@ -62,7 +64,15 @@ def knowledge_retrieval( merge_strategy = config.get("merge_strategy", "weight") reranker_id = config.get("reranker_id") reranker_top_k = config.get("reranker_top_k", 1024) - use_graph = config.get("use_graph", "false").lower() == "true" + # use_graph = config.get("use_graph", "false").lower() == "true" + + use_graph_value = config.get("use_graph", False) + if isinstance(use_graph_value, bool): + use_graph = use_graph_value + elif isinstance(use_graph_value, str): + use_graph = use_graph_value.lower() in ("true", "1", "yes") + else: + use_graph = False file_names_filter = [] if user_ids: @@ -159,13 +169,29 @@ def knowledge_retrieval( # Use the specified reranker for re-ranking if reranker_id: - return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) - # use graph + try: + return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) + except Exception as rerank_error: + # If reranker fails, log warning and continue with original results + logger.warning( + "Reranker failed, falling back to original results", + extra={ + "reranker_id": reranker_id, + "query": query, + "doc_count": len(all_results), + "error": str(rerank_error), + }, + ) + if use_graph: - from app.core.rag.common.settings import kg_retriever - doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) - if doc: - all_results.insert(0, doc) + try: + from app.core.rag.common.settings import kg_retriever + doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) + if doc: + all_results.insert(0, doc) + except Exception as graph_error: + print(f"Failed to retrieve from knowledge graph: {str(graph_error)}") + return all_results except Exception as e: diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 475c54fe..31acaafc 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -25,6 +25,18 @@ class ParameterExtractorNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config: ParameterExtractorNodeConfig | None = None + self.response_metadata = {} + + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: + if self.response_metadata: + usage = self.response_metadata.get('token_usage') + if usage: + return { + "prompt_tokens": usage.get('prompt_tokens', 0), + "completion_tokens": usage.get('completion_tokens', 0), + "total_tokens": usage.get('total_tokens', 0) + } + return None def _output_types(self) -> dict[str, VariableType]: outputs = {} @@ -180,6 +192,7 @@ class ParameterExtractorNode(BaseNode): ]) model_resp = await llm.ainvoke(messages) + self.response_metadata = model_resp.response_metadata result = json_repair.repair_json(model_resp.content, return_objects=True) logger.info(f"node: {self.node_id} get params:{result}") diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index d7496f12..38662b64 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -25,6 +25,18 @@ class QuestionClassifierNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: QuestionClassifierNodeConfig | None = None self.category_to_case_map = {} + self.response_metadata = {} + + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: + if self.response_metadata: + usage = self.response_metadata.get('token_usage') + if usage: + return { + "prompt_tokens": usage.get('prompt_tokens', 0), + "completion_tokens": usage.get('completion_tokens', 0), + "total_tokens": usage.get('total_tokens', 0) + } + return None def _output_types(self) -> dict[str, VariableType]: return { @@ -120,6 +132,7 @@ class QuestionClassifierNode(BaseNode): response = await llm.ainvoke(messages) result = response.content.strip() + self.response_metadata = response.response_metadata if result in category_names: category = result diff --git a/api/app/models/file_model.py b/api/app/models/file_model.py index 842e3dc8..44a7d613 100644 --- a/api/app/models/file_model.py +++ b/api/app/models/file_model.py @@ -14,4 +14,5 @@ class File(Base): file_name = Column(String, index=True, nullable=False, comment="file name or folder name,default folder name is /") file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf") file_size = Column(Integer, default=0, comment="file size(byte)") + file_url = Column(String, index=True, nullable=True, comment="file comes from a website url") created_at = Column(DateTime, default=datetime.datetime.now) \ No newline at end of file diff --git a/api/app/models/knowledge_model.py b/api/app/models/knowledge_model.py index 8f0909d3..fbebe1b4 100644 --- a/api/app/models/knowledge_model.py +++ b/api/app/models/knowledge_model.py @@ -57,6 +57,17 @@ class Knowledge(Base): parser_id = Column(String, index=True, default="naive", comment="default parser ID") parser_config = Column(JSON, nullable=False, default={ + "entry_url": "https://ai.redbearai.com", + "max_pages": 20, + "delay_seconds": 1.0, + "timeout_seconds": 10, + "user_agent": "KnowledgeBaseCrawler/1.0", + "yuque_user_id": "User ID", + "yuque_token": "Token", + "feishu_app_id": "App ID", + "feishu_app_secret": "App Secret", + "feishu_folder_token": "Folder Token", + "sync_cron": "30 7 * * 1-5", "layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n", diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 22972669..68e7cb04 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -86,7 +86,8 @@ class MemoryConfigRepository: n.description AS description, n.entity_type AS entity_type, n.name AS name, - COALESCE(n.fact_summary, '') AS fact_summary, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // COALESCE(n.fact_summary, '') AS fact_summary, n.end_user_id AS end_user_id, n.apply_id AS apply_id, n.user_id AS user_id, @@ -279,6 +280,9 @@ class MemoryConfigRepository: if update.config_desc is not None: db_config.config_desc = update.config_desc has_update = True + if update.scene_id is not None: + db_config.scene_id = update.scene_id + has_update = True if not has_update: raise ValueError("No fields to update") @@ -650,28 +654,32 @@ class MemoryConfigRepository: raise @staticmethod - def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]: - """获取所有配置参数 + def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]: + """获取所有配置参数,包含关联的场景名称 Args: db: 数据库会话 workspace_id: 工作空间ID,用于过滤查询结果 Returns: - List[MemoryConfig]: 配置列表 + List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称) """ + from app.models.ontology_scene import OntologyScene + db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") try: - query = db.query(MemoryConfig) + query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin( + OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id + ) if workspace_id: query = query.filter(MemoryConfig.workspace_id == workspace_id) - configs = query.order_by(desc(MemoryConfig.updated_at)).all() + results = query.order_by(desc(MemoryConfig.updated_at)).all() - db_logger.debug(f"配置列表查询成功: 数量={len(configs)}") - return configs + db_logger.debug(f"配置列表查询成功: 数量={len(results)}") + return results except Exception as e: db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}") diff --git a/api/app/repositories/neo4j/add_edges.py b/api/app/repositories/neo4j/add_edges.py index 162bf411..2b32551c 100644 --- a/api/app/repositories/neo4j/add_edges.py +++ b/api/app/repositories/neo4j/add_edges.py @@ -79,7 +79,8 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], try: edges: List[dict] = [] for s in summaries: - for chunk_id in getattr(s, "chunk_ids", []) or []: + chunk_ids = getattr(s, "chunk_ids", []) or [] + for chunk_id in chunk_ids: edges.append({ "summary_id": s.id, "chunk_id": chunk_id, @@ -91,12 +92,11 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], if not edges: return [] - result = await connector.execute_query( MEMORY_SUMMARY_STATEMENT_EDGE_SAVE, edges=edges ) created = [record.get("uuid") for record in result] if result else [] return created - except Exception: + except Exception as e: return None diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index fcf700b5..42c178b3 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -217,8 +217,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector summaries=flattened ) created_ids = [record.get("uuid") for record in result] + print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids - except Exception: + except Exception as e: + print(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index cf1732fd..651c513f 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -101,10 +101,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity e.name_embedding = CASE WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding ELSE e.name_embedding END, - e.fact_summary = CASE - WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> '' - AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary)) - THEN entity.fact_summary ELSE e.fact_summary END, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // e.fact_summary = CASE + // WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> '' + // AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary)) + // THEN entity.fact_summary ELSE e.fact_summary END, e.connect_strength = CASE WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength ELSE CASE @@ -321,7 +322,8 @@ RETURN e.id AS id, e.description AS description, e.aliases AS aliases, e.name_embedding AS name_embedding, - COALESCE(e.fact_summary, '') AS fact_summary, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // COALESCE(e.fact_summary, '') AS fact_summary, e.connect_strength AS connect_strength, collect(DISTINCT s.id) AS statement_ids, collect(DISTINCT c.id) AS chunk_ids, @@ -1002,3 +1004,58 @@ RETURN DISTINCT x.statement as statement,x.created_at as created_at """ +Graph_Node_query = """ + MATCH (n:MemorySummary) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 0 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:Dialogue) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority + LIMIT 1 + + UNION ALL + + MATCH (n:Statement) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:ExtractedEntity) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 2 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:Chunk) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 3 AS priority + LIMIT $limit + + """ \ No newline at end of file diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index fc32ca9a..526d16ec 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import ( ExtractedEntityNode, EntityEntityEdge, ) - +import logging +logger = logging.getLogger(__name__) async def save_entities_and_relationships( entity_nodes: List[ExtractedEntityNode], entity_entity_edges: List[EntityEntityEdge], @@ -41,8 +42,8 @@ async def save_entities_and_relationships( 'statement': edge.statement, 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, - 'created_at': edge.created_at.isoformat(), - 'expired_at': edge.expired_at.isoformat(), + 'created_at': edge.created_at.isoformat() if edge.created_at else None, + 'expired_at': edge.expired_at.isoformat() if edge.expired_at else None, 'run_id': edge.run_id, 'end_user_id': edge.end_user_id, } @@ -147,14 +148,14 @@ async def save_statement_entity_edges( async def save_dialog_and_statements_to_neo4j( - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - entity_edges: List[EntityEntityEdge], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + entity_edges: List[EntityEntityEdge], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + connector: Neo4jConnector ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. @@ -171,40 +172,126 @@ async def save_dialog_and_statements_to_neo4j( Returns: bool: True if successful, False otherwise """ - try: - # Save all dialogue nodes in batch - dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector) - if dialogue_uuids: + + # 定义事务函数,将所有写操作放在一个事务中 + async def _save_all_in_transaction(tx): + """在单个事务中执行所有保存操作,避免死锁""" + results = {} + + # 1. Save all dialogue nodes in batch + if dialogue_nodes: + from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE + dialogue_data = [node.model_dump() for node in dialogue_nodes] + result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data) + dialogue_uuids = [record["uuid"] async for record in result] + results['dialogues'] = dialogue_uuids print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") - else: - print("Failed to save dialogues to Neo4j") - return False - # Save all chunk nodes in batch - await save_chunk_nodes(chunk_nodes, connector) + # 2. Save all chunk nodes in batch + if chunk_nodes: + from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE + chunk_data = [node.model_dump() for node in chunk_nodes] + result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data) + chunk_uuids = [record["uuid"] async for record in result] + results['chunks'] = chunk_uuids + logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") - # Save all statement nodes in batch + # 3. Save all statement nodes in batch if statement_nodes: - statement_uuids = await add_statement_nodes(statement_nodes, connector) - if statement_uuids: - print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") - else: - print("Failed to save statement nodes to Neo4j") - return False - else: - print("No statement nodes to save") + from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE + statement_data = [node.model_dump() for node in statement_nodes] + result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data) + statement_uuids = [record["uuid"] async for record in result] + results['statements'] = statement_uuids + logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") - # Save entities and relationships - await save_entities_and_relationships(entity_nodes, entity_edges, connector) - print("Successfully saved entities and relationships to Neo4j") + # 4. Save entities + if entity_nodes: + from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE + entity_data = [entity.model_dump() for entity in entity_nodes] + result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data) + entity_uuids = [record["uuid"] async for record in result] + results['entities'] = entity_uuids + logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j") - # Save new edges - await save_statement_chunk_edges(statement_chunk_edges, connector) - await save_statement_entity_edges(statement_entity_edges, connector) + # 5. Create entity relationships + if entity_edges: + from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE + relationship_data = [] + for edge in entity_edges: + relationship_data.append({ + 'source_id': edge.source, + 'target_id': edge.target, + 'predicate': edge.relation_type, + 'statement_id': edge.source_statement_id, + 'value': edge.relation_value, + 'statement': edge.statement, + 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, + 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, + 'created_at': edge.created_at.isoformat() if edge.created_at else None, + 'expired_at': edge.expired_at.isoformat() if edge.expired_at else None, + 'run_id': edge.run_id, + 'end_user_id': edge.end_user_id, + }) + result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data) + rel_uuids = [record["uuid"] async for record in result] + results['entity_relationships'] = rel_uuids + logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j") + # 6. Save statement-chunk edges + if statement_chunk_edges: + from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE + sc_edge_data = [] + for edge in statement_chunk_edges: + sc_edge_data.append({ + "id": edge.id, + "source": edge.source, + "target": edge.target, + "created_at": edge.created_at.isoformat() if edge.created_at else None, + "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, + "run_id": edge.run_id, + "end_user_id": edge.end_user_id, + }) + result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data) + sc_uuids = [record["uuid"] async for record in result] + results['statement_chunk_edges'] = sc_uuids + logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j") + + # 7. Save statement-entity edges + if statement_entity_edges: + from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE + se_edge_data = [] + for edge in statement_entity_edges: + se_edge_data.append({ + "source": edge.source, + "target": edge.target, + "created_at": edge.created_at.isoformat() if edge.created_at else None, + "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, + "run_id": edge.run_id, + "end_user_id": edge.end_user_id, + "connect_strength": getattr(edge, "connect_strength", "strong"), + }) + result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data) + se_uuids = [record["uuid"] async for record in result] + results['statement_entity_edges'] = se_uuids + logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") + + return results + + try: + # 使用显式写事务执行所有操作,避免死锁 + results = await connector.execute_write_transaction(_save_all_in_transaction) + summary = { + key: len(value) + for key, value in results.items() + if isinstance(value, (list, tuple, set)) + } + logger.info("Transaction completed. Summary: %s", summary) + logger.debug("Full transaction results: %r", results) return True except Exception as e: + logger.error(f"Neo4j integration error: {e}", exc_info=True) print(f"Neo4j integration error: {e}") print("Continuing without database storage...") - return False \ No newline at end of file + return False diff --git a/api/app/repositories/ontology_scene_repository.py b/api/app/repositories/ontology_scene_repository.py index 322e111c..141b5d1c 100644 --- a/api/app/repositories/ontology_scene_repository.py +++ b/api/app/repositories/ontology_scene_repository.py @@ -392,3 +392,48 @@ class OntologySceneRepository: exc_info=True ) raise + + def get_simple_list(self, workspace_id: UUID) -> List[dict]: + """获取场景简单列表(仅包含scene_id和scene_name,用于下拉选择) + + 这是一个轻量级查询,不加载关联的classes,响应速度快。 + + Args: + workspace_id: 工作空间ID + + Returns: + List[dict]: 场景简单列表,每项包含scene_id和scene_name + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scenes = repo.get_simple_list(workspace_id) + >>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...] + """ + try: + logger.debug(f"Getting simple scene list for workspace: {workspace_id}") + + # 只查询需要的字段,不加载关联数据 + results = self.db.query( + OntologyScene.scene_id, + OntologyScene.scene_name + ).filter( + OntologyScene.workspace_id == workspace_id + ).order_by( + OntologyScene.updated_at.desc() + ).all() + + scenes = [ + {"scene_id": str(r.scene_id), "scene_name": r.scene_name} + for r in results + ] + + logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}") + + return scenes + + except Exception as e: + logger.error( + f"Failed to get simple scene list: {str(e)}", + exc_info=True + ) + raise diff --git a/api/app/schemas/file_schema.py b/api/app/schemas/file_schema.py index 00f1a148..7245671a 100644 --- a/api/app/schemas/file_schema.py +++ b/api/app/schemas/file_schema.py @@ -10,6 +10,8 @@ class FileBase(BaseModel): file_name: str file_ext: str file_size: int + file_url: str | None = None + created_at: datetime.datetime | None = None class FileCreate(FileBase): @@ -26,6 +28,7 @@ class FileUpdate(BaseModel): file_name: str | None = Field(None) file_ext: str | None = Field(None) file_size: str | None = Field(None) + file_url: str | None = Field(None) class File(FileBase): diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index b6f50dd7..1a5017eb 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import Optional from pydantic import BaseModel @@ -14,4 +15,15 @@ class UserInput(BaseModel): class Write_UserInput(BaseModel): messages: list[dict] end_user_id: str - config_id: Optional[str] = None \ No newline at end of file + config_id: Optional[str] = None + +class AgentMemory_Long_Term(ABC): + """长期记忆配置常量""" + STORAGE_NEO4J = "neo4j" + STORAGE_RAG = "rag" + STRATEGY_AGGREGATE = "aggregate" + STRATEGY_CHUNK = "chunk" + STRATEGY_TIME = "time" + DEFAULT_SCOPE = 6 + + diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 11cacda0..c3e7295b 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -248,8 +248,9 @@ class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 config_id: Union[uuid.UUID, int, str] = None - config_name: str = Field("配置名称", description="配置名称(字符串)") - config_desc: str = Field("配置描述", description="配置描述(字符串)") + config_name: Optional[str] = Field(None, description="配置名称(字符串)") + config_desc: Optional[str] = Field(None, description="配置描述(字符串)") + scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID") class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 42c4fe4f..66af0266 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -964,8 +964,15 @@ class AppService: ).order_by( AgentConfig.updated_at.desc() ) + config = self.db.scalars(stmt).first() + try: + config_memory=config.memory + if 'memory_content' in config_memory: + config.memory['memory_config_id'] = config.memory.pop('memory_content') + except: + logger.debug("记忆配置不存在") if config: return config diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 31662769..3b301743 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -114,6 +114,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str result = task_service.get_task_memory_read_result(task.id) status = result.get("status") logger.info(f"读取任务状态:{status}") + if memory_content: + memory_content = memory_content['answer'] finally: db.close() @@ -127,11 +129,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str "content_length": len(str(memory_content)) } ) - - # 检查是否有有效内容 - if not memory_content or str(memory_content).strip() == "" or "answer" in str(memory_content) and str(memory_content).count("''") > 0: - return "未找到相关的历史记忆。请直接回答用户的问题,不要再次调用此工具。" - return f"检索到以下历史记忆:\n\n{memory_content}" except Exception as e: logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__}) diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index b7079e62..c9327ccf 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -183,11 +183,11 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # --- Read All --- def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 - configs = MemoryConfigRepository.get_all(self.db, workspace_id) + results = MemoryConfigRepository.get_all(self.db, workspace_id) # 将 ORM 对象转换为字典列表 data_list = [] - for config in configs: + for config, scene_name in results: # 安全地转换 user_id 为 int config_id_old = None if config.config_id_old: @@ -209,7 +209,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "end_user_id": config.end_user_id, "config_id_old": config_id_old, "apply_id": config.apply_id, - "scene_id": config.scene_id, + "scene_id": str(config.scene_id) if config.scene_id else None, + "scene_name": scene_name, # 新增:场景名称 "llm_id": config.llm_id, "embedding_id": config.embedding_id, "rerank_id": config.rerank_id, @@ -637,10 +638,9 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]: if m < 1: latest_relative = "刚刚" elif m < 60: - latest_relative = f"{m}分钟前" + latest_relative = "一会前" else: - h = int(m // 60) - latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前" + latest_relative = "较早前" except Exception: pass diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 97ab64cb..80413c12 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -15,6 +15,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.conversation_repository import ConversationRepository from app.repositories.end_user_repository import EndUserRepository +from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping from app.services.implicit_memory_service import ImplicitMemoryService @@ -1525,7 +1526,6 @@ async def analytics_graph_data( user_uuid = uuid.UUID(end_user_id) repo = EndUserRepository(db) end_user = repo.get_by_id(user_uuid) - if not end_user: logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") return { @@ -1579,21 +1579,11 @@ async def analytics_graph_data( } else: # 查询所有节点 - node_query = """ - MATCH (n) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) as id, - labels(n)[0] as label, - properties(n) as properties - LIMIT $limit - """ + node_query=Graph_Node_query node_params = { "end_user_id": end_user_id, "limit": limit } - - # 执行节点查询 node_results = await _neo4j_connector.execute_query(node_query, **node_params) @@ -1604,9 +1594,9 @@ async def analytics_graph_data( for record in node_results: node_id = record["id"] - node_label = record["label"] + node_labels = record.get("labels", []) + node_label = node_labels[0] if node_labels else "Unknown" node_props = record["properties"] - # 根据节点类型提取需要的属性字段 filtered_props = await _extract_node_properties(node_label, node_props,node_id) diff --git a/api/app/tasks.py b/api/app/tasks.py index 2111f2fd..8fb82b96 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -7,6 +7,8 @@ import uuid from uuid import UUID from datetime import datetime, timezone from math import ceil +from pathlib import Path +import shutil from typing import Any, Dict, List, Optional import redis @@ -16,8 +18,13 @@ import trio # Import a unified Celery instance from app.celery_app import celery_app from app.core.config import settings +from app.core.rag.crawler.web_crawler import WebCrawler from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.feishu.models import FileInfo +from app.core.rag.integrations.yuque.client import YuqueAPIClient +from app.core.rag.integrations.yuque.models import YuqueDocInfo from app.core.rag.llm.chat_model import Base from app.core.rag.llm.cv_model import QWenCV from app.core.rag.llm.embedding_model import OpenAIEmbed @@ -29,7 +36,9 @@ from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ) from app.db import get_db, get_db_context from app.models.document_model import Document +from app.models.file_model import File from app.models.knowledge_model import Knowledge +from app.schemas import file_schema, document_schema from app.services.memory_agent_service import MemoryAgentService @@ -382,6 +391,480 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): db.close() +@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") +def sync_knowledge_for_kb(kb_id: uuid.UUID): + """ + sync knowledge document and Document parsing, vectorization, and storage + """ + db = next(get_db()) # Manually call the generator + db_knowledge = None + try: + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() + # 1. get vector_service + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + # 2. sync data + match db_knowledge.type: + case "Web": # Crawl webpages in batches through a web crawler + entry_url = db_knowledge.parser_config.get("entry_url", "") + max_pages = db_knowledge.parser_config.get("max_pages", 20) + delay_seconds = db_knowledge.parser_config.get("delay_seconds", 1.0) + timeout_seconds = db_knowledge.parser_config.get("timeout_seconds", 10) + user_agent = db_knowledge.parser_config.get("user_agent", "KnowledgeBaseCrawler/1.0") + # Create crawler + crawler = WebCrawler( + entry_url=entry_url, + max_pages=max_pages, + delay_seconds=delay_seconds, + timeout_seconds=timeout_seconds, + user_agent=user_agent + ) + try: + # 初始化存储已爬取 URLs 的集合 + file_urls = set() + # crawl entry_url by yield + for crawled_document in crawler.crawl(): + file_urls.add(crawled_document.url) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == crawled_document.url).first() + if db_file: + if db_file.file_size == crawled_document.content_length: # same + continue + else: # --update + if crawled_document.content_length: + # 1. update file + db_file.file_name = f"{crawled_document.title}.txt" + db_file.file_ext=".txt" + db_file.file_size=crawled_document.content_length + db.commit() + db.refresh(db_file) + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + content_bytes = crawled_document.content.encode('utf-8') + with open(save_path, "wb") as f: + f.write(content_bytes) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + if crawled_document.content_length: + # 1. upload file + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=f"{crawled_document.title}.txt", + file_ext=".txt", + file_size=crawled_document.content_length, + file_url=crawled_document.url, + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # Save file + content_bytes = crawled_document.content.encode('utf-8') + with open(save_path, "wb") as f: + f.write(content_bytes) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during crawl: {e}") + case "Third-party": # Integration of knowledge bases from three parties + yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "") + feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "") + if yuque_user_id: # Yuque Knowledge Base + yuque_token = db_knowledge.parser_config.get("yuque_token", "") + # Create yuqueAPIClient + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + try: + # 初始化存储获取语雀 URLs 的集合 + file_urls = set() + + # Get all files from all repos + async def async_get_files(api_client: YuqueAPIClient): + async with api_client as client: + print("\n=== Fetching repositories ===") + repos = await client.get_user_repos() + print(f"Found {len(repos)} repositories:") + all_files = [] + for repo in repos: + # Get documents from repository + print(f"\n=== Fetching documents from '{repo.name}' ===") + docs = await client.get_repo_docs(repo.id) + all_files.extend(docs) + return all_files + + files = asyncio.run(async_get_files(api_client)) + for doc in files: + file_urls.add(doc.slug) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == doc.slug).first() + if db_file: + if db_file.created_at == doc.updated_at: # same + continue + else: # --update + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # update db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + db_file.file_name = file_name + db_file.file_ext = file_extension.lower() + db_file.file_size = file_size + db_file.created_at = doc.updated_at + db.commit() + db.refresh(db_file) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.created_at = db_file.created_at + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + # add db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=file_name, + file_ext=file_extension.lower(), + file_size=file_size, + file_url=doc.slug, + created_at=doc.updated_at + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Save file + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", + value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during fetch feishu: {e}") + if feishu_app_id: # Feishu Knowledge Base + feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "") + feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "") + # Create feishuAPIClient + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret + ) + try: + # 初始化存储获取飞书 URLs 的集合 + file_urls = set() + # Get all files from folder + async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str): + async with api_client as client: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + return files + files = asyncio.run(async_get_files(api_client, feishu_folder_token)) + # Filter out folders, only sync documents + documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]] + for doc in documents: + file_urls.add(doc.url) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == doc.url).first() + if db_file: + if db_file.created_at == doc.modified_time: # same + continue + else: # --update + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), + str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # update db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + db_file.file_name = file_name + db_file.file_ext = file_extension.lower() + db_file.file_size = file_size + db_file.created_at = doc.modified_time + db.commit() + db.refresh(db_file) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.created_at = db_file.created_at + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), + str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + # add db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=file_name, + file_ext=file_extension.lower(), + file_size=file_size, + file_url=doc.url, + created_at = doc.modified_time + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Save file + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", + value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during fetch feishu: {e}") + case _: # General + print(f"General: No synchronization needed\n") + + + result = f"sync knowledge '{db_knowledge.name}' processed successfully." + return result + except Exception as e: + if 'db_knowledge' in locals(): + print(f"Failed to sync knowledge:{str(e)}\n") + result = f"sync knowledge '{db_knowledge.name}' failed." + return result + finally: + db.close() + + @celery_app.task(name="app.core.memory.agent.read_message", bind=True) def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]: diff --git a/api/app/utils/config_utils.py b/api/app/utils/config_utils.py index cc67afd2..55cfe8a3 100644 --- a/api/app/utils/config_utils.py +++ b/api/app/utils/config_utils.py @@ -5,42 +5,68 @@ Shared utilities for configuration handling to avoid circular imports. """ from uuid import UUID from sqlalchemy.orm import Session +import uuid as uuid_module -def resolve_config_id(config_id: UUID | int|str, db: Session) -> UUID: +def resolve_config_id(config_id: UUID | int | str, db: Session) -> UUID: """ - 解析 config_id,如果是整数则通过 config_id_old 查找对应的 UUID + 解析 config_id,支持 UUID、UUID字符串、整数等多种格式 Args: - config_id: 配置ID(UUID 或整数) + config_id: 配置ID(UUID、UUID字符串 或 整数) db: 数据库会话 Returns: UUID: 解析后的配置ID Raises: - ValueError: 当找不到对应的配置时 + ValueError: 当找不到对应的配置时或格式无效时 """ - from app.models.memory_config_model import MemoryConfig - if isinstance(config_id, UUID): + + # 1. 如果已经是 UUID 类型,直接返回 + if isinstance(config_id, UUID): return config_id - if isinstance(config_id, str) and len(config_id)<=6: - memory_config = db.query(MemoryConfig).filter( - MemoryConfig.config_id_old == int(config_id) - ).first() - print(memory_config) - if not memory_config: - raise ValueError(f"STR 未找到 config_id_old={config_id} 对应的配置") - return memory_config.config_id + + # 2. 如果是字符串类型 + if isinstance(config_id, str): + config_id_stripped = config_id.strip() + + # 2.1 尝试解析为 UUID(标准 UUID 字符串长度为 36) + try: + return uuid_module.UUID(config_id_stripped) + except ValueError: + pass + + # 2.2 尝试解析为整数(用于查询 config_id_old) + try: + old_id = int(config_id_stripped) + if old_id > 0: + memory_config = db.query(MemoryConfig).filter( + MemoryConfig.config_id_old == old_id + ).first() + if not memory_config: + raise ValueError(f"未找到 config_id_old={old_id} 对应的配置") + return memory_config.config_id + except ValueError: + pass + + # 2.3 无法解析的字符串格式 + raise ValueError(f"无效的 config_id 格式: '{config_id}'(必须是 UUID 或正整数)") + + # 3. 如果是整数类型,通过 config_id_old 查找 if isinstance(config_id, int): + if config_id <= 0: + raise ValueError(f"config_id 必须是正整数: {config_id}") + memory_config = db.query(MemoryConfig).filter( MemoryConfig.config_id_old == config_id ).first() if not memory_config: - raise ValueError(f"INT 未找到 config_id_old={config_id} 对应的配置") + raise ValueError(f"未找到 config_id_old={config_id} 对应的配置") return memory_config.config_id - return config_id + # 4. 不支持的类型 + raise ValueError(f"不支持的 config_id 类型: {type(config_id).__name__}") diff --git a/api/migrations/versions/ef0787b85c35_202602061233.py b/api/migrations/versions/ef0787b85c35_202602061233.py new file mode 100644 index 00000000..1d08ec71 --- /dev/null +++ b/api/migrations/versions/ef0787b85c35_202602061233.py @@ -0,0 +1,32 @@ +"""202602061233 + +Revision ID: ef0787b85c35 +Revises: 9b28b66cf8e8 +Create Date: 2026-02-06 12:33:26.114673 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'ef0787b85c35' +down_revision: Union[str, None] = '9b28b66cf8e8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('files', sa.Column('file_url', sa.String(), nullable=True, comment='file comes from a website url')) + op.create_index(op.f('ix_files_file_url'), 'files', ['file_url'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_files_file_url'), table_name='files') + op.drop_column('files', 'file_url') + # ### end Alembic commands ### diff --git a/api/pyproject.toml b/api/pyproject.toml index 6d23a3b9..66b1a295 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -141,6 +141,8 @@ dependencies = [ "flower>=2.0.1", "aiofiles>=23.0.0", "owlready2>=0.46", + "lxml>=4.9.0", + "httpx>=0.28.0", ] [tool.pytest.ini_options] diff --git a/api/requirements.txt b/api/requirements.txt index 6cdae2d1..144c0db2 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -134,3 +134,5 @@ xlrd==2.0.2 oss2>=2.18.0 boto3>=1.28.0 aiofiles>=23.0.0 +lxml>=4.9.0 +httpx>=0.28.0 diff --git a/api/uv.lock b/api/uv.lock index 587fc5b0..a9bde1ed 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -3224,6 +3224,7 @@ dependencies = [ { name = "hanziconv" }, { name = "html5lib" }, { name = "httptools" }, + { name = "httpx" }, { name = "huggingface-hub" }, { name = "idna" }, { name = "jieba" }, @@ -3237,6 +3238,7 @@ dependencies = [ { name = "langchain-ollama" }, { name = "langchain-openai" }, { name = "langfuse" }, + { name = "lxml" }, { name = "mako" }, { name = "mammoth" }, { name = "markdown" }, @@ -3361,6 +3363,7 @@ requires-dist = [ { name = "hanziconv", specifier = "==0.3.2" }, { name = "html5lib", specifier = "==1.1" }, { name = "httptools", specifier = "==0.7.1" }, + { name = "httpx", specifier = ">=0.28.0" }, { name = "huggingface-hub", specifier = "==0.25.2" }, { name = "idna", specifier = "==3.11" }, { name = "jieba", specifier = ">=0.42.1" }, @@ -3375,6 +3378,7 @@ requires-dist = [ { name = "langchain-ollama" }, { name = "langchain-openai", specifier = ">=1.0.2" }, { name = "langfuse", specifier = ">=3.10.0" }, + { name = "lxml", specifier = ">=4.9.0" }, { name = "mako", specifier = "==1.3.10" }, { name = "mammoth", specifier = "==1.11.0" }, { name = "markdown", specifier = "==3.8" }, diff --git a/web/package.json b/web/package.json index e28e8b56..89800fcf 100644 --- a/web/package.json +++ b/web/package.json @@ -13,6 +13,14 @@ "@antv/layout": "^1.2.14-beta.8", "@antv/x6": "^3.0.1", "@antv/x6-react-shape": "^3.0.1", + "@codemirror/lang-cpp": "^6.0.3", + "@codemirror/lang-java": "^6.0.2", + "@codemirror/lang-javascript": "^6.2.4", + "@codemirror/lang-python": "^6.2.1", + "@codemirror/lang-rust": "^6.0.2", + "@codemirror/state": "^6.5.4", + "@codemirror/theme-one-dark": "^6.1.3", + "@codemirror/view": "^6.39.12", "@dnd-kit/core": "^6.3.1", "@dnd-kit/modifiers": "^9.0.0", "@dnd-kit/sortable": "^10.0.0", @@ -25,6 +33,7 @@ "antd": "^5.27.4", "axios": "^1.12.2", "clsx": "^2.1.1", + "codemirror": "^6.0.2", "copy-to-clipboard": "^3.3.3", "crypto-js": "^4.2.0", "dayjs": "^1.11.18", @@ -55,6 +64,7 @@ "@tailwindcss/postcss": "^4.1.14", "@tailwindcss/typography": "^0.5.19", "@tailwindcss/vite": "^4.1.14", + "@types/codemirror": "^5.60.17", "@types/crypto-js": "^4.2.2", "@types/js-yaml": "^4.0.9", "@types/node": "^24.6.0", diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index 6f4e7f0e..987ef358 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -256,7 +256,7 @@ export const updateMemoryExtractionConfig = (values: ExtractionConfigForm) => { return request.post('/memory-storage/update_config_extracted', values) } // Memory Extraction Engine - Pilot run -export const pilotRunMemoryExtractionConfig = (values: { config_id: number | string; dialogue_text: string; }, onMessage?: (data: SSEMessage[]) => void) => { +export const pilotRunMemoryExtractionConfig = (values: { config_id: number | string; dialogue_text: string; custom_text?: string; }, onMessage?: (data: SSEMessage[]) => void) => { return handleSSE('/memory-storage/pilot_run', values, onMessage) } // Emotion Engine - Get configuration diff --git a/web/src/api/ontology.ts b/web/src/api/ontology.ts index becf899f..90a6857f 100644 --- a/web/src/api/ontology.ts +++ b/web/src/api/ontology.ts @@ -8,6 +8,7 @@ import { request } from '@/utils/request' import type { Query, OntologyModalData, OntologyClassModalData, OntologyClassExtractModalData, OntologyExportModalData } from '@/views/Ontology/types' // Scene list +export const getOntologyScenesSimpleUrl = '/memory/ontology/scenes/simple' export const getOntologyScenesUrl = '/memory/ontology/scenes' export const getOntologyScenesList = (data: Query) => { return request.get(getOntologyScenesUrl, data) diff --git a/web/src/components/CodeMirrorEditor/index.tsx b/web/src/components/CodeMirrorEditor/index.tsx new file mode 100644 index 00000000..e100b75b --- /dev/null +++ b/web/src/components/CodeMirrorEditor/index.tsx @@ -0,0 +1,150 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-04 17:20:52 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 17:20:52 + */ +import { useEffect, useRef, useMemo } from 'react'; +import { EditorView, basicSetup } from 'codemirror'; +import { EditorState } from '@codemirror/state'; +import { python } from '@codemirror/lang-python'; +import { javascript } from '@codemirror/lang-javascript'; +import { java } from '@codemirror/lang-java'; +import { cpp } from '@codemirror/lang-cpp'; +import { rust } from '@codemirror/lang-rust'; +import { oneDark } from '@codemirror/theme-one-dark'; + +/** + * Props for the CodeMirrorEditor component + * @property {string} value - The initial code content to display in the editor + * @property {string} language - Programming language for syntax highlighting (python, python3, javascript, typescript, java, cpp, c, rust) + * @property {function} onChange - Callback function triggered when editor content changes, receives the new code value + * @property {string} theme - Editor theme, either 'light' or 'dark' + * @property {boolean} readOnly - Whether the editor is read-only + * @property {string} height - Custom height for the editor + * @property {string} size - Predefined size preset: 'default' (120px min-height, 14px font) or 'small' (60px min-height, 12px font) + */ +interface CodeMirrorEditorProps { + value?: string; + language?: 'python' | 'python3' | 'javascript' | 'typescript' | 'java' | 'cpp' | 'c' | 'rust'; + onChange?: (value: string) => void; + theme?: 'light' | 'dark'; + readOnly?: boolean; + height?: string; + size?: 'default' | 'small'; +} + +/** + * Map of language identifiers to their corresponding CodeMirror language extensions + * Supports multiple programming languages with syntax highlighting + */ +const languageExtensions: Record<string, any> = { + python: python(), + python3: python(), + javascript: javascript(), + typescript: javascript({ typescript: true }), + java: java(), + cpp: cpp(), + c: cpp(), + rust: rust(), +}; + +/** + * CodeMirrorEditor - A React wrapper component for CodeMirror 6 editor + * Provides a code editor with syntax highlighting, theme support, and customizable sizing + * Used in workflow code execution nodes for editing Python and JavaScript code + */ +const CodeMirrorEditor = ({ + value = '', + language = 'javascript', + onChange, + theme = 'light', + readOnly = false, + size, +}: CodeMirrorEditorProps) => { + // Reference to the DOM element that will contain the editor + const editorRef = useRef<HTMLDivElement>(null); + // Reference to the CodeMirror EditorView instance + const viewRef = useRef<EditorView | null>(null); + + /** + * Initialize CodeMirror editor when component mounts or when language/theme/readOnly changes + * Sets up extensions for syntax highlighting, change listeners, and theme + */ + useEffect(() => { + if (!editorRef.current) return; + + // Get the appropriate language extension, fallback to JavaScript if not found + const langExtension = languageExtensions[language] || languageExtensions.javascript; + + // Configure editor extensions + const extensions = [ + basicSetup, // Basic editor features (line numbers, bracket matching, etc.) + langExtension, // Language-specific syntax highlighting + // Listen for document changes and trigger onChange callback + EditorView.updateListener.of((update) => { + if (update.docChanged && onChange) { + onChange(update.state.doc.toString()); + } + }), + EditorState.readOnly.of(readOnly), // Set read-only mode + ]; + + // Apply dark theme if specified + if (theme === 'dark') { + extensions.push(oneDark); + } + + // Create editor state with initial value and extensions + const state = EditorState.create({ + doc: value, + extensions, + }); + + // Create and mount the editor view + viewRef.current = new EditorView({ + state, + parent: editorRef.current, + }); + + // Cleanup: destroy editor instance when component unmounts or dependencies change + return () => { + viewRef.current?.destroy(); + }; + }, [language, theme, readOnly]); + + /** + * Update editor content when the value prop changes externally + * Only updates if the new value differs from current editor content + */ + useEffect(() => { + if (viewRef.current && value !== viewRef.current.state.doc.toString()) { + viewRef.current.dispatch({ + changes: { + from: 0, + to: viewRef.current.state.doc.length, + insert: value, + }, + }); + } + }, [value]); + + // Calculate minimum height based on size prop: small (60px) or default (120px) + const minHeight = useMemo(() => { + return `${size === 'small' ? 60 : 120}px` + }, [size]) + + // Calculate font size based on size prop: small (12px) or default (14px) + const fontSize = useMemo(() => { + return `${size === 'small' ? 12 : 14}px` + }, [size]) + + // Calculate line height based on size prop: small (16px) or default (20px) + const lineHeight = useMemo(() => { + return `${size === 'small' ? 16 : 20}px` + }, [size]) + + return <div ref={editorRef} style={{ minHeight, fontSize, lineHeight }} />; +}; + +export default CodeMirrorEditor; diff --git a/web/src/components/Markdown/index.tsx b/web/src/components/Markdown/index.tsx index a2fac5ba..9d3c482b 100644 --- a/web/src/components/Markdown/index.tsx +++ b/web/src/components/Markdown/index.tsx @@ -81,7 +81,7 @@ const components = { audio: ({ src, ...props }: any) => <AudioBlock node={{ children: [{ properties: { src: src || '' } }] }} {...props} />, a: ({ href, children, ...props }: any) => <Link href={href || '#'} {...props}>{children}</Link>, button: ({ children }: any) => <RbButton node={{ children }}>{[children]}</RbButton>, - table: ({ children, ...props }: any) => <table className="rb:border rb:border-[#D9D9D9] rb:mb-2" {...props}>{children}</table>, + table: ({ children, ...props }: any) => <div className="rb:overflow-x-auto rb:max-w-full"><table className="rb:border rb:border-[#D9D9D9] rb:mb-2" {...props}>{children}</table></div>, tr: ({ children, ...props }: any) => <tr className="rb:border rb:border-[#D9D9D9]" {...props}>{children}</tr>, th: ({ children, ...props }: any) => <th className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left rb:font-bold" {...props}>{children}</th>, td: ({ children, ...props }: any) => <td className="rb:border rb:border-[#D9D9D9] rb:px-2 rb:py-1 rb:text-left" {...props}>{children}</td>, diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index fe0fbc37..9d706ff6 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1543,7 +1543,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re text_preprocessing_desc: 'Text split into {{count}} semantic fragments', knowledge_extraction_desc: 'Knowledge extraction completed, identified {{entities}} entities, {{statements}} statements, {{temporal_ranges_count}} temporal extractions, {{triplets}} triplets', creating_nodes_edges_desc: 'Entity relationship creation completed, {{num}} relationships in total', - deduplication_desc: 'Deduplication and disambiguation completed, {{count}} unique entities in total' + deduplication_desc: 'Deduplication and disambiguation completed, {{count}} unique entities in total', + custom_text: 'Debug Text', }, memoryConversation: { searchPlaceholder: 'Enter user ID...', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 7fc8b652..a7ef34ac 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1617,7 +1617,8 @@ export const zh = { text_preprocessing_desc: '文本切分为{{count}}个语义片段', knowledge_extraction_desc: '知识抽取完成,共识别{{entities}}个实体,{{statements}}个句子, {{temporal_ranges_count}}个时间提取, {{triplets}}个三元组', creating_nodes_edges_desc: '实体关系创建完成,共{{num}}条关系', - deduplication_desc: '去重消歧完成,最终{{count}}个唯一实体' + deduplication_desc: '去重消歧完成,最终{{count}}个唯一实体', + custom_text: '调试文本', }, memoryConversation: { chatEmpty:'有什么我可以帮您的吗?', diff --git a/web/src/styles/index.css b/web/src/styles/index.css index bbbe9cd9..d937396a 100644 --- a/web/src/styles/index.css +++ b/web/src/styles/index.css @@ -180,4 +180,9 @@ body { .x6-node foreignObject > body { min-height: 100%; max-height: 100%; +} + +.ͼ2 .cm-gutters { + background-color: #FFFFFF; + border: none; } \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 0bfd4ba7..6feb1548 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:29:21 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-04 20:16:45 + * @Last Modified time: 2026-02-06 11:20:14 */ import { type FC, type ReactNode, useEffect, useRef, useState, forwardRef, useImperativeHandle } from 'react'; import clsx from 'clsx' @@ -38,8 +38,8 @@ import CustomSelect from '@/components/CustomSelect' import aiPrompt from '@/assets/images/application/aiPrompt.png' import AiPromptModal from './components/AiPromptModal' import ToolList from './components/ToolList/ToolList' -import ChatVariableConfigModal from './components/ChatVariableConfigModal'; import SkillList from './components/Skill' +import ChatVariableConfigModal from './components/ChatVariableConfigModal'; import type { Skill } from '@/views/Skills/types' /** @@ -169,7 +169,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => { const { skills } = response let allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : [] let allTools = Array.isArray(response.tools) ? response.tools : [] - const memoryContent = response.memory?.memory_content + const memoryContent = response.memory?.memory_config_id const parsedMemoryContent = memoryContent === null || memoryContent === '' ? undefined : !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent @@ -178,7 +178,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => { tools: allTools, memory: { ...response.memory, - memory_content: parsedMemoryContent + memory_config_id: parsedMemoryContent }, skills: { ...skills, @@ -262,7 +262,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => { if (!isSave || !data) return Promise.resolve() const { memory, knowledge_retrieval, tools, skills, ...rest } = values const { knowledge_bases = [], ...knowledgeRest } = knowledge_retrieval || {} - const { memory_content } = memory || {} + const { memory_config_id } = memory || {} // Get other necessary properties of memory from original data const originalMemory = data.memory || ({} as MemoryConfig) @@ -272,7 +272,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => { memory: { ...originalMemory, ...memory, - memory_content: memory_content ? String(memory_content) : '', + memory_config_id: memory_config_id ? String(memory_config_id) : '', }, knowledge_retrieval: knowledge_bases.length > 0 ? { ...data.knowledge_retrieval, @@ -444,7 +444,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => { <SelectWrapper title="selectMemoryContent" desc="selectMemoryContentDesc" - name={['memory', 'memory_content']} + name={['memory', 'memory_config_id']} url={memoryConfigListUrl} /> </Space> diff --git a/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx index 297e9faa..7fdf1ab2 100644 --- a/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx @@ -140,7 +140,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi title={t('application.knowledgeBaseAssociation')} extra={ <Space> - <Button style={{ padding: '0 8px', height: '24px' }} onClick={handleKnowledgeConfig}>{t('workflow.config.knowledge-retrieval.recallConfig')}</Button> + <Button style={{ padding: '0 8px', height: '24px' }} onClick={handleKnowledgeConfig}>{t('application.globalConfig')}</Button> <Button style={{ padding: '0 8px', height: '24px' }} onClick={handleAddKnowledge}>+</Button> </Space> } diff --git a/web/src/views/ApplicationConfig/components/Skill/index.tsx b/web/src/views/ApplicationConfig/components/Skill/index.tsx index 1a8dcc6d..d42edd3d 100644 --- a/web/src/views/ApplicationConfig/components/Skill/index.tsx +++ b/web/src/views/ApplicationConfig/components/Skill/index.tsx @@ -39,7 +39,7 @@ const processObj = [ * @param value - Current skill configuration values * @param onChange - Callback function when configuration changes */ -const Skill: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) => void}> = () => { +const SkillList: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) => void}> = () => { const { t } = useTranslation() const form = Form.useFormInstance() const skillConfig = Form.useWatch(['skills'], form) @@ -148,4 +148,4 @@ const Skill: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) = </Card> ) } -export default Skill \ No newline at end of file +export default SkillList \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/types.ts b/web/src/views/ApplicationConfig/types.ts index fc799b91..2d09f739 100644 --- a/web/src/views/ApplicationConfig/types.ts +++ b/web/src/views/ApplicationConfig/types.ts @@ -43,7 +43,7 @@ export interface MemoryConfig { /** Whether memory is enabled */ enabled: boolean; /** Memory content */ - memory_content?: string; + memory_config_id?: string; /** Maximum history length */ max_history?: number | string; } diff --git a/web/src/views/MemoryExtractionEngine/components/Result.tsx b/web/src/views/MemoryExtractionEngine/components/Result.tsx index 6fdeb2af..cb89661a 100644 --- a/web/src/views/MemoryExtractionEngine/components/Result.tsx +++ b/web/src/views/MemoryExtractionEngine/components/Result.tsx @@ -13,7 +13,7 @@ import { type FC, useState } from 'react' import { useParams } from 'react-router-dom' import { useTranslation } from 'react-i18next' -import { Space, Button, Progress } from 'antd' +import { Space, Button, Progress, Form, Input } from 'antd' import { ExclamationCircleFilled, CheckCircleFilled, ClockCircleOutlined, LoadingOutlined } from '@ant-design/icons' import clsx from 'clsx' import type { AnyObject } from 'antd/es/_util/type'; @@ -79,6 +79,8 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => { const [creatingNodesEdges, setCreatingNodesEdges] = useState<ModuleItem>(initObj as ModuleItem) const [deduplication, setDeduplication] = useState<ModuleItem>(initObj as ModuleItem) + const [runForm] = Form.useForm() + /** Run pilot test */ const handleRun = () => { if(!id) return @@ -187,6 +189,7 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => { pilotRunMemoryExtractionConfig({ config_id: id, dialogue_text: t('memoryExtractionEngine.exampleText'), + custom_text: runForm.getFieldValue('custom_text') }, handleStreamMessage) .finally(() => { setRunLoading(false) @@ -222,6 +225,14 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => { headerClassName="rb:pb-0! rb:pt-4!" bodyClassName="rb:min-h-[calc(100vh-388px)] rb:p-[16px_20px]!" > + <Form form={runForm} layout="vertical"> + <Form.Item + name="custom_text" + label={t('memoryExtractionEngine.custom_text')} + > + <Input.TextArea placeholder={t('common.pleaseEnter')} /> + </Form.Item> + </Form> <div className="rb:min-h-[calc(100vh-480px)] rb:overflow-y-auto"> {runLoading ? <> diff --git a/web/src/views/MemoryManagement/components/MemoryForm.tsx b/web/src/views/MemoryManagement/components/MemoryForm.tsx index 22bff65a..93246ca9 100644 --- a/web/src/views/MemoryManagement/components/MemoryForm.tsx +++ b/web/src/views/MemoryManagement/components/MemoryForm.tsx @@ -16,7 +16,7 @@ import { useTranslation } from 'react-i18next'; import type { MemoryFormData, Memory, MemoryFormRef } from '../types'; import RbModal from '@/components/RbModal' import { createMemoryConfig, updateMemoryConfig } from '@/api/memory' -import { getOntologyScenesUrl } from '@/api/ontology' +import { getOntologyScenesSimpleUrl } from '@/api/ontology' import CustomSelect from '@/components/CustomSelect'; const FormItem = Form.Item; @@ -129,8 +129,7 @@ const MemoryForm = forwardRef<MemoryFormRef, MemoryFormProps>(({ > <CustomSelect placeholder={t('common.pleaseSelect')} - url={getOntologyScenesUrl} - params={{ pagesize: 100, page: 1 }} + url={getOntologyScenesSimpleUrl} hasAll={false} valueKey='scene_id' labelKey="scene_name" diff --git a/web/src/views/MemoryManagement/index.tsx b/web/src/views/MemoryManagement/index.tsx index dbda547f..ac2b4fa5 100644 --- a/web/src/views/MemoryManagement/index.tsx +++ b/web/src/views/MemoryManagement/index.tsx @@ -112,7 +112,7 @@ const MemoryManagement: React.FC = () => { title={item.config_name} > <Tooltip title={item.config_desc}> - <div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1">{item.config_desc}</div> + <div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-[17px]">{item.config_desc}</div> </Tooltip> <RbAlert className="rb:mt-3 "> <div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}> diff --git a/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx b/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx index 8a21e012..169d9690 100644 --- a/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx +++ b/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx @@ -103,9 +103,9 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod {model.api_keys && model.api_keys.length > 0 && ( <div className="rb:mb-4"> {model.api_keys.map((key) => ( - <div key={key.id} className="rb:flex rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2"> - <div> - <div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium">{key.api_key}</div> + <div key={key.id} className="rb:flex rb:gap-3 rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2"> + <div className="rb:flex-1"> + <div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium rb:break-all">{key.api_key}</div> <div className="rb:text-[#5B6167] rb:text-[12px] rb:mt-1">{key.api_base}</div> </div> <Button type="primary" danger ghost onClick={() => handleDelete(key.id)}>{t('common.remove')}</Button> diff --git a/web/src/views/Workflow/components/Editor/index.tsx b/web/src/views/Workflow/components/Editor/index.tsx index 4c8540a8..60da03a7 100644 --- a/web/src/views/Workflow/components/Editor/index.tsx +++ b/web/src/views/Workflow/components/Editor/index.tsx @@ -15,8 +15,6 @@ import CharacterCountPlugin from './plugin/CharacterCountPlugin' import InitialValuePlugin from './plugin/InitialValuePlugin'; import CommandPlugin from './plugin/CommandPlugin'; import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin'; -import Python3HighlightPlugin from './plugin/Python3HighlightPlugin'; -import JavaScriptHighlightPlugin from './plugin/JavaScriptHighlightPlugin'; import LineNumberPlugin from './plugin/LineNumberPlugin'; import BlurPlugin from './plugin/BlurPlugin'; import { VariableNode } from './nodes/VariableNode' @@ -32,7 +30,7 @@ export interface LexicalEditorProps { lineHeight?: number; size?: 'default' | 'small'; type?: 'input' | 'textarea', - language?: 'string' | 'jinja2' | 'python3' | 'javascript' + language?: 'string' | 'jinja2' } const theme = { @@ -67,7 +65,7 @@ const Editor: FC<LexicalEditorProps> =({ const [enableLineNumbers, setEnableLineNumbers] = useState(false) useEffect(() => { - const needsLineNumbers = language === 'jinja2' || language === 'python3' || language === 'javascript'; + const needsLineNumbers = language === 'jinja2'; setEnableJinja2(language === 'jinja2'); setEnableLineNumbers(needsLineNumbers); @@ -237,13 +235,11 @@ const Editor: FC<LexicalEditorProps> =({ <HistoryPlugin /> <CommandPlugin /> {language === 'jinja2' && <Jinja2HighlightPlugin />} - {language === 'python3' && <Python3HighlightPlugin />} - {language === 'javascript' && <JavaScriptHighlightPlugin />} {enableLineNumbers && <LineNumberPlugin />} <AutocompletePlugin options={options} enableJinja2={enableJinja2} /> <CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} /> - <InitialValuePlugin key={language} value={value} options={options} enableLineNumbers={enableLineNumbers} /> - {enableLineNumbers && <BlurPlugin />} + <InitialValuePlugin value={value} options={options} enableLineNumbers={enableLineNumbers} /> + {enableJinja2 && <BlurPlugin />} </div> </LexicalComposer> ); diff --git a/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx deleted file mode 100644 index 21219139..00000000 --- a/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx +++ /dev/null @@ -1,182 +0,0 @@ -import { useEffect, useRef } from 'react'; -import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical'; - -const JS_KEYWORDS = new Set([ - 'async', 'await', 'break', 'case', 'catch', 'class', 'const', 'continue', 'debugger', 'default', - 'delete', 'do', 'else', 'export', 'extends', 'finally', 'for', 'function', 'if', 'import', - 'in', 'instanceof', 'let', 'new', 'return', 'super', 'switch', 'this', 'throw', 'try', - 'typeof', 'var', 'void', 'while', 'with', 'yield', 'true', 'false', 'null', 'undefined' -]); - -const JavaScriptHighlightPlugin = () => { - const [editor] = useLexicalComposerContext(); - const isPastingRef = useRef(false); - - useEffect(() => { - return editor.registerCommand( - PASTE_COMMAND, - () => { - isPastingRef.current = true; - setTimeout(() => { - isPastingRef.current = false; - }, 100); - return false; - }, - COMMAND_PRIORITY_LOW - ); - }, [editor]); - - useEffect(() => { - return editor.registerNodeTransform(TextNode, (textNode: TextNode) => { - if (isPastingRef.current) return; - - const text = textNode.getTextContent(); - - if (textNode.hasFormat('code')) return; - if (!needsHighlight(text)) return; - if (textNode.getStyle()) return; - - const parent = textNode.getParent(); - if (!parent) return; - - const selection = $getSelection(); - let selectionOffset = null; - if ($isRangeSelection(selection)) { - const anchor = selection.anchor; - if (anchor.getNode() === textNode) { - selectionOffset = anchor.offset; - } - } - - const tokens = tokenizeJavaScript(text); - if (tokens.length <= 1) return; - - const newNodes = tokens.map(token => { - const newNode = $createTextNode(token.text); - newNode.toggleFormat('code'); - - switch (token.type) { - case 'keyword': - newNode.setStyle('color: #d73a49; font-weight: 600;'); - break; - case 'string': - newNode.setStyle('color: #032f62;'); - break; - case 'comment': - newNode.setStyle('color: #6a737d; font-style: italic;'); - break; - case 'number': - newNode.setStyle('color: #005cc5; font-weight: 500;'); - break; - case 'function': - newNode.setStyle('color: #6f42c1; font-weight: 500;'); - break; - } - - return newNode; - }); - - if (newNodes.length > 1) { - textNode.replace(newNodes[0]); - for (let i = 1; i < newNodes.length; i++) { - newNodes[i - 1].insertAfter(newNodes[i]); - } - - if (selectionOffset !== null && $isRangeSelection(selection)) { - let currentOffset = 0; - for (const node of newNodes) { - const nodeLength = node.getTextContent().length; - if (currentOffset + nodeLength >= selectionOffset) { - node.select(selectionOffset - currentOffset, selectionOffset - currentOffset); - break; - } - currentOffset += nodeLength; - } - } - } - }); - }, [editor]); - - return null; -}; - -function needsHighlight(text: string): boolean { - return /[a-zA-Z0-9_/"'`]/.test(text); -} - -function tokenizeJavaScript(text: string): Array<{text: string, type: string}> { - const tokens: Array<{text: string, type: string}> = []; - let i = 0; - - while (i < text.length) { - // Single-line comments - if (text.slice(i, i + 2) === '//') { - let start = i; - while (i < text.length && text[i] !== '\n') i++; - tokens.push({ text: text.slice(start, i), type: 'comment' }); - continue; - } - - // Multi-line comments - if (text.slice(i, i + 2) === '/*') { - let start = i; - i += 2; - while (i < text.length && text.slice(i, i + 2) !== '*/') i++; - if (i < text.length) i += 2; - tokens.push({ text: text.slice(start, i), type: 'comment' }); - continue; - } - - // Strings - if (text[i] === '"' || text[i] === "'" || text[i] === '`') { - const quote = text[i]; - let start = i++; - - while (i < text.length) { - if (text[i] === quote && text[i - 1] !== '\\') { - i++; - break; - } - i++; - } - tokens.push({ text: text.slice(start, i), type: 'string' }); - continue; - } - - // Numbers - if (/\d/.test(text[i])) { - let start = i; - while (i < text.length && /[\d.]/.test(text[i])) i++; - tokens.push({ text: text.slice(start, i), type: 'number' }); - continue; - } - - // Keywords and identifiers - if (/[a-zA-Z_$]/.test(text[i])) { - let start = i; - while (i < text.length && /[a-zA-Z0-9_$]/.test(text[i])) i++; - const word = text.slice(start, i); - - if (JS_KEYWORDS.has(word)) { - tokens.push({ text: word, type: 'keyword' }); - } else if (i < text.length && text[i] === '(') { - tokens.push({ text: word, type: 'function' }); - } else { - tokens.push({ text: word, type: 'text' }); - } - continue; - } - - // Other characters - let start = i; - while (i < text.length && !/[a-zA-Z0-9_$/"'`]/.test(text[i])) i++; - if (start < i) { - tokens.push({ text: text.slice(start, i), type: 'text' }); - } - } - - return tokens; -} - -export default JavaScriptHighlightPlugin; diff --git a/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx deleted file mode 100644 index 12830ffb..00000000 --- a/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx +++ /dev/null @@ -1,177 +0,0 @@ -import { useEffect, useRef } from 'react'; -import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical'; - -const PYTHON_KEYWORDS = new Set([ - 'False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue', - 'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', - 'in', 'is', 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while', - 'with', 'yield' -]); - -const Python3HighlightPlugin = () => { - const [editor] = useLexicalComposerContext(); - const isPastingRef = useRef(false); - - useEffect(() => { - return editor.registerCommand( - PASTE_COMMAND, - () => { - isPastingRef.current = true; - setTimeout(() => { - isPastingRef.current = false; - }, 100); - return false; - }, - COMMAND_PRIORITY_LOW - ); - }, [editor]); - - useEffect(() => { - return editor.registerNodeTransform(TextNode, (textNode: TextNode) => { - if (isPastingRef.current) return; - - const text = textNode.getTextContent(); - - if (textNode.hasFormat('code')) return; - if (textNode.getStyle()) return; - if (!needsHighlight(text)) return; - - const parent = textNode.getParent(); - if (!parent) return; - - const selection = $getSelection(); - let selectionOffset = null; - if ($isRangeSelection(selection)) { - const anchor = selection.anchor; - if (anchor.getNode() === textNode) { - selectionOffset = anchor.offset; - } - } - - const tokens = tokenizePython(text); - if (tokens.length <= 1) return; - - const newNodes = tokens.map(token => { - const newNode = $createTextNode(token.text); - newNode.toggleFormat('code'); - - switch (token.type) { - case 'keyword': - newNode.setStyle('color: #d73a49; font-weight: 600;'); - break; - case 'string': - newNode.setStyle('color: #032f62;'); - break; - case 'comment': - newNode.setStyle('color: #6a737d; font-style: italic;'); - break; - case 'number': - newNode.setStyle('color: #005cc5; font-weight: 500;'); - break; - case 'function': - newNode.setStyle('color: #6f42c1; font-weight: 500;'); - break; - } - - return newNode; - }); - - if (newNodes.length > 1) { - textNode.replace(newNodes[0]); - for (let i = 1; i < newNodes.length; i++) { - newNodes[i - 1].insertAfter(newNodes[i]); - } - - if (selectionOffset !== null && $isRangeSelection(selection)) { - let currentOffset = 0; - for (const node of newNodes) { - const nodeLength = node.getTextContent().length; - if (currentOffset + nodeLength >= selectionOffset) { - node.select(selectionOffset - currentOffset, selectionOffset - currentOffset); - break; - } - currentOffset += nodeLength; - } - } - } - }); - }, [editor]); - - return null; -}; - -function needsHighlight(text: string): boolean { - return /[a-zA-Z0-9_#"']/.test(text); -} - -function tokenizePython(text: string): Array<{text: string, type: string}> { - const tokens: Array<{text: string, type: string}> = []; - let i = 0; - - while (i < text.length) { - // Comments - if (text[i] === '#') { - let start = i; - while (i < text.length && text[i] !== '\n') i++; - tokens.push({ text: text.slice(start, i), type: 'comment' }); - continue; - } - - // Strings - if (text[i] === '"' || text[i] === "'") { - const quote = text[i]; - let start = i++; - const isTriple = text.slice(start, start + 3) === quote.repeat(3); - if (isTriple) i += 2; - - while (i < text.length) { - if (isTriple && text.slice(i, i + 3) === quote.repeat(3)) { - i += 3; - break; - } else if (!isTriple && text[i] === quote && text[i - 1] !== '\\') { - i++; - break; - } - i++; - } - tokens.push({ text: text.slice(start, i), type: 'string' }); - continue; - } - - // Numbers - if (/\d/.test(text[i])) { - let start = i; - while (i < text.length && /[\d.]/.test(text[i])) i++; - tokens.push({ text: text.slice(start, i), type: 'number' }); - continue; - } - - // Keywords and identifiers - if (/[a-zA-Z_]/.test(text[i])) { - let start = i; - while (i < text.length && /[a-zA-Z0-9_]/.test(text[i])) i++; - const word = text.slice(start, i); - - if (PYTHON_KEYWORDS.has(word)) { - tokens.push({ text: word, type: 'keyword' }); - } else if (i < text.length && text[i] === '(') { - tokens.push({ text: word, type: 'function' }); - } else { - tokens.push({ text: word, type: 'text' }); - } - continue; - } - - // Other characters - let start = i; - while (i < text.length && !/[a-zA-Z0-9_#"']/.test(text[i])) i++; - if (start < i) { - tokens.push({ text: text.slice(start, i), type: 'text' }); - } - } - - return tokens; -} - -export default Python3HighlightPlugin; diff --git a/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx b/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx index 8a0ea03e..b9c2c881 100644 --- a/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx +++ b/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx @@ -5,8 +5,8 @@ import { Node } from '@antv/x6' import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' import MappingList from '../MappingList' -import Editor from '../../Editor' import OutputList from './OutputList' +import CodeMirrorEditor from '@/components/CodeMirrorEditor'; interface MappingItem { name?: string @@ -110,7 +110,10 @@ const CodeExecution: FC<CodeExecutionProps> = ({ options }) => { <Form.Item noStyle shouldUpdate={(prev, curr) => prev.language !== curr.language}> {() => ( <Form.Item name="code" noStyle> - <Editor size="small" language={form.getFieldValue('language')} /> + <CodeMirrorEditor + language={form.getFieldValue('language')} + size="small" + /> </Form.Item> )} </Form.Item> diff --git a/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx b/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx index da9603c8..3cd7efcd 100644 --- a/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx +++ b/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx @@ -126,7 +126,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi <div className="rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/workflow/recall.svg')] rb:group-hover:bg-[url('@/assets/images/workflow/recall_hover.svg')]" ></div> - {t('workflow.config.knowledge-retrieval.recallConfig')} + {t('application.globalConfig')} </Button> </div>