Merge branch 'develop' into fix/memory-enduser-config

This commit is contained in:
Ke Sun
2026-02-06 11:56:21 +08:00
294 changed files with 9936 additions and 4180 deletions

View File

@@ -7,6 +7,10 @@ from celery import Celery
from app.core.config import settings
# macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0
# backend: 结果存储(使用 Redis DB 10
@@ -64,6 +68,11 @@ celery_app.conf.update(
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
# Long-term storage tasks → memory_tasks queue (batched write strategies)
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},

View File

@@ -41,9 +41,9 @@ from . import (
upload_controller,
user_controller,
user_memory_controllers,
workflow_controller,
workspace_controller,
ontology_controller,
skill_controller
)
# 创建管理端 API 路由器
@@ -76,7 +76,6 @@ manager_router.include_router(release_share_controller.router)
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(multi_agent_controller.router)
manager_router.include_router(workflow_controller.router)
manager_router.include_router(emotion_controller.router)
manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
@@ -90,5 +89,6 @@ manager_router.include_router(memory_perceptual_controller.router)
manager_router.include_router(memory_working_controller.router)
manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router)
manager_router.include_router(skill_controller.router)
__all__ = ["manager_router"]

View File

@@ -22,6 +22,7 @@ from app.services import app_service, workspace_service
from app.services.agent_config_helper import enrich_agent_config
from app.services.app_service import AppService
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.services.app_statistics_service import AppStatisticsService
router = APIRouter(prefix="/apps", tags=["Apps"])
logger = get_business_logger()
@@ -904,8 +905,6 @@ def get_app_statistics(
- total_tokens: 总token消耗
"""
workspace_id = current_user.current_workspace_id
from app.services.app_statistics_service import AppStatisticsService
stats_service = AppStatisticsService(db)
result = stats_service.get_app_statistics(
@@ -916,3 +915,36 @@ def get_app_statistics(
)
return success(data=result)
@router.get("/workspace/api-statistics", summary="工作空间API调用统计")
@cur_workspace_access_guard()
def get_workspace_api_statistics(
start_date: int,
end_date: int,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""获取工作空间API调用统计
Args:
start_date: 开始时间戳(毫秒)
end_date: 结束时间戳(毫秒)
Returns:
每日统计数据列表,每项包含:
- date: 日期
- total_calls: 当日总调用次数
- app_calls: 当日应用调用次数
- service_calls: 当日服务调用次数
"""
workspace_id = current_user.current_workspace_id
stats_service = AppStatisticsService(db)
result = stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
return success(data=result)

View File

@@ -116,14 +116,6 @@ def _get_ontology_service(
detail=f"找不到指定的LLM模型: {llm_id}"
)
# 检查是否为组合模型
if hasattr(model_config, 'is_composite') and model_config.is_composite:
logger.error(f"Model {llm_id} is a composite model, which is not supported for ontology extraction")
raise HTTPException(
status_code=400,
detail="本体提取不支持使用组合模型,请选择单个模型"
)
# 验证模型配置了API密钥
if not model_config.api_keys:
logger.error(f"Model {llm_id} has no API key configuration")

View File

@@ -120,7 +120,8 @@ async def get_prompt_opt(
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
user_require=data.message,
skill=data.skill
):
# chunk 是 prompt 的增量内容
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"

View File

@@ -587,7 +587,8 @@ async def chat(
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id,
release_id=release.id
release_id=release.id,
public=True
):
event_type = event.get("event", "message")
event_data = event.get("data", {})

View File

@@ -0,0 +1,90 @@
"""Skill Controller - 技能市场管理"""
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from typing import Optional
import uuid
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import User
from app.schemas import skill_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services.skill_service import SkillService
from app.core.response_utils import success
router = APIRouter(prefix="/skills", tags=["Skills"])
@router.post("", summary="创建技能")
@cur_workspace_access_guard()
def create_skill(
data: skill_schema.SkillCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建技能 - 可以关联现有工具内置、MCP、自定义"""
tenant_id = current_user.tenant_id
skill = SkillService.create_skill(db, data, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
@router.get("", summary="技能列表")
@cur_workspace_access_guard()
def list_skills(
search: Optional[str] = Query(None, description="搜索关键词"),
is_active: Optional[bool] = Query(None, description="是否激活"),
is_public: Optional[bool] = Query(None, description="是否公开"),
page: int = Query(1, ge=1, description="页码"),
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""技能市场列表 - 包含本工作空间和公开的技能"""
tenant_id = current_user.tenant_id
skills, total = SkillService.list_skills(
db, tenant_id, search, is_active, is_public, page, pagesize
)
items = [skill_schema.Skill.model_validate(s) for s in skills]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
@router.get("/{skill_id}", summary="获取技能详情")
@cur_workspace_access_guard()
def get_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取技能详情"""
tenant_id = current_user.tenant_id
skill = SkillService.get_skill(db, skill_id, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
@router.put("/{skill_id}", summary="更新技能")
@cur_workspace_access_guard()
def update_skill(
skill_id: uuid.UUID,
data: skill_schema.SkillUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新技能"""
tenant_id = current_user.tenant_id
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
@router.delete("/{skill_id}", summary="删除技能")
@cur_workspace_access_guard()
def delete_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除技能"""
tenant_id = current_user.tenant_id
SkillService.delete_skill(db, skill_id, tenant_id)
return success(msg="技能删除成功")

View File

@@ -1,610 +0,0 @@
"""
工作流 API 控制器
"""
import logging
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Path, Query
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models.user_model import User
from app.models.app_model import App
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.schemas.workflow_schema import (
WorkflowConfigCreate,
WorkflowConfigUpdate,
WorkflowConfig,
WorkflowValidationResponse,
WorkflowExecution,
WorkflowNodeExecution,
WorkflowExecutionRequest,
WorkflowExecutionResponse
)
from app.core.response_utils import success, fail
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/apps", tags=["workflow"])
# ==================== 工作流配置管理 ====================
@router.post("/{app_id}/workflow")
@cur_workspace_access_guard()
async def create_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""创建工作流配置
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active.is_(True)
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证应用类型
if app.type != "workflow":
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"应用类型必须为 workflow当前为 {app.type}"
)
# 创建工作流配置
workflow_config = service.create_workflow_config(
app_id=app_id,
nodes=[node.model_dump() for node in config.nodes],
edges=[edge.model_dump() for edge in config.edges],
variables=[var.model_dump() for var in config.variables],
execution_config=config.execution_config.model_dump(),
triggers=[trigger.model_dump() for trigger in config.triggers],
validate=True # 进行基础验证
)
return success(
data=WorkflowConfig.model_validate(workflow_config),
msg="工作流配置创建成功"
)
except BusinessException as e:
logger.warning(f"创建工作流配置失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"创建工作流配置失败: {str(e)}"
)
#
# @router.get("/{app_id}/workflow")
# async def get_workflow_config(
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
# db: Annotated[Session, Depends(get_db)],
# current_user: Annotated[User, Depends(get_current_user)]
#
# ):
# """获取工作流配置
#
# 获取应用的工作流配置详情。
# """
# try:
# # 验证应用是否存在且属于当前工作空间
# app = db.query(App).filter(
# App.id == app_id,
# App.workspace_id == current_user.current_workspace_id,
# App.is_active == True
# ).first()
#
# if not app:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="应用不存在或无权访问"
# )
#
# # 获取工作流配置
# service = WorkflowService(db)
# workflow_config = service.get_workflow_config(app_id)
#
# if not workflow_config:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="工作流配置不存在"
# )
#
# return success(
# data=WorkflowConfig.model_validate(workflow_config)
# )
#
# except Exception as e:
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
# return fail(
# code=BizCode.INTERNAL_ERROR,
# msg=f"获取工作流配置失败: {str(e)}"
# )
# @router.put("/{app_id}/workflow")
# async def update_workflow_config(
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
# config: WorkflowConfigUpdate,
# db: Annotated[Session, Depends(get_db)],
# current_user: Annotated[User, Depends(get_current_user)],
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
# ):
# """更新工作流配置
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
# """
# try:
# # 验证应用是否存在且属于当前工作空间
# app = db.query(App).filter(
# App.id == app_id,
# App.workspace_id == current_user.current_workspace_id,
# App.is_active == True
# ).first()
# if not app:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="应用不存在或无权访问"
# )
# # 更新工作流配置
# workflow_config = service.update_workflow_config(
# app_id=app_id,
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
# validate=True
# )
# return success(
# data=WorkflowConfig.model_validate(workflow_config),
# msg="工作流配置更新成功"
# )
# except BusinessException as e:
# logger.warning(f"更新工作流配置失败: {e.message}")
# return fail(code=e.error_code, msg=e.message)
# except Exception as e:
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
# return fail(
# code=BizCode.INTERNAL_ERROR,
# msg=f"更新工作流配置失败: {str(e)}"
# )
@router.delete("/{app_id}/workflow")
async def delete_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""删除工作流配置
删除应用的工作流配置。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active.is_(True)
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 删除工作流配置
deleted = service.delete_workflow_config(app_id)
if not deleted:
return fail(
code=BizCode.NOT_FOUND,
msg="工作流配置不存在"
)
return success(msg="工作流配置删除成功")
except Exception as e:
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"删除工作流配置失败: {str(e)}"
)
@router.post("/{app_id}/workflow/validate")
async def validate_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
):
"""验证工作流配置
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active.is_(True)
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证工作流配置
if for_publish:
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
else:
workflow_config = service.get_workflow_config(app_id)
if not workflow_config:
return fail(
code=BizCode.NOT_FOUND,
msg="工作流配置不存在"
)
from app.core.workflow.validator import validate_workflow_config as validate_config
config_dict = {
"nodes": workflow_config.nodes,
"edges": workflow_config.edges,
"variables": workflow_config.variables,
"execution_config": workflow_config.execution_config,
"triggers": workflow_config.triggers
}
is_valid, errors = validate_config(config_dict, for_publish=False)
return success(
data=WorkflowValidationResponse(
is_valid=is_valid,
errors=errors,
warnings=[]
)
)
except BusinessException as e:
logger.warning(f"验证工作流配置失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"验证工作流配置失败: {str(e)}"
)
# ==================== 工作流执行管理 ====================
@router.get("/{app_id}/workflow/executions")
async def get_workflow_executions(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0
):
"""获取工作流执行记录列表
获取应用的工作流执行历史记录。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active.is_(True)
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 获取执行记录
executions = service.get_executions_by_app(app_id, limit, offset)
# 获取统计信息
statistics = service.get_execution_statistics(app_id)
return success(
data={
"executions": [WorkflowExecution.model_validate(e) for e in executions],
"statistics": statistics,
"pagination": {
"limit": limit,
"offset": offset,
"total": statistics["total"]
}
}
)
except Exception as e:
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"获取工作流执行记录失败: {str(e)}"
)
@router.get("/workflow/executions/{execution_id}")
async def get_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""获取工作流执行详情
获取单个工作流执行的详细信息,包括所有节点的执行记录。
"""
try:
# 获取执行记录
execution = service.get_execution(execution_id)
if not execution:
return fail(
code=BizCode.NOT_FOUND,
msg="执行记录不存在"
)
# 验证应用是否属于当前工作空间
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active.is_(True)
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="无权访问该执行记录"
)
# 获取节点执行记录
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
return success(
data={
"execution": WorkflowExecution.model_validate(execution),
"node_executions": [
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
]
}
)
except Exception as e:
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"获取工作流执行详情失败: {str(e)}"
)
# ==================== 工作流执行 ====================
@router.post("/{app_id}/workflow/run")
async def run_workflow(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""执行工作流
执行工作流并返回结果。支持流式和非流式两种模式。
**非流式模式**:等待工作流执行完成后返回完整结果。
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active.is_(True)
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证应用类型
if app.type != "workflow":
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"应用类型必须为 workflow当前为 {app.type}"
)
# 准备输入数据
input_data = {
"message": request.message or "",
"variables": request.variables
}
# 执行工作流
if request.stream:
# 流式执行
from fastapi.responses import StreamingResponse
import json
async def event_generator():
"""生成 SSE 事件
SSE 格式:
event: <event_type>
data: <json_data>
支持的事件类型:
- workflow_start: 工作流开始
- workflow_end: 工作流结束
- node_start: 节点开始执行
- node_end: 节点执行完成
- node_chunk: 中间节点的流式输出
- message: 最终消息的流式输出End 节点及其相邻节点)
"""
try:
async for event in await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
):
# 提取事件类型和数据
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
# event: <type>
# data: <json>
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
except Exception as e:
logger.error(f"流式执行异常: {e}", exc_info=True)
# 发送错误事件
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
yield sse_error
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
}
)
else:
# 非流式执行
result = await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=False
)
return success(
data=WorkflowExecutionResponse(
execution_id=result["execution_id"],
status=result["status"],
output=result.get("output"),
output_data=result.get("output_data"),
error_message=result.get("error_message"),
elapsed_time=result.get("elapsed_time"),
token_usage=result.get("token_usage")
),
msg="工作流执行完成"
)
except BusinessException as e:
logger.warning(f"执行工作流失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"执行工作流异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"执行工作流失败: {str(e)}"
)
@router.post("/workflow/executions/{execution_id}/cancel")
async def cancel_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""取消工作流执行
取消正在运行的工作流执行。
**注意**:当前版本仅更新状态为 cancelled实际的执行取消功能待实现。
"""
try:
# 获取执行记录
execution = service.get_execution(execution_id)
if not execution:
return fail(
code=BizCode.NOT_FOUND,
msg="执行记录不存在"
)
# 验证应用是否属于当前工作空间
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active.is_(True)
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="无权访问该执行记录"
)
# 检查执行状态
if execution.status not in ["pending", "running"]:
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"无法取消状态为 {execution.status} 的执行"
)
# 更新状态为 cancelled
service.update_execution_status(execution_id, "cancelled")
return success(msg="工作流执行已取消")
except BusinessException as e:
logger.warning(f"取消工作流执行失败: {e.message}")
return fail(code=e.code, msg=e.message)
except Exception as e:
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"取消工作流执行失败: {str(e)}"
)

View File

@@ -0,0 +1,162 @@
"""Agent Middleware - 动态技能过滤"""
import uuid
from typing import List, Dict, Any, Optional
from langchain_core.runnables import RunnablePassthrough
from app.services.skill_service import SkillService
from app.repositories.skill_repository import SkillRepository
class AgentMiddleware:
"""Agent 中间件 - 用于动态过滤和加载技能"""
def __init__(self, skills: Optional[dict] = None):
"""
初始化中间件
Args:
skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]}
"""
self.skills = skills or {}
self.enabled = self.skills.get('enabled', False)
self.all_skills = self.skills.get('all_skills', False)
self.skill_ids = self.skills.get('skill_ids', [])
@staticmethod
def filter_tools(
tools: List,
message: str = "",
skill_configs: Dict[str, Any] = None,
tool_to_skill_map: Dict[str, str] = None
) -> tuple[List, List[str]]:
"""
根据消息内容和技能配置动态过滤工具
Args:
tools: 所有可用工具列表
message: 用户消息(可用于智能过滤)
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
Returns:
(过滤后的工具列表, 激活的技能ID列表)
"""
if not tools:
return [], []
# 如果没有技能配置,返回所有工具
if not skill_configs:
return tools, []
# 基于关键词匹配激活技能
activated_skill_ids = []
message_lower = message.lower()
for skill_id, config in skill_configs.items():
if not config.get('enabled', True):
continue
keywords = config.get('keywords', [])
# 如果没有关键词限制,或消息包含关键词,则激活该技能
if not keywords or any(kw.lower() in message_lower for kw in keywords):
activated_skill_ids.append(skill_id)
# 如果没有工具映射关系,返回所有工具
if not tool_to_skill_map:
return tools, activated_skill_ids
# 根据激活的技能过滤工具
filtered_tools = []
for tool in tools:
tool_name = getattr(tool, 'name', str(id(tool)))
# 如果工具不属于任何skillbase_tools或者工具所属的skill被激活则保留
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
filtered_tools.append(tool)
return filtered_tools, activated_skill_ids
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
"""
加载技能关联的工具
Args:
db: 数据库会话
tenant_id: 租户id
base_tools: 基础工具列表
Returns:
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
"""
tools_dict = {}
tool_to_skill_map = {} # 工具名称到技能ID的映射
if base_tools:
for tool in base_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
tools_dict[tool_name] = tool
# base_tools 不属于任何 skill不加入映射
skill_configs = {}
skill_ids_to_load = []
# 如果启用技能且 all_skills 为 True加载租户下所有激活的技能
if self.enabled and self.all_skills:
skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000)
skill_ids_to_load = [str(skill.id) for skill in skills]
elif self.enabled and self.skill_ids:
skill_ids_to_load = self.skill_ids
if skill_ids_to_load:
for skill_id in skill_ids_to_load:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
if skill and skill.is_active:
# 保存技能配置包含prompt
config = skill.config or {}
config['prompt'] = skill.prompt
config['name'] = skill.name
skill_configs[skill_id] = config
except Exception:
continue
# 加载技能工具并获取映射关系
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id)
# 只添加不冲突的 skill_tools
for tool in skill_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
if tool_name not in tools_dict:
tools_dict[tool_name] = tool
# 复制映射关系
if tool_name in skill_tool_map:
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
return list(tools_dict.values()), skill_configs, tool_to_skill_map
@staticmethod
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
"""
根据激活的技能ID获取对应的提示词
Args:
activated_skill_ids: 被激活的技能ID列表
skill_configs: 技能配置字典
Returns:
合并后的提示词
"""
prompts = []
for skill_id in activated_skill_ids:
config = skill_configs.get(skill_id, {})
prompt = config.get('prompt')
name = config.get('name', 'Skill')
if prompt:
prompts.append(f"# {name}\n{prompt}")
return "\n\n".join(prompts) if prompts else ""
@staticmethod
def create_runnable():
"""创建可运行的中间件"""
return RunnablePassthrough()

View File

@@ -289,10 +289,9 @@ class LangChainAgent:
return content_parts
return messages
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db())
#TODO: 魔法数字
scope=6
try:
@@ -302,6 +301,12 @@ class LangChainAgent:
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
# Handle case where no session exists in Redis (returns False)
if not result or result is False:
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
return
if type=="chunk" or type=="aggregate":
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
@@ -309,7 +314,14 @@ class LangChainAgent:
repo.upsert(end_user_id, chunk_data)
logger.info(f'写入短长期:')
else:
# TODO: This branch handles type="time" strategy, currently unused.
# Will be activated when time-based long-term storage is implemented.
# TODO: 魔法数字 - extract 5 to a constant
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
return
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
@@ -509,13 +521,13 @@ class LangChainAgent:
elapsed_time = time.time() - start_time
if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
'''长期'''
if actual_config_id:
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
else:
logger.warning(f"Skipping term_memory_save: no memory config available for end_user {end_user_id}")
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
response = {
"content": content,
"model": self.model_name,
@@ -698,13 +710,14 @@ class LangChainAgent:
yield total_tokens
break
if memory_flag:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
if actual_config_id:
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
else:
logger.warning(f"Skipping term_memory_save: no memory config available for end_user {end_user_id}")
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)

View File

@@ -215,6 +215,9 @@ class Settings:
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
# model square loading
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
# workflow config
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))

View File

@@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
# TODO删除
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
print(contents)
@@ -60,6 +61,7 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
scope窗口大小
'''
scope=scope
redis_messages = []
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
@@ -91,6 +93,9 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
memory_config: 内存配置对象
'''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
return
format_messages = await chat_data_format(long_time_data)
if format_messages!=[]:
await write_messages(end_user_id, format_messages, memory_config)
@@ -108,8 +113,9 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
try:
# 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
if not result:
# Handle case where no session exists in Redis (returns False or empty)
if not result or result is False:
history = []
else:
history = await format_parsing(result)

View File

@@ -1,24 +1,21 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@asynccontextmanager
async def make_write_graph():
"""
@@ -39,29 +36,59 @@ async def make_write_graph():
graph = workflow.compile()
yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
"""Dispatch long-term memory storage to Celery background tasks.
Args:
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
langchain_messages: List of messages to store
memory_config: Memory configuration ID (string)
end_user_id: End user identifier
scope: Window size for 'chunk' strategy (default: 6)
"""
from app.tasks import (
long_term_storage_window_task,
# TODO: Uncomment when implemented
# long_term_storage_time_task,
# long_term_storage_aggregate_task,
)
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
from app.core.logging_config import get_logger
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
logger = get_logger(__name__)
# Convert config to string if needed
config_id = str(memory_config) if memory_config else ''
if long_term_type == 'chunk':
# Strategy 1: Window-based batching (6 rounds of dialogue)
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
long_term_storage_window_task.delay(
end_user_id=end_user_id,
langchain_messages=langchain_messages,
config_id=config_id,
scope=scope
)
# TODO: Uncomment when time-based strategy is fully implemented
# elif long_term_type == 'time':
# # Strategy 2: Time-based retrieval
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
# long_term_storage_time_task.delay(
# end_user_id=end_user_id,
# config_id=config_id,
# time_window=5
# )
# TODO: Uncomment when aggregate strategy is fully implemented
# elif long_term_type == 'aggregate':
# # Strategy 3: Aggregate judgment (deduplication)
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
# long_term_storage_aggregate_task.delay(
# end_user_id=end_user_id,
# langchain_messages=langchain_messages,
# config_id=config_id
# )
# async def main():
# """主函数 - 运行工作流"""

View File

@@ -174,4 +174,4 @@ async def write(
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
logger.info(f"Total execution time: {total_time:.2f} seconds")

View File

@@ -1,5 +1,4 @@
provider: bedrock
enabled: false
models:
- name: ai21
type: llm

View File

@@ -1,5 +1,4 @@
provider: dashscope
enabled: false
models:
- name: deepseek-r1-distill-qwen-14b
type: llm

View File

@@ -1,11 +1,11 @@
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
import os
from pathlib import Path
from typing import Callable
import yaml
from sqlalchemy.orm import Session
from app.models.models_model import ModelBase, ModelProvider
@@ -19,31 +19,9 @@ def _load_yaml_config(provider: ModelProvider) -> list[dict]:
with open(config_file, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
# 检查是否需要加载(默认为 true
if not data.get('enabled', True):
return []
return data.get('models', [])
def _disable_yaml_config(provider: ModelProvider) -> None:
"""将YAML文件的enabled标志设置为false"""
config_dir = Path(__file__).parent
config_file = config_dir / f"{provider.value}_models.yaml"
if not config_file.exists():
return
with open(config_file, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
data['enabled'] = False
with open(config_file, 'w', encoding='utf-8') as f:
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
"""
加载模型配置到数据库
@@ -75,8 +53,7 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
if not silent:
print(f"\n正在加载 {provider.value}{len(models)} 个模型...")
# provider_success = 0
for model_data in models:
try:
# 检查模型是否已存在
@@ -93,7 +70,6 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
if not silent:
print(f"更新成功: {model_data['name']}")
result["success"] += 1
# provider_success += 1
else:
# 创建新模型
model = ModelBase(**model_data)
@@ -102,17 +78,12 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
if not silent:
print(f"添加成功: {model_data['name']}")
result["success"] += 1
# provider_success += 1
except Exception as e:
db.rollback()
if not silent:
print(f"添加失败: {model_data['name']} - {str(e)}")
result["failed"] += 1
# 如果该供应商的模型全部加载成功将enabled设置为false
# if provider_success == len(models):
_disable_yaml_config(provider)
return result

