Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

# Conflicts:
#	api/app/core/agent/langchain_agent.py
This commit is contained in:
Mark
2026-02-04 15:51:44 +08:00
46 changed files with 1219 additions and 117 deletions

View File

@@ -7,6 +7,10 @@ from celery import Celery
from app.core.config import settings
# macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0
# backend: 结果存储(使用 Redis DB 10
@@ -64,6 +68,11 @@ celery_app.conf.update(
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
# Long-term storage tasks → memory_tasks queue (batched write strategies)
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},

View File

@@ -45,6 +45,7 @@ from . import (
memory_perceptual_controller,
memory_working_controller,
ontology_controller,
skill_controller
)
# 创建管理端 API 路由器
@@ -90,5 +91,6 @@ manager_router.include_router(memory_perceptual_controller.router)
manager_router.include_router(memory_working_controller.router)
manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router)
manager_router.include_router(skill_controller.router)
__all__ = ["manager_router"]

View File

@@ -116,14 +116,6 @@ def _get_ontology_service(
detail=f"找不到指定的LLM模型: {llm_id}"
)
# 检查是否为组合模型
if hasattr(model_config, 'is_composite') and model_config.is_composite:
logger.error(f"Model {llm_id} is a composite model, which is not supported for ontology extraction")
raise HTTPException(
status_code=400,
detail="本体提取不支持使用组合模型,请选择单个模型"
)
# 验证模型配置了API密钥
if not model_config.api_keys:
logger.error(f"Model {llm_id} has no API key configuration")

View File

