Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
# Conflicts: # api/app/core/agent/langchain_agent.py
This commit is contained in:
@@ -7,6 +7,10 @@ from celery import Celery
|
|||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# macOS fork() safety - must be set before any Celery initialization
|
||||||
|
if platform.system() == 'Darwin':
|
||||||
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
|
|
||||||
# 创建 Celery 应用实例
|
# 创建 Celery 应用实例
|
||||||
# broker: 任务队列(使用 Redis DB 0)
|
# broker: 任务队列(使用 Redis DB 0)
|
||||||
# backend: 结果存储(使用 Redis DB 10)
|
# backend: 结果存储(使用 Redis DB 10)
|
||||||
@@ -64,6 +68,11 @@ celery_app.conf.update(
|
|||||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||||
|
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||||
|
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||||
|
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
# Document tasks → document_tasks queue (prefork worker)
|
# Document tasks → document_tasks queue (prefork worker)
|
||||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from . import (
|
|||||||
memory_perceptual_controller,
|
memory_perceptual_controller,
|
||||||
memory_working_controller,
|
memory_working_controller,
|
||||||
ontology_controller,
|
ontology_controller,
|
||||||
|
skill_controller
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建管理端 API 路由器
|
# 创建管理端 API 路由器
|
||||||
@@ -90,5 +91,6 @@ manager_router.include_router(memory_perceptual_controller.router)
|
|||||||
manager_router.include_router(memory_working_controller.router)
|
manager_router.include_router(memory_working_controller.router)
|
||||||
manager_router.include_router(file_storage_controller.router)
|
manager_router.include_router(file_storage_controller.router)
|
||||||
manager_router.include_router(ontology_controller.router)
|
manager_router.include_router(ontology_controller.router)
|
||||||
|
manager_router.include_router(skill_controller.router)
|
||||||
|
|
||||||
__all__ = ["manager_router"]
|
__all__ = ["manager_router"]
|
||||||
|
|||||||
@@ -116,14 +116,6 @@ def _get_ontology_service(
|
|||||||
detail=f"找不到指定的LLM模型: {llm_id}"
|
detail=f"找不到指定的LLM模型: {llm_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否为组合模型
|
|
||||||
if hasattr(model_config, 'is_composite') and model_config.is_composite:
|
|
||||||
logger.error(f"Model {llm_id} is a composite model, which is not supported for ontology extraction")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="本体提取不支持使用组合模型,请选择单个模型"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 验证模型配置了API密钥
|
# 验证模型配置了API密钥
|
||||||
if not model_config.api_keys:
|
if not model_config.api_keys:
|
||||||
logger.error(f"Model {llm_id} has no API key configuration")
|
logger.error(f"Model {llm_id} has no API key configuration")
|
||||||
|
|||||||
90
api/app/controllers/skill_controller.py
Normal file
90
api/app/controllers/skill_controller.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Skill Controller - 技能市场管理"""
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||||
|
from app.models import User
|
||||||
|
from app.schemas import skill_schema
|
||||||
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
|
from app.services.skill_service import SkillService
|
||||||
|
from app.core.response_utils import success
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/skills", tags=["Skills"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", summary="创建技能")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def create_skill(
|
||||||
|
data: skill_schema.SkillCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""创建技能 - 可以关联现有工具(内置、MCP、自定义)"""
|
||||||
|
tenant_id = current_user.tenant_id
|
||||||
|
skill = SkillService.create_skill(db, data, tenant_id)
|
||||||
|
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", summary="技能列表")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def list_skills(
|
||||||
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
|
is_active: Optional[bool] = Query(None, description="是否激活"),
|
||||||
|
is_public: Optional[bool] = Query(None, description="是否公开"),
|
||||||
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
|
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""技能市场列表 - 包含本工作空间和公开的技能"""
|
||||||
|
tenant_id = current_user.tenant_id
|
||||||
|
skills, total = SkillService.list_skills(
|
||||||
|
db, tenant_id, search, is_active, is_public, page, pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
items = [skill_schema.Skill.model_validate(s) for s in skills]
|
||||||
|
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||||
|
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{skill_id}", summary="获取技能详情")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_skill(
|
||||||
|
skill_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取技能详情"""
|
||||||
|
tenant_id = current_user.tenant_id
|
||||||
|
skill = SkillService.get_skill(db, skill_id, tenant_id)
|
||||||
|
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{skill_id}", summary="更新技能")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def update_skill(
|
||||||
|
skill_id: uuid.UUID,
|
||||||
|
data: skill_schema.SkillUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""更新技能"""
|
||||||
|
tenant_id = current_user.tenant_id
|
||||||
|
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
|
||||||
|
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{skill_id}", summary="删除技能")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def delete_skill(
|
||||||
|
skill_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""删除技能"""
|
||||||
|
tenant_id = current_user.tenant_id
|
||||||
|
SkillService.delete_skill(db, skill_id, tenant_id)
|
||||||
|
return success(msg="技能删除成功")
|
||||||
151
api/app/core/agent/agent_middleware.py
Normal file
151
api/app/core/agent/agent_middleware.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""Agent Middleware - 动态技能过滤"""
|
||||||
|
import uuid
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from langchain_core.runnables import RunnablePassthrough
|
||||||
|
|
||||||
|
from app.services.skill_service import SkillService
|
||||||
|
from app.repositories.skill_repository import SkillRepository
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMiddleware:
|
||||||
|
"""Agent 中间件 - 用于动态过滤和加载技能"""
|
||||||
|
|
||||||
|
def __init__(self, skill_ids: Optional[List[str]] = None):
|
||||||
|
"""
|
||||||
|
初始化中间件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_ids: 技能ID列表
|
||||||
|
"""
|
||||||
|
self.skill_ids = skill_ids or []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def filter_tools(
|
||||||
|
tools: List,
|
||||||
|
message: str = "",
|
||||||
|
skill_configs: Dict[str, Any] = None,
|
||||||
|
tool_to_skill_map: Dict[str, str] = None
|
||||||
|
) -> tuple[List, List[str]]:
|
||||||
|
"""
|
||||||
|
根据消息内容和技能配置动态过滤工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: 所有可用工具列表
|
||||||
|
message: 用户消息(可用于智能过滤)
|
||||||
|
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
|
||||||
|
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(过滤后的工具列表, 激活的技能ID列表)
|
||||||
|
"""
|
||||||
|
if not tools:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# 如果没有技能配置,返回所有工具
|
||||||
|
if not skill_configs:
|
||||||
|
return tools, []
|
||||||
|
|
||||||
|
# 基于关键词匹配激活技能
|
||||||
|
activated_skill_ids = []
|
||||||
|
message_lower = message.lower()
|
||||||
|
|
||||||
|
for skill_id, config in skill_configs.items():
|
||||||
|
if not config.get('enabled', True):
|
||||||
|
continue
|
||||||
|
|
||||||
|
keywords = config.get('keywords', [])
|
||||||
|
# 如果没有关键词限制,或消息包含关键词,则激活该技能
|
||||||
|
if not keywords or any(kw.lower() in message_lower for kw in keywords):
|
||||||
|
activated_skill_ids.append(skill_id)
|
||||||
|
|
||||||
|
# 如果没有工具映射关系,返回所有工具
|
||||||
|
if not tool_to_skill_map:
|
||||||
|
return tools, activated_skill_ids
|
||||||
|
|
||||||
|
# 根据激活的技能过滤工具
|
||||||
|
filtered_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||||
|
# 如果工具不属于任何skill(base_tools),或者工具所属的skill被激活,则保留
|
||||||
|
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
|
||||||
|
filtered_tools.append(tool)
|
||||||
|
|
||||||
|
return filtered_tools, activated_skill_ids
|
||||||
|
|
||||||
|
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
加载技能关联的工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
tenant_id: 租户id
|
||||||
|
base_tools: 基础工具列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
|
||||||
|
"""
|
||||||
|
|
||||||
|
tools_dict = {}
|
||||||
|
tool_to_skill_map = {} # 工具名称到技能ID的映射
|
||||||
|
|
||||||
|
if base_tools:
|
||||||
|
for tool in base_tools:
|
||||||
|
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||||
|
tools_dict[tool_name] = tool
|
||||||
|
# base_tools 不属于任何 skill,不加入映射
|
||||||
|
|
||||||
|
skill_configs = {}
|
||||||
|
|
||||||
|
if self.skill_ids:
|
||||||
|
for skill_id in self.skill_ids:
|
||||||
|
try:
|
||||||
|
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
||||||
|
if skill and skill.is_active:
|
||||||
|
# 保存技能配置(包含prompt)
|
||||||
|
config = skill.config or {}
|
||||||
|
config['prompt'] = skill.prompt
|
||||||
|
config['name'] = skill.name
|
||||||
|
skill_configs[skill_id] = config
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 加载技能工具并获取映射关系
|
||||||
|
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, self.skill_ids, tenant_id)
|
||||||
|
|
||||||
|
# 只添加不冲突的 skill_tools
|
||||||
|
for tool in skill_tools:
|
||||||
|
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||||
|
if tool_name not in tools_dict:
|
||||||
|
tools_dict[tool_name] = tool
|
||||||
|
# 复制映射关系
|
||||||
|
if tool_name in skill_tool_map:
|
||||||
|
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
|
||||||
|
|
||||||
|
return list(tools_dict.values()), skill_configs, tool_to_skill_map
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
根据激活的技能ID获取对应的提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
activated_skill_ids: 被激活的技能ID列表
|
||||||
|
skill_configs: 技能配置字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
合并后的提示词
|
||||||
|
"""
|
||||||
|
prompts = []
|
||||||
|
for skill_id in activated_skill_ids:
|
||||||
|
config = skill_configs.get(skill_id, {})
|
||||||
|
prompt = config.get('prompt')
|
||||||
|
name = config.get('name', 'Skill')
|
||||||
|
if prompt:
|
||||||
|
prompts.append(f"# {name}\n{prompt}")
|
||||||
|
|
||||||
|
return "\n\n".join(prompts) if prompts else ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_runnable():
|
||||||
|
"""创建可运行的中间件"""
|
||||||
|
return RunnablePassthrough()
|
||||||
@@ -291,6 +291,7 @@ class LangChainAgent:
|
|||||||
|
|
||||||
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
||||||
db = next(get_db())
|
db = next(get_db())
|
||||||
|
#TODO: 魔法数字
|
||||||
scope=6
|
scope=6
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -300,6 +301,12 @@ class LangChainAgent:
|
|||||||
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
result = write_store.get_session_by_userid(end_user_id)
|
result = write_store.get_session_by_userid(end_user_id)
|
||||||
|
|
||||||
|
# Handle case where no session exists in Redis (returns False)
|
||||||
|
if not result or result is False:
|
||||||
|
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
|
||||||
|
return
|
||||||
|
|
||||||
if type=="chunk" or type=="aggregate":
|
if type=="chunk" or type=="aggregate":
|
||||||
data = await format_parsing(result, "dict")
|
data = await format_parsing(result, "dict")
|
||||||
chunk_data = data[:scope]
|
chunk_data = data[:scope]
|
||||||
@@ -307,7 +314,14 @@ class LangChainAgent:
|
|||||||
repo.upsert(end_user_id, chunk_data)
|
repo.upsert(end_user_id, chunk_data)
|
||||||
logger.info(f'写入短长期:')
|
logger.info(f'写入短长期:')
|
||||||
else:
|
else:
|
||||||
|
# TODO: This branch handles type="time" strategy, currently unused.
|
||||||
|
# Will be activated when time-based long-term storage is implemented.
|
||||||
|
# TODO: 魔法数字 - extract 5 to a constant
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||||
|
# Handle case where no session exists in Redis (returns False or empty)
|
||||||
|
if not long_time_data or long_time_data is False:
|
||||||
|
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
|
||||||
|
return
|
||||||
long_messages = await messages_parse(long_time_data)
|
long_messages = await messages_parse(long_time_data)
|
||||||
repo.upsert(end_user_id, long_messages)
|
repo.upsert(end_user_id, long_messages)
|
||||||
logger.info(f'写入短长期:')
|
logger.info(f'写入短长期:')
|
||||||
@@ -507,9 +521,12 @@ class LangChainAgent:
|
|||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
long_term_messages=await agent_chat_messages(message_chat,content)
|
long_term_messages=await agent_chat_messages(message_chat,content)
|
||||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
||||||
|
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
||||||
|
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
||||||
|
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
||||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||||
'''长期'''
|
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
||||||
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
|
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
|
||||||
response = {
|
response = {
|
||||||
"content": content,
|
"content": content,
|
||||||
@@ -693,9 +710,13 @@ class LangChainAgent:
|
|||||||
yield total_tokens
|
yield total_tokens
|
||||||
break
|
break
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
||||||
|
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
||||||
|
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
||||||
|
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
||||||
long_term_messages = await agent_chat_messages(message_chat, full_content)
|
long_term_messages = await agent_chat_messages(message_chat, full_content)
|
||||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||||
|
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
||||||
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
|
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -215,6 +215,9 @@ class Settings:
|
|||||||
# official environment system version
|
# official environment system version
|
||||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
||||||
|
|
||||||
|
# model square loading
|
||||||
|
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
||||||
|
|
||||||
# workflow config
|
# workflow config
|
||||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config):
|
|||||||
for node_name, node_data in update_event.items():
|
for node_name, node_data in update_event.items():
|
||||||
if 'save_neo4j' == node_name:
|
if 'save_neo4j' == node_name:
|
||||||
massages = node_data
|
massages = node_data
|
||||||
|
# TODO:删除
|
||||||
massagesstatus = massages.get('write_result')['status']
|
massagesstatus = massages.get('write_result')['status']
|
||||||
contents = massages.get('write_result')
|
contents = massages.get('write_result')
|
||||||
print(contents)
|
print(contents)
|
||||||
@@ -60,6 +61,7 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
|||||||
scope:窗口大小
|
scope:窗口大小
|
||||||
'''
|
'''
|
||||||
scope=scope
|
scope=scope
|
||||||
|
redis_messages = []
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||||
if is_end_user_id is not False:
|
if is_end_user_id is not False:
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||||
@@ -91,6 +93,9 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
|
|||||||
memory_config: 内存配置对象
|
memory_config: 内存配置对象
|
||||||
'''
|
'''
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||||
|
# Handle case where no session exists in Redis (returns False or empty)
|
||||||
|
if not long_time_data or long_time_data is False:
|
||||||
|
return
|
||||||
format_messages = await chat_data_format(long_time_data)
|
format_messages = await chat_data_format(long_time_data)
|
||||||
if format_messages!=[]:
|
if format_messages!=[]:
|
||||||
await write_messages(end_user_id, format_messages, memory_config)
|
await write_messages(end_user_id, format_messages, memory_config)
|
||||||
@@ -108,8 +113,9 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
|||||||
try:
|
try:
|
||||||
# 1. 获取历史会话数据(使用新方法)
|
# 1. 获取历史会话数据(使用新方法)
|
||||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||||
history = await format_parsing(result)
|
|
||||||
if not result:
|
# Handle case where no session exists in Redis (returns False or empty)
|
||||||
|
if not result or result is False:
|
||||||
history = []
|
history = []
|
||||||
else:
|
else:
|
||||||
history = await format_parsing(result)
|
history = await format_parsing(result)
|
||||||
|
|||||||
@@ -1,18 +1,14 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from langgraph.constants import END, START
|
from langgraph.constants import END, START
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse
|
|
||||||
from app.db import get_db
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
@@ -40,27 +36,55 @@ async def make_write_graph():
|
|||||||
|
|
||||||
yield graph
|
yield graph
|
||||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
"""Dispatch long-term memory storage to Celery background tasks.
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
Args:
|
||||||
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
|
||||||
# 获取数据库会话
|
langchain_messages: List of messages to store
|
||||||
db_session = next(get_db())
|
memory_config: Memory configuration ID (string)
|
||||||
config_service = MemoryConfigService(db_session)
|
end_user_id: End user identifier
|
||||||
memory_config = config_service.load_memory_config(
|
scope: Window size for 'chunk' strategy (default: 6)
|
||||||
config_id=memory_config, # 改为整数
|
"""
|
||||||
service_name="MemoryAgentService"
|
from app.tasks import (
|
||||||
|
long_term_storage_window_task,
|
||||||
|
# TODO: Uncomment when implemented
|
||||||
|
# long_term_storage_time_task,
|
||||||
|
# long_term_storage_aggregate_task,
|
||||||
)
|
)
|
||||||
if long_term_type=='chunk':
|
from app.core.logging_config import get_logger
|
||||||
'''方案一:对话窗口6轮对话'''
|
|
||||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
logger = get_logger(__name__)
|
||||||
if long_term_type=='time':
|
|
||||||
"""时间"""
|
# Convert config to string if needed
|
||||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
config_id = str(memory_config) if memory_config else ''
|
||||||
if long_term_type=='aggregate':
|
|
||||||
|
if long_term_type == 'chunk':
|
||||||
"""方案三:聚合判断"""
|
# Strategy 1: Window-based batching (6 rounds of dialogue)
|
||||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
|
||||||
|
long_term_storage_window_task.delay(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
langchain_messages=langchain_messages,
|
||||||
|
config_id=config_id,
|
||||||
|
scope=scope
|
||||||
|
)
|
||||||
|
# TODO: Uncomment when time-based strategy is fully implemented
|
||||||
|
# elif long_term_type == 'time':
|
||||||
|
# # Strategy 2: Time-based retrieval
|
||||||
|
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
|
||||||
|
# long_term_storage_time_task.delay(
|
||||||
|
# end_user_id=end_user_id,
|
||||||
|
# config_id=config_id,
|
||||||
|
# time_window=5
|
||||||
|
# )
|
||||||
|
# TODO: Uncomment when aggregate strategy is fully implemented
|
||||||
|
# elif long_term_type == 'aggregate':
|
||||||
|
# # Strategy 3: Aggregate judgment (deduplication)
|
||||||
|
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
|
||||||
|
# long_term_storage_aggregate_task.delay(
|
||||||
|
# end_user_id=end_user_id,
|
||||||
|
# langchain_messages=langchain_messages,
|
||||||
|
# config_id=config_id
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
# async def main():
|
# async def main():
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
provider: bedrock
|
provider: bedrock
|
||||||
enabled: false
|
|
||||||
models:
|
models:
|
||||||
- name: ai21
|
- name: ai21
|
||||||
type: llm
|
type: llm
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
provider: dashscope
|
provider: dashscope
|
||||||
enabled: false
|
|
||||||
models:
|
models:
|
||||||
- name: deepseek-r1-distill-qwen-14b
|
- name: deepseek-r1-distill-qwen-14b
|
||||||
type: llm
|
type: llm
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
|
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.models.models_model import ModelBase, ModelProvider
|
from app.models.models_model import ModelBase, ModelProvider
|
||||||
|
|
||||||
|
|
||||||
@@ -19,31 +19,9 @@ def _load_yaml_config(provider: ModelProvider) -> list[dict]:
|
|||||||
|
|
||||||
with open(config_file, 'r', encoding='utf-8') as f:
|
with open(config_file, 'r', encoding='utf-8') as f:
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
# 检查是否需要加载(默认为 true)
|
|
||||||
if not data.get('enabled', True):
|
|
||||||
return []
|
|
||||||
|
|
||||||
return data.get('models', [])
|
return data.get('models', [])
|
||||||
|
|
||||||
|
|
||||||
def _disable_yaml_config(provider: ModelProvider) -> None:
|
|
||||||
"""将YAML文件的enabled标志设置为false"""
|
|
||||||
config_dir = Path(__file__).parent
|
|
||||||
config_file = config_dir / f"{provider.value}_models.yaml"
|
|
||||||
|
|
||||||
if not config_file.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(config_file, 'r', encoding='utf-8') as f:
|
|
||||||
data = yaml.safe_load(f)
|
|
||||||
|
|
||||||
data['enabled'] = False
|
|
||||||
|
|
||||||
with open(config_file, 'w', encoding='utf-8') as f:
|
|
||||||
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
|
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
|
||||||
"""
|
"""
|
||||||
加载模型配置到数据库
|
加载模型配置到数据库
|
||||||
@@ -75,8 +53,7 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
|||||||
|
|
||||||
if not silent:
|
if not silent:
|
||||||
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
||||||
|
|
||||||
# provider_success = 0
|
|
||||||
for model_data in models:
|
for model_data in models:
|
||||||
try:
|
try:
|
||||||
# 检查模型是否已存在
|
# 检查模型是否已存在
|
||||||
@@ -93,7 +70,6 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
|||||||
if not silent:
|
if not silent:
|
||||||
print(f"更新成功: {model_data['name']}")
|
print(f"更新成功: {model_data['name']}")
|
||||||
result["success"] += 1
|
result["success"] += 1
|
||||||
# provider_success += 1
|
|
||||||
else:
|
else:
|
||||||
# 创建新模型
|
# 创建新模型
|
||||||
model = ModelBase(**model_data)
|
model = ModelBase(**model_data)
|
||||||
@@ -102,17 +78,12 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
|||||||
if not silent:
|
if not silent:
|
||||||
print(f"添加成功: {model_data['name']}")
|
print(f"添加成功: {model_data['name']}")
|
||||||
result["success"] += 1
|
result["success"] += 1
|
||||||
# provider_success += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
if not silent:
|
if not silent:
|
||||||
print(f"添加失败: {model_data['name']} - {str(e)}")
|
print(f"添加失败: {model_data['name']} - {str(e)}")
|
||||||
result["failed"] += 1
|
result["failed"] += 1
|
||||||
|
|
||||||
# 如果该供应商的模型全部加载成功,将enabled设置为false
|
|
||||||
# if provider_success == len(models):
|
|
||||||
_disable_yaml_config(provider)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
provider: openai
|
provider: openai
|
||||||
enabled: false
|
|
||||||
models:
|
models:
|
||||||
- name: chatgpt-4o-latest
|
- name: chatgpt-4o-latest
|
||||||
type: llm
|
type: llm
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import urllib.parse
|
||||||
from string import Template
|
from string import Template
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|||||||
@@ -50,13 +50,16 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
|
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
|
||||||
|
|
||||||
# 加载预定义模型
|
# 加载预定义模型
|
||||||
logger.info("开始加载预定义模型...")
|
if settings.LOAD_MODEL:
|
||||||
try:
|
logger.info("开始加载预定义模型...")
|
||||||
with get_db_context() as db:
|
try:
|
||||||
result = load_models(db, silent=True)
|
with get_db_context() as db:
|
||||||
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个")
|
result = load_models(db, silent=True)
|
||||||
except Exception as e:
|
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个")
|
||||||
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
except Exception as e:
|
||||||
|
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
||||||
|
else:
|
||||||
|
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||||||
|
|
||||||
logger.info("应用程序启动完成")
|
logger.info("应用程序启动完成")
|
||||||
yield
|
yield
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from .tool_model import (
|
|||||||
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
||||||
)
|
)
|
||||||
from .memory_perceptual_model import MemoryPerceptualModel
|
from .memory_perceptual_model import MemoryPerceptualModel
|
||||||
|
from .skill_model import Skill
|
||||||
from .ontology_scene import OntologyScene
|
from .ontology_scene import OntologyScene
|
||||||
from .ontology_class import OntologyClass
|
from .ontology_class import OntologyClass
|
||||||
from .ontology_scene import OntologyScene
|
from .ontology_scene import OntologyScene
|
||||||
@@ -84,5 +85,6 @@ __all__ = [
|
|||||||
"ExecutionStatus",
|
"ExecutionStatus",
|
||||||
"MemoryPerceptualModel",
|
"MemoryPerceptualModel",
|
||||||
"ModelBase",
|
"ModelBase",
|
||||||
"LoadBalanceStrategy"
|
"LoadBalanceStrategy",
|
||||||
|
"Skill"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class AgentConfig(Base):
|
|||||||
memory = Column(JSON, nullable=True, comment="记忆配置")
|
memory = Column(JSON, nullable=True, comment="记忆配置")
|
||||||
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
||||||
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
|
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
|
||||||
|
skill_ids = Column(JSON, default=list, nullable=True, comment="关联的技能ID列表")
|
||||||
|
|
||||||
# 多 Agent 相关字段
|
# 多 Agent 相关字段
|
||||||
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
|
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
|
||||||
|
|||||||
37
api/app/models/skill_model.py
Normal file
37
api/app/models/skill_model.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Skill 模型定义"""
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||||
|
|
||||||
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Skill(Base):
|
||||||
|
"""技能模型 - 可以关联工具(内置、MCP、自定义)"""
|
||||||
|
__tablename__ = "skills"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||||
|
name = Column(String, nullable=False, comment="技能名称")
|
||||||
|
description = Column(Text, comment="技能描述")
|
||||||
|
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID")
|
||||||
|
|
||||||
|
# 关联的工具
|
||||||
|
tools = Column(JSON, default=list, comment="关联的工具列表")
|
||||||
|
|
||||||
|
# 技能配置
|
||||||
|
config = Column(JSON, default=dict, comment="技能配置")
|
||||||
|
|
||||||
|
# 专属提示词
|
||||||
|
prompt = Column(Text, comment="技能专属提示词")
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||||
|
is_public = Column(Boolean, default=False, nullable=False, comment="是否公开到市场")
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||||
|
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Skill(id={self.id}, name={self.name})>"
|
||||||
@@ -235,6 +235,8 @@ class MemoryConfigRepository:
|
|||||||
llm_id=params.llm_id,
|
llm_id=params.llm_id,
|
||||||
embedding_id=params.embedding_id,
|
embedding_id=params.embedding_id,
|
||||||
rerank_id=params.rerank_id,
|
rerank_id=params.rerank_id,
|
||||||
|
reflection_model_id=params.reflection_model_id,
|
||||||
|
emotion_model_id=params.emotion_model_id,
|
||||||
)
|
)
|
||||||
db.add(db_config)
|
db.add(db_config)
|
||||||
db.flush() # 获取自增ID但不提交事务
|
db.flush() # 获取自增ID但不提交事务
|
||||||
|
|||||||
111
api/app/repositories/skill_repository.py
Normal file
111
api/app/repositories/skill_repository.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""Skill Repository"""
|
||||||
|
from typing import List, Optional, Tuple, Any
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import and_, or_
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from app.models.skill_model import Skill
|
||||||
|
from app.schemas.skill_schema import SkillCreate, SkillUpdate
|
||||||
|
|
||||||
|
|
||||||
|
class SkillRepository:
|
||||||
|
"""Skill 数据访问层"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
|
||||||
|
"""创建技能"""
|
||||||
|
skill = Skill(
|
||||||
|
**data.model_dump(),
|
||||||
|
tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
db.add(skill)
|
||||||
|
db.flush()
|
||||||
|
return skill
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_by_id(db: Session, skill_id: uuid.UUID, tenant_id: Optional[uuid.UUID] = None) -> Optional[Skill]:
|
||||||
|
"""根据ID获取技能"""
|
||||||
|
query = db.query(Skill).filter(Skill.id == skill_id)
|
||||||
|
if tenant_id:
|
||||||
|
query = query.filter(
|
||||||
|
or_(
|
||||||
|
Skill.tenant_id == tenant_id,
|
||||||
|
Skill.is_public == True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return query.first()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_skills(
|
||||||
|
db: Session,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
is_active: Optional[bool] = None,
|
||||||
|
is_public: Optional[bool] = None,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 10
|
||||||
|
) -> tuple[list[type[Skill]], int]:
|
||||||
|
"""列出技能"""
|
||||||
|
filters = [
|
||||||
|
or_(
|
||||||
|
Skill.tenant_id == tenant_id,
|
||||||
|
Skill.is_public == True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if search:
|
||||||
|
filters.append(
|
||||||
|
or_(
|
||||||
|
Skill.name.ilike(f"%{search}%"),
|
||||||
|
# Skill.description.ilike(f"%{search}%")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_active is not None:
|
||||||
|
filters.append(Skill.is_active == is_active)
|
||||||
|
|
||||||
|
if is_public is not None:
|
||||||
|
filters.append(Skill.is_public == is_public)
|
||||||
|
|
||||||
|
query = db.query(Skill).filter(and_(*filters))
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
skills = query.order_by(Skill.created_at.desc()).offset(
|
||||||
|
(page - 1) * pagesize
|
||||||
|
).limit(pagesize).all()
|
||||||
|
|
||||||
|
return skills, total
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Optional[Skill]:
|
||||||
|
"""更新技能"""
|
||||||
|
skill = db.query(Skill).filter(
|
||||||
|
Skill.id == skill_id,
|
||||||
|
Skill.tenant_id == tenant_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not skill:
|
||||||
|
return None
|
||||||
|
|
||||||
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(skill, key, value)
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
return skill
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
|
||||||
|
"""删除技能"""
|
||||||
|
skill = db.query(Skill).filter(
|
||||||
|
Skill.id == skill_id,
|
||||||
|
Skill.tenant_id == tenant_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not skill:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# db.delete(skill)
|
||||||
|
skill.is_active = False
|
||||||
|
db.flush()
|
||||||
|
return True
|
||||||
@@ -156,6 +156,9 @@ class AgentConfigCreate(BaseModel):
|
|||||||
description="Agent 可用的工具列表"
|
description="Agent 可用的工具列表"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 技能配置
|
||||||
|
skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表")
|
||||||
|
|
||||||
|
|
||||||
class AppCreate(BaseModel):
|
class AppCreate(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@@ -207,6 +210,9 @@ class AgentConfigUpdate(BaseModel):
|
|||||||
|
|
||||||
# 工具配置
|
# 工具配置
|
||||||
tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表")
|
tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表")
|
||||||
|
|
||||||
|
# 技能配置
|
||||||
|
skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表")
|
||||||
|
|
||||||
|
|
||||||
# ---------- Output Schemas ----------
|
# ---------- Output Schemas ----------
|
||||||
@@ -266,6 +272,8 @@ class AgentConfig(BaseModel):
|
|||||||
# 工具配置
|
# 工具配置
|
||||||
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
|
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
|
||||||
|
|
||||||
|
skill_ids: Optional[List[str]] = []
|
||||||
|
|
||||||
is_active: bool
|
is_active: bool
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
updated_at: datetime.datetime
|
updated_at: datetime.datetime
|
||||||
|
|||||||
@@ -236,6 +236,8 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
|||||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||||
|
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
||||||
|
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
||||||
|
|
||||||
|
|
||||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||||
|
|||||||
57
api/app/schemas/skill_schema.py
Normal file
57
api/app/schemas/skill_schema.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Skill Schema 定义"""
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field, field_serializer
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class SkillBase(BaseModel):
|
||||||
|
"""Skill 基础 Schema"""
|
||||||
|
name: str = Field(..., description="技能名称")
|
||||||
|
description: Optional[str] = Field(None, description="技能描述")
|
||||||
|
tools: List[Dict[str, str]] = Field(default_factory=list, description="工具对象列表: [{\"tool_id\": \"xxx\", \"operation\": \"yyy\"}]")
|
||||||
|
config: Dict[str, Any] = Field(default_factory=dict, description="技能配置")
|
||||||
|
prompt: Optional[str] = Field(None, description="技能专属提示词")
|
||||||
|
is_active: bool = Field(True, description="是否激活")
|
||||||
|
is_public: bool = Field(False, description="是否公开到市场")
|
||||||
|
|
||||||
|
|
||||||
|
class SkillCreate(SkillBase):
|
||||||
|
"""创建 Skill"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SkillUpdate(BaseModel):
|
||||||
|
"""更新 Skill"""
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
tools: Optional[List[Dict[str, str]]] = None
|
||||||
|
config: Optional[Dict[str, Any]] = None
|
||||||
|
prompt: Optional[str] = None
|
||||||
|
is_active: Optional[bool] = None
|
||||||
|
is_public: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Skill(SkillBase):
|
||||||
|
"""Skill 响应 Schema"""
|
||||||
|
id: uuid.UUID
|
||||||
|
tenant_id: uuid.UUID
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
@field_serializer('created_at', 'updated_at')
|
||||||
|
def serialize_datetime_to_timestamp(self, value: datetime) -> int:
|
||||||
|
"""(毫秒级)时间戳"""
|
||||||
|
return int(value.timestamp() * 1000)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class SkillQuery(BaseModel):
|
||||||
|
"""Skill 查询参数"""
|
||||||
|
search: Optional[str] = None
|
||||||
|
is_active: Optional[bool] = None
|
||||||
|
is_public: Optional[bool] = None
|
||||||
|
page: int = Field(1, ge=1)
|
||||||
|
pagesize: int = Field(10, ge=1, le=100)
|
||||||
@@ -48,6 +48,9 @@ class AgentConfigConverter:
|
|||||||
# 5. 工具配置
|
# 5. 工具配置
|
||||||
if hasattr(config, 'tools') and config.tools:
|
if hasattr(config, 'tools') and config.tools:
|
||||||
result["tools"] = [tool.model_dump() for tool in config.tools]
|
result["tools"] = [tool.model_dump() for tool in config.tools]
|
||||||
|
|
||||||
|
if hasattr(config, "skill_ids") and config.skill_ids:
|
||||||
|
result["skill_ids"] = [skill for skill in config.skill_ids]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -58,6 +61,7 @@ class AgentConfigConverter:
|
|||||||
memory: Optional[Dict[str, Any]],
|
memory: Optional[Dict[str, Any]],
|
||||||
variables: Optional[list],
|
variables: Optional[list],
|
||||||
tools: Optional[Union[list, Dict[str, Any]]],
|
tools: Optional[Union[list, Dict[str, Any]]],
|
||||||
|
skill_ids: Optional[list]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
将数据库存储格式转换为 Pydantic 对象
|
将数据库存储格式转换为 Pydantic 对象
|
||||||
@@ -68,6 +72,7 @@ class AgentConfigConverter:
|
|||||||
memory: 记忆配置
|
memory: 记忆配置
|
||||||
variables: 变量配置
|
variables: 变量配置
|
||||||
tools: 工具配置
|
tools: 工具配置
|
||||||
|
skill_ids: 技能 ID 列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含 Pydantic 对象的字典
|
包含 Pydantic 对象的字典
|
||||||
@@ -78,6 +83,7 @@ class AgentConfigConverter:
|
|||||||
"memory": MemoryConfig(enabled=True),
|
"memory": MemoryConfig(enabled=True),
|
||||||
"variables": [],
|
"variables": [],
|
||||||
"tools": [],
|
"tools": [],
|
||||||
|
"skill_ids": []
|
||||||
}
|
}
|
||||||
|
|
||||||
# 1. 解析模型参数配置
|
# 1. 解析模型参数配置
|
||||||
@@ -117,5 +123,8 @@ class AgentConfigConverter:
|
|||||||
name: ToolOldConfig(**tool_data)
|
name: ToolOldConfig(**tool_data)
|
||||||
for name, tool_data in tools.items()
|
for name, tool_data in tools.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if skill_ids:
|
||||||
|
result["skill_ids"] = [skill for skill in skill_ids]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
|||||||
memory=agent_cfg.memory,
|
memory=agent_cfg.memory,
|
||||||
variables=agent_cfg.variables,
|
variables=agent_cfg.variables,
|
||||||
tools=agent_cfg.tools,
|
tools=agent_cfg.tools,
|
||||||
|
skill_ids=agent_cfg.skill_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
# 将解析后的字段添加到对象上(用于序列化)
|
# 将解析后的字段添加到对象上(用于序列化)
|
||||||
@@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
|||||||
agent_cfg.memory = parsed["memory"]
|
agent_cfg.memory = parsed["memory"]
|
||||||
agent_cfg.variables = parsed["variables"]
|
agent_cfg.variables = parsed["variables"]
|
||||||
agent_cfg.tools = parsed["tools"]
|
agent_cfg.tools = parsed["tools"]
|
||||||
|
agent_cfg.skill_ids = parsed["skill_ids"]
|
||||||
|
|
||||||
return agent_cfg
|
return agent_cfg
|
||||||
|
|||||||
@@ -304,6 +304,7 @@ class AppService:
|
|||||||
memory=storage_data.get("memory"),
|
memory=storage_data.get("memory"),
|
||||||
variables=storage_data.get("variables", []),
|
variables=storage_data.get("variables", []),
|
||||||
tools=storage_data.get("tools", []),
|
tools=storage_data.get("tools", []),
|
||||||
|
skill_ids=storage_data.get("skill_ids", []),
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -907,6 +908,7 @@ class AppService:
|
|||||||
agent_cfg.variables = storage_data.get("variables", [])
|
agent_cfg.variables = storage_data.get("variables", [])
|
||||||
# if data.tools is not None:
|
# if data.tools is not None:
|
||||||
agent_cfg.tools = storage_data.get("tools", [])
|
agent_cfg.tools = storage_data.get("tools", [])
|
||||||
|
agent_cfg.skill_ids = storage_data.get("skill_ids", [])
|
||||||
|
|
||||||
agent_cfg.updated_at = now
|
agent_cfg.updated_at = now
|
||||||
|
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ class AppStatisticsService:
|
|||||||
daily_tokens[date_str] = 0
|
daily_tokens[date_str] = 0
|
||||||
daily_tokens[date_str] += int(tokens)
|
daily_tokens[date_str] += int(tokens)
|
||||||
|
|
||||||
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||||
total = sum(row["tokens"] for row in daily_data)
|
total = sum(row["count"] for row in daily_data)
|
||||||
|
|
||||||
return {"daily": daily_data, "total": total}
|
return {"daily": daily_data, "total": total}
|
||||||
|
|||||||
@@ -10,6 +10,11 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.tools import tool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
@@ -26,10 +31,8 @@ from app.services.memory_agent_service import MemoryAgentService
|
|||||||
from app.services.model_parameter_merger import ModelParameterMerger
|
from app.services.model_parameter_merger import ModelParameterMerger
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
from langchain.tools import tool
|
from app.core.agent.agent_middleware import AgentMiddleware
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
class KnowledgeRetrievalInput(BaseModel):
|
class KnowledgeRetrievalInput(BaseModel):
|
||||||
@@ -310,6 +313,7 @@ class DraftRunService:
|
|||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
tool_service = ToolService(self.db)
|
tool_service = ToolService(self.db)
|
||||||
|
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||||
|
|
||||||
# 从配置中获取启用的工具
|
# 从配置中获取启用的工具
|
||||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||||
@@ -320,9 +324,7 @@ class DraftRunService:
|
|||||||
print(f"tool_config:{tool_config}")
|
print(f"tool_config:{tool_config}")
|
||||||
if tool_config.get("enabled", False):
|
if tool_config.get("enabled", False):
|
||||||
# 根据工具名称查找工具实例
|
# 根据工具名称查找工具实例
|
||||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||||
ToolRepository.get_tenant_id_by_workspace_id(
|
|
||||||
self.db, str(workspace_id)))
|
|
||||||
if tool_instance:
|
if tool_instance:
|
||||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||||
continue
|
continue
|
||||||
@@ -345,6 +347,22 @@ class DraftRunService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 加载技能关联的工具
|
||||||
|
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
|
||||||
|
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
|
||||||
|
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||||
|
tools.extend(skill_tools)
|
||||||
|
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||||
|
|
||||||
|
# 应用动态过滤
|
||||||
|
if skill_configs:
|
||||||
|
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||||
|
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||||
|
active_prompts = AgentMiddleware.get_active_prompts(
|
||||||
|
activated_skill_ids, skill_configs
|
||||||
|
)
|
||||||
|
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
if agent_config.knowledge_retrieval:
|
if agent_config.knowledge_retrieval:
|
||||||
kb_config = agent_config.knowledge_retrieval
|
kb_config = agent_config.knowledge_retrieval
|
||||||
@@ -558,6 +576,7 @@ class DraftRunService:
|
|||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
tool_service = ToolService(self.db)
|
tool_service = ToolService(self.db)
|
||||||
|
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||||
|
|
||||||
# 从配置中获取启用的工具
|
# 从配置中获取启用的工具
|
||||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||||
@@ -567,9 +586,7 @@ class DraftRunService:
|
|||||||
# print(f"tool_config:{tool_config}")
|
# print(f"tool_config:{tool_config}")
|
||||||
if tool_config.get("enabled", False):
|
if tool_config.get("enabled", False):
|
||||||
# 根据工具名称查找工具实例
|
# 根据工具名称查找工具实例
|
||||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||||
ToolRepository.get_tenant_id_by_workspace_id(
|
|
||||||
self.db, str(workspace_id)))
|
|
||||||
if tool_instance:
|
if tool_instance:
|
||||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||||
continue
|
continue
|
||||||
@@ -592,6 +609,23 @@ class DraftRunService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 加载技能关联的工具
|
||||||
|
skill_configs = {}
|
||||||
|
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
|
||||||
|
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
|
||||||
|
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||||
|
tools.extend(skill_tools)
|
||||||
|
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||||
|
|
||||||
|
# 应用动态过滤
|
||||||
|
if skill_configs:
|
||||||
|
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||||
|
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||||
|
active_prompts = AgentMiddleware.get_active_prompts(
|
||||||
|
activated_skill_ids, skill_configs
|
||||||
|
)
|
||||||
|
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||||
|
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
if agent_config.knowledge_retrieval:
|
if agent_config.knowledge_retrieval:
|
||||||
@@ -628,7 +662,6 @@ class DraftRunService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 4. 创建 LangChain Agent
|
# 4. 创建 LangChain Agent
|
||||||
agent = LangChainAgent(
|
agent = LangChainAgent(
|
||||||
model_name=api_key_config["model_name"],
|
model_name=api_key_config["model_name"],
|
||||||
|
|||||||
@@ -53,7 +53,10 @@ def get_workspace_end_users(
|
|||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
current_user: User
|
current_user: User
|
||||||
) -> List[EndUser]:
|
) -> List[EndUser]:
|
||||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
|
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
||||||
|
|
||||||
|
返回结果按 updated_at 从新到旧排序(NULL 值排在最后)
|
||||||
|
"""
|
||||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -68,9 +71,14 @@ def get_workspace_end_users(
|
|||||||
app_ids = [app.id for app in apps_orm]
|
app_ids = [app.id for app in apps_orm]
|
||||||
|
|
||||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||||
|
# 按 updated_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||||
from app.models.end_user_model import EndUser as EndUserModel
|
from app.models.end_user_model import EndUser as EndUserModel
|
||||||
|
from sqlalchemy import desc, nullslast
|
||||||
end_users_orm = db.query(EndUserModel).filter(
|
end_users_orm = db.query(EndUserModel).filter(
|
||||||
EndUserModel.app_id.in_(app_ids)
|
EndUserModel.app_id.in_(app_ids)
|
||||||
|
).order_by(
|
||||||
|
nullslast(desc(EndUserModel.updated_at)),
|
||||||
|
desc(EndUserModel.id)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
# 转换为 Pydantic 模型(只在需要时转换)
|
# 转换为 Pydantic 模型(只在需要时转换)
|
||||||
|
|||||||
@@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
if not params.rerank_id:
|
if not params.rerank_id:
|
||||||
params.rerank_id = configs.get('rerank')
|
params.rerank_id = configs.get('rerank')
|
||||||
|
|
||||||
|
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
|
||||||
|
if not params.reflection_model_id:
|
||||||
|
params.reflection_model_id = params.llm_id
|
||||||
|
if not params.emotion_model_id:
|
||||||
|
params.emotion_model_id = params.llm_id
|
||||||
|
|
||||||
config = MemoryConfigRepository.create(self.db, params)
|
config = MemoryConfigRepository.create(self.db, params)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
return {"affected": 1, "config_id": config.config_id}
|
return {"affected": 1, "config_id": config.config_id}
|
||||||
@@ -203,6 +209,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
"end_user_id": config.end_user_id,
|
"end_user_id": config.end_user_id,
|
||||||
"config_id_old": config_id_old,
|
"config_id_old": config_id_old,
|
||||||
"apply_id": config.apply_id,
|
"apply_id": config.apply_id,
|
||||||
|
"scene_id": config.scene_id,
|
||||||
"llm_id": config.llm_id,
|
"llm_id": config.llm_id,
|
||||||
"embedding_id": config.embedding_id,
|
"embedding_id": config.embedding_id,
|
||||||
"rerank_id": config.rerank_id,
|
"rerank_id": config.rerank_id,
|
||||||
|
|||||||
109
api/app/services/skill_service.py
Normal file
109
api/app/services/skill_service.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
"""Skill Service"""
|
||||||
|
import uuid
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.repositories.skill_repository import SkillRepository
|
||||||
|
from app.schemas.skill_schema import SkillCreate, SkillUpdate
|
||||||
|
from app.models.skill_model import Skill
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.services.tool_service import ToolService
|
||||||
|
|
||||||
|
|
||||||
|
class SkillService:
|
||||||
|
"""Skill 业务逻辑层"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
|
||||||
|
"""创建技能"""
|
||||||
|
skill = SkillRepository.create(db, data, tenant_id)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(skill)
|
||||||
|
return skill
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill:
|
||||||
|
"""获取技能"""
|
||||||
|
try:
|
||||||
|
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
|
||||||
|
if not skill:
|
||||||
|
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
|
||||||
|
return skill
|
||||||
|
except (BusinessException, SQLAlchemyError) as e:
|
||||||
|
db.rollback()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_skills(
|
||||||
|
db: Session,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
search: str = None,
|
||||||
|
is_active: bool = None,
|
||||||
|
is_public: bool = None,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 10
|
||||||
|
) -> tuple[list[type[Skill]], int]:
|
||||||
|
"""列出技能"""
|
||||||
|
return SkillRepository.list_skills(
|
||||||
|
db, tenant_id, search, is_active, is_public, page, pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill:
|
||||||
|
"""更新技能"""
|
||||||
|
try:
|
||||||
|
skill = SkillRepository.update(db, skill_id, data, tenant_id)
|
||||||
|
if not skill:
|
||||||
|
raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(skill)
|
||||||
|
return skill
|
||||||
|
except (BusinessException, SQLAlchemyError) as e:
|
||||||
|
db.rollback()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
|
||||||
|
"""删除技能"""
|
||||||
|
try:
|
||||||
|
success = SkillRepository.delete(db, skill_id, tenant_id)
|
||||||
|
if not success:
|
||||||
|
raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND)
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
except (BusinessException, SQLAlchemyError) as e:
|
||||||
|
db.rollback()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]:
|
||||||
|
"""加载技能关联的工具
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(tools, tool_to_skill_map) - 工具列表和工具到技能的映射
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
tool_to_skill_map = {} # {tool_name: skill_id}
|
||||||
|
tool_service = ToolService(db)
|
||||||
|
|
||||||
|
for skill_id in skill_ids:
|
||||||
|
try:
|
||||||
|
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
||||||
|
if skill and skill.is_active:
|
||||||
|
# 加载技能关联的工具
|
||||||
|
for tool_config in skill.tools:
|
||||||
|
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||||
|
if tool:
|
||||||
|
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
|
||||||
|
tools.append(langchain_tool)
|
||||||
|
# 建立工具到技能的映射
|
||||||
|
tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool)))
|
||||||
|
tool_to_skill_map[tool_name] = skill_id
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载技能 {skill_id} 的工具时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return tools, tool_to_skill_map
|
||||||
288
api/app/tasks.py
288
api/app/tasks.py
@@ -1069,6 +1069,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
|||||||
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
db.rollback() # Rollback failed transaction to allow next query
|
||||||
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||||
all_reflection_results.append({
|
all_reflection_results.append({
|
||||||
"workspace_id": str(workspace_id),
|
"workspace_id": str(workspace_id),
|
||||||
@@ -1207,3 +1208,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
return result
|
return result
|
||||||
finally:
|
finally:
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Long-term Memory Storage Tasks (Batched Write Strategies)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True)
|
||||||
|
def long_term_storage_window_task(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
langchain_messages: List[Dict[str, Any]],
|
||||||
|
config_id: str,
|
||||||
|
scope: int = 6
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Celery task for window-based long-term memory storage.
|
||||||
|
|
||||||
|
Accumulates messages in Redis buffer until window size (scope) is reached,
|
||||||
|
then writes batched messages to Neo4j.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: End user identifier
|
||||||
|
langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||||
|
config_id: Memory configuration ID
|
||||||
|
scope: Window size (number of messages before triggering write)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing task status and metadata
|
||||||
|
"""
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
async def _run() -> Dict[str, Any]:
|
||||||
|
from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue
|
||||||
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
db = next(get_db())
|
||||||
|
try:
|
||||||
|
# Save to Redis buffer first
|
||||||
|
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||||
|
|
||||||
|
# Load memory config
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
service_name="LongTermStorageTask"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute window-based dialogue storage
|
||||||
|
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||||
|
|
||||||
|
return {"status": "SUCCESS", "strategy": "window", "scope": scope}
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_closed():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = loop.run_until_complete(_run())
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||||
|
|
||||||
|
return {
|
||||||
|
**result,
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"config_id": config_id,
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"task_id": self.request.id
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "FAILURE",
|
||||||
|
"strategy": "window",
|
||||||
|
"error": str(e),
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"config_id": config_id,
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"task_id": self.request.id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True)
|
||||||
|
# def long_term_storage_time_task(
|
||||||
|
# self,
|
||||||
|
# end_user_id: str,
|
||||||
|
# config_id: str,
|
||||||
|
# time_window: int = 5
|
||||||
|
# ) -> Dict[str, Any]:
|
||||||
|
# """Celery task for time-based long-term memory storage.
|
||||||
|
|
||||||
|
# Retrieves recent sessions from Redis within time window and writes to Neo4j.
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# end_user_id: End user identifier
|
||||||
|
# config_id: Memory configuration ID
|
||||||
|
# time_window: Time window in minutes for retrieving recent sessions
|
||||||
|
|
||||||
|
# Returns:
|
||||||
|
# Dict containing task status and metadata
|
||||||
|
# """
|
||||||
|
# from app.core.logging_config import get_logger
|
||||||
|
# logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}")
|
||||||
|
# start_time = time.time()
|
||||||
|
|
||||||
|
# async def _run() -> Dict[str, Any]:
|
||||||
|
# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage
|
||||||
|
# from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
# db = next(get_db())
|
||||||
|
# try:
|
||||||
|
# # Load memory config
|
||||||
|
# config_service = MemoryConfigService(db)
|
||||||
|
# memory_config = config_service.load_memory_config(
|
||||||
|
# config_id=config_id,
|
||||||
|
# service_name="LongTermStorageTask"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Execute time-based storage
|
||||||
|
# await memory_long_term_storage(end_user_id, memory_config, time_window)
|
||||||
|
|
||||||
|
# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window}
|
||||||
|
# finally:
|
||||||
|
# db.close()
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# import nest_asyncio
|
||||||
|
# nest_asyncio.apply()
|
||||||
|
# except ImportError:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# loop = asyncio.get_event_loop()
|
||||||
|
# if loop.is_closed():
|
||||||
|
# loop = asyncio.new_event_loop()
|
||||||
|
# asyncio.set_event_loop(loop)
|
||||||
|
# except RuntimeError:
|
||||||
|
# loop = asyncio.new_event_loop()
|
||||||
|
# asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# result = loop.run_until_complete(_run())
|
||||||
|
# elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||||
|
|
||||||
|
# return {
|
||||||
|
# **result,
|
||||||
|
# "end_user_id": end_user_id,
|
||||||
|
# "config_id": config_id,
|
||||||
|
# "elapsed_time": elapsed_time,
|
||||||
|
# "task_id": self.request.id
|
||||||
|
# }
|
||||||
|
# except Exception as e:
|
||||||
|
# elapsed_time = time.time() - start_time
|
||||||
|
# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True)
|
||||||
|
|
||||||
|
# return {
|
||||||
|
# "status": "FAILURE",
|
||||||
|
# "strategy": "time",
|
||||||
|
# "error": str(e),
|
||||||
|
# "end_user_id": end_user_id,
|
||||||
|
# "config_id": config_id,
|
||||||
|
# "elapsed_time": elapsed_time,
|
||||||
|
# "task_id": self.request.id
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True)
|
||||||
|
# def long_term_storage_aggregate_task(
|
||||||
|
# self,
|
||||||
|
# end_user_id: str,
|
||||||
|
# langchain_messages: List[Dict[str, Any]],
|
||||||
|
# config_id: str
|
||||||
|
# ) -> Dict[str, Any]:
|
||||||
|
# """Celery task for aggregate-based long-term memory storage.
|
||||||
|
|
||||||
|
# Uses LLM to determine if new messages describe the same event as history.
|
||||||
|
# Only writes to Neo4j if messages represent new information (not duplicates).
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# end_user_id: End user identifier
|
||||||
|
# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||||
|
# config_id: Memory configuration ID
|
||||||
|
|
||||||
|
# Returns:
|
||||||
|
# Dict containing task status, is_same_event flag, and metadata
|
||||||
|
# """
|
||||||
|
# from app.core.logging_config import get_logger
|
||||||
|
# logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}")
|
||||||
|
# start_time = time.time()
|
||||||
|
|
||||||
|
# async def _run() -> Dict[str, Any]:
|
||||||
|
# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment
|
||||||
|
# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||||
|
# from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
# from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
# db = next(get_db())
|
||||||
|
# try:
|
||||||
|
# # Save to Redis buffer first
|
||||||
|
# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||||
|
|
||||||
|
# # Load memory config
|
||||||
|
# config_service = MemoryConfigService(db)
|
||||||
|
# memory_config = config_service.load_memory_config(
|
||||||
|
# config_id=config_id,
|
||||||
|
# service_name="LongTermStorageTask"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Execute aggregate judgment
|
||||||
|
# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||||
|
|
||||||
|
# return {
|
||||||
|
# "status": "SUCCESS",
|
||||||
|
# "strategy": "aggregate",
|
||||||
|
# "is_same_event": result.get("is_same_event", False),
|
||||||
|
# "wrote_to_neo4j": not result.get("is_same_event", False)
|
||||||
|
# }
|
||||||
|
# finally:
|
||||||
|
# db.close()
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# import nest_asyncio
|
||||||
|
# nest_asyncio.apply()
|
||||||
|
# except ImportError:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# loop = asyncio.get_event_loop()
|
||||||
|
# if loop.is_closed():
|
||||||
|
# loop = asyncio.new_event_loop()
|
||||||
|
# asyncio.set_event_loop(loop)
|
||||||
|
# except RuntimeError:
|
||||||
|
# loop = asyncio.new_event_loop()
|
||||||
|
# asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# result = loop.run_until_complete(_run())
|
||||||
|
# elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s")
|
||||||
|
|
||||||
|
# return {
|
||||||
|
# **result,
|
||||||
|
# "end_user_id": end_user_id,
|
||||||
|
# "config_id": config_id,
|
||||||
|
# "elapsed_time": elapsed_time,
|
||||||
|
# "task_id": self.request.id
|
||||||
|
# }
|
||||||
|
# except Exception as e:
|
||||||
|
# elapsed_time = time.time() - start_time
|
||||||
|
# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True)
|
||||||
|
|
||||||
|
# return {
|
||||||
|
# "status": "FAILURE",
|
||||||
|
# "strategy": "aggregate",
|
||||||
|
# "error": str(e),
|
||||||
|
# "end_user_id": end_user_id,
|
||||||
|
# "config_id": config_id,
|
||||||
|
# "elapsed_time": elapsed_time,
|
||||||
|
# "task_id": self.request.id
|
||||||
|
# }
|
||||||
|
|||||||
50
api/migrations/versions/e7c7afa249d1_202602041355.py
Normal file
50
api/migrations/versions/e7c7afa249d1_202602041355.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""202602041355
|
||||||
|
|
||||||
|
Revision ID: e7c7afa249d1
|
||||||
|
Revises: 9def72f79398
|
||||||
|
Create Date: 2026-02-04 13:55:22.284981
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'e7c7afa249d1'
|
||||||
|
down_revision: Union[str, None] = '9def72f79398'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('skills',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=False, comment='技能名称'),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True, comment='技能描述'),
|
||||||
|
sa.Column('tenant_id', sa.UUID(), nullable=False, comment='租户ID'),
|
||||||
|
sa.Column('tools', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='关联的工具列表'),
|
||||||
|
sa.Column('config', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='技能配置'),
|
||||||
|
sa.Column('prompt', sa.Text(), nullable=True, comment='技能专属提示词'),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False, comment='是否激活'),
|
||||||
|
sa.Column('is_public', sa.Boolean(), nullable=False, comment='是否公开到市场'),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True, comment='更新时间'),
|
||||||
|
sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_skills_id'), 'skills', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_skills_tenant_id'), 'skills', ['tenant_id'], unique=False)
|
||||||
|
op.add_column('agent_configs', sa.Column('skill_ids', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='关联的技能ID列表'))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('agent_configs', 'skill_ids')
|
||||||
|
op.drop_index(op.f('ix_skills_tenant_id'), table_name='skills')
|
||||||
|
op.drop_index(op.f('ix_skills_id'), table_name='skills')
|
||||||
|
op.drop_table('skills')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -142,7 +142,7 @@ const PageScrollList = forwardRef(<T, Q = Record<string, unknown>>({
|
|||||||
dataLength={data.length}
|
dataLength={data.length}
|
||||||
next={loadMoreData}
|
next={loadMoreData}
|
||||||
hasMore={hasMore}
|
hasMore={hasMore}
|
||||||
loader={needLoading ? <PageLoading /> : undefined}
|
loader={loading && needLoading ? <PageLoading /> : false}
|
||||||
// endMessage={<Divider plain>It is all, nothing more 🤐</Divider>}
|
// endMessage={<Divider plain>It is all, nothing more 🤐</Divider>}
|
||||||
scrollableTarget="scrollableDiv"
|
scrollableTarget="scrollableDiv"
|
||||||
className='rb:h-full!'
|
className='rb:h-full!'
|
||||||
|
|||||||
@@ -180,7 +180,4 @@ body {
|
|||||||
.x6-node foreignObject > body {
|
.x6-node foreignObject > body {
|
||||||
min-height: 100%;
|
min-height: 100%;
|
||||||
max-height: 100%;
|
max-height: 100%;
|
||||||
}
|
|
||||||
#scrollableDiv .infinite-scroll-component__outerdiv {
|
|
||||||
height: 100%;
|
|
||||||
}
|
}
|
||||||
@@ -21,6 +21,7 @@ import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext
|
|||||||
import InitialValuePlugin from './plugin/InitialValuePlugin'
|
import InitialValuePlugin from './plugin/InitialValuePlugin'
|
||||||
import LineBreakPlugin from './plugin/LineBreakPlugin';
|
import LineBreakPlugin from './plugin/LineBreakPlugin';
|
||||||
import InsertTextPlugin from './plugin/InsertTextPlugin';
|
import InsertTextPlugin from './plugin/InsertTextPlugin';
|
||||||
|
import EditablePlugin from './plugin/EditablePlugin';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Editor ref methods exposed to parent components
|
* Editor ref methods exposed to parent components
|
||||||
@@ -50,6 +51,7 @@ interface LexicalEditorProps {
|
|||||||
onChange?: (value: string) => void;
|
onChange?: (value: string) => void;
|
||||||
/** Editor height in pixels */
|
/** Editor height in pixels */
|
||||||
height?: number;
|
height?: number;
|
||||||
|
disabled?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -71,6 +73,7 @@ const EditorContent = forwardRef<EditorRef, LexicalEditorProps>(({
|
|||||||
value,
|
value,
|
||||||
placeholder = "Please enter content...",
|
placeholder = "Please enter content...",
|
||||||
onChange,
|
onChange,
|
||||||
|
disabled
|
||||||
}, ref) => {
|
}, ref) => {
|
||||||
const [editor] = useLexicalComposerContext();
|
const [editor] = useLexicalComposerContext();
|
||||||
|
|
||||||
@@ -132,7 +135,11 @@ const EditorContent = forwardRef<EditorRef, LexicalEditorProps>(({
|
|||||||
<RichTextPlugin
|
<RichTextPlugin
|
||||||
contentEditable={
|
contentEditable={
|
||||||
<ContentEditable
|
<ContentEditable
|
||||||
className={clsx("rb:outline-none rb:resize-none rb:text-[14px] rb:leading-5 rb:px-4 rb:py-5 rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:overflow-auto", className)}
|
className={clsx(
|
||||||
|
"rb:outline-none rb:resize-none rb:text-[14px] rb:leading-5 rb:px-4 rb:py-5 rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:overflow-auto",
|
||||||
|
disabled && "rb:cursor-not-allowed rb:bg-[#F6F8FC] rb:text-[#5B6167]",
|
||||||
|
className
|
||||||
|
)}
|
||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
placeholder={
|
placeholder={
|
||||||
@@ -145,6 +152,7 @@ const EditorContent = forwardRef<EditorRef, LexicalEditorProps>(({
|
|||||||
<LineBreakPlugin onChange={onChange} />
|
<LineBreakPlugin onChange={onChange} />
|
||||||
<InitialValuePlugin value={value} />
|
<InitialValuePlugin value={value} />
|
||||||
<InsertTextPlugin />
|
<InsertTextPlugin />
|
||||||
|
<EditablePlugin disabled={disabled} />
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -158,6 +166,7 @@ const Editor = forwardRef<EditorRef, LexicalEditorProps>((props, ref) => {
|
|||||||
namespace: 'Editor',
|
namespace: 'Editor',
|
||||||
theme,
|
theme,
|
||||||
nodes: [],
|
nodes: [],
|
||||||
|
editable: !props.disabled,
|
||||||
onError: (error: Error) => {
|
onError: (error: Error) => {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
/*
|
||||||
|
* @Author: ZhaoYing
|
||||||
|
* @Date: 2026-02-04 11:20:49
|
||||||
|
* @Last Modified by: ZhaoYing
|
||||||
|
* @Last Modified time: 2026-02-04 11:20:49
|
||||||
|
*/
|
||||||
|
import { useEffect } from 'react';
|
||||||
|
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Props for the EditablePlugin component
|
||||||
|
*/
|
||||||
|
interface EditablePluginProps {
|
||||||
|
/** Whether the editor should be disabled (read-only mode) */
|
||||||
|
disabled?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* EditablePlugin - A Lexical editor plugin that controls the editable state of the editor
|
||||||
|
*
|
||||||
|
* This plugin allows you to dynamically toggle between editable and read-only modes.
|
||||||
|
* When disabled is true, the editor becomes read-only and users cannot modify content.
|
||||||
|
* When disabled is false or undefined, the editor is fully editable.
|
||||||
|
*
|
||||||
|
* @param {EditablePluginProps} props - Component props
|
||||||
|
* @param {boolean} [props.disabled] - Controls whether the editor is in read-only mode
|
||||||
|
* @returns {null} This plugin doesn't render any UI elements
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```tsx
|
||||||
|
* <LexicalComposer>
|
||||||
|
* <EditablePlugin disabled={isReadOnly} />
|
||||||
|
* </LexicalComposer>
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export default function EditablePlugin({ disabled }: EditablePluginProps) {
|
||||||
|
// Get the editor instance from Lexical composer context
|
||||||
|
const [editor] = useLexicalComposerContext();
|
||||||
|
|
||||||
|
// Update editor's editable state whenever the disabled prop changes
|
||||||
|
useEffect(() => {
|
||||||
|
// Set editor to editable when disabled is false, read-only when disabled is true
|
||||||
|
editor.setEditable(!disabled);
|
||||||
|
}, [editor, disabled]);
|
||||||
|
|
||||||
|
// This plugin doesn't render any UI, it only manages editor state
|
||||||
|
return null;
|
||||||
|
}
|
||||||
@@ -156,9 +156,9 @@ const Prompt: FC<{ editVo: HistoryItem | null; refresh: () => void; }> = ({ edit
|
|||||||
currentPromptValueRef.current = undefined;
|
currentPromptValueRef.current = undefined;
|
||||||
setChatList([])
|
setChatList([])
|
||||||
refresh()
|
refresh()
|
||||||
|
updateSession()
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log(values)
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Form form={form}>
|
<Form form={form}>
|
||||||
@@ -217,12 +217,13 @@ const Prompt: FC<{ editVo: HistoryItem | null; refresh: () => void; }> = ({ edit
|
|||||||
ref={editorRef}
|
ref={editorRef}
|
||||||
placeholder={t('prompt.promptPlaceholder')}
|
placeholder={t('prompt.promptPlaceholder')}
|
||||||
className="rb:h-[calc(100vh-260px)]"
|
className="rb:h-[calc(100vh-260px)]"
|
||||||
|
disabled={loading}
|
||||||
// onChange={(value) => form.setFieldValue('current_prompt', value)}
|
// onChange={(value) => form.setFieldValue('current_prompt', value)}
|
||||||
/>
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<div className="rb:grid rb:grid-cols-2 rb:gap-4 rb:mt-6">
|
<div className="rb:grid rb:grid-cols-2 rb:gap-4 rb:mt-6">
|
||||||
<Button type="primary" block disabled={!values?.current_prompt} onClick={handleSave}>{t('common.save')}</Button>
|
<Button type="primary" block disabled={!values?.current_prompt || loading} onClick={handleSave}>{t('common.save')}</Button>
|
||||||
<Button block disabled={!values?.current_prompt} onClick={handleCopy}>{t('common.copy')}</Button>
|
<Button block disabled={!values?.current_prompt || loading} onClick={handleCopy}>{t('common.copy')}</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -103,6 +103,8 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
|
|||||||
}).catch(() => {
|
}).catch(() => {
|
||||||
handleUpdate(formData)
|
handleUpdate(formData)
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
handleUpdate(formData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -158,6 +160,7 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
|
|||||||
label={t('space.spaceIcon')}
|
label={t('space.spaceIcon')}
|
||||||
valuePropName="fileList"
|
valuePropName="fileList"
|
||||||
hidden={currentStep === 1}
|
hidden={currentStep === 1}
|
||||||
|
rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('space.spaceIcon') }) }]}
|
||||||
>
|
>
|
||||||
<UploadImages />
|
<UploadImages />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|||||||
@@ -242,7 +242,7 @@ const Editor: FC<LexicalEditorProps> =({
|
|||||||
{enableLineNumbers && <LineNumberPlugin />}
|
{enableLineNumbers && <LineNumberPlugin />}
|
||||||
<AutocompletePlugin options={options} enableJinja2={enableJinja2} />
|
<AutocompletePlugin options={options} enableJinja2={enableJinja2} />
|
||||||
<CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} />
|
<CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} />
|
||||||
<InitialValuePlugin value={value} options={options} enableJinja2={enableJinja2} />
|
<InitialValuePlugin key={language} value={value} options={options} enableLineNumbers={enableLineNumbers} />
|
||||||
{enableLineNumbers && <BlurPlugin />}
|
{enableLineNumbers && <BlurPlugin />}
|
||||||
</div>
|
</div>
|
||||||
</LexicalComposer>
|
</LexicalComposer>
|
||||||
|
|||||||
@@ -16,6 +16,12 @@ export default function BlurPlugin() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否是粘贴操作导致的焦点变化
|
||||||
|
const relatedTarget = e.relatedTarget as HTMLElement;
|
||||||
|
if (!relatedTarget || relatedTarget === document.body) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
editor.update(() => {
|
editor.update(() => {
|
||||||
$setSelection(null);
|
$setSelection(null);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ import { type Suggestion } from '../plugin/AutocompletePlugin'
|
|||||||
interface InitialValuePluginProps {
|
interface InitialValuePluginProps {
|
||||||
value: string;
|
value: string;
|
||||||
options?: Suggestion[];
|
options?: Suggestion[];
|
||||||
enableJinja2?: boolean;
|
enableLineNumbers?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [], enableJinja2 = false }) => {
|
const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [], enableLineNumbers = false }) => {
|
||||||
const [editor] = useLexicalComposerContext();
|
const [editor] = useLexicalComposerContext();
|
||||||
const prevValueRef = useRef<string>('');
|
const prevValueRef = useRef<string>('');
|
||||||
|
const prevEnableLineNumbersRef = useRef<boolean>(enableLineNumbers);
|
||||||
const isUserInputRef = useRef(false);
|
const isUserInputRef = useRef(false);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -32,7 +33,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
|||||||
}, [editor]);
|
}, [editor]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (value !== prevValueRef.current && !isUserInputRef.current) {
|
if ((value !== prevValueRef.current || enableLineNumbers !== prevEnableLineNumbersRef.current) && !isUserInputRef.current) {
|
||||||
queueMicrotask(() => {
|
queueMicrotask(() => {
|
||||||
editor.update(() => {
|
editor.update(() => {
|
||||||
const root = $getRoot();
|
const root = $getRoot();
|
||||||
@@ -40,7 +41,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
|||||||
|
|
||||||
const parts = value.split(/(\{\{[^}]+\}\})/);
|
const parts = value.split(/(\{\{[^}]+\}\})/);
|
||||||
|
|
||||||
if (enableJinja2) {
|
if (enableLineNumbers) {
|
||||||
// Handle newlines properly in Jinja2 mode
|
// Handle newlines properly in Jinja2 mode
|
||||||
const lines = value.split('\n');
|
const lines = value.split('\n');
|
||||||
lines.forEach((line) => {
|
lines.forEach((line) => {
|
||||||
@@ -104,8 +105,9 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
|||||||
}
|
}
|
||||||
|
|
||||||
prevValueRef.current = value;
|
prevValueRef.current = value;
|
||||||
|
prevEnableLineNumbersRef.current = enableLineNumbers;
|
||||||
isUserInputRef.current = false;
|
isUserInputRef.current = false;
|
||||||
}, [value, options, editor, enableJinja2]);
|
}, [value, options, editor, enableLineNumbers]);
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { useEffect } from 'react';
|
import { useEffect, useRef } from 'react';
|
||||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||||
import { TextNode, $createTextNode, $getSelection, $isRangeSelection } from 'lexical';
|
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
|
||||||
|
|
||||||
const JS_KEYWORDS = new Set([
|
const JS_KEYWORDS = new Set([
|
||||||
'async', 'await', 'break', 'case', 'catch', 'class', 'const', 'continue', 'debugger', 'default',
|
'async', 'await', 'break', 'case', 'catch', 'class', 'const', 'continue', 'debugger', 'default',
|
||||||
@@ -11,13 +11,31 @@ const JS_KEYWORDS = new Set([
|
|||||||
|
|
||||||
const JavaScriptHighlightPlugin = () => {
|
const JavaScriptHighlightPlugin = () => {
|
||||||
const [editor] = useLexicalComposerContext();
|
const [editor] = useLexicalComposerContext();
|
||||||
|
const isPastingRef = useRef(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
return editor.registerCommand(
|
||||||
|
PASTE_COMMAND,
|
||||||
|
() => {
|
||||||
|
isPastingRef.current = true;
|
||||||
|
setTimeout(() => {
|
||||||
|
isPastingRef.current = false;
|
||||||
|
}, 100);
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
COMMAND_PRIORITY_LOW
|
||||||
|
);
|
||||||
|
}, [editor]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
|
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
|
||||||
|
if (isPastingRef.current) return;
|
||||||
|
|
||||||
const text = textNode.getTextContent();
|
const text = textNode.getTextContent();
|
||||||
|
|
||||||
if (textNode.hasFormat('code')) return;
|
if (textNode.hasFormat('code')) return;
|
||||||
if (!needsHighlight(text)) return;
|
if (!needsHighlight(text)) return;
|
||||||
|
if (textNode.getStyle()) return;
|
||||||
|
|
||||||
const parent = textNode.getParent();
|
const parent = textNode.getParent();
|
||||||
if (!parent) return;
|
if (!parent) return;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { useEffect } from 'react';
|
import { useEffect, useRef } from 'react';
|
||||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||||
import { TextNode, $createTextNode, $getSelection, $isRangeSelection } from 'lexical';
|
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
|
||||||
|
|
||||||
const PYTHON_KEYWORDS = new Set([
|
const PYTHON_KEYWORDS = new Set([
|
||||||
'False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue',
|
'False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue',
|
||||||
@@ -11,12 +11,30 @@ const PYTHON_KEYWORDS = new Set([
|
|||||||
|
|
||||||
const Python3HighlightPlugin = () => {
|
const Python3HighlightPlugin = () => {
|
||||||
const [editor] = useLexicalComposerContext();
|
const [editor] = useLexicalComposerContext();
|
||||||
|
const isPastingRef = useRef(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
return editor.registerCommand(
|
||||||
|
PASTE_COMMAND,
|
||||||
|
() => {
|
||||||
|
isPastingRef.current = true;
|
||||||
|
setTimeout(() => {
|
||||||
|
isPastingRef.current = false;
|
||||||
|
}, 100);
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
COMMAND_PRIORITY_LOW
|
||||||
|
);
|
||||||
|
}, [editor]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
|
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
|
||||||
|
if (isPastingRef.current) return;
|
||||||
|
|
||||||
const text = textNode.getTextContent();
|
const text = textNode.getTextContent();
|
||||||
|
|
||||||
if (textNode.hasFormat('code')) return;
|
if (textNode.hasFormat('code')) return;
|
||||||
|
if (textNode.getStyle()) return;
|
||||||
if (!needsHighlight(text)) return;
|
if (!needsHighlight(text)) return;
|
||||||
|
|
||||||
const parent = textNode.getParent();
|
const parent = textNode.getParent();
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ const codeTemplate = {
|
|||||||
const CodeExecution: FC<CodeExecutionProps> = ({ options }) => {
|
const CodeExecution: FC<CodeExecutionProps> = ({ options }) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const form = Form.useFormInstance()
|
const form = Form.useFormInstance()
|
||||||
const values = Form.useWatch([], form) || {}
|
|
||||||
|
|
||||||
const handleRefresh = () => {
|
const handleRefresh = () => {
|
||||||
const code = form.getFieldValue('code') || ''
|
const code = form.getFieldValue('code') || ''
|
||||||
@@ -66,7 +65,6 @@ const CodeExecution: FC<CodeExecutionProps> = ({ options }) => {
|
|||||||
form.setFieldValue('code', newTemplate)
|
form.setFieldValue('code', newTemplate)
|
||||||
}
|
}
|
||||||
const handleChangeLanguage = (value: string) => {
|
const handleChangeLanguage = (value: string) => {
|
||||||
form.setFieldValue('code', codeTemplate[value as keyof typeof codeTemplate])
|
|
||||||
form.setFieldsValue({
|
form.setFieldsValue({
|
||||||
input_variables: [{ name: 'arg1' }, { name: 'arg2' }],
|
input_variables: [{ name: 'arg1' }, { name: 'arg2' }],
|
||||||
code: codeTemplate[value as keyof typeof codeTemplate]
|
code: codeTemplate[value as keyof typeof codeTemplate]
|
||||||
@@ -109,8 +107,12 @@ const CodeExecution: FC<CodeExecutionProps> = ({ options }) => {
|
|||||||
</Form.Item>
|
</Form.Item>
|
||||||
</Col>
|
</Col>
|
||||||
</Row>
|
</Row>
|
||||||
<Form.Item name="code" noStyle>
|
<Form.Item noStyle shouldUpdate={(prev, curr) => prev.language !== curr.language}>
|
||||||
<Editor size="small" language={values.language} />
|
{() => (
|
||||||
|
<Form.Item name="code" noStyle>
|
||||||
|
<Editor size="small" language={form.getFieldValue('language')} />
|
||||||
|
</Form.Item>
|
||||||
|
)}
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</Space>
|
</Space>
|
||||||
|
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ export const useWorkflowGraph = ({
|
|||||||
nodeLibraryConfig.config[key].defaultValue = Object.entries(config[key]).map(([name, value]) => ({ name, value }))
|
nodeLibraryConfig.config[key].defaultValue = Object.entries(config[key]).map(([name, value]) => ({ name, value }))
|
||||||
} else if (type === 'code' && key === 'code' && config[key] && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
|
} else if (type === 'code' && key === 'code' && config[key] && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
|
||||||
try {
|
try {
|
||||||
nodeLibraryConfig.config[key].defaultValue = atob(config[key] as string)
|
nodeLibraryConfig.config[key].defaultValue = decodeURIComponent(atob(config[key] as string))
|
||||||
} catch {
|
} catch {
|
||||||
nodeLibraryConfig.config[key].defaultValue = config[key]
|
nodeLibraryConfig.config[key].defaultValue = config[key]
|
||||||
}
|
}
|
||||||
@@ -943,7 +943,7 @@ export const useWorkflowGraph = ({
|
|||||||
const code = data.config[key].defaultValue || ''
|
const code = data.config[key].defaultValue || ''
|
||||||
itemConfig = {
|
itemConfig = {
|
||||||
...itemConfig,
|
...itemConfig,
|
||||||
code: btoa(code || '')
|
code: btoa(encodeURIComponent(code || ''))
|
||||||
}
|
}
|
||||||
} else if (key === 'memory' && data.config[key] && 'defaultValue' in data.config[key]) {
|
} else if (key === 'memory' && data.config[key] && 'defaultValue' in data.config[key]) {
|
||||||
const { messages, ...rest } = data.config[key].defaultValue
|
const { messages, ...rest } = data.config[key].defaultValue
|
||||||
|
|||||||
Reference in New Issue
Block a user