View File

@@ -1,5 +1,4 @@
provider: openai
enabled: false
models:
- name: chatgpt-4o-latest
type: llm

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,3 @@
"""
安全的表达式求值器
使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。
"""
import logging
import re
from typing import Any
@@ -14,160 +8,119 @@ logger = logging.getLogger(__name__)
class ExpressionEvaluator:
"""安全的表达式求值器"""
"""Safe expression evaluator for workflow variables and node outputs."""
# 保留的命名空间
# Reserved namespaces
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@staticmethod
def evaluate(
expression: str,
variables: dict[str, Any],
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""安全地评估表达式
Args:
expression: 表达式字符串,如 "{{var.score}} > 0.8"
variables: 用户定义的变量
node_outputs: 节点输出结果
system_vars: 系统变量
Returns:
表达式求值结果
Raises:
ValueError: 表达式无效或求值失败
Examples:
>>> evaluator = ExpressionEvaluator()
>>> evaluator.evaluate(
... "var.score > 0.8",
... {"score": 0.9},
... {},
... {}
... )
True
>>> evaluator.evaluate(
... "node.intent.output == '售前咨询'",
... {},
... {"intent": {"output": "售前咨询"}},
... {}
... )
True
"""
# 移除 Jinja2 模板语法的花括号(如果存在)
Safely evaluate an expression using workflow variables.
Args:
expression (str): The expression string, e.g., "var.score > 0.8"
conv_vars (dict): Conversation-level variables
node_outputs (dict): Outputs from workflow nodes
system_vars (dict, optional): System variables
Returns:
Any: Result of the evaluated expression
Raises:
ValueError: If the expression is invalid or evaluation fails
"""
# Remove Jinja2-style brackets if present
expression = expression.strip()
# "{{system.message}} == {{ user.messge }}" -> "system.message == user.message"
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", expression).strip()
# 构建命名空间上下文
# Build context for evaluation
context = {
"var": variables, # 用户变量
"node": node_outputs, # 节点输出
"sys": system_vars or {}, # 系统变量
"conv": conv_vars, # conversation variables
"node": node_outputs, # node outputs
"sys": system_vars or {}, # system variables
}
# 为了向后兼容,也支持直接访问(但会在日志中警告)
context.update(variables)
context.update(conv_vars)
context["nodes"] = node_outputs
context.update(node_outputs)
try:
# simpleeval 只支持安全的操作:
# - 算术运算: +, -, *, /, //, %, **
# - 比较运算: ==, !=, <, <=, >, >=
# - 逻辑运算: and, or, not
# - 成员运算: in, not in
# - 属性访问: obj.attr
# - 字典/列表访问: obj["key"], obj[0]
# 不支持:函数调用、导入、赋值等危险操作
# simpleeval supports safe operations:
# arithmetic, comparisons, logical ops, attribute/dict/list access
result = simple_eval(expression, names=context)
return result
except NameNotDefined as e:
logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
raise ValueError(f"Undefined variable: {e}")
except InvalidExpression as e:
logger.error(f"表达式语法无效: {expression}, 错误: {e}")
raise ValueError(f"表达式语法无效: {e}")
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
raise ValueError(f"Invalid expression syntax: {e}")
except SyntaxError as e:
logger.error(f"表达式语法错误: {expression}, 错误: {e}")
raise ValueError(f"表达式语法错误: {e}")
logger.error(f"Syntax error in expression: {expression}, error: {e}")
raise ValueError(f"Syntax error: {e}")
except Exception as e:
logger.error(f"表达式求值异常: {expression}, 错误: {e}")
raise ValueError(f"表达式求值失败: {e}")
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
raise ValueError(f"Expression evaluation failed: {e}")
@staticmethod
def evaluate_bool(
expression: str,
variables: dict[str, Any],
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估布尔表达式(用于条件判断)
"""
Evaluate a boolean expression (for conditions).
Args:
expression: 布尔表达式
variables: 用户变量
node_outputs: 节点输出
system_vars: 系统变量
expression (str): Boolean expression
conv_var (dict): Conversation variables
node_outputs (dict): Node outputs
system_vars (dict, optional): System variables
Returns:
布尔值结果
Examples:
>>> ExpressionEvaluator.evaluate_bool(
... "var.count >= 10 and var.status == 'active'",
... {"count": 15, "status": "active"},
... {},
... {}
... )
True
bool: Boolean result
"""
result = ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
expression, conv_var, node_outputs, system_vars
)
return bool(result)
@staticmethod
def validate_variable_names(variables: list[dict]) -> list[str]:
"""验证变量名是否合法
"""
Validate variable names for legality.
Args:
variables: 变量定义列表
variables (list[dict]): List of variable definitions
Returns:
错误列表,如果为空则验证通过
Examples:
>>> ExpressionEvaluator.validate_variable_names([
... {"name": "user_input"},
... {"name": "var"} # 保留字
... ])
["变量名 'var' 是保留的命名空间,请使用其他名称"]
list[str]: List of error messages. Empty if all names are valid.
"""
errors = []
for var in variables:
var_name = var.get("name", "")
# 检查是否为保留命名空间
if var_name in ExpressionEvaluator.RESERVED_NAMESPACES:
errors.append(
f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称"
f"Variable name '{var_name}' is a reserved namespace, please use another name"
)
# 检查是否为有效的 Python 标识符
if not var_name.isidentifier():
errors.append(
f"变量名 '{var_name}' 不是有效的标识符"
f"Variable name '{var_name}' is not a valid Python identifier"
)
return errors
@@ -176,23 +129,23 @@ class ExpressionEvaluator:
# 便捷函数
def evaluate_expression(
expression: str,
variables: dict[str, Any],
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
system_vars: dict[str, Any]
) -> Any:
"""评估表达式(便捷函数)"""
"""Evaluate an expression (convenience function)."""
return ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
expression, conv_var, node_outputs, system_vars
)
def evaluate_condition(
expression: str,
variables: dict[str, Any],
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估条件表达式(便捷函数)"""
"""Evaluate a boolean condition expression (convenience function)."""
return ExpressionEvaluator.evaluate_bool(
expression, variables, node_outputs, system_vars
expression, conv_var, node_outputs, system_vars
)

View File

@@ -14,9 +14,14 @@ from pydantic import BaseModel, Field
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
SCOPE_PATTERN = re.compile(
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
)
class OutputContent(BaseModel):
"""
@@ -53,6 +58,12 @@ class OutputContent(BaseModel):
)
)
_SCOPE: str | None = None
def get_scope(self) -> str:
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
return self._SCOPE
def depends_on_scope(self, scope: str) -> bool:
"""
Check if this segment depends on a given scope.
@@ -63,8 +74,9 @@ class OutputContent(BaseModel):
Returns:
bool: True if this segment references the given scope.
"""
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
return bool(re.search(pattern, self.literal))
if self._SCOPE:
return self._SCOPE == scope
return self.get_scope() == scope
class StreamOutputConfig(BaseModel):
@@ -167,6 +179,7 @@ class GraphBuilder:
workflow_config: dict[str, Any],
stream: bool = False,
subgraph: bool = False,
variable_pool: VariablePool | None = None
):
self.workflow_config = workflow_config
@@ -180,6 +193,10 @@ class GraphBuilder:
self._find_upstream_branch_node = lru_cache(
maxsize=len(self.nodes) * 2
)(self._find_upstream_branch_node)
if variable_pool:
self.variable_pool = variable_pool
else:
self.variable_pool = VariablePool()
self.graph = StateGraph(WorkflowState)
self.add_nodes()
@@ -452,9 +469,9 @@ class GraphBuilder:
if self.stream:
# Stream mode: create an async generator function
# LangGraph collects all yielded values; the last yielded dictionary is merged into the state
def make_stream_func(inst):
def make_stream_func(inst, variable_pool=self.variable_pool):
async def node_func(state: WorkflowState):
async for item in inst.run_stream(state):
async for item in inst.run_stream(state, variable_pool):
yield item
return node_func
@@ -462,9 +479,9 @@ class GraphBuilder:
self.graph.add_node(node_id, make_stream_func(node_instance))
else:
# Non-stream mode: create an async function
def make_func(inst):
def make_func(inst, variable_pool=self.variable_pool):
async def node_func(state: WorkflowState):
return await inst.run(state)
return await inst.run(state, variable_pool)
return node_func
@@ -567,27 +584,28 @@ class GraphBuilder:
for target in branch_info["target"]:
waiting_edges[target].append(branch_info["node"]["name"])
def router_fn(state: WorkflowState) -> list[Send]:
def router_fn(state: WorkflowState, variable_pool: VariablePool = self.variable_pool) -> list[Send]:
branch_activate = []
new_state = state.copy()
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False)
for label, branch in unique_branch.items():
if evaluate_condition(
if node_output and evaluate_condition(
branch["condition"],
state.get("variables", {}),
state.get("runtime_vars", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
{},
{src: node_output},
{}
):
logger.debug(f"Conditional routing {src}: selected branch {label}")
new_state["activate"][branch["node"]["name"]] = True
branch_activate.append(
Send(
branch['node']['name'],
new_state
)
)
continue
new_state["activate"][branch["node"]["name"]] = False
for label, branch in unique_branch.items():
branch_activate.append(
Send(
branch['node']['name'],

View File

@@ -15,7 +15,6 @@ from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.tool import ToolNode
@@ -25,7 +24,6 @@ __all__ = [
"WorkflowState",
"LLMNode",
"AgentNode",
"TransformNode",
"IfElseNode",
"StartNode",
"EndNode",

View File

@@ -2,7 +2,8 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.variable.base_variable import VariableType
class AgentNodeConfig(BaseNodeConfig):

View File

@@ -2,6 +2,7 @@
Agent 节点实现
调用已发布的 Agent 应用。
# TODO
"""
import logging
@@ -9,6 +10,8 @@ from typing import Any
from langchain_core.messages import AIMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.services.draft_run_service import DraftRunService
from app.models import AppRelease
from app.db import get_db
@@ -30,19 +33,22 @@ class AgentNode(BaseNode):
}
}
"""
def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]:
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
"""准备 Agent公共逻辑
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
(draft_service, release, message): 服务实例、发布配置、消息
"""
# 1. 渲染消息
message_template = self.config.get("message", "")
message = self._render_template(message_template, state)
message = self._render_template(message_template, variable_pool)
# 2. 获取 Agent 配置
agent_id = self.config.get("agent_id")
@@ -61,16 +67,17 @@ class AgentNode(BaseNode):
return draft_service, release, message
async def execute(self, state: WorkflowState) -> dict[str, Any]:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""非流式执行
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
状态更新字典
"""
draft_service, release, message = self._prepare_agent(state)
draft_service, release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
@@ -79,9 +86,9 @@ class AgentNode(BaseNode):
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
variables=variable_pool.get_all_conversation_vars()
)
response = result.get("response", "")
@@ -99,16 +106,17 @@ class AgentNode(BaseNode):
}
}
async def execute_stream(self, state: WorkflowState):
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
"""流式执行
Args:
state: 工作流状态
variable_pool: 变量池
Yields:
流式事件字典
"""
draft_service, release, message = self._prepare_agent(state)
draft_service, release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
@@ -120,9 +128,9 @@ class AgentNode(BaseNode):
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
variables=variable_pool.get_all_conversation_vars()
):
# 提取内容
content = chunk.get("content", "")

View File

@@ -6,6 +6,7 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import AssignmentOperator
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -17,13 +18,17 @@ class AssignerNode(BaseNode):
self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
def _output_types(self) -> dict[str, VariableType]:
return {}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the assignment operation defined by this node.
Args:
state: The current workflow state, including conversation variables,
node outputs, and system variables.
variable_pool: variable pool
Returns:
None or the result of the assignment operation.
@@ -31,60 +36,57 @@ class AssignerNode(BaseNode):
# Initialize a variable pool for accessing conversation, node, and system variables
self.typed_config = AssignerNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} 开始执行")
pool = VariablePool(state)
pattern = r"\{\{\s*(.*?)\s*\}\}"
for assignment in self.typed_config.assignments:
# Get the target variable selector (e.g., "conv.test")
variable_selector = assignment.variable_selector
if isinstance(variable_selector, str):
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", variable_selector).strip()
variable_selector = expression.split('.')
namespace = re.sub(pattern, r"\1", variable_selector).split('.')[0]
# Only conversation variables ('conv') are allowed
if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]:
raise ValueError("Only conversation or cycle variables can be assigned.")
if namespace != 'conv' and namespace not in state["cycle_nodes"]:
raise ValueError(f"Only conversation or cycle variables can be assigned. - {variable_selector}")
# Get the value or expression to assign
value = assignment.value
logger.debug(f"left:{variable_selector}, right: {value}")
pattern = r"\{\{\s*(.*?)\s*\}\}"
if isinstance(value, str):
expression = re.match(pattern, value)
if expression:
expression = expression.group(1)
expression = re.sub(pattern, r"\1", expression).strip()
value = self.get_variable(expression, state)
value = self.get_variable(expression, variable_pool, default=value, strict=False)
# Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
pool.get(variable_selector)
variable_pool.get_value(variable_selector)
)(
pool, variable_selector, value
variable_pool, variable_selector, value
)
# Execute the configured assignment operation
match assignment.operation:
case AssignmentOperator.COVER:
operator.assign()
await operator.assign()
case AssignmentOperator.ASSIGN:
operator.assign()
await operator.assign()
case AssignmentOperator.CLEAR:
operator.clear()
await operator.clear()
case AssignmentOperator.ADD:
operator.add()
await operator.add()
case AssignmentOperator.SUBTRACT:
operator.subtract()
await operator.subtract()
case AssignmentOperator.MULTIPLY:
operator.multiply()
await operator.multiply()
case AssignmentOperator.DIVIDE:
operator.divide()
await operator.divide()
case AssignmentOperator.APPEND:
operator.append()
await operator.append()
case AssignmentOperator.REMOVE_FIRST:
operator.remove_first()
await operator.remove_first()
case AssignmentOperator.REMOVE_LAST:
operator.remove_last()
await operator.remove_last()
case _:
raise ValueError(f"Invalid Operator: {assignment.operation}")
logger.info(f"Node {self.node_id}: execution completed")

View File

@@ -3,79 +3,13 @@
定义所有节点配置的通用字段和数据结构。
"""
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from app.core.workflow.variable.base_variable import VariableType
VARIABLE_PATTERN = r"\{\{\s*(.*?)\s*\}\}"
class VariableType(StrEnum):
"""变量类型枚举"""
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
OBJECT = "object"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_OBJECT = "array[object]"
class TypedVariable(BaseModel):
"""
TODO: 强类型限制
Strongly typed variable that validates value on assignment.
"""
value: Any = Field(..., description="Variable value")
type: VariableType = Field(..., description="Declared type of the variable")
model_config = ConfigDict(
validate_assignment=True
)
def __setattr__(self, name, value):
if name == "value":
self._validate_value(value)
if name == "type":
raise RuntimeError("Cannot modify variable type at runtime")
super().__setattr__(name, value)
def _validate_value(self, v: Any):
t = self.type
match t:
case VariableType.STRING:
if not isinstance(v, str):
raise TypeError("Variable value does not match type STRING")
case VariableType.BOOLEAN:
if not isinstance(v, bool):
raise TypeError("Variable value does not match type BOOLEAN")
case VariableType.NUMBER:
if not isinstance(v, (int, float)):
raise TypeError("Variable value does not match type NUMBER")
case VariableType.OBJECT:
if not isinstance(v, dict):
raise TypeError("Variable value does not match type OBJECT")
case VariableType.ARRAY_STRING:
if not isinstance(v, list) or not all(isinstance(i, str) for i in v):
raise TypeError("Variable value does not match type ARRAY_STRING")
case VariableType.ARRAY_NUMBER:
if not isinstance(v, list) or not all(isinstance(i, (int, float)) for i in v):
raise TypeError("Variable value does not match type ARRAY_NUMBER")
case VariableType.ARRAY_BOOLEAN:
if not isinstance(v, list) or not all(isinstance(i, bool) for i in v):
raise TypeError("Variable value does not match type ARRAY_BOOLEAN")
case VariableType.ARRAY_OBJECT:
if not isinstance(v, list) or not all(isinstance(i, dict) for i in v):
raise TypeError("Variable value does not match type ARRAY_OBJECT")
case _:
raise TypeError(f"Unknown variable type: {t}")
class VariableDefinition(BaseModel):
"""变量定义

View File