@@ -0,0 +1,90 @@
"""Skill Controller - 技能市场管理"""
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from typing import Optional
import uuid
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import User
from app.schemas import skill_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services.skill_service import SkillService
from app.core.response_utils import success
router = APIRouter(prefix="/skills", tags=["Skills"])
@router.post("", summary="创建技能")
@cur_workspace_access_guard()
def create_skill(
data: skill_schema.SkillCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建技能 - 可以关联现有工具内置、MCP、自定义"""
tenant_id = current_user.tenant_id
skill = SkillService.create_skill(db, data, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
@router.get("", summary="技能列表")
@cur_workspace_access_guard()
def list_skills(
search: Optional[str] = Query(None, description="搜索关键词"),
is_active: Optional[bool] = Query(None, description="是否激活"),
is_public: Optional[bool] = Query(None, description="是否公开"),
page: int = Query(1, ge=1, description="页码"),
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""技能市场列表 - 包含本工作空间和公开的技能"""
tenant_id = current_user.tenant_id
skills, total = SkillService.list_skills(
db, tenant_id, search, is_active, is_public, page, pagesize
)
items = [skill_schema.Skill.model_validate(s) for s in skills]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
@router.get("/{skill_id}", summary="获取技能详情")
@cur_workspace_access_guard()
def get_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取技能详情"""
tenant_id = current_user.tenant_id
skill = SkillService.get_skill(db, skill_id, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
@router.put("/{skill_id}", summary="更新技能")
@cur_workspace_access_guard()
def update_skill(
skill_id: uuid.UUID,
data: skill_schema.SkillUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新技能"""
tenant_id = current_user.tenant_id
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
@router.delete("/{skill_id}", summary="删除技能")
@cur_workspace_access_guard()
def delete_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除技能"""
tenant_id = current_user.tenant_id
SkillService.delete_skill(db, skill_id, tenant_id)
return success(msg="技能删除成功")

View File

@@ -0,0 +1,151 @@
"""Agent Middleware - 动态技能过滤"""
import uuid
from typing import List, Dict, Any, Optional
from langchain_core.runnables import RunnablePassthrough
from app.services.skill_service import SkillService
from app.repositories.skill_repository import SkillRepository
class AgentMiddleware:
"""Agent 中间件 - 用于动态过滤和加载技能"""
def __init__(self, skill_ids: Optional[List[str]] = None):
"""
初始化中间件
Args:
skill_ids: 技能ID列表
"""
self.skill_ids = skill_ids or []
@staticmethod
def filter_tools(
tools: List,
message: str = "",
skill_configs: Dict[str, Any] = None,
tool_to_skill_map: Dict[str, str] = None
) -> tuple[List, List[str]]:
"""
根据消息内容和技能配置动态过滤工具
Args:
tools: 所有可用工具列表
message: 用户消息(可用于智能过滤)
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
Returns:
(过滤后的工具列表, 激活的技能ID列表)
"""
if not tools:
return [], []
# 如果没有技能配置,返回所有工具
if not skill_configs:
return tools, []
# 基于关键词匹配激活技能
activated_skill_ids = []
message_lower = message.lower()
for skill_id, config in skill_configs.items():
if not config.get('enabled', True):
continue
keywords = config.get('keywords', [])
# 如果没有关键词限制,或消息包含关键词,则激活该技能
if not keywords or any(kw.lower() in message_lower for kw in keywords):
activated_skill_ids.append(skill_id)
# 如果没有工具映射关系,返回所有工具
if not tool_to_skill_map:
return tools, activated_skill_ids
# 根据激活的技能过滤工具
filtered_tools = []
for tool in tools:
tool_name = getattr(tool, 'name', str(id(tool)))
# 如果工具不属于任何skillbase_tools或者工具所属的skill被激活则保留
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
filtered_tools.append(tool)
return filtered_tools, activated_skill_ids
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
"""
加载技能关联的工具
Args:
db: 数据库会话
tenant_id: 租户id
base_tools: 基础工具列表
Returns:
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
"""
tools_dict = {}
tool_to_skill_map = {} # 工具名称到技能ID的映射
if base_tools:
for tool in base_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
tools_dict[tool_name] = tool
# base_tools 不属于任何 skill不加入映射
skill_configs = {}
if self.skill_ids:
for skill_id in self.skill_ids:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
if skill and skill.is_active:
# 保存技能配置包含prompt
config = skill.config or {}
config['prompt'] = skill.prompt
config['name'] = skill.name
skill_configs[skill_id] = config
except Exception:
continue
# 加载技能工具并获取映射关系
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, self.skill_ids, tenant_id)
# 只添加不冲突的 skill_tools
for tool in skill_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
if tool_name not in tools_dict:
tools_dict[tool_name] = tool
# 复制映射关系
if tool_name in skill_tool_map:
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
return list(tools_dict.values()), skill_configs, tool_to_skill_map
@staticmethod
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
"""
根据激活的技能ID获取对应的提示词
Args:
activated_skill_ids: 被激活的技能ID列表
skill_configs: 技能配置字典
Returns:
合并后的提示词
"""
prompts = []
for skill_id in activated_skill_ids:
config = skill_configs.get(skill_id, {})
prompt = config.get('prompt')
name = config.get('name', 'Skill')
if prompt:
prompts.append(f"# {name}\n{prompt}")
return "\n\n".join(prompts) if prompts else ""
@staticmethod
def create_runnable():
"""创建可运行的中间件"""
return RunnablePassthrough()

View File

@@ -291,6 +291,7 @@ class LangChainAgent:
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db())
#TODO: 魔法数字
scope=6
try:
@@ -300,6 +301,12 @@ class LangChainAgent:
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
# Handle case where no session exists in Redis (returns False)
if not result or result is False:
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
return
if type=="chunk" or type=="aggregate":
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
@@ -307,7 +314,14 @@ class LangChainAgent:
repo.upsert(end_user_id, chunk_data)
logger.info(f'写入短长期:')
else:
# TODO: This branch handles type="time" strategy, currently unused.
# Will be activated when time-based long-term storage is implemented.
# TODO: 魔法数字 - extract 5 to a constant
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
return
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
@@ -507,9 +521,12 @@ class LangChainAgent:
elapsed_time = time.time() - start_time
if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
'''长期'''
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
response = {
"content": content,
@@ -693,9 +710,13 @@ class LangChainAgent:
yield total_tokens
break
if memory_flag:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
except Exception as e:

View File

@@ -215,6 +215,9 @@ class Settings:
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
# model square loading
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
# workflow config
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))

View File

@@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
# TODO删除
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
print(contents)
@@ -60,6 +61,7 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
scope窗口大小
'''
scope=scope
redis_messages = []
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
@@ -91,6 +93,9 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
memory_config: 内存配置对象
'''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
return
format_messages = await chat_data_format(long_time_data)
if format_messages!=[]:
await write_messages(end_user_id, format_messages, memory_config)
@@ -108,8 +113,9 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
try:
# 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
if not result:
# Handle case where no session exists in Redis (returns False or empty)
if not result or result is False:
history = []
else:
history = await format_parsing(result)

View File

@@ -1,18 +1,14 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
@@ -40,27 +36,55 @@ async def make_write_graph():
yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
"""Dispatch long-term memory storage to Celery background tasks.
Args:
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
langchain_messages: List of messages to store
memory_config: Memory configuration ID (string)
end_user_id: End user identifier
scope: Window size for 'chunk' strategy (default: 6)
"""
from app.tasks import (
long_term_storage_window_task,
# TODO: Uncomment when implemented
# long_term_storage_time_task,
# long_term_storage_aggregate_task,
)
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# Convert config to string if needed
config_id = str(memory_config) if memory_config else ''
if long_term_type == 'chunk':
# Strategy 1: Window-based batching (6 rounds of dialogue)
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
long_term_storage_window_task.delay(
end_user_id=end_user_id,
langchain_messages=langchain_messages,
config_id=config_id,
scope=scope
)
# TODO: Uncomment when time-based strategy is fully implemented
# elif long_term_type == 'time':
# # Strategy 2: Time-based retrieval
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
# long_term_storage_time_task.delay(
# end_user_id=end_user_id,
# config_id=config_id,
# time_window=5
# )
# TODO: Uncomment when aggregate strategy is fully implemented
# elif long_term_type == 'aggregate':
# # Strategy 3: Aggregate judgment (deduplication)
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
# long_term_storage_aggregate_task.delay(
# end_user_id=end_user_id,
# langchain_messages=langchain_messages,
# config_id=config_id
# )
# async def main():

View File

@@ -1,5 +1,4 @@
provider: bedrock
enabled: false
models:
- name: ai21
type: llm

View File

@@ -1,5 +1,4 @@
provider: dashscope
enabled: false
models:
- name: deepseek-r1-distill-qwen-14b
type: llm

View File

@@ -1,11 +1,11 @@
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
import os
from pathlib import Path
from typing import Callable
import yaml
from sqlalchemy.orm import Session
from app.models.models_model import ModelBase, ModelProvider
@@ -19,31 +19,9 @@ def _load_yaml_config(provider: ModelProvider) -> list[dict]:
with open(config_file, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
# 检查是否需要加载(默认为 true
if not data.get('enabled', True):
return []
return data.get('models', [])
def _disable_yaml_config(provider: ModelProvider) -> None:
"""将YAML文件的enabled标志设置为false"""
config_dir = Path(__file__).parent
config_file = config_dir / f"{provider.value}_models.yaml"
if not config_file.exists():
return
with open(config_file, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
data['enabled'] = False
with open(config_file, 'w', encoding='utf-8') as f:
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
"""
加载模型配置到数据库
@@ -75,8 +53,7 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
if not silent:
print(f"\n正在加载 {provider.value}{len(models)} 个模型...")
# provider_success = 0
for model_data in models:
try:
# 检查模型是否已存在
@@ -93,7 +70,6 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
if not silent:
print(f"更新成功: {model_data['name']}")
result["success"] += 1
# provider_success += 1
else:
# 创建新模型
model = ModelBase(**model_data)
@@ -102,17 +78,12 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
if not silent:
print(f"添加成功: {model_data['name']}")
result["success"] += 1
# provider_success += 1
except Exception as e:
db.rollback()
if not silent:
print(f"添加失败: {model_data['name']} - {str(e)}")
result["failed"] += 1
# 如果该供应商的模型全部加载成功将enabled设置为false
# if provider_success == len(models):
_disable_yaml_config(provider)
return result

View File

@@ -1,5 +1,4 @@
provider: openai
enabled: false
models:
- name: chatgpt-4o-latest
type: llm

View File

@@ -2,6 +2,7 @@ import base64
import json
import logging
import re
import urllib.parse
from string import Template
from textwrap import dedent
from typing import Any

View File

@@ -50,13 +50,16 @@ async def lifespan(app: FastAPI):
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
# 加载预定义模型
logger.info("开始加载预定义模型...")
try:
with get_db_context() as db:
result = load_models(db, silent=True)
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}")
except Exception as e:
logger.warning(f"加载预定义模型时出错: {str(e)}")
if settings.LOAD_MODEL:
logger.info("开始加载预定义模型...")
try:
with get_db_context() as db:
result = load_models(db, silent=True)
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}")
except Exception as e:
logger.warning(f"加载预定义模型时出错: {str(e)}")
else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
logger.info("应用程序启动完成")
yield

