diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 3e7db8cb..002547f6 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -7,6 +7,10 @@ from celery import Celery from app.core.config import settings +# macOS fork() safety - must be set before any Celery initialization +if platform.system() == 'Darwin': + os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') + # 创建 Celery 应用实例 # broker: 任务队列(使用 Redis DB 0) # backend: 结果存储(使用 Redis DB 10) @@ -64,6 +68,11 @@ celery_app.conf.update( 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, + # Long-term storage tasks → memory_tasks queue (batched write strategies) + 'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'}, + # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 42d7fe87..67040f40 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -45,6 +45,7 @@ from . import ( memory_perceptual_controller, memory_working_controller, ontology_controller, + skill_controller ) # 创建管理端 API 路由器 @@ -90,5 +91,6 @@ manager_router.include_router(memory_perceptual_controller.router) manager_router.include_router(memory_working_controller.router) manager_router.include_router(file_storage_controller.router) manager_router.include_router(ontology_controller.router) +manager_router.include_router(skill_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 94e3118c..f36aa6c5 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -116,14 +116,6 @@ def _get_ontology_service( detail=f"找不到指定的LLM模型: {llm_id}" ) - # 检查是否为组合模型 - if hasattr(model_config, 'is_composite') and model_config.is_composite: - logger.error(f"Model {llm_id} is a composite model, which is not supported for ontology extraction") - raise HTTPException( - status_code=400, - detail="本体提取不支持使用组合模型,请选择单个模型" - ) - # 验证模型配置了API密钥 if not model_config.api_keys: logger.error(f"Model {llm_id} has no API key configuration") diff --git a/api/app/controllers/skill_controller.py b/api/app/controllers/skill_controller.py new file mode 100644 index 00000000..2308307b --- /dev/null +++ b/api/app/controllers/skill_controller.py @@ -0,0 +1,90 @@ +"""Skill Controller - 技能市场管理""" +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session +from typing import Optional +import uuid + +from app.db import get_db +from app.dependencies import get_current_user, cur_workspace_access_guard +from app.models import User +from app.schemas import skill_schema +from app.schemas.response_schema import PageData, PageMeta +from app.services.skill_service import SkillService +from app.core.response_utils import success + +router = APIRouter(prefix="/skills", tags=["Skills"]) + + +@router.post("", summary="创建技能") +@cur_workspace_access_guard() +def create_skill( + data: skill_schema.SkillCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """创建技能 - 可以关联现有工具(内置、MCP、自定义)""" + tenant_id = current_user.tenant_id + skill = SkillService.create_skill(db, data, tenant_id) + return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功") + + +@router.get("", summary="技能列表") +@cur_workspace_access_guard() +def list_skills( + search: Optional[str] = Query(None, description="搜索关键词"), + is_active: Optional[bool] = Query(None, description="是否激活"), + is_public: Optional[bool] = Query(None, description="是否公开"), + page: int = Query(1, ge=1, description="页码"), + pagesize: int = Query(10, ge=1, le=100, description="每页数量"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """技能市场列表 - 包含本工作空间和公开的技能""" + tenant_id = current_user.tenant_id + skills, total = SkillService.list_skills( + db, tenant_id, search, is_active, is_public, page, pagesize + ) + + items = [skill_schema.Skill.model_validate(s) for s in skills] + meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) + return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功") + + +@router.get("/{skill_id}", summary="获取技能详情") +@cur_workspace_access_guard() +def get_skill( + skill_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取技能详情""" + tenant_id = current_user.tenant_id + skill = SkillService.get_skill(db, skill_id, tenant_id) + return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功") + + +@router.put("/{skill_id}", summary="更新技能") +@cur_workspace_access_guard() +def update_skill( + skill_id: uuid.UUID, + data: skill_schema.SkillUpdate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新技能""" + tenant_id = current_user.tenant_id + skill = SkillService.update_skill(db, skill_id, data, tenant_id) + return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功") + + +@router.delete("/{skill_id}", summary="删除技能") +@cur_workspace_access_guard() +def delete_skill( + skill_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """删除技能""" + tenant_id = current_user.tenant_id + SkillService.delete_skill(db, skill_id, tenant_id) + return success(msg="技能删除成功") diff --git a/api/app/core/agent/agent_middleware.py b/api/app/core/agent/agent_middleware.py new file mode 100644 index 00000000..ef5f7847 --- /dev/null +++ b/api/app/core/agent/agent_middleware.py @@ -0,0 +1,151 @@ +"""Agent Middleware - 动态技能过滤""" +import uuid +from typing import List, Dict, Any, Optional +from langchain_core.runnables import RunnablePassthrough + +from app.services.skill_service import SkillService +from app.repositories.skill_repository import SkillRepository + + +class AgentMiddleware: + """Agent 中间件 - 用于动态过滤和加载技能""" + + def __init__(self, skill_ids: Optional[List[str]] = None): + """ + 初始化中间件 + + Args: + skill_ids: 技能ID列表 + """ + self.skill_ids = skill_ids or [] + + @staticmethod + def filter_tools( + tools: List, + message: str = "", + skill_configs: Dict[str, Any] = None, + tool_to_skill_map: Dict[str, str] = None + ) -> tuple[List, List[str]]: + """ + 根据消息内容和技能配置动态过滤工具 + + Args: + tools: 所有可用工具列表 + message: 用户消息(可用于智能过滤) + skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}} + tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id} + + Returns: + (过滤后的工具列表, 激活的技能ID列表) + """ + if not tools: + return [], [] + + # 如果没有技能配置,返回所有工具 + if not skill_configs: + return tools, [] + + # 基于关键词匹配激活技能 + activated_skill_ids = [] + message_lower = message.lower() + + for skill_id, config in skill_configs.items(): + if not config.get('enabled', True): + continue + + keywords = config.get('keywords', []) + # 如果没有关键词限制,或消息包含关键词,则激活该技能 + if not keywords or any(kw.lower() in message_lower for kw in keywords): + activated_skill_ids.append(skill_id) + + # 如果没有工具映射关系,返回所有工具 + if not tool_to_skill_map: + return tools, activated_skill_ids + + # 根据激活的技能过滤工具 + filtered_tools = [] + for tool in tools: + tool_name = getattr(tool, 'name', str(id(tool))) + # 如果工具不属于任何skill(base_tools),或者工具所属的skill被激活,则保留 + if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids: + filtered_tools.append(tool) + + return filtered_tools, activated_skill_ids + + def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]: + """ + 加载技能关联的工具 + + Args: + db: 数据库会话 + tenant_id: 租户id + base_tools: 基础工具列表 + + Returns: + (工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id}) + """ + + tools_dict = {} + tool_to_skill_map = {} # 工具名称到技能ID的映射 + + if base_tools: + for tool in base_tools: + tool_name = getattr(tool, 'name', str(id(tool))) + tools_dict[tool_name] = tool + # base_tools 不属于任何 skill,不加入映射 + + skill_configs = {} + + if self.skill_ids: + for skill_id in self.skill_ids: + try: + skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id) + if skill and skill.is_active: + # 保存技能配置(包含prompt) + config = skill.config or {} + config['prompt'] = skill.prompt + config['name'] = skill.name + skill_configs[skill_id] = config + except Exception: + continue + + # 加载技能工具并获取映射关系 + skill_tools, skill_tool_map = SkillService.load_skill_tools(db, self.skill_ids, tenant_id) + + # 只添加不冲突的 skill_tools + for tool in skill_tools: + tool_name = getattr(tool, 'name', str(id(tool))) + if tool_name not in tools_dict: + tools_dict[tool_name] = tool + # 复制映射关系 + if tool_name in skill_tool_map: + tool_to_skill_map[tool_name] = skill_tool_map[tool_name] + + return list(tools_dict.values()), skill_configs, tool_to_skill_map + + @staticmethod + def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str: + """ + 根据激活的技能ID获取对应的提示词 + + Args: + activated_skill_ids: 被激活的技能ID列表 + skill_configs: 技能配置字典 + + Returns: + 合并后的提示词 + """ + prompts = [] + for skill_id in activated_skill_ids: + config = skill_configs.get(skill_id, {}) + prompt = config.get('prompt') + name = config.get('name', 'Skill') + if prompt: + prompts.append(f"# {name}\n{prompt}") + + return "\n\n".join(prompts) if prompts else "" + + @staticmethod + def create_runnable(): + """创建可运行的中间件""" + return RunnablePassthrough() diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index f0b8bb92..40cf068e 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -291,6 +291,7 @@ class LangChainAgent: async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): db = next(get_db()) + #TODO: 魔法数字 scope=6 try: @@ -300,6 +301,12 @@ class LangChainAgent: from app.core.memory.agent.utils.redis_tool import write_store result = write_store.get_session_by_userid(end_user_id) + + # Handle case where no session exists in Redis (returns False) + if not result or result is False: + logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update") + return + if type=="chunk" or type=="aggregate": data = await format_parsing(result, "dict") chunk_data = data[:scope] @@ -307,7 +314,14 @@ class LangChainAgent: repo.upsert(end_user_id, chunk_data) logger.info(f'写入短长期:') else: + # TODO: This branch handles type="time" strategy, currently unused. + # Will be activated when time-based long-term storage is implemented. + # TODO: 魔法数字 - extract 5 to a constant long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + # Handle case where no session exists in Redis (returns False or empty) + if not long_time_data or long_time_data is False: + logger.debug(f"No recent sessions in Redis for user {end_user_id}") + return long_messages = await messages_parse(long_time_data) repo.upsert(end_user_id, long_messages) logger.info(f'写入短长期:') @@ -507,9 +521,12 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: long_term_messages=await agent_chat_messages(message_chat,content) - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. + # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j + # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. + # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - '''长期''' + # Batched long-term memory storage (Redis buffer + Neo4j when window full) await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") response = { "content": content, @@ -693,9 +710,13 @@ class LangChainAgent: yield total_tokens break if memory_flag: - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. + # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j + # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. + # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. long_term_messages = await agent_chat_messages(message_chat, full_content) await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) + # Batched long-term memory storage (Redis buffer + Neo4j when window full) await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") except Exception as e: diff --git a/api/app/core/config.py b/api/app/core/config.py index 0de957c7..bf721af9 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -215,6 +215,9 @@ class Settings: # official environment system version SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1") + # model square loading + LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true" + # workflow config WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600)) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index d6fbbb38..e9de02b6 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config): for node_name, node_data in update_event.items(): if 'save_neo4j' == node_name: massages = node_data + # TODO:删除 massagesstatus = massages.get('write_result')['status'] contents = massages.get('write_result') print(contents) @@ -60,6 +61,7 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): scope:窗口大小 ''' scope=scope + redis_messages = [] is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_id is not False: is_end_user_id = count_store.get_sessions_count(end_user_id)[0] @@ -91,6 +93,9 @@ async def memory_long_term_storage(end_user_id,memory_config,time): memory_config: 内存配置对象 ''' long_time_data = write_store.find_user_recent_sessions(end_user_id, time) + # Handle case where no session exists in Redis (returns False or empty) + if not long_time_data or long_time_data is False: + return format_messages = await chat_data_format(long_time_data) if format_messages!=[]: await write_messages(end_user_id, format_messages, memory_config) @@ -108,8 +113,9 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config try: # 1. 获取历史会话数据(使用新方法) result = write_store.get_all_sessions_by_end_user_id(end_user_id) - history = await format_parsing(result) - if not result: + + # Handle case where no session exists in Redis (returns False or empty) + if not result or result is False: history = [] else: history = await format_parsing(result) diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index d0e8a45d..84ea9381 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,18 +1,14 @@ import asyncio -import json import sys import warnings from contextlib import asynccontextmanager from langgraph.constants import END, START from langgraph.graph import StateGraph -from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse -from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node -from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) @@ -40,27 +36,55 @@ async def make_write_graph(): yield graph async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): - from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment - from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format - from app.core.memory.agent.utils.redis_tool import write_store - write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) - # 获取数据库会话 - db_session = next(get_db()) - config_service = MemoryConfigService(db_session) - memory_config = config_service.load_memory_config( - config_id=memory_config, # 改为整数 - service_name="MemoryAgentService" + """Dispatch long-term memory storage to Celery background tasks. + + Args: + long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate' + langchain_messages: List of messages to store + memory_config: Memory configuration ID (string) + end_user_id: End user identifier + scope: Window size for 'chunk' strategy (default: 6) + """ + from app.tasks import ( + long_term_storage_window_task, + # TODO: Uncomment when implemented + # long_term_storage_time_task, + # long_term_storage_aggregate_task, ) - if long_term_type=='chunk': - '''方案一:对话窗口6轮对话''' - await window_dialogue(end_user_id,langchain_messages,memory_config,scope) - if long_term_type=='time': - """时间""" - await memory_long_term_storage(end_user_id, memory_config,5) - if long_term_type=='aggregate': - - """方案三:聚合判断""" - await aggregate_judgment(end_user_id, langchain_messages, memory_config) + from app.core.logging_config import get_logger + + logger = get_logger(__name__) + + # Convert config to string if needed + config_id = str(memory_config) if memory_config else '' + + if long_term_type == 'chunk': + # Strategy 1: Window-based batching (6 rounds of dialogue) + logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}") + long_term_storage_window_task.delay( + end_user_id=end_user_id, + langchain_messages=langchain_messages, + config_id=config_id, + scope=scope + ) + # TODO: Uncomment when time-based strategy is fully implemented + # elif long_term_type == 'time': + # # Strategy 2: Time-based retrieval + # logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}") + # long_term_storage_time_task.delay( + # end_user_id=end_user_id, + # config_id=config_id, + # time_window=5 + # ) + # TODO: Uncomment when aggregate strategy is fully implemented + # elif long_term_type == 'aggregate': + # # Strategy 3: Aggregate judgment (deduplication) + # logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}") + # long_term_storage_aggregate_task.delay( + # end_user_id=end_user_id, + # langchain_messages=langchain_messages, + # config_id=config_id + # ) # async def main(): diff --git a/api/app/core/models/scripts/bedrock_models.yaml b/api/app/core/models/scripts/bedrock_models.yaml index 453aaa13..e5b91d1c 100644 --- a/api/app/core/models/scripts/bedrock_models.yaml +++ b/api/app/core/models/scripts/bedrock_models.yaml @@ -1,5 +1,4 @@ provider: bedrock -enabled: false models: - name: ai21 type: llm diff --git a/api/app/core/models/scripts/dashscope_models.yaml b/api/app/core/models/scripts/dashscope_models.yaml index bcdb467e..df538e72 100644 --- a/api/app/core/models/scripts/dashscope_models.yaml +++ b/api/app/core/models/scripts/dashscope_models.yaml @@ -1,5 +1,4 @@ provider: dashscope -enabled: false models: - name: deepseek-r1-distill-qwen-14b type: llm diff --git a/api/app/core/models/scripts/loader.py b/api/app/core/models/scripts/loader.py index 6469656c..a14d3268 100644 --- a/api/app/core/models/scripts/loader.py +++ b/api/app/core/models/scripts/loader.py @@ -1,11 +1,11 @@ """模型配置加载器 - 用于将预定义模型批量导入到数据库""" -import os from pathlib import Path from typing import Callable import yaml from sqlalchemy.orm import Session + from app.models.models_model import ModelBase, ModelProvider @@ -19,31 +19,9 @@ def _load_yaml_config(provider: ModelProvider) -> list[dict]: with open(config_file, 'r', encoding='utf-8') as f: data = yaml.safe_load(f) - - # 检查是否需要加载(默认为 true) - if not data.get('enabled', True): - return [] - return data.get('models', []) -def _disable_yaml_config(provider: ModelProvider) -> None: - """将YAML文件的enabled标志设置为false""" - config_dir = Path(__file__).parent - config_file = config_dir / f"{provider.value}_models.yaml" - - if not config_file.exists(): - return - - with open(config_file, 'r', encoding='utf-8') as f: - data = yaml.safe_load(f) - - data['enabled'] = False - - with open(config_file, 'w', encoding='utf-8') as f: - yaml.dump(data, f, allow_unicode=True, sort_keys=False) - - def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict: """ 加载模型配置到数据库 @@ -75,8 +53,7 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) if not silent: print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...") - - # provider_success = 0 + for model_data in models: try: # 检查模型是否已存在 @@ -93,7 +70,6 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) if not silent: print(f"更新成功: {model_data['name']}") result["success"] += 1 - # provider_success += 1 else: # 创建新模型 model = ModelBase(**model_data) @@ -102,17 +78,12 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) if not silent: print(f"添加成功: {model_data['name']}") result["success"] += 1 - # provider_success += 1 except Exception as e: db.rollback() if not silent: print(f"添加失败: {model_data['name']} - {str(e)}") result["failed"] += 1 - - # 如果该供应商的模型全部加载成功,将enabled设置为false - # if provider_success == len(models): - _disable_yaml_config(provider) return result diff --git a/api/app/core/models/scripts/openai_models.yaml b/api/app/core/models/scripts/openai_models.yaml index 5a416264..68c63ee2 100644 --- a/api/app/core/models/scripts/openai_models.yaml +++ b/api/app/core/models/scripts/openai_models.yaml @@ -1,5 +1,4 @@ provider: openai -enabled: false models: - name: chatgpt-4o-latest type: llm diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index fa7ceeb0..f6176edf 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -2,6 +2,7 @@ import base64 import json import logging import re +import urllib.parse from string import Template from textwrap import dedent from typing import Any diff --git a/api/app/main.py b/api/app/main.py index 38020d4c..af5ed796 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -50,13 +50,16 @@ async def lifespan(app: FastAPI): logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)") # 加载预定义模型 - logger.info("开始加载预定义模型...") - try: - with get_db_context() as db: - result = load_models(db, silent=True) - logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个") - except Exception as e: - logger.warning(f"加载预定义模型时出错: {str(e)}") + if settings.LOAD_MODEL: + logger.info("开始加载预定义模型...") + try: + with get_db_context() as db: + result = load_models(db, silent=True) + logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个") + except Exception as e: + logger.warning(f"加载预定义模型时出错: {str(e)}") + else: + logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") logger.info("应用程序启动完成") yield diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index 984212de..daf03841 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -28,6 +28,7 @@ from .tool_model import ( ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus ) from .memory_perceptual_model import MemoryPerceptualModel +from .skill_model import Skill from .ontology_scene import OntologyScene from .ontology_class import OntologyClass from .ontology_scene import OntologyScene @@ -84,5 +85,6 @@ __all__ = [ "ExecutionStatus", "MemoryPerceptualModel", "ModelBase", - "LoadBalanceStrategy" + "LoadBalanceStrategy", + "Skill" ] diff --git a/api/app/models/agent_app_config_model.py b/api/app/models/agent_app_config_model.py index 96752c8e..7ed70728 100644 --- a/api/app/models/agent_app_config_model.py +++ b/api/app/models/agent_app_config_model.py @@ -30,6 +30,7 @@ class AgentConfig(Base): memory = Column(JSON, nullable=True, comment="记忆配置") variables = Column(JSON, default=list, nullable=True, comment="变量配置") tools = Column(JSON, default=dict, nullable=True, comment="工具配置") + skill_ids = Column(JSON, default=list, nullable=True, comment="关联的技能ID列表") # 多 Agent 相关字段 agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone") diff --git a/api/app/models/skill_model.py b/api/app/models/skill_model.py new file mode 100644 index 00000000..97fdeb03 --- /dev/null +++ b/api/app/models/skill_model.py @@ -0,0 +1,37 @@ +"""Skill 模型定义""" +import datetime +import uuid +from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey +from sqlalchemy.dialects.postgresql import UUID, JSON + +from app.db import Base + + +class Skill(Base): + """技能模型 - 可以关联工具(内置、MCP、自定义)""" + __tablename__ = "skills" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + name = Column(String, nullable=False, comment="技能名称") + description = Column(Text, comment="技能描述") + tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID") + + # 关联的工具 + tools = Column(JSON, default=list, comment="关联的工具列表") + + # 技能配置 + config = Column(JSON, default=dict, comment="技能配置") + + # 专属提示词 + prompt = Column(Text, comment="技能专属提示词") + + # 状态 + is_active = Column(Boolean, default=True, nullable=False, comment="是否激活") + is_public = Column(Boolean, default=False, nullable=False, comment="是否公开到市场") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") + + def __repr__(self): + return f"" diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 568c262f..22972669 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -235,6 +235,8 @@ class MemoryConfigRepository: llm_id=params.llm_id, embedding_id=params.embedding_id, rerank_id=params.rerank_id, + reflection_model_id=params.reflection_model_id, + emotion_model_id=params.emotion_model_id, ) db.add(db_config) db.flush() # 获取自增ID但不提交事务 diff --git a/api/app/repositories/skill_repository.py b/api/app/repositories/skill_repository.py new file mode 100644 index 00000000..6eeb7e08 --- /dev/null +++ b/api/app/repositories/skill_repository.py @@ -0,0 +1,111 @@ +"""Skill Repository""" +from typing import List, Optional, Tuple, Any +from sqlalchemy.orm import Session +from sqlalchemy import and_, or_ +import uuid + +from app.models.skill_model import Skill +from app.schemas.skill_schema import SkillCreate, SkillUpdate + + +class SkillRepository: + """Skill 数据访问层""" + + @staticmethod + def create(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill: + """创建技能""" + skill = Skill( + **data.model_dump(), + tenant_id=tenant_id + ) + db.add(skill) + db.flush() + return skill + + @staticmethod + def get_by_id(db: Session, skill_id: uuid.UUID, tenant_id: Optional[uuid.UUID] = None) -> Optional[Skill]: + """根据ID获取技能""" + query = db.query(Skill).filter(Skill.id == skill_id) + if tenant_id: + query = query.filter( + or_( + Skill.tenant_id == tenant_id, + Skill.is_public == True + ) + ) + return query.first() + + @staticmethod + def list_skills( + db: Session, + tenant_id: uuid.UUID, + search: Optional[str] = None, + is_active: Optional[bool] = None, + is_public: Optional[bool] = None, + page: int = 1, + pagesize: int = 10 + ) -> tuple[list[type[Skill]], int]: + """列出技能""" + filters = [ + or_( + Skill.tenant_id == tenant_id, + Skill.is_public == True + ) + ] + + if search: + filters.append( + or_( + Skill.name.ilike(f"%{search}%"), + # Skill.description.ilike(f"%{search}%") + ) + ) + + if is_active is not None: + filters.append(Skill.is_active == is_active) + + if is_public is not None: + filters.append(Skill.is_public == is_public) + + query = db.query(Skill).filter(and_(*filters)) + total = query.count() + + skills = query.order_by(Skill.created_at.desc()).offset( + (page - 1) * pagesize + ).limit(pagesize).all() + + return skills, total + + @staticmethod + def update(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Optional[Skill]: + """更新技能""" + skill = db.query(Skill).filter( + Skill.id == skill_id, + Skill.tenant_id == tenant_id + ).first() + + if not skill: + return None + + update_data = data.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(skill, key, value) + + db.flush() + return skill + + @staticmethod + def delete(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool: + """删除技能""" + skill = db.query(Skill).filter( + Skill.id == skill_id, + Skill.tenant_id == tenant_id + ).first() + + if not skill: + return False + + # db.delete(skill) + skill.is_active = False + db.flush() + return True diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 26d9b246..bcfeca57 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -156,6 +156,9 @@ class AgentConfigCreate(BaseModel): description="Agent 可用的工具列表" ) + # 技能配置 + skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表") + class AppCreate(BaseModel): name: str @@ -207,6 +210,9 @@ class AgentConfigUpdate(BaseModel): # 工具配置 tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表") + + # 技能配置 + skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表") # ---------- Output Schemas ---------- @@ -266,6 +272,8 @@ class AgentConfig(BaseModel): # 工具配置 tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = [] + skill_ids: Optional[List[str]] = [] + is_active: bool created_at: datetime.datetime updated_at: datetime.datetime diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 5e22d70f..11cacda0 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -236,6 +236,8 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body, llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") + reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致") + emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致") class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) diff --git a/api/app/schemas/skill_schema.py b/api/app/schemas/skill_schema.py new file mode 100644 index 00000000..27f16b99 --- /dev/null +++ b/api/app/schemas/skill_schema.py @@ -0,0 +1,57 @@ +"""Skill Schema 定义""" +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field, field_serializer +import uuid +from datetime import datetime + + +class SkillBase(BaseModel): + """Skill 基础 Schema""" + name: str = Field(..., description="技能名称") + description: Optional[str] = Field(None, description="技能描述") + tools: List[Dict[str, str]] = Field(default_factory=list, description="工具对象列表: [{\"tool_id\": \"xxx\", \"operation\": \"yyy\"}]") + config: Dict[str, Any] = Field(default_factory=dict, description="技能配置") + prompt: Optional[str] = Field(None, description="技能专属提示词") + is_active: bool = Field(True, description="是否激活") + is_public: bool = Field(False, description="是否公开到市场") + + +class SkillCreate(SkillBase): + """创建 Skill""" + pass + + +class SkillUpdate(BaseModel): + """更新 Skill""" + name: Optional[str] = None + description: Optional[str] = None + tools: Optional[List[Dict[str, str]]] = None + config: Optional[Dict[str, Any]] = None + prompt: Optional[str] = None + is_active: Optional[bool] = None + is_public: Optional[bool] = None + + +class Skill(SkillBase): + """Skill 响应 Schema""" + id: uuid.UUID + tenant_id: uuid.UUID + created_at: datetime + updated_at: datetime + + @field_serializer('created_at', 'updated_at') + def serialize_datetime_to_timestamp(self, value: datetime) -> int: + """(毫秒级)时间戳""" + return int(value.timestamp() * 1000) + + class Config: + from_attributes = True + + +class SkillQuery(BaseModel): + """Skill 查询参数""" + search: Optional[str] = None + is_active: Optional[bool] = None + is_public: Optional[bool] = None + page: int = Field(1, ge=1) + pagesize: int = Field(10, ge=1, le=100) diff --git a/api/app/services/agent_config_converter.py b/api/app/services/agent_config_converter.py index 094aade8..ba76e299 100644 --- a/api/app/services/agent_config_converter.py +++ b/api/app/services/agent_config_converter.py @@ -48,6 +48,9 @@ class AgentConfigConverter: # 5. 工具配置 if hasattr(config, 'tools') and config.tools: result["tools"] = [tool.model_dump() for tool in config.tools] + + if hasattr(config, "skill_ids") and config.skill_ids: + result["skill_ids"] = [skill for skill in config.skill_ids] return result @@ -58,6 +61,7 @@ class AgentConfigConverter: memory: Optional[Dict[str, Any]], variables: Optional[list], tools: Optional[Union[list, Dict[str, Any]]], + skill_ids: Optional[list] ) -> Dict[str, Any]: """ 将数据库存储格式转换为 Pydantic 对象 @@ -68,6 +72,7 @@ class AgentConfigConverter: memory: 记忆配置 variables: 变量配置 tools: 工具配置 + skill_ids: 技能 ID 列表 Returns: 包含 Pydantic 对象的字典 @@ -78,6 +83,7 @@ class AgentConfigConverter: "memory": MemoryConfig(enabled=True), "variables": [], "tools": [], + "skill_ids": [] } # 1. 解析模型参数配置 @@ -117,5 +123,8 @@ class AgentConfigConverter: name: ToolOldConfig(**tool_data) for name, tool_data in tools.items() } + + if skill_ids: + result["skill_ids"] = [skill for skill in skill_ids] return result diff --git a/api/app/services/agent_config_helper.py b/api/app/services/agent_config_helper.py index ae195913..ef6e22a4 100644 --- a/api/app/services/agent_config_helper.py +++ b/api/app/services/agent_config_helper.py @@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig: memory=agent_cfg.memory, variables=agent_cfg.variables, tools=agent_cfg.tools, + skill_ids=agent_cfg.skill_ids ) # 将解析后的字段添加到对象上(用于序列化) @@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig: agent_cfg.memory = parsed["memory"] agent_cfg.variables = parsed["variables"] agent_cfg.tools = parsed["tools"] + agent_cfg.skill_ids = parsed["skill_ids"] return agent_cfg diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 7ec4bc0e..1759206f 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -304,6 +304,7 @@ class AppService: memory=storage_data.get("memory"), variables=storage_data.get("variables", []), tools=storage_data.get("tools", []), + skill_ids=storage_data.get("skill_ids", []), is_active=True, created_at=now, updated_at=now, @@ -907,6 +908,7 @@ class AppService: agent_cfg.variables = storage_data.get("variables", []) # if data.tools is not None: agent_cfg.tools = storage_data.get("tools", []) + agent_cfg.skill_ids = storage_data.get("skill_ids", []) agent_cfg.updated_at = now diff --git a/api/app/services/app_statistics_service.py b/api/app/services/app_statistics_service.py index c164924a..5cfa3229 100644 --- a/api/app/services/app_statistics_service.py +++ b/api/app/services/app_statistics_service.py @@ -187,7 +187,7 @@ class AppStatisticsService: daily_tokens[date_str] = 0 daily_tokens[date_str] += int(tokens) - daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0] - total = sum(row["tokens"] for row in daily_data) + daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0] + total = sum(row["count"] for row in daily_data) return {"daily": daily_data, "total": total} diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index edad0123..0e0922bc 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -10,6 +10,11 @@ import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional +from langchain.tools import tool +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -26,10 +31,8 @@ from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger from app.services.tool_service import ToolService from app.services.multimodal_service import MultimodalService -from langchain.tools import tool -from pydantic import BaseModel, Field -from sqlalchemy import select -from sqlalchemy.orm import Session +from app.core.agent.agent_middleware import AgentMiddleware + logger = get_business_logger() class KnowledgeRetrievalInput(BaseModel): @@ -310,6 +313,7 @@ class DraftRunService: tools = [] tool_service = ToolService(self.db) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): @@ -320,9 +324,7 @@ class DraftRunService: print(f"tool_config:{tool_config}") if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), - ToolRepository.get_tenant_id_by_workspace_id( - self.db, str(workspace_id))) + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: if tool_instance.name == "baidu_search_tool" and not web_search: continue @@ -345,6 +347,22 @@ class DraftRunService: } ) + # 加载技能关联的工具 + if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids: + middleware = AgentMiddleware(skill_ids=agent_config.skill_ids) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" + # 添加知识库检索工具 if agent_config.knowledge_retrieval: kb_config = agent_config.knowledge_retrieval @@ -558,6 +576,7 @@ class DraftRunService: tools = [] tool_service = ToolService(self.db) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): @@ -567,9 +586,7 @@ class DraftRunService: # print(f"tool_config:{tool_config}") if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), - ToolRepository.get_tenant_id_by_workspace_id( - self.db, str(workspace_id))) + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: if tool_instance.name == "baidu_search_tool" and not web_search: continue @@ -592,6 +609,23 @@ class DraftRunService: } ) + # 加载技能关联的工具 + skill_configs = {} + if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids: + middleware = AgentMiddleware(skill_ids=agent_config.skill_ids) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" + # 添加知识库检索工具 if agent_config.knowledge_retrieval: @@ -628,7 +662,6 @@ class DraftRunService: } ) - # 4. 创建 LangChain Agent agent = LangChainAgent( model_name=api_key_config["model_name"], diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index 06a94060..6fa8b228 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -53,7 +53,10 @@ def get_workspace_end_users( workspace_id: uuid.UUID, current_user: User ) -> List[EndUser]: - """获取工作空间的所有宿主(优化版本:减少数据库查询次数)""" + """获取工作空间的所有宿主(优化版本:减少数据库查询次数) + + 返回结果按 updated_at 从新到旧排序(NULL 值排在最后) + """ business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}") try: @@ -68,9 +71,14 @@ def get_workspace_end_users( app_ids = [app.id for app in apps_orm] # 批量查询所有 end_users(一次查询而非循环查询) + # 按 updated_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性 from app.models.end_user_model import EndUser as EndUserModel + from sqlalchemy import desc, nullslast end_users_orm = db.query(EndUserModel).filter( EndUserModel.app_id.in_(app_ids) + ).order_by( + nullslast(desc(EndUserModel.updated_at)), + desc(EndUserModel.id) ).all() # 转换为 Pydantic 模型(只在需要时转换) diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 82baef9f..b7079e62 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) if not params.rerank_id: params.rerank_id = configs.get('rerank') + # reflection_model_id 和 emotion_model_id 默认与 llm_id 一致 + if not params.reflection_model_id: + params.reflection_model_id = params.llm_id + if not params.emotion_model_id: + params.emotion_model_id = params.llm_id + config = MemoryConfigRepository.create(self.db, params) self.db.commit() return {"affected": 1, "config_id": config.config_id} @@ -203,6 +209,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "end_user_id": config.end_user_id, "config_id_old": config_id_old, "apply_id": config.apply_id, + "scene_id": config.scene_id, "llm_id": config.llm_id, "embedding_id": config.embedding_id, "rerank_id": config.rerank_id, diff --git a/api/app/services/skill_service.py b/api/app/services/skill_service.py new file mode 100644 index 00000000..a20e1b22 --- /dev/null +++ b/api/app/services/skill_service.py @@ -0,0 +1,109 @@ +"""Skill Service""" +import uuid +from typing import List + +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from app.repositories.skill_repository import SkillRepository +from app.schemas.skill_schema import SkillCreate, SkillUpdate +from app.models.skill_model import Skill +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.services.tool_service import ToolService + + +class SkillService: + """Skill 业务逻辑层""" + + @staticmethod + def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill: + """创建技能""" + skill = SkillRepository.create(db, data, tenant_id) + db.commit() + db.refresh(skill) + return skill + + @staticmethod + def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill: + """获取技能""" + try: + skill = SkillRepository.get_by_id(db, skill_id, tenant_id) + if not skill: + raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND) + return skill + except (BusinessException, SQLAlchemyError) as e: + db.rollback() + raise e + + @staticmethod + def list_skills( + db: Session, + tenant_id: uuid.UUID, + search: str = None, + is_active: bool = None, + is_public: bool = None, + page: int = 1, + pagesize: int = 10 + ) -> tuple[list[type[Skill]], int]: + """列出技能""" + return SkillRepository.list_skills( + db, tenant_id, search, is_active, is_public, page, pagesize + ) + + @staticmethod + def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill: + """更新技能""" + try: + skill = SkillRepository.update(db, skill_id, data, tenant_id) + if not skill: + raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND) + db.commit() + db.refresh(skill) + return skill + except (BusinessException, SQLAlchemyError) as e: + db.rollback() + raise e + + @staticmethod + def delete_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool: + """删除技能""" + try: + success = SkillRepository.delete(db, skill_id, tenant_id) + if not success: + raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND) + db.commit() + return True + except (BusinessException, SQLAlchemyError) as e: + db.rollback() + raise e + + @staticmethod + def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]: + """加载技能关联的工具 + + Returns: + (tools, tool_to_skill_map) - 工具列表和工具到技能的映射 + """ + tools = [] + tool_to_skill_map = {} # {tool_name: skill_id} + tool_service = ToolService(db) + + for skill_id in skill_ids: + try: + skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id) + if skill and skill.is_active: + # 加载技能关联的工具 + for tool_config in skill.tools: + tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) + if tool: + langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + # 建立工具到技能的映射 + tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool))) + tool_to_skill_map[tool_name] = skill_id + except Exception as e: + print(f"加载技能 {skill_id} 的工具时出错: {e}") + continue + + return tools, tool_to_skill_map diff --git a/api/app/tasks.py b/api/app/tasks.py index 247cba76..a46a3a7b 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1069,6 +1069,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") except Exception as e: + db.rollback() # Rollback failed transaction to allow next query api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") all_reflection_results.append({ "workspace_id": str(workspace_id), @@ -1207,3 +1208,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di return result finally: loop.close() + + +# ============================================================================= +# Long-term Memory Storage Tasks (Batched Write Strategies) +# ============================================================================= + +@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True) +def long_term_storage_window_task( + self, + end_user_id: str, + langchain_messages: List[Dict[str, Any]], + config_id: str, + scope: int = 6 +) -> Dict[str, Any]: + """Celery task for window-based long-term memory storage. + + Accumulates messages in Redis buffer until window size (scope) is reached, + then writes batched messages to Neo4j. + + Args: + end_user_id: End user identifier + langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}] + config_id: Memory configuration ID + scope: Window size (number of messages before triggering write) + + Returns: + Dict containing task status and metadata + """ + from app.core.logging_config import get_logger + logger = get_logger(__name__) + + logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}") + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue + from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format + from app.core.memory.agent.utils.redis_tool import write_store + from app.services.memory_config_service import MemoryConfigService + + db = next(get_db()) + try: + # Save to Redis buffer first + write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) + + # Load memory config + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="LongTermStorageTask" + ) + + # Execute window-based dialogue storage + await window_dialogue(end_user_id, langchain_messages, memory_config, scope) + + return {"status": "SUCCESS", "strategy": "window", "scope": scope} + finally: + db.close() + + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete(_run()) + elapsed_time = time.time() - start_time + + logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s") + + return { + **result, + "end_user_id": end_user_id, + "config_id": config_id, + "elapsed_time": elapsed_time, + "task_id": self.request.id + } + except Exception as e: + elapsed_time = time.time() - start_time + logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True) + + return { + "status": "FAILURE", + "strategy": "window", + "error": str(e), + "end_user_id": end_user_id, + "config_id": config_id, + "elapsed_time": elapsed_time, + "task_id": self.request.id + } + + +# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True) +# def long_term_storage_time_task( +# self, +# end_user_id: str, +# config_id: str, +# time_window: int = 5 +# ) -> Dict[str, Any]: +# """Celery task for time-based long-term memory storage. + +# Retrieves recent sessions from Redis within time window and writes to Neo4j. + +# Args: +# end_user_id: End user identifier +# config_id: Memory configuration ID +# time_window: Time window in minutes for retrieving recent sessions + +# Returns: +# Dict containing task status and metadata +# """ +# from app.core.logging_config import get_logger +# logger = get_logger(__name__) + +# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}") +# start_time = time.time() + +# async def _run() -> Dict[str, Any]: +# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage +# from app.services.memory_config_service import MemoryConfigService + +# db = next(get_db()) +# try: +# # Load memory config +# config_service = MemoryConfigService(db) +# memory_config = config_service.load_memory_config( +# config_id=config_id, +# service_name="LongTermStorageTask" +# ) + +# # Execute time-based storage +# await memory_long_term_storage(end_user_id, memory_config, time_window) + +# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window} +# finally: +# db.close() + +# try: +# import nest_asyncio +# nest_asyncio.apply() +# except ImportError: +# pass + +# try: +# loop = asyncio.get_event_loop() +# if loop.is_closed(): +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# except RuntimeError: +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) + +# try: +# result = loop.run_until_complete(_run()) +# elapsed_time = time.time() - start_time + +# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s") + +# return { +# **result, +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } +# except Exception as e: +# elapsed_time = time.time() - start_time +# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True) + +# return { +# "status": "FAILURE", +# "strategy": "time", +# "error": str(e), +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } + + +# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True) +# def long_term_storage_aggregate_task( +# self, +# end_user_id: str, +# langchain_messages: List[Dict[str, Any]], +# config_id: str +# ) -> Dict[str, Any]: +# """Celery task for aggregate-based long-term memory storage. + +# Uses LLM to determine if new messages describe the same event as history. +# Only writes to Neo4j if messages represent new information (not duplicates). + +# Args: +# end_user_id: End user identifier +# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}] +# config_id: Memory configuration ID + +# Returns: +# Dict containing task status, is_same_event flag, and metadata +# """ +# from app.core.logging_config import get_logger +# logger = get_logger(__name__) + +# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}") +# start_time = time.time() + +# async def _run() -> Dict[str, Any]: +# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment +# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format +# from app.core.memory.agent.utils.redis_tool import write_store +# from app.services.memory_config_service import MemoryConfigService + +# db = next(get_db()) +# try: +# # Save to Redis buffer first +# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) + +# # Load memory config +# config_service = MemoryConfigService(db) +# memory_config = config_service.load_memory_config( +# config_id=config_id, +# service_name="LongTermStorageTask" +# ) + +# # Execute aggregate judgment +# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config) + +# return { +# "status": "SUCCESS", +# "strategy": "aggregate", +# "is_same_event": result.get("is_same_event", False), +# "wrote_to_neo4j": not result.get("is_same_event", False) +# } +# finally: +# db.close() + +# try: +# import nest_asyncio +# nest_asyncio.apply() +# except ImportError: +# pass + +# try: +# loop = asyncio.get_event_loop() +# if loop.is_closed(): +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# except RuntimeError: +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) + +# try: +# result = loop.run_until_complete(_run()) +# elapsed_time = time.time() - start_time + +# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s") + +# return { +# **result, +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } +# except Exception as e: +# elapsed_time = time.time() - start_time +# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True) + +# return { +# "status": "FAILURE", +# "strategy": "aggregate", +# "error": str(e), +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } diff --git a/api/migrations/versions/e7c7afa249d1_202602041355.py b/api/migrations/versions/e7c7afa249d1_202602041355.py new file mode 100644 index 00000000..0559d5b4 --- /dev/null +++ b/api/migrations/versions/e7c7afa249d1_202602041355.py @@ -0,0 +1,50 @@ +"""202602041355 + +Revision ID: e7c7afa249d1 +Revises: 9def72f79398 +Create Date: 2026-02-04 13:55:22.284981 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'e7c7afa249d1' +down_revision: Union[str, None] = '9def72f79398' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('skills', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('name', sa.String(), nullable=False, comment='技能名称'), + sa.Column('description', sa.Text(), nullable=True, comment='技能描述'), + sa.Column('tenant_id', sa.UUID(), nullable=False, comment='租户ID'), + sa.Column('tools', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='关联的工具列表'), + sa.Column('config', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='技能配置'), + sa.Column('prompt', sa.Text(), nullable=True, comment='技能专属提示词'), + sa.Column('is_active', sa.Boolean(), nullable=False, comment='是否激活'), + sa.Column('is_public', sa.Boolean(), nullable=False, comment='是否公开到市场'), + sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'), + sa.Column('updated_at', sa.DateTime(), nullable=True, comment='更新时间'), + sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_skills_id'), 'skills', ['id'], unique=False) + op.create_index(op.f('ix_skills_tenant_id'), 'skills', ['tenant_id'], unique=False) + op.add_column('agent_configs', sa.Column('skill_ids', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='关联的技能ID列表')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('agent_configs', 'skill_ids') + op.drop_index(op.f('ix_skills_tenant_id'), table_name='skills') + op.drop_index(op.f('ix_skills_id'), table_name='skills') + op.drop_table('skills') + # ### end Alembic commands ### diff --git a/web/src/components/PageScrollList/index.tsx b/web/src/components/PageScrollList/index.tsx index 49173a68..a877a9c7 100644 --- a/web/src/components/PageScrollList/index.tsx +++ b/web/src/components/PageScrollList/index.tsx @@ -142,7 +142,7 @@ const PageScrollList = forwardRef(>({ dataLength={data.length} next={loadMoreData} hasMore={hasMore} - loader={needLoading ? : undefined} + loader={loading && needLoading ? : false} // endMessage={It is all, nothing more 🤐} scrollableTarget="scrollableDiv" className='rb:h-full!' diff --git a/web/src/styles/index.css b/web/src/styles/index.css index 53670dab..bbbe9cd9 100644 --- a/web/src/styles/index.css +++ b/web/src/styles/index.css @@ -180,7 +180,4 @@ body { .x6-node foreignObject > body { min-height: 100%; max-height: 100%; -} -#scrollableDiv .infinite-scroll-component__outerdiv { - height: 100%; } \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/Editor/index.tsx b/web/src/views/ApplicationConfig/components/Editor/index.tsx index 0f878678..a5247d1b 100644 --- a/web/src/views/ApplicationConfig/components/Editor/index.tsx +++ b/web/src/views/ApplicationConfig/components/Editor/index.tsx @@ -21,6 +21,7 @@ import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext import InitialValuePlugin from './plugin/InitialValuePlugin' import LineBreakPlugin from './plugin/LineBreakPlugin'; import InsertTextPlugin from './plugin/InsertTextPlugin'; +import EditablePlugin from './plugin/EditablePlugin'; /** * Editor ref methods exposed to parent components @@ -50,6 +51,7 @@ interface LexicalEditorProps { onChange?: (value: string) => void; /** Editor height in pixels */ height?: number; + disabled?: boolean; } /** @@ -71,6 +73,7 @@ const EditorContent = forwardRef(({ value, placeholder = "Please enter content...", onChange, + disabled }, ref) => { const [editor] = useLexicalComposerContext(); @@ -132,7 +135,11 @@ const EditorContent = forwardRef(({ } placeholder={ @@ -145,6 +152,7 @@ const EditorContent = forwardRef(({ + ); }); @@ -158,6 +166,7 @@ const Editor = forwardRef((props, ref) => { namespace: 'Editor', theme, nodes: [], + editable: !props.disabled, onError: (error: Error) => { console.error(error); }, diff --git a/web/src/views/ApplicationConfig/components/Editor/plugin/EditablePlugin.tsx b/web/src/views/ApplicationConfig/components/Editor/plugin/EditablePlugin.tsx new file mode 100644 index 00000000..6c237f01 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/Editor/plugin/EditablePlugin.tsx @@ -0,0 +1,48 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-04 11:20:49 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 11:20:49 + */ +import { useEffect } from 'react'; +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; + +/** + * Props for the EditablePlugin component + */ +interface EditablePluginProps { + /** Whether the editor should be disabled (read-only mode) */ + disabled?: boolean; +} + +/** + * EditablePlugin - A Lexical editor plugin that controls the editable state of the editor + * + * This plugin allows you to dynamically toggle between editable and read-only modes. + * When disabled is true, the editor becomes read-only and users cannot modify content. + * When disabled is false or undefined, the editor is fully editable. + * + * @param {EditablePluginProps} props - Component props + * @param {boolean} [props.disabled] - Controls whether the editor is in read-only mode + * @returns {null} This plugin doesn't render any UI elements + * + * @example + * ```tsx + * + * + * + * ``` + */ +export default function EditablePlugin({ disabled }: EditablePluginProps) { + // Get the editor instance from Lexical composer context + const [editor] = useLexicalComposerContext(); + + // Update editor's editable state whenever the disabled prop changes + useEffect(() => { + // Set editor to editable when disabled is false, read-only when disabled is true + editor.setEditable(!disabled); + }, [editor, disabled]); + + // This plugin doesn't render any UI, it only manages editor state + return null; +} diff --git a/web/src/views/Prompt/Prompt.tsx b/web/src/views/Prompt/Prompt.tsx index ded9550c..01fdb35b 100644 --- a/web/src/views/Prompt/Prompt.tsx +++ b/web/src/views/Prompt/Prompt.tsx @@ -156,9 +156,9 @@ const Prompt: FC<{ editVo: HistoryItem | null; refresh: () => void; }> = ({ edit currentPromptValueRef.current = undefined; setChatList([]) refresh() + updateSession() } - console.log(values) return ( <>
@@ -217,12 +217,13 @@ const Prompt: FC<{ editVo: HistoryItem | null; refresh: () => void; }> = ({ edit ref={editorRef} placeholder={t('prompt.promptPlaceholder')} className="rb:h-[calc(100vh-260px)]" + disabled={loading} // onChange={(value) => form.setFieldValue('current_prompt', value)} />
- - + +
diff --git a/web/src/views/SpaceManagement/components/SpaceModal.tsx b/web/src/views/SpaceManagement/components/SpaceModal.tsx index 70365312..a0703d81 100644 --- a/web/src/views/SpaceManagement/components/SpaceModal.tsx +++ b/web/src/views/SpaceManagement/components/SpaceModal.tsx @@ -103,6 +103,8 @@ const SpaceModal = forwardRef(({ }).catch(() => { handleUpdate(formData) }) + } else { + handleUpdate(formData) } } }) @@ -158,6 +160,7 @@ const SpaceModal = forwardRef(({ label={t('space.spaceIcon')} valuePropName="fileList" hidden={currentStep === 1} + rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('space.spaceIcon') }) }]} > diff --git a/web/src/views/Workflow/components/Editor/index.tsx b/web/src/views/Workflow/components/Editor/index.tsx index e37c71de..4c8540a8 100644 --- a/web/src/views/Workflow/components/Editor/index.tsx +++ b/web/src/views/Workflow/components/Editor/index.tsx @@ -242,7 +242,7 @@ const Editor: FC =({ {enableLineNumbers && } { setCount(count) }} onChange={onChange} /> - + {enableLineNumbers && } diff --git a/web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx index b636605b..0fb6c48f 100644 --- a/web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx @@ -16,6 +16,12 @@ export default function BlurPlugin() { return; } + // 检查是否是粘贴操作导致的焦点变化 + const relatedTarget = e.relatedTarget as HTMLElement; + if (!relatedTarget || relatedTarget === document.body) { + return; + } + editor.update(() => { $setSelection(null); }); diff --git a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx index 22de9592..4021a9ee 100644 --- a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx @@ -8,12 +8,13 @@ import { type Suggestion } from '../plugin/AutocompletePlugin' interface InitialValuePluginProps { value: string; options?: Suggestion[]; - enableJinja2?: boolean; + enableLineNumbers?: boolean; } -const InitialValuePlugin: React.FC = ({ value, options = [], enableJinja2 = false }) => { +const InitialValuePlugin: React.FC = ({ value, options = [], enableLineNumbers = false }) => { const [editor] = useLexicalComposerContext(); const prevValueRef = useRef(''); + const prevEnableLineNumbersRef = useRef(enableLineNumbers); const isUserInputRef = useRef(false); useEffect(() => { @@ -32,7 +33,7 @@ const InitialValuePlugin: React.FC = ({ value, options }, [editor]); useEffect(() => { - if (value !== prevValueRef.current && !isUserInputRef.current) { + if ((value !== prevValueRef.current || enableLineNumbers !== prevEnableLineNumbersRef.current) && !isUserInputRef.current) { queueMicrotask(() => { editor.update(() => { const root = $getRoot(); @@ -40,7 +41,7 @@ const InitialValuePlugin: React.FC = ({ value, options const parts = value.split(/(\{\{[^}]+\}\})/); - if (enableJinja2) { + if (enableLineNumbers) { // Handle newlines properly in Jinja2 mode const lines = value.split('\n'); lines.forEach((line) => { @@ -104,8 +105,9 @@ const InitialValuePlugin: React.FC = ({ value, options } prevValueRef.current = value; + prevEnableLineNumbersRef.current = enableLineNumbers; isUserInputRef.current = false; - }, [value, options, editor, enableJinja2]); + }, [value, options, editor, enableLineNumbers]); return null; }; diff --git a/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx index 90053646..21219139 100644 --- a/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx @@ -1,6 +1,6 @@ -import { useEffect } from 'react'; +import { useEffect, useRef } from 'react'; import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -import { TextNode, $createTextNode, $getSelection, $isRangeSelection } from 'lexical'; +import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical'; const JS_KEYWORDS = new Set([ 'async', 'await', 'break', 'case', 'catch', 'class', 'const', 'continue', 'debugger', 'default', @@ -11,13 +11,31 @@ const JS_KEYWORDS = new Set([ const JavaScriptHighlightPlugin = () => { const [editor] = useLexicalComposerContext(); + const isPastingRef = useRef(false); + + useEffect(() => { + return editor.registerCommand( + PASTE_COMMAND, + () => { + isPastingRef.current = true; + setTimeout(() => { + isPastingRef.current = false; + }, 100); + return false; + }, + COMMAND_PRIORITY_LOW + ); + }, [editor]); useEffect(() => { return editor.registerNodeTransform(TextNode, (textNode: TextNode) => { + if (isPastingRef.current) return; + const text = textNode.getTextContent(); if (textNode.hasFormat('code')) return; if (!needsHighlight(text)) return; + if (textNode.getStyle()) return; const parent = textNode.getParent(); if (!parent) return; diff --git a/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx index 387160ed..12830ffb 100644 --- a/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx @@ -1,6 +1,6 @@ -import { useEffect } from 'react'; +import { useEffect, useRef } from 'react'; import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -import { TextNode, $createTextNode, $getSelection, $isRangeSelection } from 'lexical'; +import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical'; const PYTHON_KEYWORDS = new Set([ 'False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue', @@ -11,12 +11,30 @@ const PYTHON_KEYWORDS = new Set([ const Python3HighlightPlugin = () => { const [editor] = useLexicalComposerContext(); + const isPastingRef = useRef(false); + + useEffect(() => { + return editor.registerCommand( + PASTE_COMMAND, + () => { + isPastingRef.current = true; + setTimeout(() => { + isPastingRef.current = false; + }, 100); + return false; + }, + COMMAND_PRIORITY_LOW + ); + }, [editor]); useEffect(() => { return editor.registerNodeTransform(TextNode, (textNode: TextNode) => { + if (isPastingRef.current) return; + const text = textNode.getTextContent(); if (textNode.hasFormat('code')) return; + if (textNode.getStyle()) return; if (!needsHighlight(text)) return; const parent = textNode.getParent(); diff --git a/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx b/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx index 7c95a4a2..8a0ea03e 100644 --- a/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx +++ b/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx @@ -33,7 +33,6 @@ const codeTemplate = { const CodeExecution: FC = ({ options }) => { const { t } = useTranslation() const form = Form.useFormInstance() - const values = Form.useWatch([], form) || {} const handleRefresh = () => { const code = form.getFieldValue('code') || '' @@ -66,7 +65,6 @@ const CodeExecution: FC = ({ options }) => { form.setFieldValue('code', newTemplate) } const handleChangeLanguage = (value: string) => { - form.setFieldValue('code', codeTemplate[value as keyof typeof codeTemplate]) form.setFieldsValue({ input_variables: [{ name: 'arg1' }, { name: 'arg2' }], code: codeTemplate[value as keyof typeof codeTemplate] @@ -109,8 +107,12 @@ const CodeExecution: FC = ({ options }) => { - - + prev.language !== curr.language}> + {() => ( + + + + )} diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index 68d08aaa..d267faf8 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -159,7 +159,7 @@ export const useWorkflowGraph = ({ nodeLibraryConfig.config[key].defaultValue = Object.entries(config[key]).map(([name, value]) => ({ name, value })) } else if (type === 'code' && key === 'code' && config[key] && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) { try { - nodeLibraryConfig.config[key].defaultValue = atob(config[key] as string) + nodeLibraryConfig.config[key].defaultValue = decodeURIComponent(atob(config[key] as string)) } catch { nodeLibraryConfig.config[key].defaultValue = config[key] } @@ -943,7 +943,7 @@ export const useWorkflowGraph = ({ const code = data.config[key].defaultValue || '' itemConfig = { ...itemConfig, - code: btoa(code || '') + code: btoa(encodeURIComponent(code || '')) } } else if (key === 'memory' && data.config[key] && 'defaultValue' in data.config[key]) { const { messages, ...rest } = data.config[key].defaultValue