@@ -1,12 +1,7 @@
"""
工作流节点基类
定义节点的基本接口和通用功能。
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, AsyncGenerator
from langgraph.config import get_stream_writer
@@ -14,7 +9,9 @@ from typing_extensions import TypedDict, Annotated
from app.core.config import settings
from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.services.multimodal_service import PROVIDER_STRATEGIES
logger = logging.getLogger(__name__)
@@ -42,22 +39,10 @@ class WorkflowState(TypedDict):
cycle_nodes: list
looping: Annotated[int, merge_looping_state]
# Input variables (passed from configured variables)
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
variables: Annotated[dict[str, Any], lambda x, y: {
**x,
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
for k, v in y.items()}
}]
# Node outputs (stores execution results of each node for variable references)
# Uses a custom merge function to combine new node outputs into the existing dictionary
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# Runtime node variables (simplified version, stores business data for fast access between nodes)
# Format: {node_id: business_result}
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# Execution context
execution_id: str
workspace_id: str
@@ -72,17 +57,17 @@ class WorkflowState(TypedDict):
class BaseNode(ABC):
"""节点基类
所有节点类型都应该继承此基类,实现 execute 方法。
"""Base class for workflow nodes.
All node types should inherit from this class and implement the `execute` method.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化节点
"""Initialize the node.
Args:
node_config: 节点配置
workflow_config: 工作流配置
node_config: Configuration of the node.
workflow_config: Configuration of the workflow.
"""
self.node_config = node_config
self.workflow_config = workflow_config
@@ -94,7 +79,27 @@ class BaseNode(ABC):
self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {}
self.variable_updater = False
self.variable_change_able = False
@cached_property
def output_types(self) -> dict[str, VariableType]:
"""Returns the output variable types of the node.
This property is cached to avoid recomputation.
"""
return self._output_types()
@abstractmethod
def _output_types(self) -> dict[str, VariableType]:
"""Defines output variable types for the node.
Subclasses must override this method to declare the variables
produced by the node and their corresponding types.
Returns:
A mapping from output variable names to ``VariableType``.
"""
return {}
def check_activate(self, state: WorkflowState):
"""Check if the current node is activated in the workflow state.
@@ -136,92 +141,84 @@ class BaseNode(ABC):
}
@abstractmethod
async def execute(self, state: WorkflowState) -> Any:
"""执行节点业务逻辑(非流式)
节点只需要返回业务结果,不需要关心输出格式、时间统计等。
BaseNode 会自动包装成标准格式。
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""Executes the node business logic (non-streaming).
The node implementation should only return the business result.
It does not need to handle output formatting, timing, or statistics.
The ``BaseNode`` will automatically wrap the result into a standard
response format.
Args:
state: 工作流状态
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Returns:
业务结果(任意类型)
Examples:
>>> # LLM 节点
>>> "这是 AI 的回复"
>>> # Transform 节点
>>> {"processed_data": [...]}
>>> # Start/End 节点
>>> {"message": "开始", "conversation_id": "xxx"}
The business result produced by the node. The return value can be
of any type.
"""
pass
async def execute_stream(self, state: WorkflowState):
"""执行节点业务逻辑(流式)
子类可以重写此方法以支持流式输出。
默认实现:执行非流式方法并一次性返回。
节点需要:
1. yield 中间结果(如文本片段)
2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result}
Args:
state: 工作流状态
Yields:
业务数据chunk或完成标记
Examples:
# 流式 LLM 节点
full_response = ""
async for chunk in llm.astream(prompt):
full_response += chunk
yield chunk # yield 文本片段
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
"""Executes the node business logic in streaming mode.
# 最后 yield 完成标记
yield {"__final__": True, "result": AIMessage(content=full_response)}
Subclasses may override this method to support streaming output.
The default implementation executes the non-streaming method and
yields a single final result.
For streaming execution, a node implementation should:
1. Yield intermediate results (e.g. text chunks).
2. Yield a final completion marker in the following format:
``{"__final__": True, "result": final_result}``.
Args:
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Yields:
Business data chunks or a final completion marker.
"""
result = await self.execute(state)
# 默认实现:直接 yield 完成标记
result = await self.execute(state, variable_pool)
# Default implementation: yield a single final completion marker.
yield {"__final__": True, "result": result}
def supports_streaming(self) -> bool:
"""节点是否支持流式输出
"""Returns whether the node supports streaming output.
A node is considered to support streaming if its class overrides
the ``execute_stream`` method. If the default implementation from
``BaseNode`` is used, streaming is not supported.
Returns:
是否支持流式输出
True if the node supports streaming output, False otherwise.
"""
# 检查子类是否重写了 execute_stream 方法
# Check whether the subclass overrides the execute_stream method.
return self.__class__.execute_stream is not BaseNode.execute_stream
def get_timeout(self) -> int:
"""获取超时时间(秒)
@staticmethod
def get_timeout() -> int:
"""Returns the execution timeout in seconds.
Returns:
超时时间
The timeout duration, in seconds.
"""
return settings.WORKFLOW_NODE_TIMEOUT
# return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]:
"""执行节点(带错误处理和输出包装,非流式)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute() 方法
3. 将业务结果包装成标准输出格式
4. 错误处理
async def run(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Runs the node with error handling and output wrapping (non-streaming).
This method is invoked by the Executor and is responsible for:
1. Execution time measurement.
2. Invoking the node's ``execute()`` method.
3. Wrapping the business result into a standardized output format.
4. Handling execution errors.
Args:
state: 工作流状态
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Returns:
标准化的状态更新字典
A standardized state update dictionary.
"""
if not self.check_activate(state):
return self.trans_activate(state)
@@ -233,70 +230,78 @@ class BaseNode(ABC):
timeout = self.get_timeout()
try:
# 调用节点的业务逻辑
# Invoke the node business logic.
business_result = await asyncio.wait_for(
self.execute(state),
self.execute(state, variable_pool),
timeout=timeout
)
elapsed_time = time.time() - start_time
# 提取处理后的输出(调用子类的 _extract_output
# Extract processed outputs using subclass-defined logic.
extracted_output = self._extract_output(business_result)
# 包装成标准输出格式
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
# Wrap the business result into the standard output format.
wrapped_output = self._wrap_output(business_result, elapsed_time, state, variable_pool)
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# Store extracted outputs as runtime variables for downstream nodes.
if extracted_output is not None:
runtime_vars = extracted_output
if not isinstance(extracted_output, dict):
runtime_vars = {"output": extracted_output}
for k, v in runtime_vars.items():
await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able)
# 返回包装后的输出和运行时变量
# Return the wrapped output along with activation state updates.
return {
**wrapped_output,
"messages": state["messages"],
"runtime_vars": {
self.node_id: runtime_var
},
"looping": state["looping"]
} | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
logger.error(
f"Node {self.node_id} execution timed out ({timeout} seconds)."
)
return self._wrap_error(
f"Node execution timed out ({timeout} seconds).",
elapsed_time,
state,
variable_pool,
)
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
return self._wrap_error(str(e), elapsed_time, state)
logger.error(
f"Node {self.node_id} execution failed: {e}",
exc_info=True,
)
return self._wrap_error(str(e), elapsed_time, state, variable_pool)
async def run_stream(
self, state: WorkflowState,
variable_pool: VariablePool
) -> AsyncGenerator[dict[str, Any], Any]:
"""Executes the node with error handling and output wrapping (streaming).
async def run_stream(self, state: WorkflowState) -> AsyncGenerator[dict[str, Any], Any]:
"""Execute node with error handling and output wrapping (streaming)
This method is called by the Executor and is responsible for:
1. Time tracking
2. Calling the node's execute_stream() method
3. Using LangGraph's stream writer to send chunks
4. Updating streaming buffer in state for downstream nodes
5. Wrapping business data into standard output format
6. Error handling
Special handling for End nodes:
- End nodes don't send chunks via writer (prefix and LLM content already sent)
- End nodes only yield suffix for final result assembly
1. Tracking execution time.
2. Calling the node's ``execute_stream()`` method.
3. Sending streaming chunks via LangGraph's stream writer.
4. Updating activation-related state for downstream nodes.
5. Wrapping business data into a standardized output format.
6. Handling execution errors.
Args:
state: Workflow state
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Yields:
State updates with streaming buffer and final result
Incremental state updates, including activation state changes and
the final wrapped result.
"""
if not self.check_activate(state):
yield self.trans_activate(state)
logger.info(f"jump node: {self.node_id}")
logger.debug(f"jump node: {self.node_id}")
return
import time
@@ -317,7 +322,7 @@ class BaseNode(ABC):
# Stream chunks in real-time
loop_start = asyncio.get_event_loop().time()
async for item in self.execute_stream(state):
async for item in self.execute_stream(state, variable_pool):
# Check timeout
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
@@ -332,7 +337,7 @@ class BaseNode(ABC):
chunks.append(content)
# Send chunks for all nodes (including End nodes for suffix)
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...")
logger.debug(f"Node {self.node_id} sent chunk #{chunk_count}: {content[:50]}...")
# 1. Send via stream writer (for real-time client updates)
writer({
@@ -344,27 +349,26 @@ class BaseNode(ABC):
elapsed_time = time.time() - start_time
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
logger.info(f"Node {self.node_id} streaming execution finished, "
f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result)
# Wrap final result
final_output = self._wrap_output(final_result, elapsed_time, state)
final_output = self._wrap_output(final_result, elapsed_time, state, variable_pool)
# Store extracted output in runtime variables (for quick access by subsequent nodes)
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
if extracted_output is not None:
runtime_vars = extracted_output
if not isinstance(extracted_output, dict):
runtime_vars = {"output": extracted_output}
for k, v in runtime_vars.items():
await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able)
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = {
**final_output,
"messages": state["messages"],
"runtime_vars": {
self.node_id: runtime_var
},
"looping": state["looping"]
}
@@ -374,41 +378,49 @@ class BaseNode(ABC):
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
logger.error(f"Node {self.node_id} execution timed out ({timeout}s)")
error_output = self._wrap_error(
f"Node execution timed out ({timeout}s)",
elapsed_time,
state,
variable_pool
)
yield error_output
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
error_output = self._wrap_error(str(e), elapsed_time, state)
logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True)
error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool)
yield error_output
def _wrap_output(
self,
business_result: Any,
elapsed_time: float,
state: WorkflowState
state: WorkflowState,
variable_pool: VariablePool
) -> dict[str, Any]:
"""将业务结果包装成标准输出格式
Args:
business_result: 节点返回的业务结果
elapsed_time: 执行耗时
state: 工作流状态
Returns:
标准化的状态更新字典
"""
# 提取输入数据(用于记录)
input_data = self._extract_input(state)
"""Wraps the business result into a standardized node output format.
# 提取 token 使用情况(如果有)
Args:
business_result: The result returned by the node's business logic.
elapsed_time: Time elapsed during node execution (in seconds).
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Returns:
A dictionary representing the standardized state update for this node,
including node outputs, input, output, elapsed time, token usage, and status.
"""
# Extract input data (for logging or audit purposes)
input_data = self._extract_input(state, variable_pool)
# Extract token usage information (if applicable)
token_usage = self._extract_token_usage(business_result)
# 提取实际输出(去除元数据)
# Extract actual output (strip any metadata)
output = self._extract_output(business_result)
# 构建标准节点输出
# Construct standardized node output
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
@@ -423,8 +435,6 @@ class BaseNode(ABC):
final_output = {
"node_outputs": {self.node_id: node_output},
}
if self.variable_updater:
final_output = final_output | {"variables": state["variables"]}
return final_output
@@ -432,25 +442,33 @@ class BaseNode(ABC):
self,
error_message: str,
elapsed_time: float,
state: WorkflowState
state: WorkflowState,
variable_pool: VariablePool
) -> dict[str, Any]:
"""将错误包装成标准输出格式
"""Wraps an error into a standardized node output format.
This method handles both cases:
- If an error edge is defined, the workflow can continue to the error handling node.
- If no error edge exists, the workflow is stopped by raising an exception.
Args:
error_message: 错误信息
elapsed_time: 执行耗时
state: 工作流状态
error_message: The error message describing the failure.
elapsed_time: Time elapsed during node execution (in seconds).
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Returns:
标准化的状态更新字典
A dictionary representing the standardized state update for this node
when an error edge exists. If no error edge exists, this method
raises an exception to stop the workflow.
"""
# 查找错误边
# Check if the node has an error edge defined
error_edge = self._find_error_edge()
# 提取输入数据
input_data = self._extract_input(state)
# Extract input data (for logging or audit purposes)
input_data = self._extract_input(state, variable_pool)
# 构建错误输出
# Construct the standardized node output for the error
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
@@ -464,9 +482,9 @@ class BaseNode(ABC):
}
if error_edge:
# 有错误边:记录错误并继续
# If an error edge exists, log a warning and continue to error node
logger.warning(
f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}"
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
)
return {
"node_outputs": {
@@ -476,198 +494,188 @@ class BaseNode(ABC):
"error_node": self.node_id
}
else:
# If no error edge, send the error via stream writer and stop the workflow
writer = get_stream_writer()
writer({
"type": "node_error",
**node_output
})
# 无错误边:抛出异常停止工作流
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Extracts the input data for this node (used for logging or audit).
Subclasses may override this method to customize what input data
should be recorded.
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取节点输入数据(用于记录)
子类可以重写此方法来自定义输入记录。
Args:
state: 工作流状态
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Returns:
输入数据字典
A dictionary containing the node's input data.
"""
# 默认返回配置
# Default implementation returns the node configuration
return {"config": self.config}
def _extract_output(self, business_result: Any) -> Any:
"""从业务结果中提取实际输出
子类可以重写此方法来自定义输出提取。
"""Extracts the actual output from the business result.
Subclasses may override this method to customize how the node's
output is extracted.
Args:
business_result: 业务结果
business_result: The result returned by the node's business logic.
Returns:
实际输出
The actual output extracted from the business result.
"""
# 默认直接返回业务结果
# Default implementation returns the business result directly
return business_result
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从业务结果中提取 token 使用情况
子类可以重写此方法来提取 token 信息。
"""Extracts token usage information from the business result.
Subclasses may override this method to extract token usage statistics
(e.g., for LLM nodes).
Args:
business_result: 业务结果
business_result: The result returned by the node's business logic.
Returns:
token 使用情况或 None
A dictionary mapping token types to counts, or None if not applicable.
"""
# 默认返回 None
# Default implementation returns None
return None
def _find_error_edge(self) -> dict[str, Any] | None:
"""查找错误边
"""Finds the error edge for this node, if any.
An error edge is used to redirect workflow execution when this node
fails.
Returns:
错误边配置或 None
A dictionary representing the error edge configuration if it exists,
or None if no error edge is defined.
"""
for edge in self.workflow_config.get("edges", []):
if edge.get("source") == self.node_id and edge.get("type") == "error":
return edge
return None
def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str:
"""渲染模板
支持的变量命名空间:
- sys.xxx: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.xxx: 会话变量(跨多轮对话保持)
- node_id.xxx: 节点输出
@staticmethod
def _render_template(template: str, variable_pool: VariablePool, strict: bool = True) -> str:
"""Renders a template string using the provided variable pool.
Supported variable namespaces:
- sys.xxx: System variables (e.g., message, execution_id, workspace_id,
user_id, conversation_id)
- conv.xxx: Conversation variables (persist across multiple turns)
- node_id.xxx: Node outputs
Args:
template: 模板字符串
state: 工作流状态
template: The template string to render.
variable_pool: The variable pool containing system, conversation, and
node variables.
strict: If True, missing variables will raise an error; if False,
missing variables are ignored.
Returns:
渲染后的字符串
The rendered string with all variables substituted.
"""
from app.core.workflow.template_renderer import render_template
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
# 构建完整的 variables 结构
variables = {
"sys": pool.get_all_system_vars(),
"conv": pool.get_all_conversation_vars()
}
return render_template(
template=template,
variables=variables,
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
conv_vars=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars(),
strict=strict
)
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
"""评估条件表达式
支持的变量命名空间:
- sys.xxx: 系统变量
- conv.xxx: 会话变量
- node_id.xxx: 节点输出
@staticmethod
def _evaluate_condition(expression: str, variable_pool: VariablePool) -> bool:
"""Evaluates a conditional expression using the provided variable pool.
Supported variable namespaces:
- sys.xxx: System variables
- conv.xxx: Conversation variables
- node_id.xxx: Node outputs
Args:
expression: 条件表达式
state: 工作流状态
expression: The conditional expression to evaluate.
variable_pool: The variable pool containing system, conversation, and
node variables.
Returns:
布尔值结果
The boolean result of evaluating the expression.
"""
from app.core.workflow.expression_evaluator import evaluate_condition
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
# 构建完整的 variables 结构(包含 sys 和 conv
variables = {
"sys": pool.get_all_system_vars(),
"conv": pool.get_all_conversation_vars()
}
return evaluate_condition(
expression=expression,
variables=variables,
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars()
conv_var=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars()
)
def get_variable_pool(self, state: WorkflowState) -> VariablePool:
"""获取变量池实例
VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。
Args:
state: 工作流状态
Returns:
VariablePool 实例
Examples:
>>> pool = self.get_variable_pool(state)
>>> message = pool.get("sys.message")
>>> llm_output = pool.get("llm_qa.output")
"""
return VariablePool(state)
@staticmethod
def get_variable(
self,
selector: list[str] | str,
state: WorkflowState,
default: Any = None
selector: str,
variable_pool: VariablePool,
default: Any = None,
strict: bool = True
) -> Any:
"""获取变量值(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
default: 默认值
Returns:
变量值
Examples:
>>> message = self.get_variable("sys.message", state)
>>> output = self.get_variable(["llm_qa", "output"], state)
>>> custom = self.get_variable("var.custom", state, default="默认值")
"""
pool = VariablePool(state)
return pool.get(selector, default=default)
"""Retrieves a variable value from the variable pool (convenience method).
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
"""检查变量是否存在(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx).
variable_pool: The variable pool from which to fetch the value.
default: The default value to return if the variable does not exist.
strict: If True, raise an error when the variable is missing; if False, return the default.
Returns:
变量是否存在
Examples:
>>> if self.has_variable("llm_qa.output", state):
... output = self.get_variable("llm_qa.output", state)
The value of the selected variable, or the default if not found and strict is False.
"""
pool = VariablePool(state)
return pool.has(selector)
return variable_pool.get_value(selector, default, strict=strict)
@staticmethod
def has_variable(selector: str, variable_pool: VariablePool) -> bool:
"""Checks whether a variable exists in the variable pool (convenience method).
Args:
selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx).
variable_pool: The variable pool to check.
Returns:
True if the variable exists in the pool, False otherwise.
"""
return variable_pool.has(selector)
@staticmethod
async def process_message(provider, content, enable_file=False) -> dict | str | None:
if isinstance(content, str):
if enable_file:
return {"text": content}
return content
elif isinstance(content, dict):
trans_tool = PROVIDER_STRATEGIES[provider]()
result = await trans_tool.format_image(content["url"])
return result
raise TypeError('Unexpect input value type')
@staticmethod
def process_model_output(content) -> str:
result = ""
if isinstance(content, list):
for msg in content:
if isinstance(msg, dict):
result += msg.get("text")
elif isinstance(msg, str):
result += msg
elif isinstance(content, dict):
result = content.get("text")
elif isinstance(content, str):
return content
return result

View File

@@ -2,6 +2,8 @@ import logging
from typing import Any
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -14,15 +16,19 @@ class BreakNode(BaseNode):
to False, signaling the outer loop runtime to terminate further iterations.
"""
async def execute(self, state: WorkflowState) -> Any:
def _output_types(self) -> dict[str, VariableType]:
return {}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the break node.
Args:
state: Current workflow state, including loop control flags.
variable_pool: Pool of variables for the workflow.
Effects:
- Sets 'looping' in the state to False to stop the loop.
- Sets 'looping' in the state too False to stop the loop.
- Logs the action for debugging purposes.
Returns:

View File