View File

@@ -28,6 +28,7 @@ from .tool_model import (
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
)
from .memory_perceptual_model import MemoryPerceptualModel
from .skill_model import Skill
from .ontology_scene import OntologyScene
from .ontology_class import OntologyClass
from .ontology_scene import OntologyScene
@@ -84,5 +85,6 @@ __all__ = [
"ExecutionStatus",
"MemoryPerceptualModel",
"ModelBase",
"LoadBalanceStrategy"
"LoadBalanceStrategy",
"Skill"
]

View File

@@ -30,6 +30,7 @@ class AgentConfig(Base):
memory = Column(JSON, nullable=True, comment="记忆配置")
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
skill_ids = Column(JSON, default=list, nullable=True, comment="关联的技能ID列表")
# 多 Agent 相关字段
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")

View File

@@ -0,0 +1,37 @@
"""Skill 模型定义"""
import datetime
import uuid
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
from sqlalchemy.dialects.postgresql import UUID, JSON
from app.db import Base
class Skill(Base):
"""技能模型 - 可以关联工具内置、MCP、自定义"""
__tablename__ = "skills"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
name = Column(String, nullable=False, comment="技能名称")
description = Column(Text, comment="技能描述")
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID")
# 关联的工具
tools = Column(JSON, default=list, comment="关联的工具列表")
# 技能配置
config = Column(JSON, default=dict, comment="技能配置")
# 专属提示词
prompt = Column(Text, comment="技能专属提示词")
# 状态
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
is_public = Column(Boolean, default=False, nullable=False, comment="是否公开到市场")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
def __repr__(self):
return f"<Skill(id={self.id}, name={self.name})>"

View File

@@ -235,6 +235,8 @@ class MemoryConfigRepository:
llm_id=params.llm_id,
embedding_id=params.embedding_id,
rerank_id=params.rerank_id,
reflection_model_id=params.reflection_model_id,
emotion_model_id=params.emotion_model_id,
)
db.add(db_config)
db.flush() # 获取自增ID但不提交事务

View File