@@ -1,7 +1,8 @@
from typing import Literal
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.variable.base_variable import VariableType
class InputVariable(BaseModel):
@@ -44,7 +45,7 @@ class CodeNodeConfig(BaseNodeConfig):
description="code content"
)
language: Literal['python3', 'nodejs'] = Field(
language: Literal['python3', 'javascript'] = Field(
...,
description="language"
)

View File

@@ -2,15 +2,17 @@ import base64
import json
import logging
import re
import urllib.parse
from string import Template
from textwrap import dedent
from typing import Any
import urllib.parse
import httpx
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -52,6 +54,12 @@ class CodeNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: CodeNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
output_dict = {}
for output in self.typed_config.output_variables:
output_dict[output.name] = output.type
return output_dict
def extract_result(self, content: str):
match = re.search(r'<<RESULT>>(.*?)<<RESULT>>', content, re.DOTALL)
if match:
@@ -92,15 +100,16 @@ class CodeNode(BaseNode):
else:
raise RuntimeError("The output of main must be a dictionary")
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = CodeNodeConfig(**self.config)
input_variable_dict = {}
for input_variable in self.typed_config.input_variables:
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, variable_pool)
code = base64.b64decode(
self.typed_config.code
).decode("utf-8")
code = urllib.parse.unquote(code, encoding='utf-8')
input_variable_dict = base64.b64encode(
json.dumps(input_variable_dict).encode("utf-8")
@@ -110,7 +119,7 @@ class CodeNode(BaseNode):
code=code,
inputs_variable=input_variable_dict,
)
elif self.typed_config.language == 'nodejs':
elif self.typed_config.language == 'javascript':
final_script = NODEJS_SCRIPT_TEMPLATE.substitute(
code=code,
inputs_variable=input_variable_dict,

View File

@@ -8,7 +8,6 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_config import (
BaseNodeConfig,
VariableDefinition,
VariableType,
)
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
@@ -23,21 +22,18 @@ from app.core.workflow.nodes.parameter_extractor.config import ParameterExtracto
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
__all__ = [
# 基础类
"BaseNodeConfig",
"VariableDefinition",
"VariableType",
# 节点配置
"StartNodeConfig",
"EndNodeConfig",
"LLMNodeConfig",
"MessageConfig",
"AgentNodeConfig",
"TransformNodeConfig",
"IfElseNodeConfig",
"KnowledgeRetrievalNodeConfig",
"AssignerNodeConfig",

View File

@@ -2,7 +2,8 @@ from typing import Any
from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
@@ -127,4 +128,9 @@ class IterationNodeConfig(BaseNodeConfig):
description="Output of the loop iteration"
)
output_type: VariableType = Field(
default=None,
description="Data type of the loop iteration output"
)

View File

@@ -7,6 +7,7 @@ from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -28,6 +29,8 @@ class IterationRuntime:
node_id: str,
config: dict[str, Any],
state: WorkflowState,
variable_pool: VariablePool,
child_variable_pool: VariablePool,
):
"""
Initialize the iteration runtime.
@@ -44,11 +47,13 @@ class IterationRuntime:
self.node_id = node_id
self.typed_config = IterationNodeConfig(**config)
self.looping = True
self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool
self.output_value = None
self.result: list = []
def _init_iteration_state(self, item, idx):
async def _init_iteration_state(self, item, idx):
"""
Initialize a per-iteration copy of the workflow state.
@@ -62,10 +67,9 @@ class IterationRuntime:
loopstate = WorkflowState(
**self.state
)
loopstate["runtime_vars"][self.node_id] = {
"item": item,
"index": idx,
}
self.child_variable_pool.copy(self.variable_pool)
await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True)
loopstate["node_outputs"][self.node_id] = {
"item": item,
"index": idx,
@@ -74,6 +78,11 @@ class IterationRuntime:
loopstate["activate"][self.start_id] = True
return loopstate
def merge_conv_vars(self):
self.variable_pool.get_all_conversation_vars().update(
self.child_variable_pool.get_all_conversation_vars()
)
async def run_task(self, item, idx):
"""
Execute a single iteration asynchronously.
@@ -82,8 +91,8 @@ class IterationRuntime:
item: The input element for this iteration.
idx: The index of this iteration.
"""
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
output = VariablePool(result).get(self.output_value)
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
output = self.child_variable_pool.get_value(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
@@ -125,7 +134,7 @@ class IterationRuntime:
input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip()
self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip()
array_obj = VariablePool(self.state).get(input_expression)
array_obj = self.variable_pool.get_value(input_expression)
if not isinstance(array_obj, list):
raise RuntimeError("Cannot iterate over a non-list variable")
child_state = []
@@ -137,14 +146,16 @@ class IterationRuntime:
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
idx += self.typed_config.parallel_count
child_state.extend(await asyncio.gather(*tasks))
self.merge_conv_vars()
else:
# Execute iterations sequentially
while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx]
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
child_state.append(result)
output = VariablePool(result).get(self.output_value)
output = self.child_variable_pool.get_value(self.output_value)
self.merge_conv_vars()
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:

View File

@@ -31,6 +31,8 @@ class LoopRuntime:
node_id: str,
config: dict[str, Any],
state: WorkflowState,
variable_pool: VariablePool,
child_variable_pool: VariablePool
):
"""
Initialize the loop runtime executor.
@@ -40,6 +42,8 @@ class LoopRuntime:
node_id: The unique identifier of the loop node in the workflow.
config: Raw configuration dictionary for the loop node.
state: The current workflow state before entering the loop.
variable_pool: A VariablePool instance for accessing and modifying workflow variables.
child_variable_pool: A VariablePool instance for managing child node outputs.
"""
self.start_id = start_id
self.graph = graph
@@ -47,8 +51,10 @@ class LoopRuntime:
self.node_id = node_id
self.typed_config = LoopNodeConfig(**config)
self.looping = True
self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool
def _init_loop_state(self):
async def _init_loop_state(self):
"""
Initialize workflow state for loop execution.
@@ -62,33 +68,35 @@ class LoopRuntime:
Returns:
WorkflowState: A prepared workflow state used for loop execution.
"""
pool = VariablePool(self.state)
# 循环变量
self.state["runtime_vars"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
self.state["node_outputs"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
self.child_variable_pool.copy(self.variable_pool)
for variable in self.typed_config.cycle_vars:
if variable.input_type == ValueInputType.VARIABLE:
value = evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
else:
value = TypeTransformer.transform(variable.value, variable.type)
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
loopstate = WorkflowState(
**self.state
)
loopstate["node_outputs"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
loopstate["looping"] = 1
loopstate["activate"][self.start_id] = True
return loopstate
@@ -134,7 +142,12 @@ class LoopRuntime:
case _:
raise ValueError(f"Invalid condition: {operator}")
def evaluate_conditional(self, state) -> bool:
def merge_conv_vars(self):
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables.get("conv", {})
)
def evaluate_conditional(self) -> bool:
"""
Evaluate the loop continuation condition at runtime.
@@ -143,18 +156,15 @@ class LoopRuntime:
- Evaluates each comparison expression immediately
- Combines results using the configured logical operator (AND / OR)
Args:
state: The current workflow state during loop execution.
Returns:
bool: True if the loop should continue, False otherwise.
"""
conditions = []
for expression in self.typed_config.condition.expressions:
left_value = VariablePool(state).get(expression.left)
left_value = self.child_variable_pool.get_value(expression.left)
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
VariablePool(state),
self.child_variable_pool,
expression.left,
expression.right,
expression.input_type
@@ -177,16 +187,18 @@ class LoopRuntime:
Returns:
dict[str, Any]: The final runtime variables of this loop node.
"""
loopstate = self._init_loop_state()
loopstate = await self._init_loop_state()
loop_time = self.typed_config.max_loop
child_state = []
while self.evaluate_conditional(loopstate) and self.looping and loop_time > 0:
while not self.evaluate_conditional() and self.looping and loop_time > 0:
logger.info(f"loop node {self.node_id}: running")
result = await self.graph.ainvoke(loopstate)
child_state.append(result)
self.merge_conv_vars()
if result["looping"] == 2:
self.looping = False
loop_time -= 1
logger.info(f"loop node {self.node_id}: execution completed")
return loopstate["runtime_vars"][self.node_id] | {"__child_state": child_state}
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}

View File

@@ -6,9 +6,12 @@ from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -35,9 +38,41 @@ class CycleGraphNode(BaseNode):
self.start_node_id = None # ID of the start node within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None
self.child_variable_pool: VariablePool | None = None
self.build_graph()
self.iteration_flag = True
def _output_types(self) -> dict[str, VariableType]:
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
if self.node_type == NodeType.LOOP:
# Loop node outputs the final state of the loop
config = LoopNodeConfig(**self.config)
for var_def in config.cycle_vars:
outputs[var_def.name] = var_def.type
return outputs
elif self.node_type == NodeType.ITERATION:
# Iteration node outputs the processed collection
config = IterationNodeConfig(**self.config)
if not config.output_type:
outputs['output'] = VariableType.ANY
return outputs
if config.output_type in [
VariableType.ARRAY_FILE,
VariableType.ARRAY_STRING,
VariableType.NUMBER,
VariableType.ARRAY_OBJECT,
VariableType.BOOLEAN
]:
if config.flatten:
outputs['output'] = config.output_type
else:
outputs['output'] = VariableType.ARRAY_STRING
else:
outputs['output'] = VariableType(f"array[{config.output_type}]")
return outputs
else:
raise KeyError(f"Valid Cycle Node Type - {self.node_type}")
def pure_cycle_graph(self) -> tuple[list, list]:
"""
Extract cycle-scoped nodes and internal edges from the workflow configuration.
@@ -103,17 +138,20 @@ class CycleGraphNode(BaseNode):
"""
from app.core.workflow.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.child_variable_pool = VariablePool()
builder = GraphBuilder(
{
"nodes": self.cycle_nodes,
"edges": self.cycle_edges,
},
subgraph=True
subgraph=True,
variable_pool=self.child_variable_pool
)
self.start_node_id = builder.start_node_id
self.graph = builder.build()
self.child_variable_pool = builder.variable_pool
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the cycle node at runtime.
@@ -123,6 +161,7 @@ class CycleGraphNode(BaseNode):
Args:
state: The current workflow state when entering the cycle node.
variable_pool: Variable Pool
Returns:
Any: The runtime result produced by the loop or iteration executor.
@@ -137,6 +176,8 @@ class CycleGraphNode(BaseNode):
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool,
).run()
if self.node_type == NodeType.ITERATION:
return await IterationRuntime(
@@ -145,5 +186,7 @@ class CycleGraphNode(BaseNode):
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool
).run()
raise RuntimeError("Unknown cycle node type")

View File

@@ -2,7 +2,8 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.variable.base_variable import VariableType
class EndNodeConfig(BaseNodeConfig):

View File

@@ -7,6 +7,8 @@ End 节点实现
import logging
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -17,12 +19,18 @@ class EndNode(BaseNode):
工作流的结束节点,根据配置的模板输出最终结果。
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
"""
def _output_types(self) -> dict[str, VariableType]:
"""声明此节点的输出类型"""
return {
"output": VariableType.STRING
}
async def execute(self, state: WorkflowState) -> str:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> str:
"""执行 end 节点业务逻辑
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
最终输出字符串
@@ -34,7 +42,7 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state, strict=False)
output = self._render_template(output_template, variable_pool, strict=False)
else:
output = ""

View File

@@ -9,7 +9,6 @@ class NodeType(StrEnum):
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TRANSFORM = "transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"

View File

@@ -10,6 +10,8 @@ from httpx import AsyncClient, Response, Timeout
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__file__)
@@ -34,6 +36,14 @@ class HttpRequestNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: HttpRequestNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"body": VariableType.STRING,
"status_code": VariableType.NUMBER,
"headers": VariableType.OBJECT,
"output": VariableType.STRING
}
def _build_timeout(self) -> Timeout:
"""
Build httpx Timeout configuration.
@@ -50,7 +60,7 @@ class HttpRequestNode(BaseNode):
)
return timeout
def _build_auth(self, state: WorkflowState) -> dict[str, str]:
def _build_auth(self, variable_pool: VariablePool) -> dict[str, str]:
"""
Build authentication-related HTTP headers.
@@ -58,12 +68,12 @@ class HttpRequestNode(BaseNode):
the current workflow runtime state.
Args:
state: Current workflow runtime state.
variable_pool: Variable Pool
Returns:
A dictionary of HTTP headers used for authentication.
"""
api_key = self._render_template(self.typed_config.auth.api_key, state)
api_key = self._render_template(self.typed_config.auth.api_key, variable_pool)
match self.typed_config.auth.auth_type:
case HttpAuthType.NONE:
return {}
@@ -82,7 +92,7 @@ class HttpRequestNode(BaseNode):
case _:
raise RuntimeError(f"Auth type not supported: {self.typed_config.auth.auth_type}")
def _build_header(self, state: WorkflowState) -> dict[str, str]:
def _build_header(self, variable_pool: VariablePool) -> dict[str, str]:
"""
Build HTTP request headers.
@@ -90,10 +100,10 @@ class HttpRequestNode(BaseNode):
"""
headers = {}
for key, value in self.typed_config.headers.items():
headers[self._render_template(key, state)] = self._render_template(value, state)
headers[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
return headers
def _build_params(self, state: WorkflowState) -> dict[str, str]:
def _build_params(self, variable_pool: VariablePool) -> dict[str, str]:
"""
Build URL query parameters.
@@ -101,10 +111,10 @@ class HttpRequestNode(BaseNode):
"""
params = {}
for key, value in self.typed_config.params.items():
params[self._render_template(key, state)] = self._render_template(value, state)
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
return params
def _build_content(self, state) -> dict[str, Any]:
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
"""
Build HTTP request body arguments for httpx request methods.
@@ -120,13 +130,13 @@ class HttpRequestNode(BaseNode):
return {}
case HttpContentType.JSON:
content["json"] = json.loads(self._render_template(
self.typed_config.body.data, state
self.typed_config.body.data, variable_pool
))
case HttpContentType.FROM_DATA:
data = {}
for item in self.typed_config.body.data:
if item.type == "text":
data[self._render_template(item.key, state)] = self._render_template(item.value, state)
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool)
elif item.type == "file":
# TODO: File support (Feature)
pass
@@ -136,11 +146,11 @@ class HttpRequestNode(BaseNode):
pass
case HttpContentType.WWW_FORM:
content["data"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), state
json.dumps(self.typed_config.body.data), variable_pool
))
case HttpContentType.RAW:
content["content"] = self._render_template(self.typed_config.body.data, state)
content["content"] = self._render_template(self.typed_config.body.data, variable_pool)
case _:
raise RuntimeError(f"Content type not supported: {self.typed_config.body.content_type}")
return content
@@ -165,7 +175,7 @@ class HttpRequestNode(BaseNode):
case _:
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
async def execute(self, state: WorkflowState) -> dict | str:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
"""
Execute the HTTP request node.
@@ -176,6 +186,7 @@ class HttpRequestNode(BaseNode):
Args:
state: Current workflow runtime state.
variable_pool: Variable Pool
Returns:
- dict: Serialized HttpRequestNodeOutput on success
@@ -185,8 +196,8 @@ class HttpRequestNode(BaseNode):
async with httpx.AsyncClient(
verify=self.typed_config.verify_ssl,
timeout=self._build_timeout(),
headers=self._build_header(state) | self._build_auth(state),
params=self._build_params(state),
headers=self._build_header(variable_pool) | self._build_auth(variable_pool),
params=self._build_params(variable_pool),
follow_redirects=True
) as client:
retries = self.typed_config.retry.max_attempts
@@ -194,8 +205,8 @@ class HttpRequestNode(BaseNode):
try:
request_func = self._get_client_method(client)
resp = await request_func(
url=self._render_template(self.typed_config.url, state),
**self._build_content(state)
url=self._render_template(self.typed_config.url, variable_pool),
**self._build_content(variable_pool)
)
resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded")

View File

@@ -6,6 +6,8 @@ from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -15,6 +17,11 @@ class IfElseNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: IfElseNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"output": VariableType.STRING
}
@staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
match operator:
@@ -45,7 +52,7 @@ class IfElseNode(BaseNode):
case _:
raise ValueError(f"Invalid condition: {operator}")
def evaluate_conditional_edge_expressions(self, state) -> list[bool]:
def evaluate_conditional_edge_expressions(self, variable_pool: VariablePool) -> list[bool]:
"""
Build conditional edge expressions for the If-Else node.
@@ -72,11 +79,11 @@ class IfElseNode(BaseNode):
pattern = r"\{\{\s*(.*?)\s*\}\}"
left_string = re.sub(pattern, r"\1", expression.left).strip()
try:
left_value = self.get_variable(left_string, state)
left_value = self.get_variable(left_string, variable_pool)
except KeyError:
left_value = None
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
self.get_variable_pool(state),
variable_pool,
expression.left,
expression.right,
expression.input_type
@@ -95,7 +102,7 @@ class IfElseNode(BaseNode):
return conditions
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the conditional branching logic of the node.
@@ -105,13 +112,13 @@ class IfElseNode(BaseNode):
Args:
state (WorkflowState): The current workflow state, containing variables, messages, node outputs, etc.
variable_pool: Variable Pool
Returns:
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
"""
self.typed_config = IfElseNodeConfig(**self.config)
expressions = self.evaluate_conditional_edge_expressions(state)
# TODO: 变量类型及文本类型解析
expressions = self.evaluate_conditional_edge_expressions(variable_pool)
for i in range(len(expressions)):
if expressions[i]:
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")

View File

@@ -5,6 +5,8 @@ from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
from app.core.workflow.template_renderer import TemplateRenderer
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -14,7 +16,12 @@ class JinjaRenderNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: JinjaRenderNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
def _output_types(self) -> dict[str, VariableType]:
return {
"output": VariableType.STRING
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the node: render the Jinja2 template with mapped variables.
@@ -24,6 +31,7 @@ class JinjaRenderNode(BaseNode):
Args:
state (WorkflowState): Current workflow state containing variables,
node outputs, and runtime variables.
variable_pool: Variable Pool
Returns:
dict[str, Any]: Node output dictionary containing the rendered result
@@ -40,7 +48,7 @@ class JinjaRenderNode(BaseNode):
context = {}
for variable in self.typed_config.mapping:
try:
context[variable.name] = self.get_variable(variable.value, state)
context[variable.name] = self.get_variable(variable.value, variable_pool)
except Exception:
logger.info(f"variable not found, var: {variable.value}")
continue

View File

@@ -8,6 +8,8 @@ from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model, ModelType
from app.repositories import knowledge_repository, knowledgeshare_repository
@@ -22,6 +24,11 @@ class KnowledgeRetrievalNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"output": VariableType.ARRAY_STRING
}
@staticmethod
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
"""
@@ -149,7 +156,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
return reranker
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the knowledge retrieval workflow node.
@@ -163,6 +170,7 @@ class KnowledgeRetrievalNode(BaseNode):
Args:
state (WorkflowState): Current workflow execution state.
variable_pool: Variable Pool
Returns:
Any: List of retrieved knowledge chunks (dict format).
@@ -171,7 +179,7 @@ class KnowledgeRetrievalNode(BaseNode):
RuntimeError: If no valid knowledge base is found or access is denied.
"""
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
query = self._render_template(self.typed_config.query, state)
query = self._render_template(self.typed_config.query, variable_pool)
with get_db_read() as db:
knowledge_bases = self.typed_config.knowledge_bases
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])

View File

@@ -4,7 +4,8 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.variable.base_variable import VariableType
class MessageConfig(BaseModel):
@@ -70,6 +71,16 @@ class LLMNodeConfig(BaseNodeConfig):
description="对话上下文窗口"
)
vision: bool = Field(
default=False,
description="是否启用视觉模型"
)
vision_input: str = Field(
default=None,
description="视觉输入"
)
# 简单模式
prompt: str | None = Field(
default=None,

View File

@@ -15,6 +15,8 @@ from app.core.exceptions import BusinessException
from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_context
from app.models import ModelType
from app.services.model_service import ModelConfigService
@@ -66,59 +68,34 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage
"""
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: LLMNodeConfig | None = None
def _render_context(self, message, state):
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
def _render_context(self, message: str, variable_pool: VariablePool):
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
return re.sub(r"{{context}}", context, message)
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
async def _prepare_llm(
self,
state: WorkflowState,
variable_pool: VariablePool,
stream: bool = False
) -> RedBearLLM:
"""准备 LLM 实例(公共逻辑)
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
"""
# 1. 处理消息格式(优先使用 messages
self.typed_config = LLMNodeConfig(**self.config)
messages_config = self.typed_config.messages
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.role.lower()
content_template = msg_config.content
content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象
if role == "system":
messages.append({"role": "system", "content": content})
elif role in ["user", "human"]:
messages.append({"role": "user", "content": content})
elif role in ["ai", "assistant"]:
messages.append({"role": "assistant", "content": content})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({"role": "user", "content": content})
if self.typed_config.memory.enable:
# if self.typed_config.memory.enable_window:
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
prompt_or_messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
prompt_template = self.config.get("prompt", "")
prompt_or_messages = self._render_template(prompt_template, state)
# 2. 获取模型配置
model_id = self.config.get("model_id")
model_id = self.typed_config.model_id
if not model_id:
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
@@ -157,27 +134,82 @@ class LLMNode(BaseNode):
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
return llm, prompt_or_messages
messages_config = self.typed_config.messages
async def execute(self, state: WorkflowState) -> AIMessage:
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.role.lower()
content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool)
# 根据角色创建对应的消息对象
if role == "system":
messages.append({
"role": "system",
"content": content
})
elif role in ["user", "human"]:
messages.append({
"role": "user",
"content": content
})
elif role in ["ai", "assistant"]:
messages.append({
"role": "assistant",
"content": content
})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({
"role": "user",
"content": content
})
if self.typed_config.vision_input and self.typed_config.vision:
file_content = []
files = variable_pool.get_value(self.typed_config.vision_input)
for file in files:
content = await self.process_message(provider, file, self.typed_config.vision)
if content:
file_content.append(content)
if messages and messages[-1]["role"] == 'user':
messages[-1]['content'] = [messages[-1]["content"]] + file_content
else:
messages.append({"role": "user", "content": file_content})
if self.typed_config.memory.enable:
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
self.messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
prompt_template = self.config.get("prompt", "")
self.messages = self._render_template(prompt_template, variable_pool)
return llm
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
"""非流式执行 LLM 调用
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
LLM 响应消息
"""
# self.typed_config = LLMNodeConfig(**self.config)
llm, prompt_or_messages = self._prepare_llm(state, True)
llm = await self._prepare_llm(state, variable_pool, False)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
response = await llm.ainvoke(self.messages)
# 提取内容
if hasattr(response, 'content'):
content = response.content
content = self.process_model_output(response.content)
else:
content = str(response)
@@ -186,16 +218,15 @@ class LLMNode(BaseNode):
# 返回 AIMessage包含响应元数据
return response if isinstance(response, AIMessage) else AIMessage(content=content)
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
_, prompt_or_messages = self._prepare_llm(state)
return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"prompt": self.messages if isinstance(self.messages, str) else None,
"messages": [
{"role": msg.get("role"), "content": msg.get("content", "")}
for msg in prompt_or_messages
] if isinstance(prompt_or_messages, list) else None,
for msg in self.messages
] if isinstance(self.messages, list) else None,
"config": {
"model_id": self.config.get("model_id"),
"temperature": self.config.get("temperature"),
@@ -215,24 +246,25 @@ class LLMNode(BaseNode):
usage = business_result.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"prompt_tokens": usage.get('input_tokens', 0),
"completion_tokens": usage.get('output_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
async def execute_stream(self, state: WorkflowState):
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
"""流式执行 LLM 调用
Args:
state: 工作流状态
variable_pool: 变量池
Yields:
文本片段chunk或完成标记
"""
self.typed_config = LLMNodeConfig(**self.config)
llm, prompt_or_messages = self._prepare_llm(state, True)
llm = await self._prepare_llm(state, variable_pool, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
@@ -243,10 +275,10 @@ class LLMNode(BaseNode):
# 调用 LLM流式支持字符串或消息列表
last_meta_data = {}
async for chunk in llm.astream(prompt_or_messages, stream_usage=True):
async for chunk in llm.astream(self.messages, stream_usage=True):
# 提取内容
if hasattr(chunk, 'content'):
content = chunk.content
content = self.process_model_output(chunk.content)
else:
content = str(chunk)
if hasattr(chunk, 'response_metadata'):

View File

@@ -3,6 +3,8 @@ from typing import Any
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task
@@ -13,17 +15,23 @@ class MemoryReadNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: MemoryReadNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
def _output_types(self) -> dict[str, VariableType]:
return {
"answer": VariableType.STRING,
"intermediate_outputs": VariableType.ARRAY_OBJECT
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = MemoryReadNodeConfig(**self.config)
with get_db_read() as db:
end_user_id = self.get_variable("sys.user_id", state)
end_user_id = self.get_variable("sys.user_id", variable_pool)
if not end_user_id:
raise RuntimeError("End user id is required")
return await MemoryAgentService().read_memory(
end_user_id=end_user_id,
message=self._render_template(self.typed_config.message, state),
message=self._render_template(self.typed_config.message, variable_pool),
config_id=self.typed_config.config_id,
search_switch=self.typed_config.search_switch,
history=[],
@@ -38,16 +46,19 @@ class MemoryWriteNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: MemoryWriteNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = MemoryWriteNodeConfig(**self.config)
end_user_id = self.get_variable("sys.user_id", state)
end_user_id = self.get_variable("sys.user_id", variable_pool)
if not end_user_id:
raise RuntimeError("End user id is required")
write_message_task.delay(
end_user_id,
self._render_template(self.typed_config.message, state),
self._render_template(self.typed_config.message, variable_pool),
str(self.typed_config.config_id),
"neo4j",
""

View File

@@ -22,7 +22,6 @@ from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.memory import MemoryReadNode, MemoryWriteNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
@@ -37,7 +36,6 @@ WorkflowNode = Union[
LLMNode,
IfElseNode,
AgentNode,
TransformNode,
AssignerNode,
HttpRequestNode,
KnowledgeRetrievalNode,
@@ -67,7 +65,6 @@ class NodeFactory:
NodeType.END: EndNode,
NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode,

View File

@@ -1,9 +1,9 @@
import json
import re
from abc import ABC
from typing import Union, Type, NoReturn
from typing import Union, Type, NoReturn, Any
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.nodes.enums import ValueInputType
from app.core.workflow.variable_pool import VariablePool
@@ -69,7 +69,7 @@ class TypeTransformer:
class OperatorBase(ABC):
def __init__(self, pool: VariablePool, left_selector, right):
def __init__(self, pool: VariablePool, left_selector: str, right: Any):
self.pool = pool
self.left_selector = left_selector
self.right = right
@@ -77,7 +77,7 @@ class OperatorBase(ABC):
self.type_limit: type[str, int, dict, list] = None
def check(self, no_right=False):
left = self.pool.get(self.left_selector)
left = self.pool.get_value(self.left_selector)
if not isinstance(left, self.type_limit):
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
@@ -92,13 +92,13 @@ class StringOperator(OperatorBase):
super().__init__(pool, left_selector, right)
self.type_limit = str
def assign(self) -> None:
async def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
await self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
async def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, '')
await self.pool.set(self.left_selector, '')
class NumberOperator(OperatorBase):
@@ -106,33 +106,33 @@ class NumberOperator(OperatorBase):
super().__init__(pool, left_selector, right)
self.type_limit = (float, int)
def assign(self) -> None:
async def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
await self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
async def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, 0)
await self.pool.set(self.left_selector, 0)
def add(self) -> None:
async def add(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin + self.right)
origin = self.pool.get_value(self.left_selector)
await self.pool.set(self.left_selector, origin + self.right)
def subtract(self) -> None:
async def subtract(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin - self.right)
origin = self.pool.get_value(self.left_selector)
await self.pool.set(self.left_selector, origin - self.right)
def multiply(self) -> None:
async def multiply(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin * self.right)
origin = self.pool.get_value(self.left_selector)
await self.pool.set(self.left_selector, origin * self.right)
def divide(self) -> None:
async def divide(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin / self.right)
origin = self.pool.get_value(self.left_selector)
await self.pool.set(self.left_selector, origin / self.right)
class BooleanOperator(OperatorBase):
@@ -140,13 +140,13 @@ class BooleanOperator(OperatorBase):
super().__init__(pool, left_selector, right)
self.type_limit = bool
def assign(self) -> None:
async def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
await self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
async def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, False)
await self.pool.set(self.left_selector, False)
class ArrayOperator(OperatorBase):
@@ -154,38 +154,37 @@ class ArrayOperator(OperatorBase):
super().__init__(pool, left_selector, right)
self.type_limit = list
def assign(self) -> None:
async def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
await self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
async def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, list())
await self.pool.set(self.left_selector, list())
def append(self) -> None:
async def append(self) -> None:
self.check(no_right=True)
# TODOrequire type limit in list
origin = self.pool.get(self.left_selector)
origin = self.pool.get_value(self.left_selector)
origin.append(self.right)
self.pool.set(self.left_selector, origin)
await self.pool.set(self.left_selector, origin)
def extend(self) -> None:
async def extend(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin = self.pool.get_value(self.left_selector)
origin.extend(self.right)
self.pool.set(self.left_selector, origin)
await self.pool.set(self.left_selector, origin)
def remove_last(self) -> None:
async def remove_last(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin = self.pool.get_value(self.left_selector)
origin.pop()
self.pool.set(self.left_selector, origin)
await self.pool.set(self.left_selector, origin)
def remove_first(self) -> None:
async def remove_first(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin = self.pool.get_value(self.left_selector)
origin.pop(0)
self.pool.set(self.left_selector, origin)
await self.pool.set(self.left_selector, origin)
class ObjectOperator(OperatorBase):
@@ -193,13 +192,13 @@ class ObjectOperator(OperatorBase):
super().__init__(pool, left_selector, right)
self.type_limit = dict
def assign(self) -> None:
async def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
await self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
async def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, dict())
await self.pool.set(self.left_selector, dict())
class AssignmentOperatorResolver:
@@ -245,7 +244,7 @@ class ConditionBase(ABC):
self.right_selector = right_selector
self.input_type = input_type
self.left_value = self.pool.get(self.left_selector)
self.left_value = self.pool.get_value(self.left_selector)
self.right_value = self.resolve_right_literal_value()
self.type_limit = getattr(self, "type_limit", None)
@@ -254,7 +253,7 @@ class ConditionBase(ABC):
if self.input_type == ValueInputType.VARIABLE:
pattern = r"\{\{\s*(.*?)\s*\}\}"
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
return self.pool.get(right_expression)
return self.pool.get_value(right_expression)
elif self.input_type == ValueInputType.CONSTANT:
return self.right_selector
raise RuntimeError("Unsupported variable type")

View File

@@ -1,7 +1,7 @@
import uuid
from enum import StrEnum
from pydantic import Field, BaseModel
from enum import StrEnum
from app.core.workflow.nodes.base_config import BaseNodeConfig

View File

@@ -12,6 +12,8 @@ from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.models import ModelType
from app.services.model_service import ModelConfigService
@@ -24,6 +26,12 @@ class ParameterExtractorNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: ParameterExtractorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
outputs = {}
for param in self.typed_config.params:
outputs[param.name] = param.type
return outputs
@staticmethod
def _get_prompt():
"""
@@ -120,7 +128,7 @@ class ParameterExtractorNode(BaseNode):
field_type[param.name] = f'{param.type}, required:{str(param.required)}'
return field_type
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Main execution function for this node.
@@ -138,6 +146,7 @@ class ParameterExtractorNode(BaseNode):
Args:
state (WorkflowState): Current state of the workflow, used for template rendering.
variable_pool (VariablePool): Used for accessing and setting variables during execution.
Returns:
dict[str, Any]: Dictionary containing extracted parameters under the "output" key.
@@ -153,7 +162,7 @@ class ParameterExtractorNode(BaseNode):
rendered_user_prompt = user_prompt_teplate.render(
field_descriptions=str(self._get_field_desc()),
field_type=str(self._get_field_type()),
text_input=self._render_template(self.typed_config.text, state)
text_input=self._render_template(self.typed_config.text, variable_pool)
)
messages = [
@@ -162,7 +171,7 @@ class ParameterExtractorNode(BaseNode):
]
if self.typed_config.prompt:
messages.extend([
("user", self._render_template(self.typed_config.prompt, state)),
("user", self._render_template(self.typed_config.prompt, variable_pool)),
("user", rendered_user_prompt),
])
else:

View File

@@ -6,6 +6,8 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.models import ModelType
from app.services.model_service import ModelConfigService
@@ -24,6 +26,12 @@ class QuestionClassifierNode(BaseNode):
self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {}
def _output_types(self) -> dict[str, VariableType]:
return {
"class_name": VariableType.STRING,
"output": VariableType.STRING
}
def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例"""
with get_db_read() as db:
@@ -65,7 +73,7 @@ class QuestionClassifierNode(BaseNode):
category_map[category_name] = case_tag
return category_map
async def execute(self, state: WorkflowState) -> dict:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict:
"""执行问题分类"""
self.typed_config = QuestionClassifierNodeConfig(**self.config)
self.category_to_case_map = self._build_category_case_map()
@@ -102,7 +110,7 @@ class QuestionClassifierNode(BaseNode):
categories=", ".join(category_names),
supplement_prompt=supplement_prompt
),
state
variable_pool
)
messages = [

View File

@@ -2,7 +2,8 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.variable.base_variable import VariableType
class StartNodeConfig(BaseNodeConfig):

View File

@@ -7,9 +7,10 @@ Start 节点实现
import logging
from typing import Any
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -36,14 +37,25 @@ class StartNode(BaseNode):
# 解析并验证配置
self.typed_config: StartNodeConfig | None = None
self.output_var_types = {}
async def execute(self, state: WorkflowState) -> dict[str, Any]:
def _output_types(self) -> dict[str, VariableType]:
return self.output_var_types | {
"message": VariableType.STRING,
"execution_id": VariableType.STRING,
"conversation_id": VariableType.STRING,
"workspace_id": VariableType.STRING,
"user_id": VariableType.STRING,
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""执行 start 节点业务逻辑
Start 节点输出系统变量、会话变量和自定义变量。
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
包含系统参数、会话变量和自定义变量的字典
@@ -51,19 +63,16 @@ class StartNode(BaseNode):
self.typed_config = StartNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 创建变量池实例(在方法内复用)
pool = self.get_variable_pool(state)
# 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(pool)
custom_vars = self._process_custom_variables(variable_pool)
# 返回业务数据(包含自定义变量)
result = {
"message": pool.get("sys.message"),
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"workspace_id": pool.get("sys.workspace_id"),
"user_id": pool.get("sys.user_id"),
"message": variable_pool.get_value("sys.message"),
"execution_id": variable_pool.get_value("sys.execution_id"),
"conversation_id": variable_pool.get_value("sys.conversation_id"),
"workspace_id": variable_pool.get_value("sys.workspace_id"),
"user_id": variable_pool.get_value("sys.user_id"),
**custom_vars # 自定义变量作为节点输出的一部分
}
@@ -74,7 +83,7 @@ class StartNode(BaseNode):
return result
def _process_custom_variables(self, pool) -> dict[str, Any]:
def _process_custom_variables(self, pool: VariablePool) -> dict[str, Any]:
"""处理自定义变量
从输入数据中提取自定义变量,应用默认值和验证。
@@ -89,13 +98,14 @@ class StartNode(BaseNode):
ValueError: 缺少必需变量
"""
# 获取输入数据中的自定义变量
input_variables = pool.get("sys.input_variables", default={})
input_variables = pool.get_value("sys.input_variables", default={}, strict=False)
processed = {}
# 遍历配置的变量定义
for var_def in self.typed_config.variables:
var_name = var_def.name
var_type = var_def.type
# 检查变量是否存在
if var_name in input_variables:
@@ -116,21 +126,12 @@ class StartNode(BaseNode):
f"变量 '{var_name}' 使用默认值: {var_def.default}"
)
else:
match var_def.type:
case VariableType.STRING:
processed[var_name] = ""
case VariableType.NUMBER:
processed[var_name] = 0
case VariableType.OBJECT:
processed[var_name] = {}
case VariableType.BOOLEAN:
processed[var_name] = False
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
processed[var_name] = []
processed[var_name] = DEFAULT_VALUE(var_type)
self.output_var_types[var_name] = var_type
return processed
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)
Args:
@@ -139,11 +140,9 @@ class StartNode(BaseNode):
Returns:
输入数据字典
"""
pool = self.get_variable_pool(state)
return {
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"message": pool.get("sys.message"),
"conversation_vars": pool.get_all_conversation_vars()
"execution_id": variable_pool.get_value("sys.execution_id"),
"conversation_id": variable_pool.get_value("sys.conversation_id"),
"message": variable_pool.get_value("sys.message"),
"conversation_vars": variable_pool.get_all_conversation_vars()
}

View File

@@ -6,6 +6,8 @@ from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.services.tool_service import ToolService
from app.db import get_db_read
@@ -21,13 +23,20 @@ class ToolNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: ToolNodeConfig | None = None
async def execute(self, state: WorkflowState) -> dict[str, Any]:
def _output_types(self) -> dict[str, VariableType]:
return {
"data": VariableType.STRING,
"error_code": VariableType.STRING,
"execution_time": VariableType.NUMBER
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""执行工具"""
self.typed_config = ToolNodeConfig(**self.config)
# 获取租户ID和用户ID
tenant_id = self.get_variable("sys.tenant_id", state)
user_id = self.get_variable("sys.user_id", state)
workspace_id = self.get_variable("sys.workspace_id", state)
tenant_id = self.get_variable("sys.tenant_id", variable_pool, strict=False)
user_id = self.get_variable("sys.user_id", variable_pool)
workspace_id = self.get_variable("sys.workspace_id", variable_pool)
# 如果没有租户ID尝试从工作流ID获取
if not tenant_id:
@@ -48,7 +57,7 @@ class ToolNode(BaseNode):
for param_name, param_template in self.typed_config.tool_parameters.items():
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
try:
rendered_value = self._render_template(param_template, state)
rendered_value = self._render_template(param_template, variable_pool)
except Exception as e:
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
else:

View File

@@ -1,6 +0,0 @@
"""Transform 节点"""
from app.core.workflow.nodes.transform.node import TransformNode
from app.core.workflow.nodes.transform.config import TransformNodeConfig
__all__ = ["TransformNode", "TransformNodeConfig"]

View File

@@ -1,80 +0,0 @@
"""Transform 节点配置"""
from typing import Literal
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class TransformNodeConfig(BaseNodeConfig):
"""Transform 节点配置
用于数据转换和处理。
"""
transform_type: Literal["template", "code", "json"] = Field(
default="template",
description="转换类型template(模板), code(代码), json(JSON处理)"
)
# 模板模式
template: str | None = Field(
default=None,
description="转换模板,支持变量引用"
)
# 代码模式
code: str | None = Field(
default=None,
description="Python 代码,用于数据转换"
)
# JSON 模式
json_path: str | None = Field(
default=None,
description="JSON 路径表达式"
)
# 输入变量
inputs: dict[str, str] | None = Field(
default=None,
description="输入变量映射key 为变量名value 为变量选择器"
)
# 输出变量
output_key: str = Field(
default="result",
description="输出变量的键名"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="result",
type=VariableType.STRING,
description="转换后的结果"
)
],
description="输出变量定义(根据 output_key 动态生成)"
)
class Config:
json_schema_extra = {
"examples": [
{
"transform_type": "template",
"template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}",
"output_key": "formatted_result"
},
{
"transform_type": "code",
"code": "result = input_text.upper()",
"inputs": {
"input_text": "{{ sys.message }}"
},
"output_key": "uppercase_text"
}
]
}

View File

@@ -1,60 +0,0 @@
"""
Transform 节点实现
数据转换节点,用于处理和转换数据。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class TransformNode(BaseNode):
"""数据转换节点
配置示例:
{
"type": "transform",
"config": {
"mapping": {
"output_field": "{{node.previous.output}}",
"processed": "{{var.input | upper}}"
}
}
}
"""
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行数据转换
Args:
state: 工作流状态
Returns:
状态更新字典
"""
logger.info(f"节点 {self.node_id} 开始执行数据转换")
# 获取映射配置
mapping = self.config.get("mapping", {})
# 执行数据转换
transformed_data = {}
for target_key, source_template in mapping.items():
# 渲染模板获取值
value = self._render_template(str(source_template), state)
transformed_data[target_key] = value
logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}")
return {
"node_outputs": {
self.node_id: {
"output": transformed_data,
"status": "completed"
}
}
}

View File

@@ -1,6 +1,7 @@
from pydantic import Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.variable.base_variable import VariableType
class VariableAggregatorNodeConfig(BaseNodeConfig):
@@ -14,6 +15,11 @@ class VariableAggregatorNodeConfig(BaseNodeConfig):
description="需要被聚合的变量"
)
group_type: dict[str, VariableType] = Field(
default=None,
description="每个分组的变量类型"
)
@field_validator("group_variables")
@classmethod
def group_variables_validator(cls, v, info):

View File

@@ -5,6 +5,8 @@ from typing import Any
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -14,6 +16,17 @@ class VariableAggregatorNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: VariableAggregatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
config = VariableAggregatorNodeConfig(**self.config)
output = {}
if not config.group_type:
for group_name in config.group_variables.keys():
output[group_name] = VariableType.ANY
return output
for var_type in config.group_type:
output[var_type] = config.group_type[var_type]
return output
@staticmethod
def _get_express(variable_string: str) -> Any:
"""
@@ -29,7 +42,7 @@ class VariableAggregatorNode(BaseNode):
expression = re.sub(pattern, r"\1", variable_string).strip()
return expression
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the variable aggregation logic.
@@ -45,7 +58,7 @@ class VariableAggregatorNode(BaseNode):
for variable in self.typed_config.group_variables:
var_express = self._get_express(variable)
try:
value = self.get_variable(var_express, state)
value = self.get_variable(var_express, variable_pool)
except Exception as e:
logger.warning(f"Failed to get variable '{var_express}': {e}")
continue
@@ -55,7 +68,9 @@ class VariableAggregatorNode(BaseNode):
return value
logger.info("No variable found in non-group mode; returning empty string.")
return ""
if not self.typed_config.group_type:
return ""
return DEFAULT_VALUE(self.typed_config.group_type["output"])
# --------------------------
# Group mode
@@ -65,7 +80,7 @@ class VariableAggregatorNode(BaseNode):
for variable in variables:
var_express = self._get_express(variable)
try:
value = self.get_variable(var_express, state)
value = self.get_variable(var_express, variable_pool)
except Exception as e:
logger.warning(f"Failed to get variable '{var_express}' in group '{group_name}': {e}")
continue
@@ -74,7 +89,10 @@ class VariableAggregatorNode(BaseNode):
result[group_name] = value
break
else:
result[group_name] = ""
if not self.typed_config.group_type:
result[group_name] = ""
else:
result[group_name] = DEFAULT_VALUE(self.typed_config.group_type[group_name])
logger.info(f"No variable found for group '{group_name}'; set empty string.")
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
return result

View File

@@ -43,7 +43,7 @@ class TemplateRenderer:
def render(
self,
template: str,
variables: dict[str, Any],
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> str:
@@ -51,7 +51,7 @@ class TemplateRenderer:
Args:
template: 模板字符串
variables: 用户定义的变量
conv_vars: 会话变量
node_outputs: 节点输出结果
system_vars: 系统变量
@@ -80,20 +80,11 @@ class TemplateRenderer:
'分析结果: 正面情绪'
"""
# 构建命名空间上下文
# variables 的结构:{"sys": {...}, "conv": {...}}
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
if self.strict:
context = defaultdict(dict)
context["conv"] = conv_vars
context["node"] = node_outputs
context["sys"] = {**(system_vars or {}), **sys_vars}
else:
context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
}
context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": system_vars, # 系统变量:{{sys.execution_id}}
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
@@ -157,9 +148,9 @@ _default_renderer = TemplateRenderer(strict=True)
def render_template(
template: str,
variables: dict[str, Any],
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None,
system_vars: dict[str, Any],
strict: bool = True
) -> str:
"""渲染模板(便捷函数)
@@ -167,7 +158,7 @@ def render_template(
Args:
strict: 严格模式
template: 模板字符串
variables: 用户变量
conv_vars: 会话变量
node_outputs: 节点输出
system_vars: 系统变量
@@ -184,7 +175,7 @@ def render_template(
'请分析: 这是一段文本'
"""
renderer = TemplateRenderer(strict=strict)
return renderer.render(template, variables, node_outputs, system_vars)
return renderer.render(template, conv_vars, node_outputs, system_vars)
def validate_template(template: str) -> list[str]:

View File

@@ -5,10 +5,13 @@
"""
import logging
from typing import Any, Union
from typing import Any, Union, TYPE_CHECKING
from app.core.workflow.nodes.enums import NodeType
if TYPE_CHECKING:
from app.schemas.workflow_schema import WorkflowConfig
logger = logging.getLogger(__name__)
@@ -64,7 +67,7 @@ class WorkflowValidator:
return cycle_nodes, cycle_edges
@classmethod
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
def get_subgraph(cls, workflow_config: Union[dict[str, Any], "WorkflowConfig"]) -> list:
if not isinstance(workflow_config, dict):
workflow_config = {
"nodes": workflow_config.nodes,
@@ -331,7 +334,7 @@ class WorkflowValidator:
def validate_workflow_config(
workflow_config: dict[str, Any],
workflow_config: Union[dict[str, Any], 'WorkflowConfig'],
for_publish: bool = False
) -> tuple[bool, list[str]]:
"""验证工作流配置(便捷函数)

View File

@@ -0,0 +1,170 @@
from enum import StrEnum
from abc import abstractmethod, ABC
from typing import Any
from pydantic import BaseModel
from app.schemas import FileType
class VariableType(StrEnum):
"""Enumeration of supported variable types in the workflow."""
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
OBJECT = "object"
FILE = "file"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILE = "array[file]"
NESTED_ARRAY = "array_nest"
ANY = 'any'
@classmethod
def type_map(cls, var: Any) -> "VariableType":
"""Maps a Python value to a corresponding VariableType.
Args:
var: The Python value to map.
Returns:
The VariableType corresponding to the input value.
Raises:
TypeError: If the type of the input value is not supported.
"""
var_type = type(var)
if isinstance(var_type, str):
return cls.STRING
elif isinstance(var_type, (int, float)):
return cls.NUMBER
elif isinstance(var_type, bool):
return cls.BOOLEAN
elif isinstance(var_type, FileObject) or (isinstance(var, dict) and var.get('__file')):
return cls.FILE
elif isinstance(var_type, dict):
return cls.OBJECT
elif isinstance(var_type, list):
if len(var) == 0:
return cls.ARRAY_STRING
else:
child_type = type(var[0])
if child_type == str:
return cls.ARRAY_STRING
elif child_type == int or child_type == float:
return cls.ARRAY_NUMBER
elif child_type == bool:
return cls.ARRAY_BOOLEAN
elif child_type == dict:
return cls.ARRAY_OBJECT
elif child_type == list:
return cls.NESTED_ARRAY
else:
raise TypeError(f"Unsupported array child type - {child_type}")
raise TypeError(f"Unsupported type - {var_type}")
def DEFAULT_VALUE(var_type: VariableType) -> Any:
"""Returns the default value for a given VariableType.
Args:
var_type: The variable type for which to get the default value.
Returns:
The default Python value corresponding to the VariableType.
Raises:
TypeError: If the VariableType is invalid.
"""
match var_type:
case VariableType.STRING:
return ""
case VariableType.NUMBER:
return 0
case VariableType.BOOLEAN:
return False
case VariableType.OBJECT:
return {}
case VariableType.FILE:
return None
case VariableType.ARRAY_STRING:
return []
case VariableType.ARRAY_NUMBER:
return []
case VariableType.ARRAY_BOOLEAN:
return []
case VariableType.ARRAY_OBJECT:
return []
case VariableType.ARRAY_FILE:
return []
case _:
raise TypeError(f"Invalid type - {type}")
class FileObject(BaseModel):
type: FileType
url: str
__file: bool
class BaseVariable(ABC):
"""Abstract base class for all workflow variables.
Subclasses must implement validation and serialization methods.
"""
type = None
def __init__(self, value: Any):
"""Initializes a variable instance.
Args:
value: The initial value for the variable.
Attributes:
self.value: The validated value stored in the variable.
self.literal: A string representation of the variable.
"""
self.value = self.valid_value(value)
self.literal = self.to_literal()
@abstractmethod
def valid_value(self, value) -> Any:
"""Validates or converts a value to the correct type for the variable.
Args:
value: The value to validate.
Returns:
The validated or converted value.
Raises:
TypeError: If the value is invalid.
"""
pass
@abstractmethod
def to_literal(self) -> str:
"""Converts the variable value to a string literal representation.
Returns:
A string representing the variable's value.
"""
pass
def get_value(self) -> Any:
"""Returns the current value of the variable."""
return self.value
def set(self, value):
"""Sets the variable to a new value after validation.
Args:
value: The new value to assign to the variable.
"""
self.value = self.valid_value(value)

View File

@@ -0,0 +1,174 @@
from typing import Any, TypeVar, Type, Generic
from deprecated import deprecated
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType
T = TypeVar("T", bound=BaseVariable)
class StringVariable(BaseVariable):
type = 'str'
def valid_value(self, value) -> str:
if not isinstance(value, str):
raise TypeError(f"Value must be a string - {type(value)}:{value}")
return value
def to_literal(self) -> str:
return self.value
class NumberVariable(BaseVariable):
type = 'number'
def valid_value(self, value) -> int | float:
if not isinstance(value, (int, float)):
raise TypeError(f"Value must be a number - {type(value)}:{value}")
return value
def to_literal(self) -> str:
return str(self.value)
class BooleanVariable(BaseVariable):
type = 'boolean'
def valid_value(self, value) -> bool:
if not isinstance(value, bool):
raise TypeError(f"Value must be a boolean - {type(value)}:{value}")
return value
def to_literal(self) -> str:
return str(self.value).lower()
class DictVariable(BaseVariable):
type = 'object'
def valid_value(self, value) -> dict:
if not isinstance(value, dict):
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
return value
def to_literal(self) -> str:
return str(self.value)
class FileVariable(BaseVariable):
type = 'file'
def valid_value(self, value) -> FileObject:
if isinstance(value, dict):
if not value.get("__file"):
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
return FileObject(
**{
"type": str(value.get('type')),
"url": value.get('url'),
"__file": True
}
)
if isinstance(value, FileObject):
return value
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
def to_literal(self) -> str:
return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})'
def get_value(self) -> Any:
return self.value.model_dump()
class ArrayObject(BaseVariable, Generic[T]):
type = 'array'
def __init__(self, child_type: Type[T], value: list[Any]):
if not issubclass(child_type, BaseVariable):
raise TypeError("child_type must be a subclass of BaseVariable")
self.child_type = child_type
super().__init__(value)
def valid_value(self, value: list[Any]) -> list[T]:
if not isinstance(value, list):
raise TypeError(f"Value must be a list - {type(value)}:{value}")
final_value = []
for v in value:
try:
final_value.append(self.child_type(v))
except:
raise TypeError(f"All elements must be of type {self.child_type.type}")
return final_value
def to_literal(self) -> str:
return "\n".join([v.to_literal() for v in self.value])
def get_value(self) -> Any:
return [v.get_value() for v in self.value]
class NestedArrayObject(BaseVariable):
type = 'array_nest'
def valid_value(self, value: list[T]) -> list[T]:
if not isinstance(value, list):
raise TypeError(f"Value must be a list - {type(value)}:{value}")
final_value = []
for v in value:
if not isinstance(v, ArrayObject):
raise TypeError("All elements must be of type list")
final_value.append(v)
return final_value
def to_literal(self) -> str:
return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value])
def get_value(self) -> Any:
return [[item.get_value() for item in row] for row in self.value]
@deprecated(
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
category=RuntimeWarning
)
class AnyObject(BaseVariable):
type = 'any'
def valid_value(self, value: Any) -> Any:
return value
def to_literal(self) -> str:
return str(self.value)
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
"""简化 ArrayObject 创建,不需要重复写类型"""
return ArrayObject(child_type, value)
def create_variable_instance(var_type: VariableType, value: Any) -> T:
match var_type:
case VariableType.STRING:
return StringVariable(value)
case VariableType.NUMBER:
return NumberVariable(value)
case VariableType.BOOLEAN:
return BooleanVariable(value)
case VariableType.OBJECT:
return DictVariable(value)
case VariableType.ARRAY_STRING:
return make_array(StringVariable, value)
case VariableType.ARRAY_NUMBER:
return make_array(NumberVariable, value)
case VariableType.ARRAY_BOOLEAN:
return make_array(BooleanVariable, value)
case VariableType.ARRAY_OBJECT:
return make_array(DictVariable, value)
case VariableType.ARRAY_FILE:
return make_array(FileVariable, value)
case VariableType.ANY:
return AnyObject(value)
case _:
raise TypeError(f"Invalid type - {var_type}")

View File

@@ -11,10 +11,15 @@
import logging
import re
from typing import Any, TYPE_CHECKING
from asyncio import Lock
from collections import defaultdict
from copy import deepcopy
from typing import Any, Generic
if TYPE_CHECKING:
from app.core.workflow.nodes import WorkflowState
from pydantic import BaseModel
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import T, create_variable_instance
logger = logging.getLogger(__name__)
@@ -23,11 +28,6 @@ class VariableSelector:
"""变量选择器
用于引用变量的路径表示。
Examples:
>>> selector = VariableSelector(["sys", "message"])
>>> selector = VariableSelector(["node_A", "output"])
>>> selector = VariableSelector.from_string("sys.message")
"""
def __init__(self, path: list[str]):
@@ -52,10 +52,6 @@ class VariableSelector:
Returns:
VariableSelector 实例
Examples:
>>> selector = VariableSelector.from_string("sys.message")
>>> selector = VariableSelector.from_string("llm_qa.output")
"""
path = selector_str.split(".")
return cls(path)
@@ -67,160 +63,212 @@ class VariableSelector:
return f"VariableSelector({self.path})"
class VariableStruct(BaseModel, Generic[T]):
"""A typed variable struct.
Represents a runtime variable with an associated logical type and
a concrete value object.
This class bridges the static type system (via generics) and the
runtime type system (via ``VariableType``).
Attributes:
type:
Logical variable type descriptor used for runtime validation,
serialization, and workflow type checking.
instance:
The concrete variable object. The actual Python type is
represented by the generic parameter ``T`` (e.g. StringVariable,
NumberVariable, ArrayObject[StringVariable]).
mut:
Whether the variable is mutable.
"""
type: VariableType
instance: T
mut: bool
model_config = {
"arbitrary_types_allowed": True
}
class VariablePool:
"""变量池
管理工作流执行过程中的所有变量。
变量命名空间:
- sys.*: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.*: 会话变量(跨多轮对话保持的变量)
- <node_id>.*: 节点输出
Examples:
>>> pool = VariablePool(state)
>>> pool.get(["sys", "message"])
"用户的问题"
>>> pool.get(["llm_qa", "output"])
"AI 的回答"
>>> pool.set(["conv", "user_name"], "张三")
"""Variable pool.
Manages all variables during workflow execution, including storage,
namespacing, and concurrency control.
Variable namespace conventions:
- ``sys.*``:
System variables (e.g. message, execution_id, workspace_id,
user_id, conversation_id).
- ``conv.*``:
Conversation-level variables that persist across multiple turns.
- ``<node_id>.*``:
Variables produced by workflow nodes.
"""
def __init__(self, state: "WorkflowState"):
"""初始化变量池
Args:
state: 工作流状态LangGraph State
"""
self.state = state
def __init__(self):
"""Initialize the variable pool.
Attributes:
self.locks:
A per-key lock table used for fine-grained concurrency control.
self.variables:
Storage for all variables managed by the pool.
"""
self.locks = defaultdict(Lock)
self.variables: dict[str, dict[str, VariableStruct[Any]]] = {}
@staticmethod
def transform_selector(selector):
pattern = r"\{\{\s*(.*?)\s*\}\}"
variable_literal = re.sub(pattern, r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if len(selector) != 2:
raise ValueError(f"Selector not valid - {selector}")
return selector
def _get_variable_struct(
self,
selector: str
) -> VariableStruct[T] | None:
"""Retrieve a variable struct from the variable pool.
def get(self, selector: list[str] | str, default: Any = None) -> Any:
"""获取变量值
Args:
selector: 变量选择器,可以是列表或字符串
default: 默认值(变量不存在时返回)
selector:
Variable selector, either:
- A string variable literal (e.g. "{{ sys.message }}")
Returns:
变量值
Examples:
>>> pool.get(["sys", "message"])
>>> pool.get("sys.message")
>>> pool.get(["llm_qa", "output"])
>>> pool.get("llm_qa.output")
Raises:
KeyError: 变量不存在且未提供默认值
The variable's struct if it exists; otherwise returns None.
"""
# 转换为 VariableSelector
if isinstance(selector, str):
pattern = r"\{\{\s*(.*?)\s*\}\}"
variable_literal = re.sub(pattern, r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if not selector or len(selector) < 1:
raise ValueError("变量选择器不能为空")
selector = self.transform_selector(selector)
namespace = selector[0]
variable_name = selector[1]
try:
# 系统变量
if namespace == "sys":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("sys", {})
return self.state.get("variables", {}).get("sys", {}).get(key, default)
namespace_variables = self.variables.get(namespace)
if namespace_variables is None:
return None
# 会话变量
elif namespace == "conv":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("conv", {})
return self.state.get("variables", {}).get("conv", {}).get(key, default)
var_instance = namespace_variables.get(variable_name)
if var_instance is None:
return None
return var_instance
# 节点输出(从 runtime_vars 读取)
else:
node_id = namespace
runtime_vars = self.state.get("runtime_vars", {})
def get_value(
self,
selector: str,
default: Any = None,
strict: bool = True,
) -> Any:
"""Retrieve a variable value from the variable pool.
if node_id not in runtime_vars:
if default is not None:
return default
raise KeyError(f"节点 '{node_id}' 的输出不存在")
Args:
selector:
Variable selector, either:
- A list of path components (e.g. ["sys", "message"])
- A string variable literal (e.g. "{{ sys.message }}")
default:
The value to return if the variable does not exist.
strict:
If True, raises KeyError when the variable does not exist.
node_var = runtime_vars[node_id]
Returns:
The variable's value if it exists; otherwise returns `default`.
# 如果只有节点 ID返回整个变量
if len(selector) == 1:
return node_var
Raises:
KeyError: If strict is True and the variable does not exist.
"""
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
# 获取特定字段
# 支持嵌套访问,如 node_id.field.subfield
result = node_var
for k in selector[1:]:
if isinstance(result, dict):
result = result.get(k)
if result is None:
if default is not None:
return default
raise KeyError(f"字段 '{'.'.join(selector)}' 不存在")
else:
if default is not None:
return default
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
return variable_struct.instance.get_value()
return result
def get_literal(
self,
selector: str,
default: Any = None,
strict: bool = True,
) -> Any:
"""Retrieve a variable value from the variable pool.
except KeyError:
if default is not None:
return default
raise
Args:
selector:
Variable selector, either:
- A list of path components (e.g. ["sys", "message"])
- A string variable literal (e.g. "{{ sys.message }}")
default:
The value to return if the variable does not exist.
strict:
If True, raises KeyError when the variable does not exist.
def set(self, selector: list[str] | str, value: Any):
Returns:
The variable's value if it exists; otherwise returns `default`.
Raises:
KeyError: If strict is True and the variable does not exist.
"""
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
return variable_struct.instance.to_literal()
async def set(
self,
selector: str,
value: Any
):
"""设置变量值
Args:
selector: 变量选择器
value: 变量值
Examples:
>>> pool.set(["conv", "user_name"], "张三")
>>> pool.set("conv.user_name", "张三")
Note:
- 只能设置会话变量 (conv.*)
- 系统变量和节点输出是只读的
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
raise KeyError(f"Variable {selector} is not defined")
if not variable_struct.mut:
raise KeyError(f"{selector} cannot be modified")
async with self.locks[selector]:
variable_struct.instance.set(value)
if not selector or len(selector) < 2:
raise ValueError("变量选择器必须包含命名空间和键名")
async def new(
self,
namespace: str,
key: str,
value: Any,
var_type: VariableType,
mut: bool
):
if self.has(f"{namespace}.{key}"):
try:
await self.set(f"{namespace}.{key}", value)
except KeyError:
pass
instance = create_variable_instance(var_type, value)
variable_struct = VariableStruct(type=var_type, instance=instance, mut=mut)
namespace_variable = self.variables.get(namespace)
if namespace_variable is None:
self.variables[namespace] = {
key: variable_struct
}
else:
self.variables[namespace][key] = variable_struct
namespace = selector[0]
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
raise ValueError("Only conversation or cycle variables can be assigned.")
key = selector[1]
# 确保 variables 结构存在
if "variables" not in self.state:
self.state["variables"] = {"sys": {}, "conv": {}}
if namespace == "conv":
if "conv" not in self.state["variables"]:
self.state["variables"]["conv"] = {}
# 设置值
self.state["variables"]["conv"][key] = value
elif namespace in self.state["cycle_nodes"]:
self.state["runtime_vars"][namespace][key] = value
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
def has(self, selector: list[str] | str) -> bool:
def has(self, selector: str) -> bool:
"""检查变量是否存在
Args:
@@ -228,18 +276,8 @@ class VariablePool:
Returns:
变量是否存在
Examples:
>>> pool.has(["sys", "message"])
True
>>> pool.has("llm_qa.output")
False
"""
try:
self.get(selector)
return True
except KeyError:
return False
return self._get_variable_struct(selector) is not None
def get_all_system_vars(self) -> dict[str, Any]:
"""获取所有系统变量
@@ -247,7 +285,8 @@ class VariablePool:
Returns:
系统变量字典
"""
return self.state.get("variables", {}).get("sys", {})
sys_namespace = self.variables.get("sys", {})
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
def get_all_conversation_vars(self) -> dict[str, Any]:
"""获取所有会话变量
@@ -255,7 +294,8 @@ class VariablePool:
Returns:
会话变量字典
"""
return self.state.get("variables", {}).get("conv", {})
conv_namespace = self.variables.get("conv", {})
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
def get_all_node_outputs(self) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
@@ -263,18 +303,37 @@ class VariablePool:
Returns:
节点输出字典,键为节点 ID
"""
return self.state.get("runtime_vars", {})
runtime_vars = {
namespace: {
k: v.instance.get_value()
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
return runtime_vars
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
"""获取指定节点的输出(运行时变量)
Args:
node_id: 节点 ID
defalut: 默认值
strict: 是否严格模式
Returns:
节点输出或 None
"""
return self.state.get("runtime_vars", {}).get(node_id)
node_namespace = self.variables.get(node_id)
if node_namespace:
return {k: v.instance.get_value() for k, v in node_namespace.items()}
if strict:
raise KeyError(f"node {node_id} output not exist")
else:
return defalut
def copy(self, pool: 'VariablePool'):
self.variables = deepcopy(pool.variables)
def to_dict(self) -> dict[str, Any]:
"""导出为字典

View File

@@ -50,13 +50,16 @@ async def lifespan(app: FastAPI):
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
# 加载预定义模型
logger.info("开始加载预定义模型...")
try:
with get_db_context() as db:
result = load_models(db, silent=True)
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}")
except Exception as e:
logger.warning(f"加载预定义模型时出错: {str(e)}")
if settings.LOAD_MODEL:
logger.info("开始加载预定义模型...")
try:
with get_db_context() as db:
result = load_models(db, silent=True)
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}")
except Exception as e:
logger.warning(f"加载预定义模型时出错: {str(e)}")
else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
logger.info("应用程序启动完成")
yield
@@ -77,10 +80,14 @@ default_origins = [
]
allowed_origins = list({o for o in (default_origins + settings.CORS_ORIGINS) if o})
# 如果 CORS_ORIGINS 包含 "*",则允许所有来源
if "*" in settings.CORS_ORIGINS:
allowed_origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_credentials=True if "*" not in allowed_origins else False, # 允许所有来源时不能使用 credentials
allow_methods=["*"],
allow_headers=["*"],
)

View File

@@ -28,6 +28,7 @@ from .tool_model import (
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
)
from .memory_perceptual_model import MemoryPerceptualModel
from .skill_model import Skill
from .ontology_scene import OntologyScene
from .ontology_class import OntologyClass
from .ontology_scene import OntologyScene
@@ -84,5 +85,6 @@ __all__ = [
"ExecutionStatus",
"MemoryPerceptualModel",
"ModelBase",
"LoadBalanceStrategy"
"LoadBalanceStrategy",
"Skill"
]

View File

@@ -29,7 +29,8 @@ class AgentConfig(Base):
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
memory = Column(JSON, nullable=True, comment="记忆配置")
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
tools = Column(JSON, default=list, nullable=True, comment="工具配置")
skills = Column(JSON, default=dict, nullable=True, comment="技能配置")
# 多 Agent 相关字段
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")

View File

@@ -0,0 +1,37 @@
"""Skill 模型定义"""
import datetime
import uuid
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
from sqlalchemy.dialects.postgresql import UUID, JSON
from app.db import Base
class Skill(Base):
"""技能模型 - 可以关联工具内置、MCP、自定义"""
__tablename__ = "skills"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
name = Column(String, nullable=False, comment="技能名称")
description = Column(Text, comment="技能描述")
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID")
# 关联的工具
tools = Column(JSON, default=list, comment="关联的工具列表")
# 技能配置
config = Column(JSON, default=dict, comment="技能配置")
# 专属提示词
prompt = Column(Text, comment="技能专属提示词")
# 状态
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
is_public = Column(Boolean, default=False, nullable=False, comment="是否公开到市场")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
def __repr__(self):
return f"<Skill(id={self.id}, name={self.name})>"

View File

@@ -235,6 +235,8 @@ class MemoryConfigRepository:
llm_id=params.llm_id,
embedding_id=params.embedding_id,
rerank_id=params.rerank_id,
reflection_model_id=params.reflection_model_id,
emotion_model_id=params.emotion_model_id,
)
db.add(db_config)
db.flush() # 获取自增ID但不提交事务

View File

@@ -583,7 +583,7 @@ class ModelApiKeyRepository:
db_api_key.usage_count = str(current_count + 1)
db_api_key.last_used_at = func.now()
db.commit()
db.flush()
db_logger.debug(f"API Key使用统计更新成功: api_key_id={api_key_id}")
return True

View File

@@ -207,4 +207,4 @@ async def save_dialog_and_statements_to_neo4j(
except Exception as e:
print(f"Neo4j integration error: {e}")
print("Continuing without database storage...")
return False
return False

View File

@@ -0,0 +1,111 @@
"""Skill Repository"""
from typing import List, Optional, Tuple, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
import uuid
from app.models.skill_model import Skill
from app.schemas.skill_schema import SkillCreate, SkillUpdate
class SkillRepository:
"""Skill 数据访问层"""
@staticmethod
def create(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
"""创建技能"""
skill = Skill(
**data.model_dump(),
tenant_id=tenant_id
)
db.add(skill)
db.flush()
return skill
@staticmethod
def get_by_id(db: Session, skill_id: uuid.UUID, tenant_id: Optional[uuid.UUID] = None) -> Optional[Skill]:
"""根据ID获取技能"""
query = db.query(Skill).filter(Skill.id == skill_id)
if tenant_id:
query = query.filter(
or_(
Skill.tenant_id == tenant_id,
Skill.is_public == True
)
)
return query.first()
@staticmethod
def list_skills(
db: Session,
tenant_id: uuid.UUID,
search: Optional[str] = None,
is_active: Optional[bool] = None,
is_public: Optional[bool] = None,
page: int = 1,
pagesize: int = 10
) -> tuple[list[type[Skill]], int]:
"""列出技能"""
filters = [
or_(
Skill.tenant_id == tenant_id,
Skill.is_public == True
)
]
if search:
filters.append(
or_(
Skill.name.ilike(f"%{search}%"),
# Skill.description.ilike(f"%{search}%")
)
)
if is_active is not None:
filters.append(Skill.is_active == is_active)
if is_public is not None:
filters.append(Skill.is_public == is_public)
query = db.query(Skill).filter(and_(*filters))
total = query.count()
skills = query.order_by(Skill.created_at.desc()).offset(
(page - 1) * pagesize
).limit(pagesize).all()
return skills, total
@staticmethod
def update(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Optional[Skill]:
"""更新技能"""
skill = db.query(Skill).filter(
Skill.id == skill_id,
Skill.tenant_id == tenant_id
).first()
if not skill:
return None
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(skill, key, value)
db.flush()
return skill
@staticmethod
def delete(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
"""删除技能"""
skill = db.query(Skill).filter(
Skill.id == skill_id,
Skill.tenant_id == tenant_id
).first()
if not skill:
return False
# db.delete(skill)
skill.is_active = False
db.flush()
return True

View File

@@ -1,14 +1,14 @@
import datetime
import uuid
from typing import Optional, Any, List, Dict, Union
from enum import Enum
from enum import Enum, StrEnum
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
# ---------- Multimodal File Support ----------
class FileType(str, Enum):
class FileType(StrEnum):
"""文件类型枚举"""
IMAGE = "image"
DOCUMENT = "document"
@@ -82,6 +82,12 @@ class ToolConfig(BaseModel):
tool_id: Optional[str] = Field(default=None, description="工具ID")
operation: Optional[str] = Field(default=None, description="工具特定配置")
class SkillConfig(BaseModel):
"""技能配置"""
enabled: bool = Field(default=True, description="是否启用该技能")
skill_ids: Optional[list[str]] = Field(default=list, description="技能ID列表")
all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能")
class ToolOldConfig(BaseModel):
"""工具配置"""
@@ -92,7 +98,7 @@ class ToolOldConfig(BaseModel):
class MemoryConfig(BaseModel):
"""记忆配置"""
enabled: bool = Field(default=True, description="是否启用对话历史记忆")
memory_content: Optional[str] = Field(default=None, description="选择记忆的内容类型")
memory_config_id: Optional[str] = Field(default=None, description="选择记忆的内容类型")
max_history: int = Field(default=10, ge=0, le=100, description="最大保留的历史对话轮数")
@@ -156,6 +162,9 @@ class AgentConfigCreate(BaseModel):
description="Agent 可用的工具列表"
)
# 技能配置
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
class AppCreate(BaseModel):
name: str
@@ -207,6 +216,9 @@ class AgentConfigUpdate(BaseModel):
# 工具配置
tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表")
# 技能配置
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
# ---------- Output Schemas ----------
@@ -266,6 +278,8 @@ class AgentConfig(BaseModel):
# 工具配置
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
skills: Optional[SkillConfig] = {}
is_active: bool
created_at: datetime.datetime
updated_at: datetime.datetime

View File

@@ -236,6 +236,8 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致")
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致")
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)

View File

@@ -21,6 +21,11 @@ class PromptOptMessage(BaseModel):
description="currently optimized prompt"
)
skill: bool = Field(
default=False,
description="Enable variable output"
)
class PromptSaveRequest(BaseModel):
session_id: UUID = Field(

View File

@@ -0,0 +1,64 @@
"""Skill Schema 定义"""
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, field_serializer
import uuid
from datetime import datetime
class SkillBase(BaseModel):
"""Skill 基础 Schema"""
name: str = Field(..., description="技能名称")
description: Optional[str] = Field(None, description="技能描述")
tools: List[Dict[str, str]] = Field(default_factory=list, description="工具对象列表: [{\"tool_id\": \"xxx\", \"operation\": \"yyy\"}]")
config: Dict[str, Any] = Field(default_factory=dict, description="技能配置")
prompt: Optional[str] = Field(None, description="技能专属提示词")
is_active: bool = Field(True, description="是否激活")
is_public: bool = Field(False, description="是否公开到市场")
class SkillCreate(SkillBase):
"""创建 Skill"""
pass
class SkillUpdate(BaseModel):
"""更新 Skill"""
name: Optional[str] = None
description: Optional[str] = None
tools: Optional[List[Dict[str, str]]] = None
config: Optional[Dict[str, Any]] = None
prompt: Optional[str] = None
is_active: Optional[bool] = None
is_public: Optional[bool] = None
class Skill(BaseModel):
"""Skill 响应 Schema"""
id: uuid.UUID
tenant_id: uuid.UUID
name: str
description: Optional[str] = None
tools: Union[List[Dict[str, Any]], List[Dict[str, str]]] = Field(default_factory=list, description="工具列表,可以是简单格式或包含工具详情")
config: Dict[str, Any] = Field(default_factory=dict)
prompt: Optional[str] = None
is_active: bool = True
is_public: bool = False
created_at: datetime
updated_at: datetime
@field_serializer('created_at', 'updated_at')
def serialize_datetime_to_timestamp(self, value: datetime) -> int:
"""(毫秒级)时间戳"""
return int(value.timestamp() * 1000)
class Config:
from_attributes = True
class SkillQuery(BaseModel):
"""Skill 查询参数"""
search: Optional[str] = None
is_active: Optional[bool] = None
is_public: Optional[bool] = None
page: int = Field(1, ge=1)
pagesize: int = Field(10, ge=1, le=100)

View File

@@ -9,7 +9,7 @@ from app.schemas.app_schema import (
VariableDefinition,
ToolConfig,
AgentConfigCreate,
AgentConfigUpdate, ToolOldConfig,
AgentConfigUpdate, ToolOldConfig, SkillConfig,
)
@@ -48,6 +48,9 @@ class AgentConfigConverter:
# 5. 工具配置
if hasattr(config, 'tools') and config.tools:
result["tools"] = [tool.model_dump() for tool in config.tools]
if hasattr(config, "skills") and config.skills:
result["skills"] = config.skills.model_dump()
return result
@@ -58,6 +61,7 @@ class AgentConfigConverter:
memory: Optional[Dict[str, Any]],
variables: Optional[list],
tools: Optional[Union[list, Dict[str, Any]]],
skills: Optional[dict]
) -> Dict[str, Any]:
"""
将数据库存储格式转换为 Pydantic 对象
@@ -68,6 +72,7 @@ class AgentConfigConverter:
memory: 记忆配置
variables: 变量配置
tools: 工具配置
skills: 技能列表
Returns:
包含 Pydantic 对象的字典
@@ -78,6 +83,7 @@ class AgentConfigConverter:
"memory": MemoryConfig(enabled=True),
"variables": [],
"tools": [],
"skills": SkillConfig(enabled=False, all_skills=False, skill_ids=[])
}
# 1. 解析模型参数配置
@@ -117,5 +123,10 @@ class AgentConfigConverter:
name: ToolOldConfig(**tool_data)
for name, tool_data in tools.items()
}
if skills:
result["skills"] = SkillConfig(**skills)
else:
result["skills"] = SkillConfig(enabled=False, all_skills=False, skill_ids=[])
return result

View File

@@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
memory=agent_cfg.memory,
variables=agent_cfg.variables,
tools=agent_cfg.tools,
skills=agent_cfg.skills
)
# 将解析后的字段添加到对象上(用于序列化)
@@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
agent_cfg.memory = parsed["memory"]
agent_cfg.variables = parsed["variables"]
agent_cfg.tools = parsed["tools"]
agent_cfg.skills = parsed["skills"]
return agent_cfg

View File

@@ -8,6 +8,7 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -63,7 +64,7 @@ class AppChatService:
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt
if variables:
@@ -79,21 +80,55 @@ class AppChatService:
# 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
@@ -113,22 +148,6 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.model_parameters
@@ -192,6 +211,8 @@ class AppChatService:
}
)
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
elapsed_time = time.time() - start_time
return {
@@ -230,7 +251,7 @@ class AppChatService:
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt
if variables:
@@ -246,20 +267,54 @@ class AppChatService:
# 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
@@ -279,22 +334,6 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.model_parameters
@@ -374,6 +413,8 @@ class AppChatService:
}
)
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
@@ -618,6 +659,7 @@ class AppChatService:
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
public=False
) -> AsyncGenerator[dict, None]:
"""聊天(流式)"""
@@ -634,7 +676,8 @@ class AppChatService:
payload=payload,
config=config,
workspace_id=workspace_id,
release_id=release_id
release_id=release_id,
public=public
):
yield event

View File

@@ -313,6 +313,7 @@ class AppService:
memory=storage_data.get("memory"),
variables=storage_data.get("variables", []),
tools=storage_data.get("tools", []),
skills=storage_data.get("skills", {}),
is_active=True,
created_at=now,
updated_at=now,
@@ -916,6 +917,7 @@ class AppService:
agent_cfg.variables = storage_data.get("variables", [])
# if data.tools is not None:
agent_cfg.tools = storage_data.get("tools", [])
agent_cfg.skills = storage_data.get("skills", {})
agent_cfg.updated_at = now
@@ -1003,11 +1005,12 @@ class AppService:
},
memory={
"enabled": True,
"memory_content": None,
"memory_config_id": None,
"max_history": 10
},
variables=[],
tools=[],
skills=[],
is_active=True,
created_at=now,
updated_at=now,
@@ -1403,6 +1406,7 @@ class AppService:
"memory": agent_cfg.memory,
"variables": agent_cfg.variables or [],
"tools": agent_cfg.tools or [],
"skills": agent_cfg.skills or {},
}
# config = AgentConfigConverter.from_storage_format(agent_cfg)
default_model_config_id = agent_cfg.default_model_config_id