@@ -0,0 +1,111 @@
"""Skill Repository"""
from typing import List, Optional, Tuple, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
import uuid
from app.models.skill_model import Skill
from app.schemas.skill_schema import SkillCreate, SkillUpdate
class SkillRepository:
"""Skill 数据访问层"""
@staticmethod
def create(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
"""创建技能"""
skill = Skill(
**data.model_dump(),
tenant_id=tenant_id
)
db.add(skill)
db.flush()
return skill
@staticmethod
def get_by_id(db: Session, skill_id: uuid.UUID, tenant_id: Optional[uuid.UUID] = None) -> Optional[Skill]:
"""根据ID获取技能"""
query = db.query(Skill).filter(Skill.id == skill_id)
if tenant_id:
query = query.filter(
or_(
Skill.tenant_id == tenant_id,
Skill.is_public == True
)
)
return query.first()
@staticmethod
def list_skills(
db: Session,
tenant_id: uuid.UUID,
search: Optional[str] = None,
is_active: Optional[bool] = None,
is_public: Optional[bool] = None,
page: int = 1,
pagesize: int = 10
) -> tuple[list[type[Skill]], int]:
"""列出技能"""
filters = [
or_(
Skill.tenant_id == tenant_id,
Skill.is_public == True
)
]
if search:
filters.append(
or_(
Skill.name.ilike(f"%{search}%"),
# Skill.description.ilike(f"%{search}%")
)
)
if is_active is not None:
filters.append(Skill.is_active == is_active)
if is_public is not None:
filters.append(Skill.is_public == is_public)
query = db.query(Skill).filter(and_(*filters))
total = query.count()
skills = query.order_by(Skill.created_at.desc()).offset(
(page - 1) * pagesize
).limit(pagesize).all()
return skills, total
@staticmethod
def update(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Optional[Skill]:
"""更新技能"""
skill = db.query(Skill).filter(
Skill.id == skill_id,
Skill.tenant_id == tenant_id
).first()
if not skill:
return None
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(skill, key, value)
db.flush()
return skill
@staticmethod
def delete(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
"""删除技能"""
skill = db.query(Skill).filter(
Skill.id == skill_id,
Skill.tenant_id == tenant_id
).first()
if not skill:
return False
# db.delete(skill)
skill.is_active = False
db.flush()
return True

View File

@@ -156,6 +156,9 @@ class AgentConfigCreate(BaseModel):
description="Agent 可用的工具列表"
)
# 技能配置
skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表")
class AppCreate(BaseModel):
name: str
@@ -207,6 +210,9 @@ class AgentConfigUpdate(BaseModel):
# 工具配置
tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表")
# 技能配置
skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表")
# ---------- Output Schemas ----------
@@ -266,6 +272,8 @@ class AgentConfig(BaseModel):
# 工具配置
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
skill_ids: Optional[List[str]] = []
is_active: bool
created_at: datetime.datetime
updated_at: datetime.datetime

View File

@@ -236,6 +236,8 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致")
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致")
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)

View File

@@ -0,0 +1,57 @@
"""Skill Schema 定义"""
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field, field_serializer
import uuid
from datetime import datetime
class SkillBase(BaseModel):
"""Skill 基础 Schema"""
name: str = Field(..., description="技能名称")
description: Optional[str] = Field(None, description="技能描述")
tools: List[Dict[str, str]] = Field(default_factory=list, description="工具对象列表: [{\"tool_id\": \"xxx\", \"operation\": \"yyy\"}]")
config: Dict[str, Any] = Field(default_factory=dict, description="技能配置")
prompt: Optional[str] = Field(None, description="技能专属提示词")
is_active: bool = Field(True, description="是否激活")
is_public: bool = Field(False, description="是否公开到市场")
class SkillCreate(SkillBase):
"""创建 Skill"""
pass
class SkillUpdate(BaseModel):
"""更新 Skill"""
name: Optional[str] = None
description: Optional[str] = None
tools: Optional[List[Dict[str, str]]] = None
config: Optional[Dict[str, Any]] = None
prompt: Optional[str] = None
is_active: Optional[bool] = None
is_public: Optional[bool] = None
class Skill(SkillBase):
"""Skill 响应 Schema"""
id: uuid.UUID
tenant_id: uuid.UUID
created_at: datetime
updated_at: datetime
@field_serializer('created_at', 'updated_at')
def serialize_datetime_to_timestamp(self, value: datetime) -> int:
"""(毫秒级)时间戳"""
return int(value.timestamp() * 1000)
class Config:
from_attributes = True
class SkillQuery(BaseModel):
"""Skill 查询参数"""
search: Optional[str] = None
is_active: Optional[bool] = None
is_public: Optional[bool] = None
page: int = Field(1, ge=1)
pagesize: int = Field(10, ge=1, le=100)

View File

@@ -48,6 +48,9 @@ class AgentConfigConverter:
# 5. 工具配置
if hasattr(config, 'tools') and config.tools:
result["tools"] = [tool.model_dump() for tool in config.tools]
if hasattr(config, "skill_ids") and config.skill_ids:
result["skill_ids"] = [skill for skill in config.skill_ids]
return result
@@ -58,6 +61,7 @@ class AgentConfigConverter:
memory: Optional[Dict[str, Any]],
variables: Optional[list],
tools: Optional[Union[list, Dict[str, Any]]],
skill_ids: Optional[list]
) -> Dict[str, Any]:
"""
将数据库存储格式转换为 Pydantic 对象
@@ -68,6 +72,7 @@ class AgentConfigConverter:
memory: 记忆配置
variables: 变量配置
tools: 工具配置
skill_ids: 技能 ID 列表
Returns:
包含 Pydantic 对象的字典
@@ -78,6 +83,7 @@ class AgentConfigConverter:
"memory": MemoryConfig(enabled=True),
"variables": [],
"tools": [],
"skill_ids": []
}
# 1. 解析模型参数配置
@@ -117,5 +123,8 @@ class AgentConfigConverter:
name: ToolOldConfig(**tool_data)
for name, tool_data in tools.items()
}
if skill_ids:
result["skill_ids"] = [skill for skill in skill_ids]
return result

View File

@@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
memory=agent_cfg.memory,
variables=agent_cfg.variables,
tools=agent_cfg.tools,
skill_ids=agent_cfg.skill_ids
)
# 将解析后的字段添加到对象上(用于序列化)
@@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
agent_cfg.memory = parsed["memory"]
agent_cfg.variables = parsed["variables"]
agent_cfg.tools = parsed["tools"]
agent_cfg.skill_ids = parsed["skill_ids"]
return agent_cfg

View File

@@ -304,6 +304,7 @@ class AppService:
memory=storage_data.get("memory"),
variables=storage_data.get("variables", []),
tools=storage_data.get("tools", []),
skill_ids=storage_data.get("skill_ids", []),
is_active=True,
created_at=now,
updated_at=now,
@@ -907,6 +908,7 @@ class AppService:
agent_cfg.variables = storage_data.get("variables", [])
# if data.tools is not None:
agent_cfg.tools = storage_data.get("tools", [])
agent_cfg.skill_ids = storage_data.get("skill_ids", [])
agent_cfg.updated_at = now

View File

@@ -187,7 +187,7 @@ class AppStatisticsService:
daily_tokens[date_str] = 0
daily_tokens[date_str] += int(tokens)
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
total = sum(row["tokens"] for row in daily_data)
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
total = sum(row["count"] for row in daily_data)
return {"daily": daily_data, "total": total}

View File

@@ -10,6 +10,11 @@ import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -26,10 +31,8 @@ from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.services.tool_service import ToolService
from app.services.multimodal_service import MultimodalService
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -310,6 +313,7 @@ class DraftRunService:
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
@@ -320,9 +324,7 @@ class DraftRunService:
print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
@@ -345,6 +347,22 @@ class DraftRunService:
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
@@ -558,6 +576,7 @@ class DraftRunService:
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
@@ -567,9 +586,7 @@ class DraftRunService:
# print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
@@ -592,6 +609,23 @@ class DraftRunService:
}
)
# 加载技能关联的工具
skill_configs = {}
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
@@ -628,7 +662,6 @@ class DraftRunService:
}
)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],

View File

@@ -53,7 +53,10 @@ def get_workspace_end_users(
workspace_id: uuid.UUID,
current_user: User
) -> List[EndUser]:
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
返回结果按 updated_at 从新到旧排序NULL 值排在最后)
"""
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
@@ -68,9 +71,14 @@ def get_workspace_end_users(
app_ids = [app.id for app in apps_orm]
# 批量查询所有 end_users一次查询而非循环查询
# 按 updated_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
from app.models.end_user_model import EndUser as EndUserModel
from sqlalchemy import desc, nullslast
end_users_orm = db.query(EndUserModel).filter(
EndUserModel.app_id.in_(app_ids)
).order_by(
nullslast(desc(EndUserModel.updated_at)),
desc(EndUserModel.id)
).all()
# 转换为 Pydantic 模型(只在需要时转换)

View File

@@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类PostgreSQL
if not params.rerank_id:
params.rerank_id = configs.get('rerank')
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
if not params.reflection_model_id:
params.reflection_model_id = params.llm_id
if not params.emotion_model_id:
params.emotion_model_id = params.llm_id
config = MemoryConfigRepository.create(self.db, params)
self.db.commit()
return {"affected": 1, "config_id": config.config_id}
@@ -203,6 +209,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"end_user_id": config.end_user_id,
"config_id_old": config_id_old,
"apply_id": config.apply_id,
"scene_id": config.scene_id,
"llm_id": config.llm_id,
"embedding_id": config.embedding_id,
"rerank_id": config.rerank_id,

View File

@@ -0,0 +1,109 @@
"""Skill Service"""
import uuid
from typing import List
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from app.repositories.skill_repository import SkillRepository
from app.schemas.skill_schema import SkillCreate, SkillUpdate
from app.models.skill_model import Skill
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.tool_service import ToolService
class SkillService:
"""Skill 业务逻辑层"""
@staticmethod
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
"""创建技能"""
skill = SkillRepository.create(db, data, tenant_id)
db.commit()
db.refresh(skill)
return skill
@staticmethod
def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill:
"""获取技能"""
try:
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
if not skill:
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
return skill
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def list_skills(
db: Session,
tenant_id: uuid.UUID,
search: str = None,
is_active: bool = None,
is_public: bool = None,
page: int = 1,
pagesize: int = 10
) -> tuple[list[type[Skill]], int]:
"""列出技能"""
return SkillRepository.list_skills(
db, tenant_id, search, is_active, is_public, page, pagesize
)
@staticmethod
def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill:
"""更新技能"""
try:
skill = SkillRepository.update(db, skill_id, data, tenant_id)
if not skill:
raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND)
db.commit()
db.refresh(skill)
return skill
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def delete_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
"""删除技能"""
try:
success = SkillRepository.delete(db, skill_id, tenant_id)
if not success:
raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND)
db.commit()
return True
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]:
"""加载技能关联的工具
Returns:
(tools, tool_to_skill_map) - 工具列表和工具到技能的映射
"""
tools = []
tool_to_skill_map = {} # {tool_name: skill_id}
tool_service = ToolService(db)
for skill_id in skill_ids:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
if skill and skill.is_active:
# 加载技能关联的工具
for tool_config in skill.tools:
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool:
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 建立工具到技能的映射
tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool)))
tool_to_skill_map[tool_name] = skill_id
except Exception as e:
print(f"加载技能 {skill_id} 的工具时出错: {e}")
continue
return tools, tool_to_skill_map