View File

@@ -1,15 +1,13 @@
"""应用统计服务"""
from datetime import datetime, timedelta
from typing import Dict, Any, List
from typing import Dict, Any
import uuid
from sqlalchemy import func, and_, cast, Date
from sqlalchemy.orm import Session
from app.models.conversation_model import Conversation, Message
from app.models.end_user_model import EndUser
from app.models.api_key_model import ApiKey, ApiKeyLog
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.models.api_key_model import ApiKey, ApiKeyLog, ApiKeyType
class AppStatisticsService:
@@ -146,7 +144,6 @@ class AppStatisticsService:
end_dt: datetime
) -> Dict[str, Any]:
"""获取Token消耗统计从Message的meta_data中提取"""
from sqlalchemy import text
# 查询所有相关消息的token使用情况
# meta_data中可能包含: {"usage": {"total_tokens": 100}} 或 {"tokens": 100}
@@ -187,7 +184,80 @@ class AppStatisticsService:
daily_tokens[date_str] = 0
daily_tokens[date_str] += int(tokens)
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
total = sum(row["tokens"] for row in daily_data)
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
total = sum(row["count"] for row in daily_data)
return {"daily": daily_data, "total": total}
def get_workspace_api_statistics(
self,
workspace_id: uuid.UUID,
start_date: int,
end_date: int
) -> list[Any]:
"""获取工作空间API调用统计
Args:
workspace_id: 工作空间ID
start_date: 开始时间戳(毫秒)
end_date: 结束时间戳(毫秒)
Returns:
每日统计数据列表
"""
# 将毫秒时间戳转换为 datetime
start_time = datetime.fromtimestamp(start_date / 1000)
end_time = datetime.fromtimestamp(end_date / 1000)
# 应用类型agent, multi_agent, workflow
app_types = [ApiKeyType.AGENT, ApiKeyType.CLUSTER, ApiKeyType.WORKFLOW]
# 每日应用类型调用次数
daily_app_calls = self.db.query(
cast(ApiKeyLog.created_at, Date).label('date'),
func.count(ApiKeyLog.id).label('count')
).join(
ApiKey, ApiKeyLog.api_key_id == ApiKey.id
).filter(
and_(
ApiKey.workspace_id == workspace_id,
ApiKey.type.in_(app_types),
ApiKeyLog.created_at >= start_time,
ApiKeyLog.created_at <= end_time
)
).group_by(cast(ApiKeyLog.created_at, Date)).all()
# 每日服务类型调用次数
daily_service_calls = self.db.query(
cast(ApiKeyLog.created_at, Date).label('date'),
func.count(ApiKeyLog.id).label('count')
).join(
ApiKey, ApiKeyLog.api_key_id == ApiKey.id
).filter(
and_(
ApiKey.workspace_id == workspace_id,
ApiKey.type == ApiKeyType.SERVICE,
ApiKeyLog.created_at >= start_time,
ApiKeyLog.created_at <= end_time
)
).group_by(cast(ApiKeyLog.created_at, Date)).all()
# 构建每日数据
app_calls_dict = {str(row.date): row.count for row in daily_app_calls}
service_calls_dict = {str(row.date): row.count for row in daily_service_calls}
# 合并所有日期
all_dates = sorted(set(app_calls_dict.keys()) | set(service_calls_dict.keys()))
daily_data = []
for date in all_dates:
app_count = app_calls_dict.get(date, 0)
service_count = service_calls_dict.get(date, 0)
daily_data.append({
"date": date,
"total_calls": app_count + service_count,
"app_calls": app_count,
"service_calls": service_count
})
return daily_data

View File

@@ -24,6 +24,7 @@ from app.core.error_codes import BizCode
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.models import ModelType
from app.services.model_service import ModelApiKeyService
logger = get_business_logger()
@@ -357,6 +358,8 @@ class CollaborativeOrchestrator:
"usage": response.get("usage", {"total_tokens": 0}),
"is_final_answer": True
}
ModelApiKeyService.record_api_key_usage(self.db, agent_config.get("api_key_id"))
# 检查是否有工具调用handoff
tool_calls = response.get("tool_calls", [])
@@ -427,7 +430,7 @@ class CollaborativeOrchestrator:
)
# 获取 API Key
api_key_config = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
api_key_config = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key_config:
raise BusinessException(
f"Agent 模型没有可用的 API Key: {agent_id}",
@@ -442,7 +445,8 @@ class CollaborativeOrchestrator:
"provider": api_key_config.provider,
"api_key": api_key_config.api_key,
"api_base": api_key_config.api_base,
"model_parameters": config_data.get("model_parameters", {})
"model_parameters": config_data.get("model_parameters", {}),
"api_key_id": api_key_config.id
}
except ValueError:

View File

@@ -10,6 +10,11 @@ import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -24,12 +29,11 @@ from app.services import task_service
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.services.model_service import ModelApiKeyService
from app.services.tool_service import ToolService
from app.services.multimodal_service import MultimodalService
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -59,7 +63,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
长期记忆工具
"""
# search_switch = memory_config.get("search_switch", "2")
config_id= memory_config.get("memory_content") or memory_config.get("memory_config",None)
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None)
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
@tool(args_schema=LongTermMemoryInput)
def long_term_memory(question: str) -> str:
@@ -310,6 +315,7 @@ class DraftRunService:
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
@@ -320,9 +326,7 @@ class DraftRunService:
print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
@@ -345,6 +349,25 @@ class DraftRunService:
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
@@ -433,7 +456,8 @@ class DraftRunService:
)
memory_config_= agent_config.memory
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
# 8. 调用 Agent支持多模态
result = await agent.chat(
@@ -450,6 +474,8 @@ class DraftRunService:
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
# 9. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
await self._save_conversation_message(
@@ -558,6 +584,7 @@ class DraftRunService:
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
@@ -567,9 +594,7 @@ class DraftRunService:
# print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
@@ -592,6 +617,25 @@ class DraftRunService:
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
@@ -628,7 +672,6 @@ class DraftRunService:
}
)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],
@@ -677,7 +720,8 @@ class DraftRunService:
})
memory_config_ = agent_config.memory
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
# 9. 流式调用 Agent支持多模态
full_content = ""
@@ -704,6 +748,8 @@ class DraftRunService:
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
if sub_agent:
yield self._format_sse_event("sub_usage", {
"total_tokens": total_tokens
@@ -770,7 +816,7 @@ class DraftRunService:
Raises:
BusinessException: 当没有可用的 API Key 时
"""
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
# stmt = (
# select(ModelApiKey).join(
# ModelConfig, ModelApiKey.model_configs
@@ -784,7 +830,8 @@ class DraftRunService:
# )
#
# api_key = self.db.scalars(stmt).first()
api_key = api_keys[0] if api_keys else None
# api_key = api_keys[0] if api_keys else None
api_key = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
@@ -793,7 +840,8 @@ class DraftRunService:
"model_name": api_key.model_name,
"provider": api_key.provider,
"api_key": api_key.api_key,
"api_base": api_key.api_base
"api_base": api_key.api_base,
"api_key_id": api_key.id
}
async def _ensure_conversation(
@@ -1051,7 +1099,7 @@ class DraftRunService:
except Exception as e:
# 对于多 Agent 应用,没有直接的 AgentConfig 是正常的
logger.debug("获取配置快照失败(可能是多 Agent 应用)", extra={"error": str(e)})
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
return {}
def _replace_variables(

View File

@@ -537,7 +537,7 @@ def convert_multi_agent_config_to_handoffs(
# 获取该 Agent 的模型配置
if release.default_model_config_id:
model_api_key = ModelApiKeyService.get_a_api_key(db, release.default_model_config_id)
model_api_key = ModelApiKeyService.get_available_api_key(db, release.default_model_config_id)
if model_api_key:
model_config = RedBearModelConfig(
model_name=model_api_key.model_name,
@@ -551,6 +551,7 @@ def convert_multi_agent_config_to_handoffs(
}
)
logger.debug(f"Agent {agent_name} 使用模型: {model_api_key.model_name}")
ModelApiKeyService.record_api_key_usage(db, model_api_key.id)
else:
logger.warning(f"Agent {agent_name} 模型配置无效: {release.default_model_config_id}")
else:

View File

@@ -382,6 +382,7 @@ class LLMRouter:
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.models import ModelApiKey, ModelType
from app.services.model_service import ModelApiKeyService
# 获取 API Key 配置(通过关联关系)
# api_key_config = self.db.query(ModelApiKey).join(
@@ -389,8 +390,9 @@ class LLMRouter:
# ).filter(ModelConfig.id == self.routing_model_config.id,
# ModelApiKey.is_active == True
# ).first()
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id)
api_key_config = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id)
# api_key_config = api_keys[0] if api_keys else None
api_key_config = ModelApiKeyService.get_available_api_key(self.db, self.routing_model_config.id)
if not api_key_config:
raise Exception("路由模型没有可用的 API Key")
@@ -424,7 +426,6 @@ class LLMRouter:
# 调用模型
response = await llm.ainvoke(prompt)
from app.services.model_service import ModelApiKeyService
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
# 提取响应内容

View File

@@ -349,7 +349,7 @@ class MasterAgentRouter:
from app.models import ModelApiKey, ModelType
# 获取 API Key 配置
api_key_config = ModelApiKeyService.get_a_api_key(self.db, self.master_model_config.id)
api_key_config = ModelApiKeyService.get_available_api_key(self.db, self.master_model_config.id)
if not api_key_config:
raise Exception("Master Agent 模型没有可用的 API Key")
@@ -400,6 +400,7 @@ class MasterAgentRouter:
# 调用模型
response = await llm.ainvoke(prompt)
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
# 提取响应内容
if hasattr(response, 'content'):

View File

@@ -1194,7 +1194,9 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
workspace_id=app.workspace_id
)
memory_config_id = str(memory_config.config_id) if memory_config else None
memory_obj = config.get('memory', {})
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
result = {
"end_user_id": str(end_user_id),
@@ -1284,7 +1286,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
if release:
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
if memory_config_id:
# 判断是否为UUID格式
if len(str(memory_config_id))>=5:
@@ -1330,7 +1333,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 从 config 中提取 memory_config_id
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 获取配置名称使用字符串形式的ID进行查找兼容新旧格式
memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None

View File

@@ -53,7 +53,10 @@ def get_workspace_end_users(
workspace_id: uuid.UUID,
current_user: User
) -> List[EndUser]:
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
返回结果按 updated_at 从新到旧排序NULL 值排在最后)
"""
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
@@ -68,9 +71,14 @@ def get_workspace_end_users(
app_ids = [app.id for app in apps_orm]
# 批量查询所有 end_users一次查询而非循环查询
# 按 updated_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
from app.models.end_user_model import EndUser as EndUserModel
from sqlalchemy import desc, nullslast
end_users_orm = db.query(EndUserModel).filter(
EndUserModel.app_id.in_(app_ids)
).order_by(
nullslast(desc(EndUserModel.updated_at)),
desc(EndUserModel.id)
).all()
# 转换为 Pydantic 模型(只在需要时转换)

View File

@@ -108,13 +108,14 @@ class WorkspaceAppService:
app_info["releases"].append(release_info)
def _extract_memory_content(self, config: Any) -> str:
"""Extract memory_comtent from config"""
"""Extract memory_config_id from config (兼容新旧字段名)"""
if not config or not isinstance(config, dict):
return None
memory_obj = config.get('memory')
if memory_obj and isinstance(memory_obj, dict):
return memory_obj.get('memory_content')
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
return memory_obj.get('memory_config_id') or memory_obj.get('memory_content')
return None

View File

@@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类PostgreSQL
if not params.rerank_id:
params.rerank_id = configs.get('rerank')
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
if not params.reflection_model_id:
params.reflection_model_id = params.llm_id
if not params.emotion_model_id:
params.emotion_model_id = params.llm_id
config = MemoryConfigRepository.create(self.db, params)
self.db.commit()
return {"affected": 1, "config_id": config.config_id}
@@ -203,6 +209,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"end_user_id": config.end_user_id,
"config_id_old": config_id_old,
"apply_id": config.apply_id,
"scene_id": config.scene_id,
"llm_id": config.llm_id,
"embedding_id": config.embedding_id,
"rerank_id": config.rerank_id,

View File

@@ -6,7 +6,7 @@ import math
import time
import asyncio
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
from app.schemas import model_schema
from app.schemas.model_schema import (
@@ -633,19 +633,31 @@ class ModelApiKeyService:
@staticmethod
def get_available_api_key(db: Session, model_config_id: uuid.UUID) -> Optional[ModelApiKey]:
"""获取可用的API Key按优先级和负载均衡)"""
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active=True)
"""获取可用的API Key根据负载均衡策略"""
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config:
return None
api_keys = [key for key in model_config.api_keys if key.is_active]
if not api_keys:
return None
return min(api_keys, key=lambda x: int(x.usage_count or "0"))
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
# 否则返回第一个
return api_keys[0]
@staticmethod
def record_api_key_usage(db: Session, api_key_id: uuid.UUID) -> bool:
def record_api_key_usage(db: Session, api_key_id: uuid.UUID | None) -> bool:
"""记录API Key使用"""
success = ModelApiKeyRepository.update_usage(db, api_key_id)
if success:
db.commit()
return success
if api_key_id:
success = ModelApiKeyRepository.update_usage(db, api_key_id)
if success:
db.commit()
return success
return False
@staticmethod
def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey:

View File

@@ -14,6 +14,7 @@ from app.services.conversation_state_manager import ConversationStateManager
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.services.model_service import ModelApiKeyService
logger = get_business_logger()
@@ -2569,8 +2570,9 @@ class MultiAgentOrchestrator:
# ModelConfig.id == default_model_config_id,
# ModelApiKey.is_active.is_(True)
# ).first()
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
api_key_config = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
# api_key_config = api_keys[0] if api_keys else None
api_key_config = ModelApiKeyService.get_available_api_key(self.db, default_model_config_id)
if not api_key_config:
logger.warning("Master Agent 没有可用的 API Key使用简单整合")
@@ -2601,6 +2603,8 @@ class MultiAgentOrchestrator:
# 调用模型进行整合
response = await llm.ainvoke(merge_prompt)
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
# 提取响应内容
if hasattr(response, 'content'):
merged_response = response.content
@@ -2730,8 +2734,9 @@ class MultiAgentOrchestrator:
# ModelConfig.id == default_model_config_id,
# ModelApiKey.is_active.is_(True)
# ).first()
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
api_key_config = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
# api_key_config = api_keys[0] if api_keys else None
api_key_config = ModelApiKeyService.get_available_api_key(self.db, default_model_config_id)
if not api_key_config:
logger.warning("Master Agent 没有可用的 API Key使用简单整合")
@@ -2790,6 +2795,8 @@ class MultiAgentOrchestrator:
logger.debug(f"收到流式 chunk #{chunk_count}: {content[:30]}...")
yield self._format_sse_event("message", {"content": content})
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
logger.info(f"Master Agent 流式整合完成,共 {chunk_count} 个 chunks")
except AttributeError as e:

View File

@@ -23,7 +23,7 @@ logger = get_business_logger()
class ImageFormatStrategy(Protocol):
"""图片格式策略接口"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""将图片 URL 转换为特定 provider 的格式"""
...
@@ -31,7 +31,7 @@ class ImageFormatStrategy(Protocol):
class DashScopeImageStrategy:
"""通义千问图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""通义千问格式: {"type": "image", "image": "url"}"""
return {
@@ -42,7 +42,7 @@ class DashScopeImageStrategy:
class BedrockImageStrategy:
"""Bedrock/Anthropic 图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""
Bedrock/Anthropic 格式: base64 编码
@@ -51,17 +51,17 @@ class BedrockImageStrategy:
import httpx
import base64
from mimetypes import guess_type
logger.info(f"下载并编码图片: {url}")
# 下载图片
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 获取图片数据
image_data = response.content
# 确定 media type
content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"):
@@ -69,12 +69,12 @@ class BedrockImageStrategy:
else:
guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
# 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8")
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return {
"type": "image",
"source": {
@@ -87,7 +87,7 @@ class BedrockImageStrategy:
class OpenAIImageStrategy:
"""OpenAI 图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
return {
@@ -109,7 +109,7 @@ PROVIDER_STRATEGIES = {
class MultimodalService:
"""多模态文件处理服务"""
def __init__(self, db: Session, provider: str = "dashscope"):
"""
初始化多模态服务
@@ -120,10 +120,10 @@ class MultimodalService:
"""
self.db = db
self.provider = provider.lower()
async def process_files(
self,
files: Optional[List[FileInput]]
self,
files: Optional[List[FileInput]]
) -> List[Dict[str, Any]]:
"""
处理文件列表,返回 LLM 可用的格式
@@ -136,7 +136,7 @@ class MultimodalService:
"""
if not files:
return []
result = []
for idx, file in enumerate(files):
try:
@@ -168,10 +168,10 @@ class MultimodalService:
"type": "text",
"text": f"[文件处理失败: {str(e)}]"
})
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
"""
处理图片文件
@@ -184,14 +184,10 @@ class MultimodalService:
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
- 通义千问: {"type": "image", "image": "url"}
"""
if file.transfer_method == TransferMethod.REMOTE_URL:
url = file.url
else:
# 本地文件,获取访问 URL
url = await self._get_file_url(file.upload_file_id)
url = await self.get_file_url(file)
logger.debug(f"处理图片: {url}, provider={self.provider}")
# 根据 provider 返回不同格式
if self.provider in ["bedrock", "anthropic"]:
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
@@ -223,7 +219,7 @@ class MultimodalService:
"type": "image",
"image": url
}
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
"""
下载图片并转换为 base64
@@ -237,15 +233,15 @@ class MultimodalService:
import httpx
import base64
from mimetypes import guess_type
# 下载图片
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 获取图片数据
image_data = response.content
# 确定 media type
content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"):
@@ -254,14 +250,14 @@ class MultimodalService:
# 从 URL 推断
guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
# 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8")
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return base64_data, media_type
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
"""
处理文档文件PDF、Word 等)
@@ -284,14 +280,14 @@ class MultimodalService:
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file.upload_file_id
).first()
file_name = generic_file.file_name if generic_file else "unknown"
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
"""
处理音频文件
@@ -307,7 +303,7 @@ class MultimodalService:
"type": "text",
"text": "[音频文件,暂不支持处理]"
}
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
"""
处理视频文件
@@ -323,13 +319,13 @@ class MultimodalService:
"type": "text",
"text": "[视频文件,暂不支持处理]"
}
async def _get_file_url(self, file_id: uuid.UUID) -> str:
async def get_file_url(self, file: FileInput) -> str:
"""
获取文件的访问 URL
Args:
file_id: 文件ID
file: File Input Struct
Returns:
str: 文件访问 URL
@@ -337,26 +333,31 @@ class MultimodalService:
Raises:
BusinessException: 文件不存在
"""
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file_id,
GenericFile.status == "active"
).first()
if not generic_file:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
# 如果有 access_url直接返回
if generic_file.access_url:
return generic_file.access_url
# 否则,根据 storage_path 生成 URL
# TODO: 根据实际存储方式生成 URL本地存储、OSS 等)
# 这里暂时返回一个占位 URL
return f"/api/files/{file_id}/download"
if file.transfer_method == TransferMethod.REMOTE_URL:
return file.url
else:
# 本地文件,获取访问 URL
file_id = file.upload_file_id
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file.upload_file_id,
GenericFile.status == "active"
).first()
if not generic_file:
raise BusinessException(
f"文件不存在或已删除: {file.upload_file_id}",
BizCode.NOT_FOUND
)
# 如果有 access_url直接返回
if generic_file.access_url:
return generic_file.access_url
# 否则,根据 storage_path 生成 URL
# TODO: 根据实际存储方式生成 URL本地存储、OSS 等)
# 这里暂时返回一个占位 URL
return f"/api/files/{file_id}/download"
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
"""
提取文档文本内容
@@ -371,20 +372,20 @@ class MultimodalService:
GenericFile.id == file_id,
GenericFile.status == "active"
).first()
if not generic_file:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
# TODO: 根据文件类型提取文本
# - PDF: 使用 PyPDF2 或 pdfplumber
# - Word: 使用 python-docx
# - TXT/MD: 直接读取
file_ext = generic_file.file_ext.lower()
if file_ext in ['.txt', '.md', '.markdown']:
return await self._read_text_file(generic_file.storage_path)
elif file_ext == '.pdf':
@@ -393,7 +394,7 @@ class MultimodalService:
return await self._extract_word_text(generic_file.storage_path)
else:
return f"[不支持的文档格式: {file_ext}]"
async def _read_text_file(self, storage_path: str) -> str:
"""读取纯文本文件"""
try:
@@ -402,7 +403,7 @@ class MultimodalService:
except Exception as e:
logger.error(f"读取文本文件失败: {e}")
return f"[文件读取失败: {str(e)}]"
async def _extract_pdf_text(self, storage_path: str) -> str:
"""提取 PDF 文本"""
try:
@@ -412,7 +413,7 @@ class MultimodalService:
except Exception as e:
logger.error(f"提取 PDF 文本失败: {e}")
return f"[PDF 提取失败: {str(e)}]"
async def _extract_word_text(self, storage_path: str) -> str:
"""提取 Word 文档文本"""
try:

View File

@@ -1,4 +1,3 @@
{% raw %}
Role: AI Prompt Optimization Expert
Profile
@@ -12,11 +11,11 @@ Skills
Core Optimization Skills
Requirement Analysis: Accurately understand the relationship between the users current needs and the original prompt.
Structural Reconstruction: Transform vague requirements into clear, block-structured instructions.
Variable Handling: Identify and standardize dynamic variables in prompts.
{% if skill != true %}Variable Handling: Identify and standardize dynamic variables in prompts.{% endif %}
Conflict Resolution: Prioritize current requirements when historical requirements conflict with current needs.
Auxiliary Generation Skills
Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.
{% if skill != true %}Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.{% endif %}
Language Consistency: Maintain consistency between label language and user input language.
Executability Verification: Ensure optimized prompts can be directly used in AI tools.
Format Standardization: Strictly adhere to specified output format requirements.
@@ -25,30 +24,30 @@ Rules
Basic Principles
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints
{% if skill != true %}Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints{% endif %}
Language Rule: All label languages must fully match the user input language.
Behavior Guidelines
Precision Guideline: All instructions must be precise and directly executable, avoiding ambiguity.
Readability Guideline: Ensure optimized prompts have good readability and logical flow.
Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed.
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.
{% if skill != true %}{% raw %}Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed.
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.{% endraw %}{% endif %}
Constraints
Output Constraint: Must output in JSON format including the fields "prompt" and "desc".
Content Constraint: Must not include any explanations, analyses, or additional comments.
Language Constraint: Must use clear and concise language.
Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).
{% if skill != true %}Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).{% endif %}
Workflows
Goal: Optimize or generate AI prompts that can be directly used according to user requirements.
Step 1: Receive the users current requirement description {{user_require}} and the original prompt {{original_prompt}}.
Step 2: Analyze requirements, identify conflicts, and prioritize current requirements.
Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined.
{% if skill != true %}Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined.
Step 4: Generate a JSON output containing the optimized prompt and its description.
{% else %}Step 3: Generate a JSON output containing the optimized prompt and its description.{% endif %}
Expected Outcome: Obtain a clear, directly executable AI prompt accompanied by an optimization description.
Initialization
As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows.
{% endraw %}
As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows.

View File

@@ -23,6 +23,7 @@ from app.repositories.prompt_optimizer_repository import (
PromptReleaseRepository
)
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
from app.services.model_service import ModelApiKeyService
logger = get_business_logger()
@@ -128,7 +129,8 @@ class PromptOptimizerService:
session_id: uuid.UUID,
user_id: uuid.UUID,
current_prompt: str,
user_require: str
user_require: str,
skill: bool = False
) -> AsyncGenerator[dict[str, str | Any], Any]:
"""
Optimize a user-provided prompt using a configured prompt optimizer LLM.
@@ -157,6 +159,7 @@ class PromptOptimizerService:
user_id (uuid.UUID): Identifier of the user associated with the session.
current_prompt (str): Original prompt to optimize.
user_require (str): User's requirements or instructions for optimization.
skill(bool): Is skill required
Returns:
OptimizePromptResult: An object containing:
@@ -174,8 +177,9 @@ class PromptOptimizerService:
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
# Create LLM instance
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
api_config: ModelApiKey = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
# api_config: ModelApiKey = api_keys[0] if api_keys else None
api_config: ModelApiKey = ModelApiKeyService.get_available_api_key(self.db, model_config.id)
llm = RedBearLLM(RedBearModelConfig(
model_name=api_config.model_name,
provider=api_config.provider,
@@ -186,7 +190,7 @@ class PromptOptimizerService:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render()
rendered_system_message = Template(opt_system_prompt).render(skill=skill)
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
opt_user_prompt = f.read()
@@ -250,6 +254,7 @@ class PromptOptimizerService:
optim_result = json_repair.repair_json(buffer, return_objects=True)
# prompt = optim_result.get("prompt")
desc = optim_result.get("desc")
ModelApiKeyService.record_api_key_usage(self.db, api_config.id)
self.create_message(
tenant_id=tenant_id,
session_id=session_id,

View File

@@ -10,6 +10,7 @@ from app.services.memory_konwledges_server import write_rag
from app.models import ReleaseShare, AppRelease, Conversation
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService
from app.services.release_share_service import ReleaseShareService
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
@@ -178,8 +179,9 @@ class SharedChatService:
# .limit(1)
# )
# api_key_obj = self.db.scalars(stmt).first()
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
api_key_obj = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
# api_key_obj = api_keys[0] if api_keys else None
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key_obj:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
@@ -309,6 +311,8 @@ class SharedChatService:
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
return {
"conversation_id": conversation.id,
@@ -349,7 +353,8 @@ class SharedChatService:
if variables is None:
variables = {}
memory_config = {"enabled": memory, "memory_content": "17", "max_history": 10}
# 兼容新旧字段名:使用 memory_config_id
memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10}
try:
# 获取发布版本和配置
@@ -383,8 +388,9 @@ class SharedChatService:
# .limit(1)
# )
# api_key_obj = self.db.scalars(stmt).first()
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
api_key_obj = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
# api_key_obj = api_keys[0] if api_keys else None
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key_obj:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
@@ -513,7 +519,8 @@ class SharedChatService:
}
)
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"

View File

@@ -0,0 +1,133 @@
"""Skill Service"""
import uuid
from typing import List
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from app.repositories.skill_repository import SkillRepository
from app.schemas.skill_schema import SkillCreate, SkillUpdate
from app.models.skill_model import Skill
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.tool_service import ToolService
class SkillService:
"""Skill 业务逻辑层"""
@staticmethod
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
"""创建技能"""
# 检查同名技能
existing = db.query(Skill).filter(
Skill.tenant_id == tenant_id,
Skill.name == data.name
).first()
if existing:
raise BusinessException(f"技能名称'{data.name}'已存在", BizCode.DUPLICATE_NAME)
skill = SkillRepository.create(db, data, tenant_id)
db.commit()
db.refresh(skill)
return skill
@staticmethod
def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill:
"""获取技能"""
try:
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
if not skill:
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
# 填充工具详情
tool_service = ToolService(db)
enriched_tools = []
for tool_config in skill.tools:
tool_id = tool_config.get("tool_id")
if tool_id:
tool_info = tool_service.get_tool_info(tool_id, tenant_id)
if tool_info:
enriched_tools.append({
"tool_id": tool_id,
"operation": tool_config.get("operation"),
"tool_info": tool_info
})
skill.tools = enriched_tools
return skill
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def list_skills(
db: Session,
tenant_id: uuid.UUID,
search: str = None,
is_active: bool = None,
is_public: bool = None,
page: int = 1,
pagesize: int = 10
) -> tuple[list[type[Skill]], int]:
"""列出技能"""
return SkillRepository.list_skills(
db, tenant_id, search, is_active, is_public, page, pagesize
)
@staticmethod
def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill:
"""更新技能"""
try:
skill = SkillRepository.update(db, skill_id, data, tenant_id)
if not skill:
raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND)
db.commit()
db.refresh(skill)
return skill
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def delete_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
"""删除技能"""
try:
success = SkillRepository.delete(db, skill_id, tenant_id)
if not success:
raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND)
db.commit()
return True
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]:
"""加载技能关联的工具
Returns:
(tools, tool_to_skill_map) - 工具列表和工具到技能的映射
"""
tools = []
tool_to_skill_map = {} # {tool_name: skill_id}
tool_service = ToolService(db)
for skill_id in skill_ids:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
if skill and skill.is_active:
# 加载技能关联的工具
for tool_config in skill.tools:
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool:
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 建立工具到技能的映射
tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool)))
tool_to_skill_map[tool_name] = skill_id
except Exception as e:
print(f"加载技能 {skill_id} 的工具时出错: {e}")
continue
return tools, tool_to_skill_map