View File

@@ -1069,6 +1069,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
except Exception as e:
db.rollback() # Rollback failed transaction to allow next query
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
all_reflection_results.append({
"workspace_id": str(workspace_id),
@@ -1207,3 +1208,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
return result
finally:
loop.close()
# =============================================================================
# Long-term Memory Storage Tasks (Batched Write Strategies)
# =============================================================================
@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True)
def long_term_storage_window_task(
self,
end_user_id: str,
langchain_messages: List[Dict[str, Any]],
config_id: str,
scope: int = 6
) -> Dict[str, Any]:
"""Celery task for window-based long-term memory storage.
Accumulates messages in Redis buffer until window size (scope) is reached,
then writes batched messages to Neo4j.
Args:
end_user_id: End user identifier
langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
config_id: Memory configuration ID
scope: Window size (number of messages before triggering write)
Returns:
Dict containing task status and metadata
"""
from app.core.logging_config import get_logger
logger = get_logger(__name__)
logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}")
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
from app.core.memory.agent.utils.redis_tool import write_store
from app.services.memory_config_service import MemoryConfigService
db = next(get_db())
try:
# Save to Redis buffer first
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# Load memory config
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="LongTermStorageTask"
)
# Execute window-based dialogue storage
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
return {"status": "SUCCESS", "strategy": "window", "scope": scope}
finally:
db.close()
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s")
return {
**result,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True)
return {
"status": "FAILURE",
"strategy": "window",
"error": str(e),
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True)
# def long_term_storage_time_task(
# self,
# end_user_id: str,
# config_id: str,
# time_window: int = 5
# ) -> Dict[str, Any]:
# """Celery task for time-based long-term memory storage.
# Retrieves recent sessions from Redis within time window and writes to Neo4j.
# Args:
# end_user_id: End user identifier
# config_id: Memory configuration ID
# time_window: Time window in minutes for retrieving recent sessions
# Returns:
# Dict containing task status and metadata
# """
# from app.core.logging_config import get_logger
# logger = get_logger(__name__)
# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}")
# start_time = time.time()
# async def _run() -> Dict[str, Any]:
# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage
# from app.services.memory_config_service import MemoryConfigService
# db = next(get_db())
# try:
# # Load memory config
# config_service = MemoryConfigService(db)
# memory_config = config_service.load_memory_config(
# config_id=config_id,
# service_name="LongTermStorageTask"
# )
# # Execute time-based storage
# await memory_long_term_storage(end_user_id, memory_config, time_window)
# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window}
# finally:
# db.close()
# try:
# import nest_asyncio
# nest_asyncio.apply()
# except ImportError:
# pass
# try:
# loop = asyncio.get_event_loop()
# if loop.is_closed():
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# except RuntimeError:
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# try:
# result = loop.run_until_complete(_run())
# elapsed_time = time.time() - start_time
# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s")
# return {
# **result,
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }
# except Exception as e:
# elapsed_time = time.time() - start_time
# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True)
# return {
# "status": "FAILURE",
# "strategy": "time",
# "error": str(e),
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }
# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True)
# def long_term_storage_aggregate_task(
# self,
# end_user_id: str,
# langchain_messages: List[Dict[str, Any]],
# config_id: str
# ) -> Dict[str, Any]:
# """Celery task for aggregate-based long-term memory storage.
# Uses LLM to determine if new messages describe the same event as history.
# Only writes to Neo4j if messages represent new information (not duplicates).
# Args:
# end_user_id: End user identifier
# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
# config_id: Memory configuration ID
# Returns:
# Dict containing task status, is_same_event flag, and metadata
# """
# from app.core.logging_config import get_logger
# logger = get_logger(__name__)
# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}")
# start_time = time.time()
# async def _run() -> Dict[str, Any]:
# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment
# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
# from app.core.memory.agent.utils.redis_tool import write_store
# from app.services.memory_config_service import MemoryConfigService
# db = next(get_db())
# try:
# # Save to Redis buffer first
# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# # Load memory config
# config_service = MemoryConfigService(db)
# memory_config = config_service.load_memory_config(
# config_id=config_id,
# service_name="LongTermStorageTask"
# )
# # Execute aggregate judgment
# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config)
# return {
# "status": "SUCCESS",
# "strategy": "aggregate",
# "is_same_event": result.get("is_same_event", False),
# "wrote_to_neo4j": not result.get("is_same_event", False)
# }
# finally:
# db.close()
# try:
# import nest_asyncio
# nest_asyncio.apply()
# except ImportError:
# pass
# try:
# loop = asyncio.get_event_loop()
# if loop.is_closed():
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# except RuntimeError:
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# try:
# result = loop.run_until_complete(_run())
# elapsed_time = time.time() - start_time
# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s")
# return {
# **result,
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }
# except Exception as e:
# elapsed_time = time.time() - start_time
# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True)
# return {
# "status": "FAILURE",
# "strategy": "aggregate",
# "error": str(e),
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }

View File

@@ -0,0 +1,50 @@
"""202602041355
Revision ID: e7c7afa249d1
Revises: 9def72f79398
Create Date: 2026-02-04 13:55:22.284981
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'e7c7afa249d1'
down_revision: Union[str, None] = '9def72f79398'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('skills',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('name', sa.String(), nullable=False, comment='技能名称'),
sa.Column('description', sa.Text(), nullable=True, comment='技能描述'),
sa.Column('tenant_id', sa.UUID(), nullable=False, comment='租户ID'),
sa.Column('tools', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='关联的工具列表'),
sa.Column('config', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='技能配置'),
sa.Column('prompt', sa.Text(), nullable=True, comment='技能专属提示词'),
sa.Column('is_active', sa.Boolean(), nullable=False, comment='是否激活'),
sa.Column('is_public', sa.Boolean(), nullable=False, comment='是否公开到市场'),
sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_skills_id'), 'skills', ['id'], unique=False)
op.create_index(op.f('ix_skills_tenant_id'), 'skills', ['tenant_id'], unique=False)
op.add_column('agent_configs', sa.Column('skill_ids', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='关联的技能ID列表'))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('agent_configs', 'skill_ids')
op.drop_index(op.f('ix_skills_tenant_id'), table_name='skills')
op.drop_index(op.f('ix_skills_id'), table_name='skills')
op.drop_table('skills')
# ### end Alembic commands ###

View File

@@ -142,7 +142,7 @@ const PageScrollList = forwardRef(<T, Q = Record<string, unknown>>({
dataLength={data.length}
next={loadMoreData}
hasMore={hasMore}
loader={needLoading ? <PageLoading /> : undefined}
loader={loading && needLoading ? <PageLoading /> : false}
// endMessage={<Divider plain>It is all, nothing more 🤐</Divider>}
scrollableTarget="scrollableDiv"
className='rb:h-full!'

View File

@@ -180,7 +180,4 @@ body {
.x6-node foreignObject > body {
min-height: 100%;
max-height: 100%;
}
#scrollableDiv .infinite-scroll-component__outerdiv {
height: 100%;
}

View File

@@ -21,6 +21,7 @@ import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext
import InitialValuePlugin from './plugin/InitialValuePlugin'
import LineBreakPlugin from './plugin/LineBreakPlugin';
import InsertTextPlugin from './plugin/InsertTextPlugin';
import EditablePlugin from './plugin/EditablePlugin';
/**
* Editor ref methods exposed to parent components
@@ -50,6 +51,7 @@ interface LexicalEditorProps {
onChange?: (value: string) => void;
/** Editor height in pixels */
height?: number;
disabled?: boolean;
}
/**
@@ -71,6 +73,7 @@ const EditorContent = forwardRef<EditorRef, LexicalEditorProps>(({
value,
placeholder = "Please enter content...",
onChange,
disabled
}, ref) => {
const [editor] = useLexicalComposerContext();
@@ -132,7 +135,11 @@ const EditorContent = forwardRef<EditorRef, LexicalEditorProps>(({
<RichTextPlugin
contentEditable={
<ContentEditable
className={clsx("rb:outline-none rb:resize-none rb:text-[14px] rb:leading-5 rb:px-4 rb:py-5 rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:overflow-auto", className)}
className={clsx(
"rb:outline-none rb:resize-none rb:text-[14px] rb:leading-5 rb:px-4 rb:py-5 rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:overflow-auto",
disabled && "rb:cursor-not-allowed rb:bg-[#F6F8FC] rb:text-[#5B6167]",
className
)}
/>
}
placeholder={
@@ -145,6 +152,7 @@ const EditorContent = forwardRef<EditorRef, LexicalEditorProps>(({
<LineBreakPlugin onChange={onChange} />
<InitialValuePlugin value={value} />
<InsertTextPlugin />
<EditablePlugin disabled={disabled} />
</div>
);
});
@@ -158,6 +166,7 @@ const Editor = forwardRef<EditorRef, LexicalEditorProps>((props, ref) => {
namespace: 'Editor',
theme,
nodes: [],
editable: !props.disabled,
onError: (error: Error) => {
console.error(error);
},

View File

@@ -0,0 +1,48 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-04 11:20:49
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-04 11:20:49
*/
import { useEffect } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
/**
* Props for the EditablePlugin component
*/
interface EditablePluginProps {
/** Whether the editor should be disabled (read-only mode) */
disabled?: boolean;
}
/**
* EditablePlugin - A Lexical editor plugin that controls the editable state of the editor
*
* This plugin allows you to dynamically toggle between editable and read-only modes.
* When disabled is true, the editor becomes read-only and users cannot modify content.
* When disabled is false or undefined, the editor is fully editable.
*
* @param {EditablePluginProps} props - Component props
* @param {boolean} [props.disabled] - Controls whether the editor is in read-only mode
* @returns {null} This plugin doesn't render any UI elements
*
* @example
* ```tsx
* <LexicalComposer>
* <EditablePlugin disabled={isReadOnly} />
* </LexicalComposer>
* ```
*/
export default function EditablePlugin({ disabled }: EditablePluginProps) {
// Get the editor instance from Lexical composer context
const [editor] = useLexicalComposerContext();
// Update editor's editable state whenever the disabled prop changes
useEffect(() => {
// Set editor to editable when disabled is false, read-only when disabled is true
editor.setEditable(!disabled);
}, [editor, disabled]);
// This plugin doesn't render any UI, it only manages editor state
return null;
}

View File

@@ -156,9 +156,9 @@ const Prompt: FC<{ editVo: HistoryItem | null; refresh: () => void; }> = ({ edit
currentPromptValueRef.current = undefined;
setChatList([])
refresh()
updateSession()
}
console.log(values)
return (
<>
<Form form={form}>
@@ -217,12 +217,13 @@ const Prompt: FC<{ editVo: HistoryItem | null; refresh: () => void; }> = ({ edit
ref={editorRef}
placeholder={t('prompt.promptPlaceholder')}
className="rb:h-[calc(100vh-260px)]"
disabled={loading}
// onChange={(value) => form.setFieldValue('current_prompt', value)}
/>
</Form.Item>
<div className="rb:grid rb:grid-cols-2 rb:gap-4 rb:mt-6">
<Button type="primary" block disabled={!values?.current_prompt} onClick={handleSave}>{t('common.save')}</Button>
<Button block disabled={!values?.current_prompt} onClick={handleCopy}>{t('common.copy')}</Button>
<Button type="primary" block disabled={!values?.current_prompt || loading} onClick={handleSave}>{t('common.save')}</Button>
<Button block disabled={!values?.current_prompt || loading} onClick={handleCopy}>{t('common.copy')}</Button>
</div>
</div>
</div>

View File

@@ -103,6 +103,8 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
}).catch(() => {
handleUpdate(formData)
})
} else {
handleUpdate(formData)
}
}
})
@@ -158,6 +160,7 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
label={t('space.spaceIcon')}
valuePropName="fileList"
hidden={currentStep === 1}
rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('space.spaceIcon') }) }]}
>
<UploadImages />
</Form.Item>

View File

@@ -242,7 +242,7 @@ const Editor: FC<LexicalEditorProps> =({
{enableLineNumbers && <LineNumberPlugin />}
<AutocompletePlugin options={options} enableJinja2={enableJinja2} />
<CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} />
<InitialValuePlugin value={value} options={options} enableJinja2={enableJinja2} />
<InitialValuePlugin key={language} value={value} options={options} enableLineNumbers={enableLineNumbers} />
{enableLineNumbers && <BlurPlugin />}
</div>
</LexicalComposer>

View File

@@ -16,6 +16,12 @@ export default function BlurPlugin() {
return;
}
// 检查是否是粘贴操作导致的焦点变化
const relatedTarget = e.relatedTarget as HTMLElement;
if (!relatedTarget || relatedTarget === document.body) {
return;
}
editor.update(() => {
$setSelection(null);
});

View File

@@ -8,12 +8,13 @@ import { type Suggestion } from '../plugin/AutocompletePlugin'
interface InitialValuePluginProps {
value: string;
options?: Suggestion[];
enableJinja2?: boolean;
enableLineNumbers?: boolean;
}
const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [], enableJinja2 = false }) => {
const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [], enableLineNumbers = false }) => {
const [editor] = useLexicalComposerContext();
const prevValueRef = useRef<string>('');
const prevEnableLineNumbersRef = useRef<boolean>(enableLineNumbers);
const isUserInputRef = useRef(false);
useEffect(() => {
@@ -32,7 +33,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
}, [editor]);
useEffect(() => {
if (value !== prevValueRef.current && !isUserInputRef.current) {
if ((value !== prevValueRef.current || enableLineNumbers !== prevEnableLineNumbersRef.current) && !isUserInputRef.current) {
queueMicrotask(() => {
editor.update(() => {
const root = $getRoot();
@@ -40,7 +41,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
const parts = value.split(/(\{\{[^}]+\}\})/);
if (enableJinja2) {
if (enableLineNumbers) {
// Handle newlines properly in Jinja2 mode
const lines = value.split('\n');
lines.forEach((line) => {
@@ -104,8 +105,9 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
}
prevValueRef.current = value;
prevEnableLineNumbersRef.current = enableLineNumbers;
isUserInputRef.current = false;
}, [value, options, editor, enableJinja2]);
}, [value, options, editor, enableLineNumbers]);
return null;
};

View File

@@ -1,6 +1,6 @@
import { useEffect } from 'react';
import { useEffect, useRef } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { TextNode, $createTextNode, $getSelection, $isRangeSelection } from 'lexical';
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
const JS_KEYWORDS = new Set([
'async', 'await', 'break', 'case', 'catch', 'class', 'const', 'continue', 'debugger', 'default',
@@ -11,13 +11,31 @@ const JS_KEYWORDS = new Set([
const JavaScriptHighlightPlugin = () => {
const [editor] = useLexicalComposerContext();
const isPastingRef = useRef(false);
useEffect(() => {
return editor.registerCommand(
PASTE_COMMAND,
() => {
isPastingRef.current = true;
setTimeout(() => {
isPastingRef.current = false;
}, 100);
return false;
},
COMMAND_PRIORITY_LOW
);
}, [editor]);
useEffect(() => {
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
if (isPastingRef.current) return;
const text = textNode.getTextContent();
if (textNode.hasFormat('code')) return;
if (!needsHighlight(text)) return;
if (textNode.getStyle()) return;
const parent = textNode.getParent();
if (!parent) return;

View File

@@ -1,6 +1,6 @@
import { useEffect } from 'react';
import { useEffect, useRef } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { TextNode, $createTextNode, $getSelection, $isRangeSelection } from 'lexical';
import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical';
const PYTHON_KEYWORDS = new Set([
'False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue',
@@ -11,12 +11,30 @@ const PYTHON_KEYWORDS = new Set([
const Python3HighlightPlugin = () => {
const [editor] = useLexicalComposerContext();
const isPastingRef = useRef(false);
useEffect(() => {
return editor.registerCommand(
PASTE_COMMAND,
() => {
isPastingRef.current = true;
setTimeout(() => {
isPastingRef.current = false;
}, 100);
return false;
},
COMMAND_PRIORITY_LOW
);
}, [editor]);
useEffect(() => {
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
if (isPastingRef.current) return;
const text = textNode.getTextContent();
if (textNode.hasFormat('code')) return;
if (textNode.getStyle()) return;
if (!needsHighlight(text)) return;
const parent = textNode.getParent();

View File

@@ -33,7 +33,6 @@ const codeTemplate = {
const CodeExecution: FC<CodeExecutionProps> = ({ options }) => {
const { t } = useTranslation()
const form = Form.useFormInstance()
const values = Form.useWatch([], form) || {}
const handleRefresh = () => {
const code = form.getFieldValue('code') || ''
@@ -66,7 +65,6 @@ const CodeExecution: FC<CodeExecutionProps> = ({ options }) => {
form.setFieldValue('code', newTemplate)
}
const handleChangeLanguage = (value: string) => {
form.setFieldValue('code', codeTemplate[value as keyof typeof codeTemplate])
form.setFieldsValue({
input_variables: [{ name: 'arg1' }, { name: 'arg2' }],
code: codeTemplate[value as keyof typeof codeTemplate]
@@ -109,8 +107,12 @@ const CodeExecution: FC<CodeExecutionProps> = ({ options }) => {
</Form.Item>
</Col>
</Row>
<Form.Item name="code" noStyle>
<Editor size="small" language={values.language} />
<Form.Item noStyle shouldUpdate={(prev, curr) => prev.language !== curr.language}>
{() => (
<Form.Item name="code" noStyle>
<Editor size="small" language={form.getFieldValue('language')} />
</Form.Item>
)}
</Form.Item>
</Space>

View File

@@ -159,7 +159,7 @@ export const useWorkflowGraph = ({
nodeLibraryConfig.config[key].defaultValue = Object.entries(config[key]).map(([name, value]) => ({ name, value }))
} else if (type === 'code' && key === 'code' && config[key] && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
try {
nodeLibraryConfig.config[key].defaultValue = atob(config[key] as string)
nodeLibraryConfig.config[key].defaultValue = decodeURIComponent(atob(config[key] as string))
} catch {
nodeLibraryConfig.config[key].defaultValue = config[key]
}
@@ -943,7 +943,7 @@ export const useWorkflowGraph = ({
const code = data.config[key].defaultValue || ''
itemConfig = {
...itemConfig,
code: btoa(code || '')
code: btoa(encodeURIComponent(code || ''))
}
} else if (key === 'memory' && data.config[key] && 'defaultValue' in data.config[key]) {
const { messages, ...rest } = data.config[key].defaultValue