View File

@@ -4,9 +4,8 @@
import datetime
import logging
import uuid
from typing import Any, Annotated, AsyncGenerator, Optional
from typing import Any, Annotated, Optional
from deprecated import deprecated
from fastapi import Depends
from sqlalchemy.orm import Session
@@ -23,6 +22,7 @@ from app.repositories.workflow_repository import (
from app.schemas import DraftRunRequest
from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__)
@@ -36,6 +36,7 @@ class WorkflowService:
self.execution_repo = WorkflowExecutionRepository(db)
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
self.conversation_service = ConversationService(db)
self.multimodal_service = MultimodalService(db)
# ==================== 配置管理 ====================
@@ -445,24 +446,22 @@ class WorkflowService:
code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id}
files = []
if payload.files:
for file in payload.files:
files.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
)
# 转换 user_id 为 UUID
triggered_by_uuid = None
if payload.user_id:
try:
triggered_by_uuid = uuid.UUID(payload.user_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id, "files": files}
# 转换 conversation_id 为 UUID
conversation_id_uuid = None
if payload.conversation_id:
try:
conversation_id_uuid = uuid.UUID(payload.conversation_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
# 2. 创建执行记录
execution = self.create_execution(
@@ -544,10 +543,10 @@ class WorkflowService:
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"variables": result.get("variables"),
"messages": result.get("messages"),
# "variables": result.get("variables"),
# "messages": result.get("messages"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
@@ -566,6 +565,41 @@ class WorkflowService:
message=f"工作流执行失败: {str(e)}"
)
@staticmethod
def _map_public_event(event: dict) -> dict | None:
event_type = event.get("event")
payload = event.get("data")
match event_type:
case "workflow_start":
return {
"event": "start",
"data": {
"conversation_id": payload.get("conversation_id"),
}
}
case "workflow_end":
return {
"event": "end",
"data": {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", ""))
}
}
case "node_start" | "node_end" | "node_error":
return None
case _:
return event
def _emit(self, public: bool, internal_event: dict):
"""
decide
"""
if public:
mapped = self._map_public_event(internal_event)
else:
mapped = internal_event
return mapped
async def run_stream(
self,
app_id: uuid.UUID,
@@ -573,6 +607,7 @@ class WorkflowService:
config: WorkflowConfig,
workspace_id: uuid.UUID,
release_id: Optional[uuid.UUID] = None,
public: bool = False
):
"""运行工作流(流式)
@@ -582,6 +617,7 @@ class WorkflowService:
app_id: 应用 ID
payload: 请求对象(包含 message, variables, conversation_id 等)
config: 存储类型(可选)
public: 是否发布
Yields:
SSE 格式的流式事件
@@ -597,24 +633,23 @@ class WorkflowService:
code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID
triggered_by_uuid = None
if payload.user_id:
try:
triggered_by_uuid = uuid.UUID(payload.user_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
files = []
if payload.files:
for file in payload.files:
files.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id, "files": files}
# 转换 conversation_id 为 UUID
conversation_id_uuid = None
if payload.conversation_id:
try:
conversation_id_uuid = uuid.UUID(payload.conversation_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
# 2. 创建执行记录
execution = self.create_execution(
@@ -661,7 +696,7 @@ class WorkflowService:
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=payload.user_id
user_id=payload.user_id,
):
if event.get("event") == "workflow_end":
@@ -692,7 +727,9 @@ class WorkflowService:
)
else:
logger.error(f"unexpect workflow run status, status: {status}")
yield event
event = self._emit(public, event)
if event:
yield event
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
@@ -710,134 +747,6 @@ class WorkflowService:
}
}
@deprecated(reason="This method is deprecated. "
"Please use WorkflowService.run / run_stream instead.")
async def run_workflow(
self,
app_id: uuid.UUID,
input_data: dict[str, Any],
triggered_by: uuid.UUID,
conversation_id: uuid.UUID | None = None,
stream: bool = False
) -> AsyncGenerator | dict:
"""运行工作流
Args:
app_id: 应用 ID
input_data: 输入数据(包含 message 和 variables
triggered_by: 触发用户 ID
conversation_id: 会话 ID可选
stream: 是否流式返回
Returns:
执行结果(非流式)或生成器(流式)
Raises:
BusinessException: 配置不存在或执行失败时抛出
"""
# 1. 获取工作流配置
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
code=BizCode.NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
# 2. 创建执行记录
execution = self.create_execution(
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by,
conversation_id=conversation_id,
input_data=input_data
)
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
app = self.db.query(App).filter(
App.id == app_id,
App.is_active.is_(True)
).first()
if not app:
raise BusinessException(
code=BizCode.NOT_FOUND,
message=f"应用不存在: app_id={app_id}"
)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
if stream:
# 流式执行
return self._run_workflow_stream(
workflow_config_dict,
input_data,
execution.execution_id,
str(app.workspace_id),
str(triggered_by)
)
else:
# 非流式执行
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(app.workspace_id),
user_id=str(triggered_by)
)
# 更新执行结果
if result.get("status") == "completed":
token_usage = result.get("data").get("token_usage", {}) or {}
self.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {}),
token_usage=token_usage.get("total_tokens", None)
)
else:
self.update_execution_status(
execution.execution_id,
"failed",
error_message=result.get("error")
)
# 返回增强的响应结构
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage")
}
except Exception as e:
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
raise BusinessException(
code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
"""清理事件数据,移除不可序列化的对象
@@ -869,72 +778,6 @@ class WorkflowService:
return clean_value(event)
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
async def _run_workflow_stream(
self,
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str):
"""运行工作流(流式,内部方法)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Yields:
流式事件(格式:{"event": "<type>", "data": {...}}
"""
from app.core.workflow.executor import execute_workflow_stream
try:
async for event in execute_workflow_stream(
workflow_config=workflow_config,
input_data=input_data,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
):
# 直接转发事件executor 已经返回正确格式)
if event.get("event") == "workflow_end":
token_usage = event.get("data").get("token_usage", {}) or {}
status = event.get("data", {}).get("status")
if status == "completed":
self.update_execution_status(
execution_id,
"completed",
output_data=event.get("data"),
token_usage=token_usage.get("total_tokens", None)
)
elif status == "failed":
self.update_execution_status(
execution_id,
"failed",
output_data=event.get("data")
)
else:
logger.error(f"unexpect workflow run status, status: {status}")
yield event
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution_id,
"failed",
error_message=str(e)
)
yield {
"event": "error",
"data": {
"execution_id": execution_id,
"error": str(e)
}
}
# ==================== 依赖注入函数 ====================

View File

@@ -1069,6 +1069,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
except Exception as e:
db.rollback() # Rollback failed transaction to allow next query
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
all_reflection_results.append({
"workspace_id": str(workspace_id),
@@ -1207,3 +1208,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
return result
finally:
loop.close()
# =============================================================================
# Long-term Memory Storage Tasks (Batched Write Strategies)
# =============================================================================
@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True)
def long_term_storage_window_task(
self,
end_user_id: str,
langchain_messages: List[Dict[str, Any]],
config_id: str,
scope: int = 6
) -> Dict[str, Any]:
"""Celery task for window-based long-term memory storage.
Accumulates messages in Redis buffer until window size (scope) is reached,
then writes batched messages to Neo4j.
Args:
end_user_id: End user identifier
langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
config_id: Memory configuration ID
scope: Window size (number of messages before triggering write)
Returns:
Dict containing task status and metadata
"""
from app.core.logging_config import get_logger
logger = get_logger(__name__)
logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}")
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
from app.core.memory.agent.utils.redis_tool import write_store
from app.services.memory_config_service import MemoryConfigService
db = next(get_db())
try:
# Save to Redis buffer first
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# Load memory config
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="LongTermStorageTask"
)
# Execute window-based dialogue storage
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
return {"status": "SUCCESS", "strategy": "window", "scope": scope}
finally:
db.close()
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s")
return {
**result,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True)
return {
"status": "FAILURE",
"strategy": "window",
"error": str(e),
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True)
# def long_term_storage_time_task(
# self,
# end_user_id: str,
# config_id: str,
# time_window: int = 5
# ) -> Dict[str, Any]:
# """Celery task for time-based long-term memory storage.
# Retrieves recent sessions from Redis within time window and writes to Neo4j.
# Args:
# end_user_id: End user identifier
# config_id: Memory configuration ID
# time_window: Time window in minutes for retrieving recent sessions
# Returns:
# Dict containing task status and metadata
# """
# from app.core.logging_config import get_logger
# logger = get_logger(__name__)
# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}")
# start_time = time.time()
# async def _run() -> Dict[str, Any]:
# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage
# from app.services.memory_config_service import MemoryConfigService
# db = next(get_db())
# try:
# # Load memory config
# config_service = MemoryConfigService(db)
# memory_config = config_service.load_memory_config(
# config_id=config_id,
# service_name="LongTermStorageTask"
# )
# # Execute time-based storage
# await memory_long_term_storage(end_user_id, memory_config, time_window)
# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window}
# finally:
# db.close()
# try:
# import nest_asyncio
# nest_asyncio.apply()
# except ImportError:
# pass
# try:
# loop = asyncio.get_event_loop()
# if loop.is_closed():
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# except RuntimeError:
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# try:
# result = loop.run_until_complete(_run())
# elapsed_time = time.time() - start_time
# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s")
# return {
# **result,
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }
# except Exception as e:
# elapsed_time = time.time() - start_time
# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True)
# return {
# "status": "FAILURE",
# "strategy": "time",
# "error": str(e),
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }
# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True)
# def long_term_storage_aggregate_task(
# self,
# end_user_id: str,
# langchain_messages: List[Dict[str, Any]],
# config_id: str
# ) -> Dict[str, Any]:
# """Celery task for aggregate-based long-term memory storage.
# Uses LLM to determine if new messages describe the same event as history.
# Only writes to Neo4j if messages represent new information (not duplicates).
# Args:
# end_user_id: End user identifier
# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
# config_id: Memory configuration ID
# Returns:
# Dict containing task status, is_same_event flag, and metadata
# """
# from app.core.logging_config import get_logger
# logger = get_logger(__name__)
# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}")
# start_time = time.time()
# async def _run() -> Dict[str, Any]:
# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment
# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
# from app.core.memory.agent.utils.redis_tool import write_store
# from app.services.memory_config_service import MemoryConfigService
# db = next(get_db())
# try:
# # Save to Redis buffer first
# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# # Load memory config
# config_service = MemoryConfigService(db)
# memory_config = config_service.load_memory_config(
# config_id=config_id,
# service_name="LongTermStorageTask"
# )
# # Execute aggregate judgment
# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config)
# return {
# "status": "SUCCESS",
# "strategy": "aggregate",
# "is_same_event": result.get("is_same_event", False),
# "wrote_to_neo4j": not result.get("is_same_event", False)
# }
# finally:
# db.close()
# try:
# import nest_asyncio
# nest_asyncio.apply()
# except ImportError:
# pass
# try:
# loop = asyncio.get_event_loop()
# if loop.is_closed():
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# except RuntimeError:
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# try:
# result = loop.run_until_complete(_run())
# elapsed_time = time.time() - start_time
# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s")
# return {
# **result,
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }
# except Exception as e:
# elapsed_time = time.time() - start_time
# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True)
# return {
# "status": "FAILURE",
# "strategy": "aggregate",
# "error": str(e),
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }

View File

@@ -99,7 +99,8 @@ def agent_config_4_app_release(release: AppRelease) -> AgentConfig:
knowledge_retrieval=config_dict.get("knowledge_retrieval"),
memory=config_dict.get("memory"),
variables=config_dict.get("variables", []),
tools=config_dict.get("tools", {}),
tools=config_dict.get("tools", []),
skills=config_dict.get("skills", {})
)
return agent_config

Some files were not shown because too many files have changed in this diff Show More