Merge branch 'develop' into fix/memory-enduser-config
This commit is contained in:
@@ -7,6 +7,10 @@ from celery import Celery
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
# backend: 结果存储(使用 Redis DB 10)
|
||||
@@ -64,6 +68,11 @@ celery_app.conf.update(
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
@@ -41,9 +41,9 @@ from . import (
|
||||
upload_controller,
|
||||
user_controller,
|
||||
user_memory_controllers,
|
||||
workflow_controller,
|
||||
workspace_controller,
|
||||
ontology_controller,
|
||||
skill_controller
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -76,7 +76,6 @@ manager_router.include_router(release_share_controller.router)
|
||||
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(multi_agent_controller.router)
|
||||
manager_router.include_router(workflow_controller.router)
|
||||
manager_router.include_router(emotion_controller.router)
|
||||
manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
@@ -90,5 +89,6 @@ manager_router.include_router(memory_perceptual_controller.router)
|
||||
manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -22,6 +22,7 @@ from app.services import app_service, workspace_service
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.services.app_service import AppService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -904,8 +905,6 @@ def get_app_statistics(
|
||||
- total_tokens: 总token消耗
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
result = stats_service.get_app_statistics(
|
||||
@@ -916,3 +915,36 @@ def get_app_statistics(
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/workspace/api-statistics", summary="工作空间API调用统计")
|
||||
@cur_workspace_access_guard()
|
||||
def get_workspace_api_statistics(
|
||||
start_date: int,
|
||||
end_date: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取工作空间API调用统计
|
||||
|
||||
Args:
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
每日统计数据列表,每项包含:
|
||||
- date: 日期
|
||||
- total_calls: 当日总调用次数
|
||||
- app_calls: 当日应用调用次数
|
||||
- service_calls: 当日服务调用次数
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
result = stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
@@ -116,14 +116,6 @@ def _get_ontology_service(
|
||||
detail=f"找不到指定的LLM模型: {llm_id}"
|
||||
)
|
||||
|
||||
# 检查是否为组合模型
|
||||
if hasattr(model_config, 'is_composite') and model_config.is_composite:
|
||||
logger.error(f"Model {llm_id} is a composite model, which is not supported for ontology extraction")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="本体提取不支持使用组合模型,请选择单个模型"
|
||||
)
|
||||
|
||||
# 验证模型配置了API密钥
|
||||
if not model_config.api_keys:
|
||||
logger.error(f"Model {llm_id} has no API key configuration")
|
||||
|
||||
@@ -120,7 +120,8 @@ async def get_prompt_opt(
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
user_require=data.message,
|
||||
skill=data.skill
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
|
||||
@@ -587,7 +587,8 @@ async def chat(
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release.id
|
||||
release_id=release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
90
api/app/controllers/skill_controller.py
Normal file
90
api/app/controllers/skill_controller.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Skill Controller - 技能市场管理"""
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models import User
|
||||
from app.schemas import skill_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.skill_service import SkillService
|
||||
from app.core.response_utils import success
|
||||
|
||||
router = APIRouter(prefix="/skills", tags=["Skills"])
|
||||
|
||||
|
||||
@router.post("", summary="创建技能")
|
||||
@cur_workspace_access_guard()
|
||||
def create_skill(
|
||||
data: skill_schema.SkillCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建技能 - 可以关联现有工具(内置、MCP、自定义)"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.create_skill(db, data, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
|
||||
|
||||
|
||||
@router.get("", summary="技能列表")
|
||||
@cur_workspace_access_guard()
|
||||
def list_skills(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_active: Optional[bool] = Query(None, description="是否激活"),
|
||||
is_public: Optional[bool] = Query(None, description="是否公开"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""技能市场列表 - 包含本工作空间和公开的技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skills, total = SkillService.list_skills(
|
||||
db, tenant_id, search, is_active, is_public, page, pagesize
|
||||
)
|
||||
|
||||
items = [skill_schema.Skill.model_validate(s) for s in skills]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
|
||||
|
||||
|
||||
@router.get("/{skill_id}", summary="获取技能详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取技能详情"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.get_skill(db, skill_id, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
|
||||
|
||||
|
||||
@router.put("/{skill_id}", summary="更新技能")
|
||||
@cur_workspace_access_guard()
|
||||
def update_skill(
|
||||
skill_id: uuid.UUID,
|
||||
data: skill_schema.SkillUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
|
||||
|
||||
|
||||
@router.delete("/{skill_id}", summary="删除技能")
|
||||
@cur_workspace_access_guard()
|
||||
def delete_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
SkillService.delete_skill(db, skill_id, tenant_id)
|
||||
return success(msg="技能删除成功")
|
||||
@@ -1,610 +0,0 @@
|
||||
"""
|
||||
工作流 API 控制器
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.schemas.workflow_schema import (
|
||||
WorkflowConfigCreate,
|
||||
WorkflowConfigUpdate,
|
||||
WorkflowConfig,
|
||||
WorkflowValidationResponse,
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowExecutionRequest,
|
||||
WorkflowExecutionResponse
|
||||
)
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["workflow"])
|
||||
|
||||
|
||||
# ==================== 工作流配置管理 ====================
|
||||
|
||||
@router.post("/{app_id}/workflow")
|
||||
@cur_workspace_access_guard()
|
||||
async def create_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""创建工作流配置
|
||||
|
||||
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 创建工作流配置
|
||||
workflow_config = service.create_workflow_config(
|
||||
app_id=app_id,
|
||||
nodes=[node.model_dump() for node in config.nodes],
|
||||
edges=[edge.model_dump() for edge in config.edges],
|
||||
variables=[var.model_dump() for var in config.variables],
|
||||
execution_config=config.execution_config.model_dump(),
|
||||
triggers=[trigger.model_dump() for trigger in config.triggers],
|
||||
validate=True # 进行基础验证
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowConfig.model_validate(workflow_config),
|
||||
msg="工作流配置创建成功"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"创建工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"创建工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# @router.get("/{app_id}/workflow")
|
||||
# async def get_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)]
|
||||
#
|
||||
# ):
|
||||
# """获取工作流配置
|
||||
#
|
||||
# 获取应用的工作流配置详情。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
#
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
#
|
||||
# # 获取工作流配置
|
||||
# service = WorkflowService(db)
|
||||
# workflow_config = service.get_workflow_config(app_id)
|
||||
#
|
||||
# if not workflow_config:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="工作流配置不存在"
|
||||
# )
|
||||
#
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config)
|
||||
# )
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"获取工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
# @router.put("/{app_id}/workflow")
|
||||
# async def update_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# config: WorkflowConfigUpdate,
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)],
|
||||
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
# ):
|
||||
# """更新工作流配置
|
||||
|
||||
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
|
||||
# # 更新工作流配置
|
||||
# workflow_config = service.update_workflow_config(
|
||||
# app_id=app_id,
|
||||
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
|
||||
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
|
||||
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
|
||||
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
|
||||
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
|
||||
# validate=True
|
||||
# )
|
||||
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config),
|
||||
# msg="工作流配置更新成功"
|
||||
# )
|
||||
|
||||
# except BusinessException as e:
|
||||
# logger.warning(f"更新工作流配置失败: {e.message}")
|
||||
# return fail(code=e.error_code, msg=e.message)
|
||||
# except Exception as e:
|
||||
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"更新工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.delete("/{app_id}/workflow")
|
||||
async def delete_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""删除工作流配置
|
||||
|
||||
删除应用的工作流配置。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 删除工作流配置
|
||||
deleted = service.delete_workflow_config(app_id)
|
||||
|
||||
if not deleted:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
return success(msg="工作流配置删除成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"删除工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{app_id}/workflow/validate")
|
||||
async def validate_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
):
|
||||
"""验证工作流配置
|
||||
|
||||
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证工作流配置
|
||||
|
||||
if for_publish:
|
||||
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
|
||||
else:
|
||||
workflow_config = service.get_workflow_config(app_id)
|
||||
if not workflow_config:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
from app.core.workflow.validator import validate_workflow_config as validate_config
|
||||
config_dict = {
|
||||
"nodes": workflow_config.nodes,
|
||||
"edges": workflow_config.edges,
|
||||
"variables": workflow_config.variables,
|
||||
"execution_config": workflow_config.execution_config,
|
||||
"triggers": workflow_config.triggers
|
||||
}
|
||||
is_valid, errors = validate_config(config_dict, for_publish=False)
|
||||
|
||||
return success(
|
||||
data=WorkflowValidationResponse(
|
||||
is_valid=is_valid,
|
||||
errors=errors,
|
||||
warnings=[]
|
||||
)
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"验证工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"验证工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 工作流执行管理 ====================
|
||||
|
||||
@router.get("/{app_id}/workflow/executions")
|
||||
async def get_workflow_executions(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
):
|
||||
"""获取工作流执行记录列表
|
||||
|
||||
获取应用的工作流执行历史记录。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 获取执行记录
|
||||
executions = service.get_executions_by_app(app_id, limit, offset)
|
||||
|
||||
# 获取统计信息
|
||||
statistics = service.get_execution_statistics(app_id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"executions": [WorkflowExecution.model_validate(e) for e in executions],
|
||||
"statistics": statistics,
|
||||
"pagination": {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"total": statistics["total"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行记录失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/workflow/executions/{execution_id}")
|
||||
async def get_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""获取工作流执行详情
|
||||
|
||||
获取单个工作流执行的详细信息,包括所有节点的执行记录。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 获取节点执行记录
|
||||
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"execution": WorkflowExecution.model_validate(execution),
|
||||
"node_executions": [
|
||||
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行详情失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
@router.post("/{app_id}/workflow/run")
|
||||
async def run_workflow(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""执行工作流
|
||||
|
||||
执行工作流并返回结果。支持流式和非流式两种模式。
|
||||
|
||||
**非流式模式**:等待工作流执行完成后返回完整结果。
|
||||
|
||||
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 准备输入数据
|
||||
input_data = {
|
||||
"message": request.message or "",
|
||||
"variables": request.variables
|
||||
}
|
||||
|
||||
# 执行工作流
|
||||
|
||||
if request.stream:
|
||||
# 流式执行
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
async def event_generator():
|
||||
"""生成 SSE 事件
|
||||
|
||||
SSE 格式:
|
||||
event: <event_type>
|
||||
data: <json_data>
|
||||
|
||||
支持的事件类型:
|
||||
- workflow_start: 工作流开始
|
||||
- workflow_end: 工作流结束
|
||||
- node_start: 节点开始执行
|
||||
- node_end: 节点执行完成
|
||||
- node_chunk: 中间节点的流式输出
|
||||
- message: 最终消息的流式输出(End 节点及其相邻节点)
|
||||
"""
|
||||
try:
|
||||
async for event in await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
# event: <type>
|
||||
# data: <json>
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
# 发送错误事件
|
||||
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
yield sse_error
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
result = await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=False
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowExecutionResponse(
|
||||
execution_id=result["execution_id"],
|
||||
status=result["status"],
|
||||
output=result.get("output"),
|
||||
output_data=result.get("output_data"),
|
||||
error_message=result.get("error_message"),
|
||||
elapsed_time=result.get("elapsed_time"),
|
||||
token_usage=result.get("token_usage")
|
||||
),
|
||||
msg="工作流执行完成"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"执行工作流失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"执行工作流失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||
async def cancel_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""取消工作流执行
|
||||
|
||||
取消正在运行的工作流执行。
|
||||
|
||||
**注意**:当前版本仅更新状态为 cancelled,实际的执行取消功能待实现。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 检查执行状态
|
||||
if execution.status not in ["pending", "running"]:
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"无法取消状态为 {execution.status} 的执行"
|
||||
)
|
||||
|
||||
# 更新状态为 cancelled
|
||||
service.update_execution_status(execution_id, "cancelled")
|
||||
|
||||
return success(msg="工作流执行已取消")
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||
return fail(code=e.code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"取消工作流执行失败: {str(e)}"
|
||||
)
|
||||
162
api/app/core/agent/agent_middleware.py
Normal file
162
api/app/core/agent/agent_middleware.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Agent Middleware - 动态技能过滤"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
|
||||
from app.services.skill_service import SkillService
|
||||
from app.repositories.skill_repository import SkillRepository
|
||||
|
||||
|
||||
class AgentMiddleware:
|
||||
"""Agent 中间件 - 用于动态过滤和加载技能"""
|
||||
|
||||
def __init__(self, skills: Optional[dict] = None):
|
||||
"""
|
||||
初始化中间件
|
||||
|
||||
Args:
|
||||
skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]}
|
||||
"""
|
||||
self.skills = skills or {}
|
||||
self.enabled = self.skills.get('enabled', False)
|
||||
self.all_skills = self.skills.get('all_skills', False)
|
||||
self.skill_ids = self.skills.get('skill_ids', [])
|
||||
|
||||
@staticmethod
|
||||
def filter_tools(
|
||||
tools: List,
|
||||
message: str = "",
|
||||
skill_configs: Dict[str, Any] = None,
|
||||
tool_to_skill_map: Dict[str, str] = None
|
||||
) -> tuple[List, List[str]]:
|
||||
"""
|
||||
根据消息内容和技能配置动态过滤工具
|
||||
|
||||
Args:
|
||||
tools: 所有可用工具列表
|
||||
message: 用户消息(可用于智能过滤)
|
||||
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
|
||||
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
|
||||
|
||||
Returns:
|
||||
(过滤后的工具列表, 激活的技能ID列表)
|
||||
"""
|
||||
if not tools:
|
||||
return [], []
|
||||
|
||||
# 如果没有技能配置,返回所有工具
|
||||
if not skill_configs:
|
||||
return tools, []
|
||||
|
||||
# 基于关键词匹配激活技能
|
||||
activated_skill_ids = []
|
||||
message_lower = message.lower()
|
||||
|
||||
for skill_id, config in skill_configs.items():
|
||||
if not config.get('enabled', True):
|
||||
continue
|
||||
|
||||
keywords = config.get('keywords', [])
|
||||
# 如果没有关键词限制,或消息包含关键词,则激活该技能
|
||||
if not keywords or any(kw.lower() in message_lower for kw in keywords):
|
||||
activated_skill_ids.append(skill_id)
|
||||
|
||||
# 如果没有工具映射关系,返回所有工具
|
||||
if not tool_to_skill_map:
|
||||
return tools, activated_skill_ids
|
||||
|
||||
# 根据激活的技能过滤工具
|
||||
filtered_tools = []
|
||||
for tool in tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
# 如果工具不属于任何skill(base_tools),或者工具所属的skill被激活,则保留
|
||||
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
|
||||
filtered_tools.append(tool)
|
||||
|
||||
return filtered_tools, activated_skill_ids
|
||||
|
||||
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
加载技能关联的工具
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
tenant_id: 租户id
|
||||
base_tools: 基础工具列表
|
||||
|
||||
Returns:
|
||||
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
|
||||
"""
|
||||
|
||||
tools_dict = {}
|
||||
tool_to_skill_map = {} # 工具名称到技能ID的映射
|
||||
|
||||
if base_tools:
|
||||
for tool in base_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
tools_dict[tool_name] = tool
|
||||
# base_tools 不属于任何 skill,不加入映射
|
||||
|
||||
skill_configs = {}
|
||||
skill_ids_to_load = []
|
||||
|
||||
# 如果启用技能且 all_skills 为 True,加载租户下所有激活的技能
|
||||
if self.enabled and self.all_skills:
|
||||
skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000)
|
||||
skill_ids_to_load = [str(skill.id) for skill in skills]
|
||||
elif self.enabled and self.skill_ids:
|
||||
skill_ids_to_load = self.skill_ids
|
||||
|
||||
if skill_ids_to_load:
|
||||
for skill_id in skill_ids_to_load:
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
||||
if skill and skill.is_active:
|
||||
# 保存技能配置(包含prompt)
|
||||
config = skill.config or {}
|
||||
config['prompt'] = skill.prompt
|
||||
config['name'] = skill.name
|
||||
skill_configs[skill_id] = config
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 加载技能工具并获取映射关系
|
||||
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id)
|
||||
|
||||
# 只添加不冲突的 skill_tools
|
||||
for tool in skill_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
if tool_name not in tools_dict:
|
||||
tools_dict[tool_name] = tool
|
||||
# 复制映射关系
|
||||
if tool_name in skill_tool_map:
|
||||
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
|
||||
|
||||
return list(tools_dict.values()), skill_configs, tool_to_skill_map
|
||||
|
||||
@staticmethod
|
||||
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
根据激活的技能ID获取对应的提示词
|
||||
|
||||
Args:
|
||||
activated_skill_ids: 被激活的技能ID列表
|
||||
skill_configs: 技能配置字典
|
||||
|
||||
Returns:
|
||||
合并后的提示词
|
||||
"""
|
||||
prompts = []
|
||||
for skill_id in activated_skill_ids:
|
||||
config = skill_configs.get(skill_id, {})
|
||||
prompt = config.get('prompt')
|
||||
name = config.get('name', 'Skill')
|
||||
if prompt:
|
||||
prompts.append(f"# {name}\n{prompt}")
|
||||
|
||||
return "\n\n".join(prompts) if prompts else ""
|
||||
|
||||
@staticmethod
|
||||
def create_runnable():
|
||||
"""创建可运行的中间件"""
|
||||
return RunnablePassthrough()
|
||||
@@ -289,10 +289,9 @@ class LangChainAgent:
|
||||
|
||||
return content_parts
|
||||
|
||||
return messages
|
||||
|
||||
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
||||
db = next(get_db())
|
||||
#TODO: 魔法数字
|
||||
scope=6
|
||||
|
||||
try:
|
||||
@@ -302,6 +301,12 @@ class LangChainAgent:
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
|
||||
# Handle case where no session exists in Redis (returns False)
|
||||
if not result or result is False:
|
||||
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
|
||||
return
|
||||
|
||||
if type=="chunk" or type=="aggregate":
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
@@ -309,7 +314,14 @@ class LangChainAgent:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'写入短长期:')
|
||||
else:
|
||||
# TODO: This branch handles type="time" strategy, currently unused.
|
||||
# Will be activated when time-based long-term storage is implemented.
|
||||
# TODO: 魔法数字 - extract 5 to a constant
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||
# Handle case where no session exists in Redis (returns False or empty)
|
||||
if not long_time_data or long_time_data is False:
|
||||
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
|
||||
return
|
||||
long_messages = await messages_parse(long_time_data)
|
||||
repo.upsert(end_user_id, long_messages)
|
||||
logger.info(f'写入短长期:')
|
||||
@@ -509,13 +521,13 @@ class LangChainAgent:
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
long_term_messages=await agent_chat_messages(message_chat,content)
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
'''长期'''
|
||||
if actual_config_id:
|
||||
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
|
||||
else:
|
||||
logger.warning(f"Skipping term_memory_save: no memory config available for end_user {end_user_id}")
|
||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
||||
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -698,13 +710,14 @@ class LangChainAgent:
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
||||
long_term_messages = await agent_chat_messages(message_chat, full_content)
|
||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||
if actual_config_id:
|
||||
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
|
||||
else:
|
||||
logger.warning(f"Skipping term_memory_save: no memory config available for end_user {end_user_id}")
|
||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
||||
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -215,6 +215,9 @@ class Settings:
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
||||
|
||||
# model square loading
|
||||
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
# TODO:删除
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
print(contents)
|
||||
@@ -60,6 +61,7 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
redis_messages = []
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
@@ -91,6 +93,9 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
# Handle case where no session exists in Redis (returns False or empty)
|
||||
if not long_time_data or long_time_data is False:
|
||||
return
|
||||
format_messages = await chat_data_format(long_time_data)
|
||||
if format_messages!=[]:
|
||||
await write_messages(end_user_id, format_messages, memory_config)
|
||||
@@ -108,8 +113,9 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
|
||||
# Handle case where no session exists in Redis (returns False or empty)
|
||||
if not result or result is False:
|
||||
history = []
|
||||
else:
|
||||
history = await format_parsing(result)
|
||||
|
||||
@@ -1,24 +1,21 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph():
|
||||
"""
|
||||
@@ -39,29 +36,59 @@ async def make_write_graph():
|
||||
graph = workflow.compile()
|
||||
|
||||
yield graph
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
|
||||
|
||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
||||
end_user_id: str = '', scope: int = 6):
|
||||
"""Dispatch long-term memory storage to Celery background tasks.
|
||||
|
||||
Args:
|
||||
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
|
||||
langchain_messages: List of messages to store
|
||||
memory_config: Memory configuration ID (string)
|
||||
end_user_id: End user identifier
|
||||
scope: Window size for 'chunk' strategy (default: 6)
|
||||
"""
|
||||
from app.tasks import (
|
||||
long_term_storage_window_task,
|
||||
# TODO: Uncomment when implemented
|
||||
# long_term_storage_time_task,
|
||||
# long_term_storage_aggregate_task,
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
"""方案三:聚合判断"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Convert config to string if needed
|
||||
config_id = str(memory_config) if memory_config else ''
|
||||
|
||||
if long_term_type == 'chunk':
|
||||
# Strategy 1: Window-based batching (6 rounds of dialogue)
|
||||
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
|
||||
long_term_storage_window_task.delay(
|
||||
end_user_id=end_user_id,
|
||||
langchain_messages=langchain_messages,
|
||||
config_id=config_id,
|
||||
scope=scope
|
||||
)
|
||||
# TODO: Uncomment when time-based strategy is fully implemented
|
||||
# elif long_term_type == 'time':
|
||||
# # Strategy 2: Time-based retrieval
|
||||
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
|
||||
# long_term_storage_time_task.delay(
|
||||
# end_user_id=end_user_id,
|
||||
# config_id=config_id,
|
||||
# time_window=5
|
||||
# )
|
||||
# TODO: Uncomment when aggregate strategy is fully implemented
|
||||
# elif long_term_type == 'aggregate':
|
||||
# # Strategy 3: Aggregate judgment (deduplication)
|
||||
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
|
||||
# long_term_storage_aggregate_task.delay(
|
||||
# end_user_id=end_user_id,
|
||||
# langchain_messages=langchain_messages,
|
||||
# config_id=config_id
|
||||
# )
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
|
||||
@@ -174,4 +174,4 @@ async def write(
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
@@ -1,5 +1,4 @@
|
||||
provider: bedrock
|
||||
enabled: false
|
||||
models:
|
||||
- name: ai21
|
||||
type: llm
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
provider: dashscope
|
||||
enabled: false
|
||||
models:
|
||||
- name: deepseek-r1-distill-qwen-14b
|
||||
type: llm
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.models_model import ModelBase, ModelProvider
|
||||
|
||||
|
||||
@@ -19,31 +19,9 @@ def _load_yaml_config(provider: ModelProvider) -> list[dict]:
|
||||
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
# 检查是否需要加载(默认为 true)
|
||||
if not data.get('enabled', True):
|
||||
return []
|
||||
|
||||
return data.get('models', [])
|
||||
|
||||
|
||||
def _disable_yaml_config(provider: ModelProvider) -> None:
|
||||
"""将YAML文件的enabled标志设置为false"""
|
||||
config_dir = Path(__file__).parent
|
||||
config_file = config_dir / f"{provider.value}_models.yaml"
|
||||
|
||||
if not config_file.exists():
|
||||
return
|
||||
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
data['enabled'] = False
|
||||
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
|
||||
|
||||
|
||||
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
|
||||
"""
|
||||
加载模型配置到数据库
|
||||
@@ -75,8 +53,7 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
|
||||
if not silent:
|
||||
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
||||
|
||||
# provider_success = 0
|
||||
|
||||
for model_data in models:
|
||||
try:
|
||||
# 检查模型是否已存在
|
||||
@@ -93,7 +70,6 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
if not silent:
|
||||
print(f"更新成功: {model_data['name']}")
|
||||
result["success"] += 1
|
||||
# provider_success += 1
|
||||
else:
|
||||
# 创建新模型
|
||||
model = ModelBase(**model_data)
|
||||
@@ -102,17 +78,12 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
if not silent:
|
||||
print(f"添加成功: {model_data['name']}")
|
||||
result["success"] += 1
|
||||
# provider_success += 1
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
if not silent:
|
||||
print(f"添加失败: {model_data['name']} - {str(e)}")
|
||||
result["failed"] += 1
|
||||
|
||||
# 如果该供应商的模型全部加载成功,将enabled设置为false
|
||||
# if provider_success == len(models):
|
||||
_disable_yaml_config(provider)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
provider: openai
|
||||
enabled: false
|
||||
models:
|
||||
- name: chatgpt-4o-latest
|
||||
type: llm
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,3 @@
|
||||
"""
|
||||
安全的表达式求值器
|
||||
|
||||
使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
@@ -14,160 +8,119 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExpressionEvaluator:
|
||||
"""安全的表达式求值器"""
|
||||
"""Safe expression evaluator for workflow variables and node outputs."""
|
||||
|
||||
# 保留的命名空间
|
||||
# Reserved namespaces
|
||||
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
||||
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""安全地评估表达式
|
||||
|
||||
Args:
|
||||
expression: 表达式字符串,如 "{{var.score}} > 0.8"
|
||||
variables: 用户定义的变量
|
||||
node_outputs: 节点输出结果
|
||||
system_vars: 系统变量
|
||||
|
||||
Returns:
|
||||
表达式求值结果
|
||||
|
||||
Raises:
|
||||
ValueError: 表达式无效或求值失败
|
||||
|
||||
Examples:
|
||||
>>> evaluator = ExpressionEvaluator()
|
||||
>>> evaluator.evaluate(
|
||||
... "var.score > 0.8",
|
||||
... {"score": 0.9},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
|
||||
>>> evaluator.evaluate(
|
||||
... "node.intent.output == '售前咨询'",
|
||||
... {},
|
||||
... {"intent": {"output": "售前咨询"}},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
"""
|
||||
# 移除 Jinja2 模板语法的花括号(如果存在)
|
||||
Safely evaluate an expression using workflow variables.
|
||||
|
||||
Args:
|
||||
expression (str): The expression string, e.g., "var.score > 0.8"
|
||||
conv_vars (dict): Conversation-level variables
|
||||
node_outputs (dict): Outputs from workflow nodes
|
||||
system_vars (dict, optional): System variables
|
||||
|
||||
Returns:
|
||||
Any: Result of the evaluated expression
|
||||
|
||||
Raises:
|
||||
ValueError: If the expression is invalid or evaluation fails
|
||||
"""
|
||||
# Remove Jinja2-style brackets if present
|
||||
expression = expression.strip()
|
||||
# "{{system.message}} == {{ user.messge }}" -> "system.message == user.message"
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
expression = re.sub(pattern, r"\1", expression).strip()
|
||||
|
||||
# 构建命名空间上下文
|
||||
# Build context for evaluation
|
||||
context = {
|
||||
"var": variables, # 用户变量
|
||||
"node": node_outputs, # 节点输出
|
||||
"sys": system_vars or {}, # 系统变量
|
||||
"conv": conv_vars, # conversation variables
|
||||
"node": node_outputs, # node outputs
|
||||
"sys": system_vars or {}, # system variables
|
||||
}
|
||||
|
||||
# 为了向后兼容,也支持直接访问(但会在日志中警告)
|
||||
context.update(variables)
|
||||
|
||||
context.update(conv_vars)
|
||||
context["nodes"] = node_outputs
|
||||
context.update(node_outputs)
|
||||
|
||||
try:
|
||||
# simpleeval 只支持安全的操作:
|
||||
# - 算术运算: +, -, *, /, //, %, **
|
||||
# - 比较运算: ==, !=, <, <=, >, >=
|
||||
# - 逻辑运算: and, or, not
|
||||
# - 成员运算: in, not in
|
||||
# - 属性访问: obj.attr
|
||||
# - 字典/列表访问: obj["key"], obj[0]
|
||||
# 不支持:函数调用、导入、赋值等危险操作
|
||||
# simpleeval supports safe operations:
|
||||
# arithmetic, comparisons, logical ops, attribute/dict/list access
|
||||
result = simple_eval(expression, names=context)
|
||||
return result
|
||||
|
||||
except NameNotDefined as e:
|
||||
logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}")
|
||||
raise ValueError(f"未定义的变量: {e}")
|
||||
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
|
||||
raise ValueError(f"Undefined variable: {e}")
|
||||
|
||||
except InvalidExpression as e:
|
||||
logger.error(f"表达式语法无效: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式语法无效: {e}")
|
||||
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
|
||||
raise ValueError(f"Invalid expression syntax: {e}")
|
||||
|
||||
except SyntaxError as e:
|
||||
logger.error(f"表达式语法错误: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式语法错误: {e}")
|
||||
logger.error(f"Syntax error in expression: {expression}, error: {e}")
|
||||
raise ValueError(f"Syntax error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"表达式求值异常: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式求值失败: {e}")
|
||||
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
|
||||
raise ValueError(f"Expression evaluation failed: {e}")
|
||||
|
||||
@staticmethod
|
||||
def evaluate_bool(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""评估布尔表达式(用于条件判断)
|
||||
|
||||
"""
|
||||
Evaluate a boolean expression (for conditions).
|
||||
|
||||
Args:
|
||||
expression: 布尔表达式
|
||||
variables: 用户变量
|
||||
node_outputs: 节点输出
|
||||
system_vars: 系统变量
|
||||
|
||||
expression (str): Boolean expression
|
||||
conv_var (dict): Conversation variables
|
||||
node_outputs (dict): Node outputs
|
||||
system_vars (dict, optional): System variables
|
||||
|
||||
Returns:
|
||||
布尔值结果
|
||||
|
||||
Examples:
|
||||
>>> ExpressionEvaluator.evaluate_bool(
|
||||
... "var.count >= 10 and var.status == 'active'",
|
||||
... {"count": 15, "status": "active"},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
bool: Boolean result
|
||||
"""
|
||||
result = ExpressionEvaluator.evaluate(
|
||||
expression, variables, node_outputs, system_vars
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
@staticmethod
|
||||
def validate_variable_names(variables: list[dict]) -> list[str]:
|
||||
"""验证变量名是否合法
|
||||
|
||||
"""
|
||||
Validate variable names for legality.
|
||||
|
||||
Args:
|
||||
variables: 变量定义列表
|
||||
|
||||
variables (list[dict]): List of variable definitions
|
||||
|
||||
Returns:
|
||||
错误列表,如果为空则验证通过
|
||||
|
||||
Examples:
|
||||
>>> ExpressionEvaluator.validate_variable_names([
|
||||
... {"name": "user_input"},
|
||||
... {"name": "var"} # 保留字
|
||||
... ])
|
||||
["变量名 'var' 是保留的命名空间,请使用其他名称"]
|
||||
list[str]: List of error messages. Empty if all names are valid.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for var in variables:
|
||||
var_name = var.get("name", "")
|
||||
|
||||
# 检查是否为保留命名空间
|
||||
|
||||
if var_name in ExpressionEvaluator.RESERVED_NAMESPACES:
|
||||
errors.append(
|
||||
f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称"
|
||||
f"Variable name '{var_name}' is a reserved namespace, please use another name"
|
||||
)
|
||||
|
||||
# 检查是否为有效的 Python 标识符
|
||||
|
||||
if not var_name.isidentifier():
|
||||
errors.append(
|
||||
f"变量名 '{var_name}' 不是有效的标识符"
|
||||
f"Variable name '{var_name}' is not a valid Python identifier"
|
||||
)
|
||||
|
||||
return errors
|
||||
@@ -176,23 +129,23 @@ class ExpressionEvaluator:
|
||||
# 便捷函数
|
||||
def evaluate_expression(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
system_vars: dict[str, Any]
|
||||
) -> Any:
|
||||
"""评估表达式(便捷函数)"""
|
||||
"""Evaluate an expression (convenience function)."""
|
||||
return ExpressionEvaluator.evaluate(
|
||||
expression, variables, node_outputs, system_vars
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
)
|
||||
|
||||
|
||||
def evaluate_condition(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""评估条件表达式(便捷函数)"""
|
||||
"""Evaluate a boolean condition expression (convenience function)."""
|
||||
return ExpressionEvaluator.evaluate_bool(
|
||||
expression, variables, node_outputs, system_vars
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
)
|
||||
|
||||
@@ -14,9 +14,14 @@ from pydantic import BaseModel, Field
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
|
||||
)
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
"""
|
||||
@@ -53,6 +58,12 @@ class OutputContent(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
_SCOPE: str | None = None
|
||||
|
||||
def get_scope(self) -> str:
|
||||
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
|
||||
return self._SCOPE
|
||||
|
||||
def depends_on_scope(self, scope: str) -> bool:
|
||||
"""
|
||||
Check if this segment depends on a given scope.
|
||||
@@ -63,8 +74,9 @@ class OutputContent(BaseModel):
|
||||
Returns:
|
||||
bool: True if this segment references the given scope.
|
||||
"""
|
||||
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
||||
return bool(re.search(pattern, self.literal))
|
||||
if self._SCOPE:
|
||||
return self._SCOPE == scope
|
||||
return self.get_scope() == scope
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
@@ -167,6 +179,7 @@ class GraphBuilder:
|
||||
workflow_config: dict[str, Any],
|
||||
stream: bool = False,
|
||||
subgraph: bool = False,
|
||||
variable_pool: VariablePool | None = None
|
||||
):
|
||||
self.workflow_config = workflow_config
|
||||
|
||||
@@ -180,6 +193,10 @@ class GraphBuilder:
|
||||
self._find_upstream_branch_node = lru_cache(
|
||||
maxsize=len(self.nodes) * 2
|
||||
)(self._find_upstream_branch_node)
|
||||
if variable_pool:
|
||||
self.variable_pool = variable_pool
|
||||
else:
|
||||
self.variable_pool = VariablePool()
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
@@ -452,9 +469,9 @@ class GraphBuilder:
|
||||
if self.stream:
|
||||
# Stream mode: create an async generator function
|
||||
# LangGraph collects all yielded values; the last yielded dictionary is merged into the state
|
||||
def make_stream_func(inst):
|
||||
def make_stream_func(inst, variable_pool=self.variable_pool):
|
||||
async def node_func(state: WorkflowState):
|
||||
async for item in inst.run_stream(state):
|
||||
async for item in inst.run_stream(state, variable_pool):
|
||||
yield item
|
||||
|
||||
return node_func
|
||||
@@ -462,9 +479,9 @@ class GraphBuilder:
|
||||
self.graph.add_node(node_id, make_stream_func(node_instance))
|
||||
else:
|
||||
# Non-stream mode: create an async function
|
||||
def make_func(inst):
|
||||
def make_func(inst, variable_pool=self.variable_pool):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
return await inst.run(state, variable_pool)
|
||||
|
||||
return node_func
|
||||
|
||||
@@ -567,27 +584,28 @@ class GraphBuilder:
|
||||
for target in branch_info["target"]:
|
||||
waiting_edges[target].append(branch_info["node"]["name"])
|
||||
|
||||
def router_fn(state: WorkflowState) -> list[Send]:
|
||||
def router_fn(state: WorkflowState, variable_pool: VariablePool = self.variable_pool) -> list[Send]:
|
||||
branch_activate = []
|
||||
new_state = state.copy()
|
||||
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
|
||||
|
||||
node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False)
|
||||
for label, branch in unique_branch.items():
|
||||
if evaluate_condition(
|
||||
if node_output and evaluate_condition(
|
||||
branch["condition"],
|
||||
state.get("variables", {}),
|
||||
state.get("runtime_vars", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
{},
|
||||
{src: node_output},
|
||||
{}
|
||||
):
|
||||
logger.debug(f"Conditional routing {src}: selected branch {label}")
|
||||
new_state["activate"][branch["node"]["name"]] = True
|
||||
branch_activate.append(
|
||||
Send(
|
||||
branch['node']['name'],
|
||||
new_state
|
||||
)
|
||||
)
|
||||
continue
|
||||
new_state["activate"][branch["node"]["name"]] = False
|
||||
for label, branch in unique_branch.items():
|
||||
branch_activate.append(
|
||||
Send(
|
||||
branch['node']['name'],
|
||||
|
||||
@@ -15,7 +15,6 @@ from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
@@ -25,7 +24,6 @@ __all__ = [
|
||||
"WorkflowState",
|
||||
"LLMNode",
|
||||
"AgentNode",
|
||||
"TransformNode",
|
||||
"IfElseNode",
|
||||
"StartNode",
|
||||
"EndNode",
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class AgentNodeConfig(BaseNodeConfig):
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Agent 节点实现
|
||||
|
||||
调用已发布的 Agent 应用。
|
||||
# TODO
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -9,6 +10,8 @@ from typing import Any
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.models import AppRelease
|
||||
from app.db import get_db
|
||||
@@ -30,19 +33,22 @@ class AgentNode(BaseNode):
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]:
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
|
||||
"""准备 Agent(公共逻辑)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
(draft_service, release, message): 服务实例、发布配置、消息
|
||||
"""
|
||||
# 1. 渲染消息
|
||||
message_template = self.config.get("message", "")
|
||||
message = self._render_template(message_template, state)
|
||||
message = self._render_template(message_template, variable_pool)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
agent_id = self.config.get("agent_id")
|
||||
@@ -61,16 +67,17 @@ class AgentNode(BaseNode):
|
||||
|
||||
return draft_service, release, message
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""非流式执行
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
状态更新字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(state)
|
||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
||||
|
||||
@@ -79,9 +86,9 @@ class AgentNode(BaseNode):
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=state.get("workspace_id"),
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=state.get("variables", {})
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
)
|
||||
|
||||
response = result.get("response", "")
|
||||
@@ -99,16 +106,17 @@ class AgentNode(BaseNode):
|
||||
}
|
||||
}
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
"""流式执行
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Yields:
|
||||
流式事件字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(state)
|
||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
||||
|
||||
@@ -120,9 +128,9 @@ class AgentNode(BaseNode):
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=state.get("workspace_id"),
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=state.get("variables", {})
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
):
|
||||
# 提取内容
|
||||
content = chunk.get("content", "")
|
||||
|
||||
@@ -6,6 +6,7 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,13 +18,17 @@ class AssignerNode(BaseNode):
|
||||
self.variable_updater = True
|
||||
self.typed_config: AssignerNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the assignment operation defined by this node.
|
||||
|
||||
Args:
|
||||
state: The current workflow state, including conversation variables,
|
||||
node outputs, and system variables.
|
||||
variable_pool: variable pool
|
||||
|
||||
Returns:
|
||||
None or the result of the assignment operation.
|
||||
@@ -31,60 +36,57 @@ class AssignerNode(BaseNode):
|
||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} 开始执行")
|
||||
pool = VariablePool(state)
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
|
||||
for assignment in self.typed_config.assignments:
|
||||
# Get the target variable selector (e.g., "conv.test")
|
||||
variable_selector = assignment.variable_selector
|
||||
if isinstance(variable_selector, str):
|
||||
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
expression = re.sub(pattern, r"\1", variable_selector).strip()
|
||||
variable_selector = expression.split('.')
|
||||
namespace = re.sub(pattern, r"\1", variable_selector).split('.')[0]
|
||||
|
||||
# Only conversation variables ('conv') are allowed
|
||||
if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]:
|
||||
raise ValueError("Only conversation or cycle variables can be assigned.")
|
||||
if namespace != 'conv' and namespace not in state["cycle_nodes"]:
|
||||
raise ValueError(f"Only conversation or cycle variables can be assigned. - {variable_selector}")
|
||||
|
||||
# Get the value or expression to assign
|
||||
value = assignment.value
|
||||
logger.debug(f"left:{variable_selector}, right: {value}")
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
|
||||
if isinstance(value, str):
|
||||
expression = re.match(pattern, value)
|
||||
if expression:
|
||||
expression = expression.group(1)
|
||||
expression = re.sub(pattern, r"\1", expression).strip()
|
||||
value = self.get_variable(expression, state)
|
||||
value = self.get_variable(expression, variable_pool, default=value, strict=False)
|
||||
|
||||
# Select the appropriate assignment operator instance based on the target variable type
|
||||
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
|
||||
pool.get(variable_selector)
|
||||
variable_pool.get_value(variable_selector)
|
||||
)(
|
||||
pool, variable_selector, value
|
||||
variable_pool, variable_selector, value
|
||||
)
|
||||
|
||||
# Execute the configured assignment operation
|
||||
match assignment.operation:
|
||||
case AssignmentOperator.COVER:
|
||||
operator.assign()
|
||||
await operator.assign()
|
||||
case AssignmentOperator.ASSIGN:
|
||||
operator.assign()
|
||||
await operator.assign()
|
||||
case AssignmentOperator.CLEAR:
|
||||
operator.clear()
|
||||
await operator.clear()
|
||||
case AssignmentOperator.ADD:
|
||||
operator.add()
|
||||
await operator.add()
|
||||
case AssignmentOperator.SUBTRACT:
|
||||
operator.subtract()
|
||||
await operator.subtract()
|
||||
case AssignmentOperator.MULTIPLY:
|
||||
operator.multiply()
|
||||
await operator.multiply()
|
||||
case AssignmentOperator.DIVIDE:
|
||||
operator.divide()
|
||||
await operator.divide()
|
||||
case AssignmentOperator.APPEND:
|
||||
operator.append()
|
||||
await operator.append()
|
||||
case AssignmentOperator.REMOVE_FIRST:
|
||||
operator.remove_first()
|
||||
await operator.remove_first()
|
||||
case AssignmentOperator.REMOVE_LAST:
|
||||
operator.remove_last()
|
||||
await operator.remove_last()
|
||||
case _:
|
||||
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||
logger.info(f"Node {self.node_id}: execution completed")
|
||||
|
||||
@@ -3,79 +3,13 @@
|
||||
定义所有节点配置的通用字段和数据结构。
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
VARIABLE_PATTERN = r"\{\{\s*(.*?)\s*\}\}"
|
||||
|
||||
|
||||
class VariableType(StrEnum):
|
||||
"""变量类型枚举"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
OBJECT = "object"
|
||||
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
|
||||
|
||||
class TypedVariable(BaseModel):
|
||||
"""
|
||||
TODO: 强类型限制
|
||||
Strongly typed variable that validates value on assignment.
|
||||
"""
|
||||
|
||||
value: Any = Field(..., description="Variable value")
|
||||
type: VariableType = Field(..., description="Declared type of the variable")
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True
|
||||
)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name == "value":
|
||||
self._validate_value(value)
|
||||
if name == "type":
|
||||
raise RuntimeError("Cannot modify variable type at runtime")
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def _validate_value(self, v: Any):
|
||||
t = self.type
|
||||
match t:
|
||||
case VariableType.STRING:
|
||||
if not isinstance(v, str):
|
||||
raise TypeError("Variable value does not match type STRING")
|
||||
case VariableType.BOOLEAN:
|
||||
if not isinstance(v, bool):
|
||||
raise TypeError("Variable value does not match type BOOLEAN")
|
||||
case VariableType.NUMBER:
|
||||
if not isinstance(v, (int, float)):
|
||||
raise TypeError("Variable value does not match type NUMBER")
|
||||
case VariableType.OBJECT:
|
||||
if not isinstance(v, dict):
|
||||
raise TypeError("Variable value does not match type OBJECT")
|
||||
case VariableType.ARRAY_STRING:
|
||||
if not isinstance(v, list) or not all(isinstance(i, str) for i in v):
|
||||
raise TypeError("Variable value does not match type ARRAY_STRING")
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
if not isinstance(v, list) or not all(isinstance(i, (int, float)) for i in v):
|
||||
raise TypeError("Variable value does not match type ARRAY_NUMBER")
|
||||
case VariableType.ARRAY_BOOLEAN:
|
||||
if not isinstance(v, list) or not all(isinstance(i, bool) for i in v):
|
||||
raise TypeError("Variable value does not match type ARRAY_BOOLEAN")
|
||||
case VariableType.ARRAY_OBJECT:
|
||||
if not isinstance(v, list) or not all(isinstance(i, dict) for i in v):
|
||||
raise TypeError("Variable value does not match type ARRAY_OBJECT")
|
||||
case _:
|
||||
raise TypeError(f"Unknown variable type: {t}")
|
||||
|
||||
|
||||
class VariableDefinition(BaseModel):
|
||||
"""变量定义
|
||||
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
"""
|
||||
工作流节点基类
|
||||
|
||||
定义节点的基本接口和通用功能。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langgraph.config import get_stream_writer
|
||||
@@ -14,7 +9,9 @@ from typing_extensions import TypedDict, Annotated
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.multimodal_service import PROVIDER_STRATEGIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,22 +39,10 @@ class WorkflowState(TypedDict):
|
||||
cycle_nodes: list
|
||||
looping: Annotated[int, merge_looping_state]
|
||||
|
||||
# Input variables (passed from configured variables)
|
||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||
variables: Annotated[dict[str, Any], lambda x, y: {
|
||||
**x,
|
||||
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
|
||||
for k, v in y.items()}
|
||||
}]
|
||||
|
||||
# Node outputs (stores execution results of each node for variable references)
|
||||
# Uses a custom merge function to combine new node outputs into the existing dictionary
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# Runtime node variables (simplified version, stores business data for fast access between nodes)
|
||||
# Format: {node_id: business_result}
|
||||
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# Execution context
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
@@ -72,17 +57,17 @@ class WorkflowState(TypedDict):
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""节点基类
|
||||
|
||||
所有节点类型都应该继承此基类,实现 execute 方法。
|
||||
"""Base class for workflow nodes.
|
||||
|
||||
All node types should inherit from this class and implement the `execute` method.
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
"""初始化节点
|
||||
|
||||
"""Initialize the node.
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
node_config: Configuration of the node.
|
||||
workflow_config: Configuration of the workflow.
|
||||
"""
|
||||
self.node_config = node_config
|
||||
self.workflow_config = workflow_config
|
||||
@@ -94,7 +79,27 @@ class BaseNode(ABC):
|
||||
self.config = node_config.get("config") or {}
|
||||
self.error_handling = node_config.get("error_handling") or {}
|
||||
|
||||
self.variable_updater = False
|
||||
self.variable_change_able = False
|
||||
|
||||
@cached_property
|
||||
def output_types(self) -> dict[str, VariableType]:
|
||||
"""Returns the output variable types of the node.
|
||||
|
||||
This property is cached to avoid recomputation.
|
||||
"""
|
||||
return self._output_types()
|
||||
|
||||
@abstractmethod
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
"""Defines output variable types for the node.
|
||||
|
||||
Subclasses must override this method to declare the variables
|
||||
produced by the node and their corresponding types.
|
||||
|
||||
Returns:
|
||||
A mapping from output variable names to ``VariableType``.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def check_activate(self, state: WorkflowState):
|
||||
"""Check if the current node is activated in the workflow state.
|
||||
@@ -136,92 +141,84 @@ class BaseNode(ABC):
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""执行节点业务逻辑(非流式)
|
||||
|
||||
节点只需要返回业务结果,不需要关心输出格式、时间统计等。
|
||||
BaseNode 会自动包装成标准格式。
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""Executes the node business logic (non-streaming).
|
||||
|
||||
The node implementation should only return the business result.
|
||||
It does not need to handle output formatting, timing, or statistics.
|
||||
The ``BaseNode`` will automatically wrap the result into a standard
|
||||
response format.
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
业务结果(任意类型)
|
||||
|
||||
Examples:
|
||||
>>> # LLM 节点
|
||||
>>> "这是 AI 的回复"
|
||||
|
||||
>>> # Transform 节点
|
||||
>>> {"processed_data": [...]}
|
||||
|
||||
>>> # Start/End 节点
|
||||
>>> {"message": "开始", "conversation_id": "xxx"}
|
||||
The business result produced by the node. The return value can be
|
||||
of any type.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""执行节点业务逻辑(流式)
|
||||
|
||||
子类可以重写此方法以支持流式输出。
|
||||
默认实现:执行非流式方法并一次性返回。
|
||||
|
||||
节点需要:
|
||||
1. yield 中间结果(如文本片段)
|
||||
2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result}
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
业务数据(chunk)或完成标记
|
||||
|
||||
Examples:
|
||||
# 流式 LLM 节点
|
||||
full_response = ""
|
||||
async for chunk in llm.astream(prompt):
|
||||
full_response += chunk
|
||||
yield chunk # yield 文本片段
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
"""Executes the node business logic in streaming mode.
|
||||
|
||||
# 最后 yield 完成标记
|
||||
yield {"__final__": True, "result": AIMessage(content=full_response)}
|
||||
Subclasses may override this method to support streaming output.
|
||||
The default implementation executes the non-streaming method and
|
||||
yields a single final result.
|
||||
|
||||
For streaming execution, a node implementation should:
|
||||
1. Yield intermediate results (e.g. text chunks).
|
||||
2. Yield a final completion marker in the following format:
|
||||
``{"__final__": True, "result": final_result}``.
|
||||
|
||||
Args:
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Yields:
|
||||
Business data chunks or a final completion marker.
|
||||
"""
|
||||
result = await self.execute(state)
|
||||
# 默认实现:直接 yield 完成标记
|
||||
result = await self.execute(state, variable_pool)
|
||||
# Default implementation: yield a single final completion marker.
|
||||
yield {"__final__": True, "result": result}
|
||||
|
||||
def supports_streaming(self) -> bool:
|
||||
"""节点是否支持流式输出
|
||||
|
||||
"""Returns whether the node supports streaming output.
|
||||
|
||||
A node is considered to support streaming if its class overrides
|
||||
the ``execute_stream`` method. If the default implementation from
|
||||
``BaseNode`` is used, streaming is not supported.
|
||||
|
||||
Returns:
|
||||
是否支持流式输出
|
||||
True if the node supports streaming output, False otherwise.
|
||||
"""
|
||||
# 检查子类是否重写了 execute_stream 方法
|
||||
# Check whether the subclass overrides the execute_stream method.
|
||||
return self.__class__.execute_stream is not BaseNode.execute_stream
|
||||
|
||||
def get_timeout(self) -> int:
|
||||
"""获取超时时间(秒)
|
||||
|
||||
@staticmethod
|
||||
def get_timeout() -> int:
|
||||
"""Returns the execution timeout in seconds.
|
||||
|
||||
Returns:
|
||||
超时时间
|
||||
The timeout duration, in seconds.
|
||||
"""
|
||||
return settings.WORKFLOW_NODE_TIMEOUT
|
||||
# return self.error_handling.get("timeout", 60)
|
||||
|
||||
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行节点(带错误处理和输出包装,非流式)
|
||||
|
||||
这个方法由 Executor 调用,负责:
|
||||
1. 时间统计
|
||||
2. 调用节点的 execute() 方法
|
||||
3. 将业务结果包装成标准输出格式
|
||||
4. 错误处理
|
||||
|
||||
async def run(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""Runs the node with error handling and output wrapping (non-streaming).
|
||||
|
||||
This method is invoked by the Executor and is responsible for:
|
||||
1. Execution time measurement.
|
||||
2. Invoking the node's ``execute()`` method.
|
||||
3. Wrapping the business result into a standardized output format.
|
||||
4. Handling execution errors.
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
A standardized state update dictionary.
|
||||
"""
|
||||
if not self.check_activate(state):
|
||||
return self.trans_activate(state)
|
||||
@@ -233,70 +230,78 @@ class BaseNode(ABC):
|
||||
timeout = self.get_timeout()
|
||||
|
||||
try:
|
||||
# 调用节点的业务逻辑
|
||||
# Invoke the node business logic.
|
||||
business_result = await asyncio.wait_for(
|
||||
self.execute(state),
|
||||
self.execute(state, variable_pool),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取处理后的输出(调用子类的 _extract_output)
|
||||
# Extract processed outputs using subclass-defined logic.
|
||||
extracted_output = self._extract_output(business_result)
|
||||
|
||||
# 包装成标准输出格式
|
||||
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
|
||||
# Wrap the business result into the standard output format.
|
||||
wrapped_output = self._wrap_output(business_result, elapsed_time, state, variable_pool)
|
||||
|
||||
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
|
||||
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
|
||||
if isinstance(extracted_output, dict):
|
||||
runtime_var = extracted_output
|
||||
else:
|
||||
runtime_var = {"output": extracted_output}
|
||||
# Store extracted outputs as runtime variables for downstream nodes.
|
||||
if extracted_output is not None:
|
||||
runtime_vars = extracted_output
|
||||
if not isinstance(extracted_output, dict):
|
||||
runtime_vars = {"output": extracted_output}
|
||||
for k, v in runtime_vars.items():
|
||||
await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able)
|
||||
|
||||
# 返回包装后的输出和运行时变量
|
||||
# Return the wrapped output along with activation state updates.
|
||||
return {
|
||||
**wrapped_output,
|
||||
"messages": state["messages"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
},
|
||||
"looping": state["looping"]
|
||||
} | self.trans_activate(state)
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||
return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
logger.error(
|
||||
f"Node {self.node_id} execution timed out ({timeout} seconds)."
|
||||
)
|
||||
return self._wrap_error(
|
||||
f"Node execution timed out ({timeout} seconds).",
|
||||
elapsed_time,
|
||||
state,
|
||||
variable_pool,
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
return self._wrap_error(str(e), elapsed_time, state)
|
||||
logger.error(
|
||||
f"Node {self.node_id} execution failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return self._wrap_error(str(e), elapsed_time, state, variable_pool)
|
||||
|
||||
async def run_stream(
|
||||
self, state: WorkflowState,
|
||||
variable_pool: VariablePool
|
||||
) -> AsyncGenerator[dict[str, Any], Any]:
|
||||
"""Executes the node with error handling and output wrapping (streaming).
|
||||
|
||||
async def run_stream(self, state: WorkflowState) -> AsyncGenerator[dict[str, Any], Any]:
|
||||
"""Execute node with error handling and output wrapping (streaming)
|
||||
|
||||
This method is called by the Executor and is responsible for:
|
||||
1. Time tracking
|
||||
2. Calling the node's execute_stream() method
|
||||
3. Using LangGraph's stream writer to send chunks
|
||||
4. Updating streaming buffer in state for downstream nodes
|
||||
5. Wrapping business data into standard output format
|
||||
6. Error handling
|
||||
|
||||
Special handling for End nodes:
|
||||
- End nodes don't send chunks via writer (prefix and LLM content already sent)
|
||||
- End nodes only yield suffix for final result assembly
|
||||
|
||||
1. Tracking execution time.
|
||||
2. Calling the node's ``execute_stream()`` method.
|
||||
3. Sending streaming chunks via LangGraph's stream writer.
|
||||
4. Updating activation-related state for downstream nodes.
|
||||
5. Wrapping business data into a standardized output format.
|
||||
6. Handling execution errors.
|
||||
|
||||
Args:
|
||||
state: Workflow state
|
||||
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Yields:
|
||||
State updates with streaming buffer and final result
|
||||
Incremental state updates, including activation state changes and
|
||||
the final wrapped result.
|
||||
"""
|
||||
if not self.check_activate(state):
|
||||
yield self.trans_activate(state)
|
||||
logger.info(f"jump node: {self.node_id}")
|
||||
logger.debug(f"jump node: {self.node_id}")
|
||||
return
|
||||
|
||||
import time
|
||||
@@ -317,7 +322,7 @@ class BaseNode(ABC):
|
||||
# Stream chunks in real-time
|
||||
loop_start = asyncio.get_event_loop().time()
|
||||
|
||||
async for item in self.execute_stream(state):
|
||||
async for item in self.execute_stream(state, variable_pool):
|
||||
# Check timeout
|
||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||
raise TimeoutError()
|
||||
@@ -332,7 +337,7 @@ class BaseNode(ABC):
|
||||
chunks.append(content)
|
||||
|
||||
# Send chunks for all nodes (including End nodes for suffix)
|
||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...")
|
||||
logger.debug(f"Node {self.node_id} sent chunk #{chunk_count}: {content[:50]}...")
|
||||
|
||||
# 1. Send via stream writer (for real-time client updates)
|
||||
writer({
|
||||
@@ -344,27 +349,26 @@ class BaseNode(ABC):
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
logger.info(f"Node {self.node_id} streaming execution finished, "
|
||||
f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
|
||||
# Extract processed output (call subclass's _extract_output)
|
||||
extracted_output = self._extract_output(final_result)
|
||||
|
||||
# Wrap final result
|
||||
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||
final_output = self._wrap_output(final_result, elapsed_time, state, variable_pool)
|
||||
|
||||
# Store extracted output in runtime variables (for quick access by subsequent nodes)
|
||||
if isinstance(extracted_output, dict):
|
||||
runtime_var = extracted_output
|
||||
else:
|
||||
runtime_var = {"output": extracted_output}
|
||||
if extracted_output is not None:
|
||||
runtime_vars = extracted_output
|
||||
if not isinstance(extracted_output, dict):
|
||||
runtime_vars = {"output": extracted_output}
|
||||
for k, v in runtime_vars.items():
|
||||
await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able)
|
||||
|
||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"messages": state["messages"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
},
|
||||
"looping": state["looping"]
|
||||
}
|
||||
|
||||
@@ -374,41 +378,49 @@ class BaseNode(ABC):
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
|
||||
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
|
||||
logger.error(f"Node {self.node_id} execution timed out ({timeout}s)")
|
||||
error_output = self._wrap_error(
|
||||
f"Node execution timed out ({timeout}s)",
|
||||
elapsed_time,
|
||||
state,
|
||||
variable_pool
|
||||
)
|
||||
yield error_output
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
error_output = self._wrap_error(str(e), elapsed_time, state)
|
||||
logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True)
|
||||
error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool)
|
||||
yield error_output
|
||||
|
||||
def _wrap_output(
|
||||
self,
|
||||
business_result: Any,
|
||||
elapsed_time: float,
|
||||
state: WorkflowState
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool
|
||||
) -> dict[str, Any]:
|
||||
"""将业务结果包装成标准输出格式
|
||||
|
||||
Args:
|
||||
business_result: 节点返回的业务结果
|
||||
elapsed_time: 执行耗时
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
"""
|
||||
# 提取输入数据(用于记录)
|
||||
input_data = self._extract_input(state)
|
||||
"""Wraps the business result into a standardized node output format.
|
||||
|
||||
# 提取 token 使用情况(如果有)
|
||||
Args:
|
||||
business_result: The result returned by the node's business logic.
|
||||
elapsed_time: Time elapsed during node execution (in seconds).
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the standardized state update for this node,
|
||||
including node outputs, input, output, elapsed time, token usage, and status.
|
||||
"""
|
||||
# Extract input data (for logging or audit purposes)
|
||||
input_data = self._extract_input(state, variable_pool)
|
||||
|
||||
# Extract token usage information (if applicable)
|
||||
token_usage = self._extract_token_usage(business_result)
|
||||
|
||||
# 提取实际输出(去除元数据)
|
||||
# Extract actual output (strip any metadata)
|
||||
output = self._extract_output(business_result)
|
||||
|
||||
# 构建标准节点输出
|
||||
# Construct standardized node output
|
||||
node_output = {
|
||||
"node_id": self.node_id,
|
||||
"node_type": self.node_type,
|
||||
@@ -423,8 +435,6 @@ class BaseNode(ABC):
|
||||
final_output = {
|
||||
"node_outputs": {self.node_id: node_output},
|
||||
}
|
||||
if self.variable_updater:
|
||||
final_output = final_output | {"variables": state["variables"]}
|
||||
|
||||
return final_output
|
||||
|
||||
@@ -432,25 +442,33 @@ class BaseNode(ABC):
|
||||
self,
|
||||
error_message: str,
|
||||
elapsed_time: float,
|
||||
state: WorkflowState
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool
|
||||
) -> dict[str, Any]:
|
||||
"""将错误包装成标准输出格式
|
||||
|
||||
"""Wraps an error into a standardized node output format.
|
||||
|
||||
This method handles both cases:
|
||||
- If an error edge is defined, the workflow can continue to the error handling node.
|
||||
- If no error edge exists, the workflow is stopped by raising an exception.
|
||||
|
||||
Args:
|
||||
error_message: 错误信息
|
||||
elapsed_time: 执行耗时
|
||||
state: 工作流状态
|
||||
|
||||
error_message: The error message describing the failure.
|
||||
elapsed_time: Time elapsed during node execution (in seconds).
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
A dictionary representing the standardized state update for this node
|
||||
when an error edge exists. If no error edge exists, this method
|
||||
raises an exception to stop the workflow.
|
||||
"""
|
||||
# 查找错误边
|
||||
# Check if the node has an error edge defined
|
||||
error_edge = self._find_error_edge()
|
||||
|
||||
# 提取输入数据
|
||||
input_data = self._extract_input(state)
|
||||
# Extract input data (for logging or audit purposes)
|
||||
input_data = self._extract_input(state, variable_pool)
|
||||
|
||||
# 构建错误输出
|
||||
# Construct the standardized node output for the error
|
||||
node_output = {
|
||||
"node_id": self.node_id,
|
||||
"node_type": self.node_type,
|
||||
@@ -464,9 +482,9 @@ class BaseNode(ABC):
|
||||
}
|
||||
|
||||
if error_edge:
|
||||
# 有错误边:记录错误并继续
|
||||
# If an error edge exists, log a warning and continue to error node
|
||||
logger.warning(
|
||||
f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}"
|
||||
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
||||
)
|
||||
return {
|
||||
"node_outputs": {
|
||||
@@ -476,198 +494,188 @@ class BaseNode(ABC):
|
||||
"error_node": self.node_id
|
||||
}
|
||||
else:
|
||||
# If no error edge, send the error via stream writer and stop the workflow
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "node_error",
|
||||
**node_output
|
||||
})
|
||||
# 无错误边:抛出异常停止工作流
|
||||
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
|
||||
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
|
||||
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
||||
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""Extracts the input data for this node (used for logging or audit).
|
||||
|
||||
Subclasses may override this method to customize what input data
|
||||
should be recorded.
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""提取节点输入数据(用于记录)
|
||||
|
||||
子类可以重写此方法来自定义输入记录。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
输入数据字典
|
||||
A dictionary containing the node's input data.
|
||||
"""
|
||||
# 默认返回配置
|
||||
# Default implementation returns the node configuration
|
||||
return {"config": self.config}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
"""从业务结果中提取实际输出
|
||||
|
||||
子类可以重写此方法来自定义输出提取。
|
||||
|
||||
"""Extracts the actual output from the business result.
|
||||
|
||||
Subclasses may override this method to customize how the node's
|
||||
output is extracted.
|
||||
|
||||
Args:
|
||||
business_result: 业务结果
|
||||
|
||||
business_result: The result returned by the node's business logic.
|
||||
|
||||
Returns:
|
||||
实际输出
|
||||
The actual output extracted from the business result.
|
||||
"""
|
||||
# 默认直接返回业务结果
|
||||
# Default implementation returns the business result directly
|
||||
return business_result
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""从业务结果中提取 token 使用情况
|
||||
|
||||
子类可以重写此方法来提取 token 信息。
|
||||
|
||||
"""Extracts token usage information from the business result.
|
||||
|
||||
Subclasses may override this method to extract token usage statistics
|
||||
(e.g., for LLM nodes).
|
||||
|
||||
Args:
|
||||
business_result: 业务结果
|
||||
|
||||
business_result: The result returned by the node's business logic.
|
||||
|
||||
Returns:
|
||||
token 使用情况或 None
|
||||
A dictionary mapping token types to counts, or None if not applicable.
|
||||
"""
|
||||
# 默认返回 None
|
||||
# Default implementation returns None
|
||||
return None
|
||||
|
||||
def _find_error_edge(self) -> dict[str, Any] | None:
|
||||
"""查找错误边
|
||||
|
||||
"""Finds the error edge for this node, if any.
|
||||
|
||||
An error edge is used to redirect workflow execution when this node
|
||||
fails.
|
||||
|
||||
Returns:
|
||||
错误边配置或 None
|
||||
A dictionary representing the error edge configuration if it exists,
|
||||
or None if no error edge is defined.
|
||||
"""
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("source") == self.node_id and edge.get("type") == "error":
|
||||
return edge
|
||||
return None
|
||||
|
||||
def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str:
|
||||
"""渲染模板
|
||||
|
||||
支持的变量命名空间:
|
||||
- sys.xxx: 系统变量(message, execution_id, workspace_id, user_id, conversation_id)
|
||||
- conv.xxx: 会话变量(跨多轮对话保持)
|
||||
- node_id.xxx: 节点输出
|
||||
|
||||
@staticmethod
|
||||
def _render_template(template: str, variable_pool: VariablePool, strict: bool = True) -> str:
|
||||
"""Renders a template string using the provided variable pool.
|
||||
|
||||
Supported variable namespaces:
|
||||
- sys.xxx: System variables (e.g., message, execution_id, workspace_id,
|
||||
user_id, conversation_id)
|
||||
- conv.xxx: Conversation variables (persist across multiple turns)
|
||||
- node_id.xxx: Node outputs
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
template: The template string to render.
|
||||
variable_pool: The variable pool containing system, conversation, and
|
||||
node variables.
|
||||
strict: If True, missing variables will raise an error; if False,
|
||||
missing variables are ignored.
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
The rendered string with all variables substituted.
|
||||
"""
|
||||
from app.core.workflow.template_renderer import render_template
|
||||
|
||||
# 处理 state 为 None 的情况
|
||||
if state is None:
|
||||
state = {}
|
||||
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
# 构建完整的 variables 结构
|
||||
variables = {
|
||||
"sys": pool.get_all_system_vars(),
|
||||
"conv": pool.get_all_conversation_vars()
|
||||
}
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
variables=variables,
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
conv_vars=variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=variable_pool.get_all_node_outputs(),
|
||||
system_vars=variable_pool.get_all_system_vars(),
|
||||
strict=strict
|
||||
)
|
||||
|
||||
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
|
||||
"""评估条件表达式
|
||||
|
||||
支持的变量命名空间:
|
||||
- sys.xxx: 系统变量
|
||||
- conv.xxx: 会话变量
|
||||
- node_id.xxx: 节点输出
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_condition(expression: str, variable_pool: VariablePool) -> bool:
|
||||
"""Evaluates a conditional expression using the provided variable pool.
|
||||
|
||||
Supported variable namespaces:
|
||||
- sys.xxx: System variables
|
||||
- conv.xxx: Conversation variables
|
||||
- node_id.xxx: Node outputs
|
||||
|
||||
Args:
|
||||
expression: 条件表达式
|
||||
state: 工作流状态
|
||||
|
||||
expression: The conditional expression to evaluate.
|
||||
variable_pool: The variable pool containing system, conversation, and
|
||||
node variables.
|
||||
|
||||
Returns:
|
||||
布尔值结果
|
||||
The boolean result of evaluating the expression.
|
||||
"""
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
|
||||
# 处理 state 为 None 的情况
|
||||
if state is None:
|
||||
state = {}
|
||||
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
# 构建完整的 variables 结构(包含 sys 和 conv)
|
||||
variables = {
|
||||
"sys": pool.get_all_system_vars(),
|
||||
"conv": pool.get_all_conversation_vars()
|
||||
}
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
variables=variables,
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
conv_var=variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=variable_pool.get_all_node_outputs(),
|
||||
system_vars=variable_pool.get_all_system_vars()
|
||||
)
|
||||
|
||||
def get_variable_pool(self, state: WorkflowState) -> VariablePool:
|
||||
"""获取变量池实例
|
||||
|
||||
VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
VariablePool 实例
|
||||
|
||||
Examples:
|
||||
>>> pool = self.get_variable_pool(state)
|
||||
>>> message = pool.get("sys.message")
|
||||
>>> llm_output = pool.get("llm_qa.output")
|
||||
"""
|
||||
return VariablePool(state)
|
||||
|
||||
@staticmethod
|
||||
def get_variable(
|
||||
self,
|
||||
selector: list[str] | str,
|
||||
state: WorkflowState,
|
||||
default: Any = None
|
||||
selector: str,
|
||||
variable_pool: VariablePool,
|
||||
default: Any = None,
|
||||
strict: bool = True
|
||||
) -> Any:
|
||||
"""获取变量值(便捷方法)
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
state: 工作流状态
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
变量值
|
||||
|
||||
Examples:
|
||||
>>> message = self.get_variable("sys.message", state)
|
||||
>>> output = self.get_variable(["llm_qa", "output"], state)
|
||||
>>> custom = self.get_variable("var.custom", state, default="默认值")
|
||||
"""
|
||||
pool = VariablePool(state)
|
||||
return pool.get(selector, default=default)
|
||||
"""Retrieves a variable value from the variable pool (convenience method).
|
||||
|
||||
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
|
||||
"""检查变量是否存在(便捷方法)
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
state: 工作流状态
|
||||
|
||||
selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx).
|
||||
variable_pool: The variable pool from which to fetch the value.
|
||||
default: The default value to return if the variable does not exist.
|
||||
strict: If True, raise an error when the variable is missing; if False, return the default.
|
||||
|
||||
Returns:
|
||||
变量是否存在
|
||||
|
||||
Examples:
|
||||
>>> if self.has_variable("llm_qa.output", state):
|
||||
... output = self.get_variable("llm_qa.output", state)
|
||||
The value of the selected variable, or the default if not found and strict is False.
|
||||
"""
|
||||
pool = VariablePool(state)
|
||||
return pool.has(selector)
|
||||
return variable_pool.get_value(selector, default, strict=strict)
|
||||
|
||||
@staticmethod
|
||||
def has_variable(selector: str, variable_pool: VariablePool) -> bool:
|
||||
"""Checks whether a variable exists in the variable pool (convenience method).
|
||||
|
||||
Args:
|
||||
selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx).
|
||||
variable_pool: The variable pool to check.
|
||||
|
||||
Returns:
|
||||
True if the variable exists in the pool, False otherwise.
|
||||
"""
|
||||
return variable_pool.has(selector)
|
||||
|
||||
@staticmethod
|
||||
async def process_message(provider, content, enable_file=False) -> dict | str | None:
|
||||
if isinstance(content, str):
|
||||
if enable_file:
|
||||
return {"text": content}
|
||||
return content
|
||||
elif isinstance(content, dict):
|
||||
trans_tool = PROVIDER_STRATEGIES[provider]()
|
||||
result = await trans_tool.format_image(content["url"])
|
||||
return result
|
||||
raise TypeError('Unexpect input value type')
|
||||
|
||||
@staticmethod
|
||||
def process_model_output(content) -> str:
|
||||
result = ""
|
||||
if isinstance(content, list):
|
||||
for msg in content:
|
||||
if isinstance(msg, dict):
|
||||
result += msg.get("text")
|
||||
elif isinstance(msg, str):
|
||||
result += msg
|
||||
elif isinstance(content, dict):
|
||||
result = content.get("text")
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
return result
|
||||
|
||||
@@ -2,6 +2,8 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,15 +16,19 @@ class BreakNode(BaseNode):
|
||||
to False, signaling the outer loop runtime to terminate further iterations.
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the break node.
|
||||
|
||||
Args:
|
||||
state: Current workflow state, including loop control flags.
|
||||
variable_pool: Pool of variables for the workflow.
|
||||
|
||||
Effects:
|
||||
- Sets 'looping' in the state to False to stop the loop.
|
||||
- Sets 'looping' in the state too False to stop the loop.
|
||||
- Logs the action for debugging purposes.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Literal
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class InputVariable(BaseModel):
|
||||
@@ -44,7 +45,7 @@ class CodeNodeConfig(BaseNodeConfig):
|
||||
description="code content"
|
||||
)
|
||||
|
||||
language: Literal['python3', 'nodejs'] = Field(
|
||||
language: Literal['python3', 'javascript'] = Field(
|
||||
...,
|
||||
description="language"
|
||||
)
|
||||
|
||||
@@ -2,15 +2,17 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import urllib.parse
|
||||
from string import Template
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
import urllib.parse
|
||||
import httpx
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,6 +54,12 @@ class CodeNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: CodeNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
output_dict = {}
|
||||
for output in self.typed_config.output_variables:
|
||||
output_dict[output.name] = output.type
|
||||
return output_dict
|
||||
|
||||
def extract_result(self, content: str):
|
||||
match = re.search(r'<<RESULT>>(.*?)<<RESULT>>', content, re.DOTALL)
|
||||
if match:
|
||||
@@ -92,15 +100,16 @@ class CodeNode(BaseNode):
|
||||
else:
|
||||
raise RuntimeError("The output of main must be a dictionary")
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = CodeNodeConfig(**self.config)
|
||||
input_variable_dict = {}
|
||||
for input_variable in self.typed_config.input_variables:
|
||||
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
|
||||
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, variable_pool)
|
||||
|
||||
code = base64.b64decode(
|
||||
self.typed_config.code
|
||||
).decode("utf-8")
|
||||
code = urllib.parse.unquote(code, encoding='utf-8')
|
||||
|
||||
input_variable_dict = base64.b64encode(
|
||||
json.dumps(input_variable_dict).encode("utf-8")
|
||||
@@ -110,7 +119,7 @@ class CodeNode(BaseNode):
|
||||
code=code,
|
||||
inputs_variable=input_variable_dict,
|
||||
)
|
||||
elif self.typed_config.language == 'nodejs':
|
||||
elif self.typed_config.language == 'javascript':
|
||||
final_script = NODEJS_SCRIPT_TEMPLATE.substitute(
|
||||
code=code,
|
||||
inputs_variable=input_variable_dict,
|
||||
|
||||
@@ -8,7 +8,6 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_config import (
|
||||
BaseNodeConfig,
|
||||
VariableDefinition,
|
||||
VariableType,
|
||||
)
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
@@ -23,21 +22,18 @@ from app.core.workflow.nodes.parameter_extractor.config import ParameterExtracto
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
"BaseNodeConfig",
|
||||
"VariableDefinition",
|
||||
"VariableType",
|
||||
# 节点配置
|
||||
"StartNodeConfig",
|
||||
"EndNodeConfig",
|
||||
"LLMNodeConfig",
|
||||
"MessageConfig",
|
||||
"AgentNodeConfig",
|
||||
"TransformNodeConfig",
|
||||
"IfElseNodeConfig",
|
||||
"KnowledgeRetrievalNodeConfig",
|
||||
"AssignerNodeConfig",
|
||||
|
||||
@@ -2,7 +2,8 @@ from typing import Any
|
||||
|
||||
from pydantic import Field, BaseModel, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||
|
||||
|
||||
@@ -127,4 +128,9 @@ class IterationNodeConfig(BaseNodeConfig):
|
||||
description="Output of the loop iteration"
|
||||
)
|
||||
|
||||
output_type: VariableType = Field(
|
||||
default=None,
|
||||
description="Data type of the loop iteration output"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -28,6 +29,8 @@ class IterationRuntime:
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool,
|
||||
child_variable_pool: VariablePool,
|
||||
):
|
||||
"""
|
||||
Initialize the iteration runtime.
|
||||
@@ -44,11 +47,13 @@ class IterationRuntime:
|
||||
self.node_id = node_id
|
||||
self.typed_config = IterationNodeConfig(**config)
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
|
||||
self.output_value = None
|
||||
self.result: list = []
|
||||
|
||||
def _init_iteration_state(self, item, idx):
|
||||
async def _init_iteration_state(self, item, idx):
|
||||
"""
|
||||
Initialize a per-iteration copy of the workflow state.
|
||||
|
||||
@@ -62,10 +67,9 @@ class IterationRuntime:
|
||||
loopstate = WorkflowState(
|
||||
**self.state
|
||||
)
|
||||
loopstate["runtime_vars"][self.node_id] = {
|
||||
"item": item,
|
||||
"index": idx,
|
||||
}
|
||||
self.child_variable_pool.copy(self.variable_pool)
|
||||
await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
|
||||
await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True)
|
||||
loopstate["node_outputs"][self.node_id] = {
|
||||
"item": item,
|
||||
"index": idx,
|
||||
@@ -74,6 +78,11 @@ class IterationRuntime:
|
||||
loopstate["activate"][self.start_id] = True
|
||||
return loopstate
|
||||
|
||||
def merge_conv_vars(self):
|
||||
self.variable_pool.get_all_conversation_vars().update(
|
||||
self.child_variable_pool.get_all_conversation_vars()
|
||||
)
|
||||
|
||||
async def run_task(self, item, idx):
|
||||
"""
|
||||
Execute a single iteration asynchronously.
|
||||
@@ -82,8 +91,8 @@ class IterationRuntime:
|
||||
item: The input element for this iteration.
|
||||
idx: The index of this iteration.
|
||||
"""
|
||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||
output = VariablePool(result).get(self.output_value)
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
@@ -125,7 +134,7 @@ class IterationRuntime:
|
||||
input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip()
|
||||
self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip()
|
||||
|
||||
array_obj = VariablePool(self.state).get(input_expression)
|
||||
array_obj = self.variable_pool.get_value(input_expression)
|
||||
if not isinstance(array_obj, list):
|
||||
raise RuntimeError("Cannot iterate over a non-list variable")
|
||||
child_state = []
|
||||
@@ -137,14 +146,16 @@ class IterationRuntime:
|
||||
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
|
||||
idx += self.typed_config.parallel_count
|
||||
child_state.extend(await asyncio.gather(*tasks))
|
||||
self.merge_conv_vars()
|
||||
else:
|
||||
# Execute iterations sequentially
|
||||
while idx < len(array_obj) and self.looping:
|
||||
logger.info(f"Iteration node {self.node_id}: running")
|
||||
item = array_obj[idx]
|
||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
child_state.append(result)
|
||||
output = VariablePool(result).get(self.output_value)
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
self.merge_conv_vars()
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
|
||||
@@ -31,6 +31,8 @@ class LoopRuntime:
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool,
|
||||
child_variable_pool: VariablePool
|
||||
):
|
||||
"""
|
||||
Initialize the loop runtime executor.
|
||||
@@ -40,6 +42,8 @@ class LoopRuntime:
|
||||
node_id: The unique identifier of the loop node in the workflow.
|
||||
config: Raw configuration dictionary for the loop node.
|
||||
state: The current workflow state before entering the loop.
|
||||
variable_pool: A VariablePool instance for accessing and modifying workflow variables.
|
||||
child_variable_pool: A VariablePool instance for managing child node outputs.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.graph = graph
|
||||
@@ -47,8 +51,10 @@ class LoopRuntime:
|
||||
self.node_id = node_id
|
||||
self.typed_config = LoopNodeConfig(**config)
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
|
||||
def _init_loop_state(self):
|
||||
async def _init_loop_state(self):
|
||||
"""
|
||||
Initialize workflow state for loop execution.
|
||||
|
||||
@@ -62,33 +68,35 @@ class LoopRuntime:
|
||||
Returns:
|
||||
WorkflowState: A prepared workflow state used for loop execution.
|
||||
"""
|
||||
pool = VariablePool(self.state)
|
||||
# 循环变量
|
||||
self.state["runtime_vars"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
)
|
||||
if variable.input_type == ValueInputType.VARIABLE
|
||||
else TypeTransformer.transform(variable.value, variable.type)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
self.state["node_outputs"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
)
|
||||
if variable.input_type == ValueInputType.VARIABLE
|
||||
else TypeTransformer.transform(variable.value, variable.type)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
self.child_variable_pool.copy(self.variable_pool)
|
||||
|
||||
for variable in self.typed_config.cycle_vars:
|
||||
if variable.input_type == ValueInputType.VARIABLE:
|
||||
value = evaluate_expression(
|
||||
expression=variable.value,
|
||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
||||
system_vars=self.variable_pool.get_all_system_vars(),
|
||||
)
|
||||
else:
|
||||
value = TypeTransformer.transform(variable.value, variable.type)
|
||||
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
|
||||
loopstate = WorkflowState(
|
||||
**self.state
|
||||
)
|
||||
loopstate["node_outputs"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
||||
system_vars=self.variable_pool.get_all_system_vars(),
|
||||
)
|
||||
if variable.input_type == ValueInputType.VARIABLE
|
||||
else TypeTransformer.transform(variable.value, variable.type)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
|
||||
loopstate["looping"] = 1
|
||||
loopstate["activate"][self.start_id] = True
|
||||
return loopstate
|
||||
@@ -134,7 +142,12 @@ class LoopRuntime:
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {operator}")
|
||||
|
||||
def evaluate_conditional(self, state) -> bool:
|
||||
def merge_conv_vars(self):
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables.get("conv", {})
|
||||
)
|
||||
|
||||
def evaluate_conditional(self) -> bool:
|
||||
"""
|
||||
Evaluate the loop continuation condition at runtime.
|
||||
|
||||
@@ -143,18 +156,15 @@ class LoopRuntime:
|
||||
- Evaluates each comparison expression immediately
|
||||
- Combines results using the configured logical operator (AND / OR)
|
||||
|
||||
Args:
|
||||
state: The current workflow state during loop execution.
|
||||
|
||||
Returns:
|
||||
bool: True if the loop should continue, False otherwise.
|
||||
"""
|
||||
conditions = []
|
||||
|
||||
for expression in self.typed_config.condition.expressions:
|
||||
left_value = VariablePool(state).get(expression.left)
|
||||
left_value = self.child_variable_pool.get_value(expression.left)
|
||||
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||
VariablePool(state),
|
||||
self.child_variable_pool,
|
||||
expression.left,
|
||||
expression.right,
|
||||
expression.input_type
|
||||
@@ -177,16 +187,18 @@ class LoopRuntime:
|
||||
Returns:
|
||||
dict[str, Any]: The final runtime variables of this loop node.
|
||||
"""
|
||||
loopstate = self._init_loop_state()
|
||||
loopstate = await self._init_loop_state()
|
||||
loop_time = self.typed_config.max_loop
|
||||
child_state = []
|
||||
while self.evaluate_conditional(loopstate) and self.looping and loop_time > 0:
|
||||
while not self.evaluate_conditional() and self.looping and loop_time > 0:
|
||||
logger.info(f"loop node {self.node_id}: running")
|
||||
result = await self.graph.ainvoke(loopstate)
|
||||
child_state.append(result)
|
||||
|
||||
self.merge_conv_vars()
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
loop_time -= 1
|
||||
|
||||
logger.info(f"loop node {self.node_id}: execution completed")
|
||||
return loopstate["runtime_vars"][self.node_id] | {"__child_state": child_state}
|
||||
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
|
||||
|
||||
@@ -6,9 +6,12 @@ from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
|
||||
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,9 +38,41 @@ class CycleGraphNode(BaseNode):
|
||||
self.start_node_id = None # ID of the start node within the cycle
|
||||
|
||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||
self.child_variable_pool: VariablePool | None = None
|
||||
self.build_graph()
|
||||
self.iteration_flag = True
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
|
||||
if self.node_type == NodeType.LOOP:
|
||||
# Loop node outputs the final state of the loop
|
||||
config = LoopNodeConfig(**self.config)
|
||||
for var_def in config.cycle_vars:
|
||||
outputs[var_def.name] = var_def.type
|
||||
return outputs
|
||||
elif self.node_type == NodeType.ITERATION:
|
||||
# Iteration node outputs the processed collection
|
||||
config = IterationNodeConfig(**self.config)
|
||||
if not config.output_type:
|
||||
outputs['output'] = VariableType.ANY
|
||||
return outputs
|
||||
if config.output_type in [
|
||||
VariableType.ARRAY_FILE,
|
||||
VariableType.ARRAY_STRING,
|
||||
VariableType.NUMBER,
|
||||
VariableType.ARRAY_OBJECT,
|
||||
VariableType.BOOLEAN
|
||||
]:
|
||||
if config.flatten:
|
||||
outputs['output'] = config.output_type
|
||||
else:
|
||||
outputs['output'] = VariableType.ARRAY_STRING
|
||||
else:
|
||||
outputs['output'] = VariableType(f"array[{config.output_type}]")
|
||||
return outputs
|
||||
else:
|
||||
raise KeyError(f"Valid Cycle Node Type - {self.node_type}")
|
||||
|
||||
def pure_cycle_graph(self) -> tuple[list, list]:
|
||||
"""
|
||||
Extract cycle-scoped nodes and internal edges from the workflow configuration.
|
||||
@@ -103,17 +138,20 @@ class CycleGraphNode(BaseNode):
|
||||
"""
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
self.child_variable_pool = VariablePool()
|
||||
builder = GraphBuilder(
|
||||
{
|
||||
"nodes": self.cycle_nodes,
|
||||
"edges": self.cycle_edges,
|
||||
},
|
||||
subgraph=True
|
||||
subgraph=True,
|
||||
variable_pool=self.child_variable_pool
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.graph = builder.build()
|
||||
self.child_variable_pool = builder.variable_pool
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the cycle node at runtime.
|
||||
|
||||
@@ -123,6 +161,7 @@ class CycleGraphNode(BaseNode):
|
||||
|
||||
Args:
|
||||
state: The current workflow state when entering the cycle node.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
Any: The runtime result produced by the loop or iteration executor.
|
||||
@@ -137,6 +176,8 @@ class CycleGraphNode(BaseNode):
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool,
|
||||
).run()
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
return await IterationRuntime(
|
||||
@@ -145,5 +186,7 @@ class CycleGraphNode(BaseNode):
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool
|
||||
).run()
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class EndNodeConfig(BaseNodeConfig):
|
||||
|
||||
@@ -7,6 +7,8 @@ End 节点实现
|
||||
import logging
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,12 +19,18 @@ class EndNode(BaseNode):
|
||||
工作流的结束节点,根据配置的模板输出最终结果。
|
||||
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
|
||||
"""
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
"""声明此节点的输出类型"""
|
||||
return {
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> str:
|
||||
"""执行 end 节点业务逻辑
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
最终输出字符串
|
||||
@@ -34,7 +42,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state, strict=False)
|
||||
output = self._render_template(output_template, variable_pool, strict=False)
|
||||
else:
|
||||
output = ""
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ class NodeType(StrEnum):
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TRANSFORM = "transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
|
||||
@@ -10,6 +10,8 @@ from httpx import AsyncClient, Response, Timeout
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
@@ -34,6 +36,14 @@ class HttpRequestNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: HttpRequestNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"body": VariableType.STRING,
|
||||
"status_code": VariableType.NUMBER,
|
||||
"headers": VariableType.OBJECT,
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
def _build_timeout(self) -> Timeout:
|
||||
"""
|
||||
Build httpx Timeout configuration.
|
||||
@@ -50,7 +60,7 @@ class HttpRequestNode(BaseNode):
|
||||
)
|
||||
return timeout
|
||||
|
||||
def _build_auth(self, state: WorkflowState) -> dict[str, str]:
|
||||
def _build_auth(self, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Build authentication-related HTTP headers.
|
||||
|
||||
@@ -58,12 +68,12 @@ class HttpRequestNode(BaseNode):
|
||||
the current workflow runtime state.
|
||||
|
||||
Args:
|
||||
state: Current workflow runtime state.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
A dictionary of HTTP headers used for authentication.
|
||||
"""
|
||||
api_key = self._render_template(self.typed_config.auth.api_key, state)
|
||||
api_key = self._render_template(self.typed_config.auth.api_key, variable_pool)
|
||||
match self.typed_config.auth.auth_type:
|
||||
case HttpAuthType.NONE:
|
||||
return {}
|
||||
@@ -82,7 +92,7 @@ class HttpRequestNode(BaseNode):
|
||||
case _:
|
||||
raise RuntimeError(f"Auth type not supported: {self.typed_config.auth.auth_type}")
|
||||
|
||||
def _build_header(self, state: WorkflowState) -> dict[str, str]:
|
||||
def _build_header(self, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Build HTTP request headers.
|
||||
|
||||
@@ -90,10 +100,10 @@ class HttpRequestNode(BaseNode):
|
||||
"""
|
||||
headers = {}
|
||||
for key, value in self.typed_config.headers.items():
|
||||
headers[self._render_template(key, state)] = self._render_template(value, state)
|
||||
headers[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
|
||||
return headers
|
||||
|
||||
def _build_params(self, state: WorkflowState) -> dict[str, str]:
|
||||
def _build_params(self, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Build URL query parameters.
|
||||
|
||||
@@ -101,10 +111,10 @@ class HttpRequestNode(BaseNode):
|
||||
"""
|
||||
params = {}
|
||||
for key, value in self.typed_config.params.items():
|
||||
params[self._render_template(key, state)] = self._render_template(value, state)
|
||||
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
|
||||
return params
|
||||
|
||||
def _build_content(self, state) -> dict[str, Any]:
|
||||
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""
|
||||
Build HTTP request body arguments for httpx request methods.
|
||||
|
||||
@@ -120,13 +130,13 @@ class HttpRequestNode(BaseNode):
|
||||
return {}
|
||||
case HttpContentType.JSON:
|
||||
content["json"] = json.loads(self._render_template(
|
||||
self.typed_config.body.data, state
|
||||
self.typed_config.body.data, variable_pool
|
||||
))
|
||||
case HttpContentType.FROM_DATA:
|
||||
data = {}
|
||||
for item in self.typed_config.body.data:
|
||||
if item.type == "text":
|
||||
data[self._render_template(item.key, state)] = self._render_template(item.value, state)
|
||||
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool)
|
||||
elif item.type == "file":
|
||||
# TODO: File support (Feature)
|
||||
pass
|
||||
@@ -136,11 +146,11 @@ class HttpRequestNode(BaseNode):
|
||||
pass
|
||||
case HttpContentType.WWW_FORM:
|
||||
content["data"] = json.loads(self._render_template(
|
||||
json.dumps(self.typed_config.body.data), state
|
||||
json.dumps(self.typed_config.body.data), variable_pool
|
||||
))
|
||||
|
||||
case HttpContentType.RAW:
|
||||
content["content"] = self._render_template(self.typed_config.body.data, state)
|
||||
content["content"] = self._render_template(self.typed_config.body.data, variable_pool)
|
||||
case _:
|
||||
raise RuntimeError(f"Content type not supported: {self.typed_config.body.content_type}")
|
||||
return content
|
||||
@@ -165,7 +175,7 @@ class HttpRequestNode(BaseNode):
|
||||
case _:
|
||||
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict | str:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
|
||||
"""
|
||||
Execute the HTTP request node.
|
||||
|
||||
@@ -176,6 +186,7 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
Args:
|
||||
state: Current workflow runtime state.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
- dict: Serialized HttpRequestNodeOutput on success
|
||||
@@ -185,8 +196,8 @@ class HttpRequestNode(BaseNode):
|
||||
async with httpx.AsyncClient(
|
||||
verify=self.typed_config.verify_ssl,
|
||||
timeout=self._build_timeout(),
|
||||
headers=self._build_header(state) | self._build_auth(state),
|
||||
params=self._build_params(state),
|
||||
headers=self._build_header(variable_pool) | self._build_auth(variable_pool),
|
||||
params=self._build_params(variable_pool),
|
||||
follow_redirects=True
|
||||
) as client:
|
||||
retries = self.typed_config.retry.max_attempts
|
||||
@@ -194,8 +205,8 @@ class HttpRequestNode(BaseNode):
|
||||
try:
|
||||
request_func = self._get_client_method(client)
|
||||
resp = await request_func(
|
||||
url=self._render_template(self.typed_config.url, state),
|
||||
**self._build_content(state)
|
||||
url=self._render_template(self.typed_config.url, variable_pool),
|
||||
**self._build_content(variable_pool)
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
|
||||
@@ -6,6 +6,8 @@ from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,6 +17,11 @@ class IfElseNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: IfElseNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||
match operator:
|
||||
@@ -45,7 +52,7 @@ class IfElseNode(BaseNode):
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {operator}")
|
||||
|
||||
def evaluate_conditional_edge_expressions(self, state) -> list[bool]:
|
||||
def evaluate_conditional_edge_expressions(self, variable_pool: VariablePool) -> list[bool]:
|
||||
"""
|
||||
Build conditional edge expressions for the If-Else node.
|
||||
|
||||
@@ -72,11 +79,11 @@ class IfElseNode(BaseNode):
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
left_string = re.sub(pattern, r"\1", expression.left).strip()
|
||||
try:
|
||||
left_value = self.get_variable(left_string, state)
|
||||
left_value = self.get_variable(left_string, variable_pool)
|
||||
except KeyError:
|
||||
left_value = None
|
||||
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||
self.get_variable_pool(state),
|
||||
variable_pool,
|
||||
expression.left,
|
||||
expression.right,
|
||||
expression.input_type
|
||||
@@ -95,7 +102,7 @@ class IfElseNode(BaseNode):
|
||||
|
||||
return conditions
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the conditional branching logic of the node.
|
||||
|
||||
@@ -105,13 +112,13 @@ class IfElseNode(BaseNode):
|
||||
|
||||
Args:
|
||||
state (WorkflowState): The current workflow state, containing variables, messages, node outputs, etc.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
|
||||
"""
|
||||
self.typed_config = IfElseNodeConfig(**self.config)
|
||||
expressions = self.evaluate_conditional_edge_expressions(state)
|
||||
# TODO: 变量类型及文本类型解析
|
||||
expressions = self.evaluate_conditional_edge_expressions(variable_pool)
|
||||
for i in range(len(expressions)):
|
||||
if expressions[i]:
|
||||
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
|
||||
|
||||
@@ -5,6 +5,8 @@ from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
|
||||
from app.core.workflow.template_renderer import TemplateRenderer
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,7 +16,12 @@ class JinjaRenderNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: JinjaRenderNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the node: render the Jinja2 template with mapped variables.
|
||||
|
||||
@@ -24,6 +31,7 @@ class JinjaRenderNode(BaseNode):
|
||||
Args:
|
||||
state (WorkflowState): Current workflow state containing variables,
|
||||
node outputs, and runtime variables.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Node output dictionary containing the rendered result
|
||||
@@ -40,7 +48,7 @@ class JinjaRenderNode(BaseNode):
|
||||
context = {}
|
||||
for variable in self.typed_config.mapping:
|
||||
try:
|
||||
context[variable.name] = self.get_variable(variable.value, state)
|
||||
context[variable.name] = self.get_variable(variable.value, variable_pool)
|
||||
except Exception:
|
||||
logger.info(f"variable not found, var: {variable.value}")
|
||||
continue
|
||||
|
||||
@@ -8,6 +8,8 @@ from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
@@ -22,6 +24,11 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"output": VariableType.ARRAY_STRING
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
||||
"""
|
||||
@@ -149,7 +156,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
return reranker
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the knowledge retrieval workflow node.
|
||||
|
||||
@@ -163,6 +170,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
Args:
|
||||
state (WorkflowState): Current workflow execution state.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
Any: List of retrieved knowledge chunks (dict format).
|
||||
@@ -171,7 +179,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
RuntimeError: If no valid knowledge base is found or access is denied.
|
||||
"""
|
||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||
query = self._render_template(self.typed_config.query, state)
|
||||
query = self._render_template(self.typed_config.query, variable_pool)
|
||||
with get_db_read() as db:
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
|
||||
|
||||
@@ -4,7 +4,8 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class MessageConfig(BaseModel):
|
||||
@@ -70,6 +71,16 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
description="对话上下文窗口"
|
||||
)
|
||||
|
||||
vision: bool = Field(
|
||||
default=False,
|
||||
description="是否启用视觉模型"
|
||||
)
|
||||
|
||||
vision_input: str = Field(
|
||||
default=None,
|
||||
description="视觉输入"
|
||||
)
|
||||
|
||||
# 简单模式
|
||||
prompt: str | None = Field(
|
||||
default=None,
|
||||
|
||||
@@ -15,6 +15,8 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
@@ -66,59 +68,34 @@ class LLMNode(BaseNode):
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: LLMNodeConfig | None = None
|
||||
|
||||
def _render_context(self, message, state):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||
def _render_context(self, message: str, variable_pool: VariablePool):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
||||
return re.sub(r"{{context}}", context, message)
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
|
||||
async def _prepare_llm(
|
||||
self,
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool,
|
||||
stream: bool = False
|
||||
) -> RedBearLLM:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
|
||||
"""
|
||||
|
||||
# 1. 处理消息格式(优先使用 messages)
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.role.lower()
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, state)
|
||||
content = self._render_template(content_template, state)
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({"role": "system", "content": content})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({"role": "user", "content": content})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
if self.typed_config.memory.enable:
|
||||
# if self.typed_config.memory.enable_window:
|
||||
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||
prompt_or_messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
prompt_template = self.config.get("prompt", "")
|
||||
prompt_or_messages = self._render_template(prompt_template, state)
|
||||
|
||||
# 2. 获取模型配置
|
||||
model_id = self.config.get("model_id")
|
||||
model_id = self.typed_config.model_id
|
||||
if not model_id:
|
||||
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
||||
|
||||
@@ -157,27 +134,82 @@ class LLMNode(BaseNode):
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
|
||||
return llm, prompt_or_messages
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.role.lower()
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": content
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_value(self.typed_config.vision_input)
|
||||
for file in files:
|
||||
content = await self.process_message(provider, file, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.append(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
messages[-1]['content'] = [messages[-1]["content"]] + file_content
|
||||
else:
|
||||
messages.append({"role": "user", "content": file_content})
|
||||
|
||||
if self.typed_config.memory.enable:
|
||||
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||
self.messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
prompt_template = self.config.get("prompt", "")
|
||||
self.messages = self._render_template(prompt_template, variable_pool)
|
||||
|
||||
return llm
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
||||
"""非流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
# self.typed_config = LLMNodeConfig(**self.config)
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
llm = await self._prepare_llm(state, variable_pool, False)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(prompt_or_messages)
|
||||
response = await llm.ainvoke(self.messages)
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = response.content
|
||||
content = self.process_model_output(response.content)
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
@@ -186,16 +218,15 @@ class LLMNode(BaseNode):
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
_, prompt_or_messages = self._prepare_llm(state)
|
||||
|
||||
return {
|
||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||
"prompt": self.messages if isinstance(self.messages, str) else None,
|
||||
"messages": [
|
||||
{"role": msg.get("role"), "content": msg.get("content", "")}
|
||||
for msg in prompt_or_messages
|
||||
] if isinstance(prompt_or_messages, list) else None,
|
||||
for msg in self.messages
|
||||
] if isinstance(self.messages, list) else None,
|
||||
"config": {
|
||||
"model_id": self.config.get("model_id"),
|
||||
"temperature": self.config.get("temperature"),
|
||||
@@ -215,24 +246,25 @@ class LLMNode(BaseNode):
|
||||
usage = business_result.response_metadata.get('token_usage')
|
||||
if usage:
|
||||
return {
|
||||
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||
"completion_tokens": usage.get('completion_tokens', 0),
|
||||
"prompt_tokens": usage.get('input_tokens', 0),
|
||||
"completion_tokens": usage.get('output_tokens', 0),
|
||||
"total_tokens": usage.get('total_tokens', 0)
|
||||
}
|
||||
return None
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
"""流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
llm = await self._prepare_llm(state, variable_pool, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
@@ -243,10 +275,10 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
last_meta_data = {}
|
||||
async for chunk in llm.astream(prompt_or_messages, stream_usage=True):
|
||||
async for chunk in llm.astream(self.messages, stream_usage=True):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = chunk.content
|
||||
content = self.process_model_output(chunk.content)
|
||||
else:
|
||||
content = str(chunk)
|
||||
if hasattr(chunk, 'response_metadata'):
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Any
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
@@ -13,17 +15,23 @@ class MemoryReadNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: MemoryReadNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"answer": VariableType.STRING,
|
||||
"intermediate_outputs": VariableType.ARRAY_OBJECT
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
with get_db_read() as db:
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
return await MemoryAgentService().read_memory(
|
||||
end_user_id=end_user_id,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
message=self._render_template(self.typed_config.message, variable_pool),
|
||||
config_id=self.typed_config.config_id,
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
@@ -38,16 +46,19 @@ class MemoryWriteNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: MemoryWriteNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
write_message_task.delay(
|
||||
end_user_id,
|
||||
self._render_template(self.typed_config.message, state),
|
||||
self._render_template(self.typed_config.message, variable_pool),
|
||||
str(self.typed_config.config_id),
|
||||
"neo4j",
|
||||
""
|
||||
|
||||
@@ -22,7 +22,6 @@ from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.memory import MemoryReadNode, MemoryWriteNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.breaker import BreakNode
|
||||
@@ -37,7 +36,6 @@ WorkflowNode = Union[
|
||||
LLMNode,
|
||||
IfElseNode,
|
||||
AgentNode,
|
||||
TransformNode,
|
||||
AssignerNode,
|
||||
HttpRequestNode,
|
||||
KnowledgeRetrievalNode,
|
||||
@@ -67,7 +65,6 @@ class NodeFactory:
|
||||
NodeType.END: EndNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.AGENT: AgentNode,
|
||||
NodeType.TRANSFORM: TransformNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.ASSIGNER: AssignerNode,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC
|
||||
from typing import Union, Type, NoReturn
|
||||
from typing import Union, Type, NoReturn, Any
|
||||
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.nodes.enums import ValueInputType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
@@ -69,7 +69,7 @@ class TypeTransformer:
|
||||
|
||||
|
||||
class OperatorBase(ABC):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
def __init__(self, pool: VariablePool, left_selector: str, right: Any):
|
||||
self.pool = pool
|
||||
self.left_selector = left_selector
|
||||
self.right = right
|
||||
@@ -77,7 +77,7 @@ class OperatorBase(ABC):
|
||||
self.type_limit: type[str, int, dict, list] = None
|
||||
|
||||
def check(self, no_right=False):
|
||||
left = self.pool.get(self.left_selector)
|
||||
left = self.pool.get_value(self.left_selector)
|
||||
if not isinstance(left, self.type_limit):
|
||||
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
|
||||
|
||||
@@ -92,13 +92,13 @@ class StringOperator(OperatorBase):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = str
|
||||
|
||||
def assign(self) -> None:
|
||||
async def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
await self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
async def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, '')
|
||||
await self.pool.set(self.left_selector, '')
|
||||
|
||||
|
||||
class NumberOperator(OperatorBase):
|
||||
@@ -106,33 +106,33 @@ class NumberOperator(OperatorBase):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = (float, int)
|
||||
|
||||
def assign(self) -> None:
|
||||
async def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
await self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
async def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, 0)
|
||||
await self.pool.set(self.left_selector, 0)
|
||||
|
||||
def add(self) -> None:
|
||||
async def add(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin + self.right)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
await self.pool.set(self.left_selector, origin + self.right)
|
||||
|
||||
def subtract(self) -> None:
|
||||
async def subtract(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin - self.right)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
await self.pool.set(self.left_selector, origin - self.right)
|
||||
|
||||
def multiply(self) -> None:
|
||||
async def multiply(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin * self.right)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
await self.pool.set(self.left_selector, origin * self.right)
|
||||
|
||||
def divide(self) -> None:
|
||||
async def divide(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin / self.right)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
await self.pool.set(self.left_selector, origin / self.right)
|
||||
|
||||
|
||||
class BooleanOperator(OperatorBase):
|
||||
@@ -140,13 +140,13 @@ class BooleanOperator(OperatorBase):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = bool
|
||||
|
||||
def assign(self) -> None:
|
||||
async def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
await self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
async def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, False)
|
||||
await self.pool.set(self.left_selector, False)
|
||||
|
||||
|
||||
class ArrayOperator(OperatorBase):
|
||||
@@ -154,38 +154,37 @@ class ArrayOperator(OperatorBase):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = list
|
||||
|
||||
def assign(self) -> None:
|
||||
async def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
await self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
async def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, list())
|
||||
await self.pool.set(self.left_selector, list())
|
||||
|
||||
def append(self) -> None:
|
||||
async def append(self) -> None:
|
||||
self.check(no_right=True)
|
||||
# TODO:require type limit in list
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
origin.append(self.right)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
await self.pool.set(self.left_selector, origin)
|
||||
|
||||
def extend(self) -> None:
|
||||
async def extend(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
origin.extend(self.right)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
await self.pool.set(self.left_selector, origin)
|
||||
|
||||
def remove_last(self) -> None:
|
||||
async def remove_last(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
origin.pop()
|
||||
self.pool.set(self.left_selector, origin)
|
||||
await self.pool.set(self.left_selector, origin)
|
||||
|
||||
def remove_first(self) -> None:
|
||||
async def remove_first(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
origin.pop(0)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
await self.pool.set(self.left_selector, origin)
|
||||
|
||||
|
||||
class ObjectOperator(OperatorBase):
|
||||
@@ -193,13 +192,13 @@ class ObjectOperator(OperatorBase):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = dict
|
||||
|
||||
def assign(self) -> None:
|
||||
async def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
await self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
async def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, dict())
|
||||
await self.pool.set(self.left_selector, dict())
|
||||
|
||||
|
||||
class AssignmentOperatorResolver:
|
||||
@@ -245,7 +244,7 @@ class ConditionBase(ABC):
|
||||
self.right_selector = right_selector
|
||||
self.input_type = input_type
|
||||
|
||||
self.left_value = self.pool.get(self.left_selector)
|
||||
self.left_value = self.pool.get_value(self.left_selector)
|
||||
self.right_value = self.resolve_right_literal_value()
|
||||
|
||||
self.type_limit = getattr(self, "type_limit", None)
|
||||
@@ -254,7 +253,7 @@ class ConditionBase(ABC):
|
||||
if self.input_type == ValueInputType.VARIABLE:
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
|
||||
return self.pool.get(right_expression)
|
||||
return self.pool.get_value(right_expression)
|
||||
elif self.input_type == ValueInputType.CONSTANT:
|
||||
return self.right_selector
|
||||
raise RuntimeError("Unsupported variable type")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
from enum import StrEnum
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
@@ -24,6 +26,12 @@ class ParameterExtractorNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {}
|
||||
for param in self.typed_config.params:
|
||||
outputs[param.name] = param.type
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _get_prompt():
|
||||
"""
|
||||
@@ -120,7 +128,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
field_type[param.name] = f'{param.type}, required:{str(param.required)}'
|
||||
return field_type
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Main execution function for this node.
|
||||
|
||||
@@ -138,6 +146,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
Args:
|
||||
state (WorkflowState): Current state of the workflow, used for template rendering.
|
||||
variable_pool (VariablePool): Used for accessing and setting variables during execution.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Dictionary containing extracted parameters under the "output" key.
|
||||
@@ -153,7 +162,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
rendered_user_prompt = user_prompt_teplate.render(
|
||||
field_descriptions=str(self._get_field_desc()),
|
||||
field_type=str(self._get_field_type()),
|
||||
text_input=self._render_template(self.typed_config.text, state)
|
||||
text_input=self._render_template(self.typed_config.text, variable_pool)
|
||||
)
|
||||
|
||||
messages = [
|
||||
@@ -162,7 +171,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
]
|
||||
if self.typed_config.prompt:
|
||||
messages.extend([
|
||||
("user", self._render_template(self.typed_config.prompt, state)),
|
||||
("user", self._render_template(self.typed_config.prompt, variable_pool)),
|
||||
("user", rendered_user_prompt),
|
||||
])
|
||||
else:
|
||||
|
||||
@@ -6,6 +6,8 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
@@ -24,6 +26,12 @@ class QuestionClassifierNode(BaseNode):
|
||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||
self.category_to_case_map = {}
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"class_name": VariableType.STRING,
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
def _get_llm_instance(self) -> RedBearLLM:
|
||||
"""获取LLM实例"""
|
||||
with get_db_read() as db:
|
||||
@@ -65,7 +73,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
category_map[category_name] = case_tag
|
||||
return category_map
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict:
|
||||
"""执行问题分类"""
|
||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||
self.category_to_case_map = self._build_category_case_map()
|
||||
@@ -102,7 +110,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
categories=", ".join(category_names),
|
||||
supplement_prompt=supplement_prompt
|
||||
),
|
||||
state
|
||||
variable_pool
|
||||
)
|
||||
|
||||
messages = [
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class StartNodeConfig(BaseNodeConfig):
|
||||
|
||||
@@ -7,9 +7,10 @@ Start 节点实现
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,14 +37,25 @@ class StartNode(BaseNode):
|
||||
|
||||
# 解析并验证配置
|
||||
self.typed_config: StartNodeConfig | None = None
|
||||
self.output_var_types = {}
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return self.output_var_types | {
|
||||
"message": VariableType.STRING,
|
||||
"execution_id": VariableType.STRING,
|
||||
"conversation_id": VariableType.STRING,
|
||||
"workspace_id": VariableType.STRING,
|
||||
"user_id": VariableType.STRING,
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""执行 start 节点业务逻辑
|
||||
|
||||
Start 节点输出系统变量、会话变量和自定义变量。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
包含系统参数、会话变量和自定义变量的字典
|
||||
@@ -51,19 +63,16 @@ class StartNode(BaseNode):
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||
|
||||
# 创建变量池实例(在方法内复用)
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
# 处理自定义变量(传入 pool 避免重复创建)
|
||||
custom_vars = self._process_custom_variables(pool)
|
||||
custom_vars = self._process_custom_variables(variable_pool)
|
||||
|
||||
# 返回业务数据(包含自定义变量)
|
||||
result = {
|
||||
"message": pool.get("sys.message"),
|
||||
"execution_id": pool.get("sys.execution_id"),
|
||||
"conversation_id": pool.get("sys.conversation_id"),
|
||||
"workspace_id": pool.get("sys.workspace_id"),
|
||||
"user_id": pool.get("sys.user_id"),
|
||||
"message": variable_pool.get_value("sys.message"),
|
||||
"execution_id": variable_pool.get_value("sys.execution_id"),
|
||||
"conversation_id": variable_pool.get_value("sys.conversation_id"),
|
||||
"workspace_id": variable_pool.get_value("sys.workspace_id"),
|
||||
"user_id": variable_pool.get_value("sys.user_id"),
|
||||
**custom_vars # 自定义变量作为节点输出的一部分
|
||||
}
|
||||
|
||||
@@ -74,7 +83,7 @@ class StartNode(BaseNode):
|
||||
|
||||
return result
|
||||
|
||||
def _process_custom_variables(self, pool) -> dict[str, Any]:
|
||||
def _process_custom_variables(self, pool: VariablePool) -> dict[str, Any]:
|
||||
"""处理自定义变量
|
||||
|
||||
从输入数据中提取自定义变量,应用默认值和验证。
|
||||
@@ -89,13 +98,14 @@ class StartNode(BaseNode):
|
||||
ValueError: 缺少必需变量
|
||||
"""
|
||||
# 获取输入数据中的自定义变量
|
||||
input_variables = pool.get("sys.input_variables", default={})
|
||||
input_variables = pool.get_value("sys.input_variables", default={}, strict=False)
|
||||
|
||||
processed = {}
|
||||
|
||||
# 遍历配置的变量定义
|
||||
for var_def in self.typed_config.variables:
|
||||
var_name = var_def.name
|
||||
var_type = var_def.type
|
||||
|
||||
# 检查变量是否存在
|
||||
if var_name in input_variables:
|
||||
@@ -116,21 +126,12 @@ class StartNode(BaseNode):
|
||||
f"变量 '{var_name}' 使用默认值: {var_def.default}"
|
||||
)
|
||||
else:
|
||||
match var_def.type:
|
||||
case VariableType.STRING:
|
||||
processed[var_name] = ""
|
||||
case VariableType.NUMBER:
|
||||
processed[var_name] = 0
|
||||
case VariableType.OBJECT:
|
||||
processed[var_name] = {}
|
||||
case VariableType.BOOLEAN:
|
||||
processed[var_name] = False
|
||||
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
|
||||
processed[var_name] = []
|
||||
processed[var_name] = DEFAULT_VALUE(var_type)
|
||||
self.output_var_types[var_name] = var_type
|
||||
|
||||
return processed
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)
|
||||
|
||||
Args:
|
||||
@@ -139,11 +140,9 @@ class StartNode(BaseNode):
|
||||
Returns:
|
||||
输入数据字典
|
||||
"""
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
return {
|
||||
"execution_id": pool.get("sys.execution_id"),
|
||||
"conversation_id": pool.get("sys.conversation_id"),
|
||||
"message": pool.get("sys.message"),
|
||||
"conversation_vars": pool.get_all_conversation_vars()
|
||||
"execution_id": variable_pool.get_value("sys.execution_id"),
|
||||
"conversation_id": variable_pool.get_value("sys.conversation_id"),
|
||||
"message": variable_pool.get_value("sys.message"),
|
||||
"conversation_vars": variable_pool.get_all_conversation_vars()
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.tool_service import ToolService
|
||||
from app.db import get_db_read
|
||||
|
||||
@@ -21,13 +23,20 @@ class ToolNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: ToolNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"data": VariableType.STRING,
|
||||
"error_code": VariableType.STRING,
|
||||
"execution_time": VariableType.NUMBER
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""执行工具"""
|
||||
self.typed_config = ToolNodeConfig(**self.config)
|
||||
# 获取租户ID和用户ID
|
||||
tenant_id = self.get_variable("sys.tenant_id", state)
|
||||
user_id = self.get_variable("sys.user_id", state)
|
||||
workspace_id = self.get_variable("sys.workspace_id", state)
|
||||
tenant_id = self.get_variable("sys.tenant_id", variable_pool, strict=False)
|
||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
workspace_id = self.get_variable("sys.workspace_id", variable_pool)
|
||||
|
||||
# 如果没有租户ID,尝试从工作流ID获取
|
||||
if not tenant_id:
|
||||
@@ -48,7 +57,7 @@ class ToolNode(BaseNode):
|
||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
|
||||
try:
|
||||
rendered_value = self._render_template(param_template, state)
|
||||
rendered_value = self._render_template(param_template, variable_pool)
|
||||
except Exception as e:
|
||||
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||
else:
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Transform 节点"""
|
||||
|
||||
from app.core.workflow.nodes.transform.node import TransformNode
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
|
||||
__all__ = ["TransformNode", "TransformNodeConfig"]
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Transform 节点配置"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class TransformNodeConfig(BaseNodeConfig):
|
||||
"""Transform 节点配置
|
||||
|
||||
用于数据转换和处理。
|
||||
"""
|
||||
|
||||
transform_type: Literal["template", "code", "json"] = Field(
|
||||
default="template",
|
||||
description="转换类型:template(模板), code(代码), json(JSON处理)"
|
||||
)
|
||||
|
||||
# 模板模式
|
||||
template: str | None = Field(
|
||||
default=None,
|
||||
description="转换模板,支持变量引用"
|
||||
)
|
||||
|
||||
# 代码模式
|
||||
code: str | None = Field(
|
||||
default=None,
|
||||
description="Python 代码,用于数据转换"
|
||||
)
|
||||
|
||||
# JSON 模式
|
||||
json_path: str | None = Field(
|
||||
default=None,
|
||||
description="JSON 路径表达式"
|
||||
)
|
||||
|
||||
# 输入变量
|
||||
inputs: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="输入变量映射,key 为变量名,value 为变量选择器"
|
||||
)
|
||||
|
||||
# 输出变量
|
||||
output_key: str = Field(
|
||||
default="result",
|
||||
description="输出变量的键名"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="result",
|
||||
type=VariableType.STRING,
|
||||
description="转换后的结果"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(根据 output_key 动态生成)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"transform_type": "template",
|
||||
"template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}",
|
||||
"output_key": "formatted_result"
|
||||
},
|
||||
{
|
||||
"transform_type": "code",
|
||||
"code": "result = input_text.upper()",
|
||||
"inputs": {
|
||||
"input_text": "{{ sys.message }}"
|
||||
},
|
||||
"output_key": "uppercase_text"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
"""
|
||||
Transform 节点实现
|
||||
|
||||
数据转换节点,用于处理和转换数据。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransformNode(BaseNode):
|
||||
"""数据转换节点
|
||||
|
||||
配置示例:
|
||||
{
|
||||
"type": "transform",
|
||||
"config": {
|
||||
"mapping": {
|
||||
"output_field": "{{node.previous.output}}",
|
||||
"processed": "{{var.input | upper}}"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行数据转换
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
状态更新字典
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} 开始执行数据转换")
|
||||
|
||||
# 获取映射配置
|
||||
mapping = self.config.get("mapping", {})
|
||||
|
||||
# 执行数据转换
|
||||
transformed_data = {}
|
||||
for target_key, source_template in mapping.items():
|
||||
# 渲染模板获取值
|
||||
value = self._render_template(str(source_template), state)
|
||||
transformed_data[target_key] = value
|
||||
|
||||
logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}")
|
||||
|
||||
return {
|
||||
"node_outputs": {
|
||||
self.node_id: {
|
||||
"output": transformed_data,
|
||||
"status": "completed"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class VariableAggregatorNodeConfig(BaseNodeConfig):
|
||||
@@ -14,6 +15,11 @@ class VariableAggregatorNodeConfig(BaseNodeConfig):
|
||||
description="需要被聚合的变量"
|
||||
)
|
||||
|
||||
group_type: dict[str, VariableType] = Field(
|
||||
default=None,
|
||||
description="每个分组的变量类型"
|
||||
)
|
||||
|
||||
@field_validator("group_variables")
|
||||
@classmethod
|
||||
def group_variables_validator(cls, v, info):
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Any
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,6 +16,17 @@ class VariableAggregatorNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: VariableAggregatorNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
config = VariableAggregatorNodeConfig(**self.config)
|
||||
output = {}
|
||||
if not config.group_type:
|
||||
for group_name in config.group_variables.keys():
|
||||
output[group_name] = VariableType.ANY
|
||||
return output
|
||||
for var_type in config.group_type:
|
||||
output[var_type] = config.group_type[var_type]
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _get_express(variable_string: str) -> Any:
|
||||
"""
|
||||
@@ -29,7 +42,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
expression = re.sub(pattern, r"\1", variable_string).strip()
|
||||
return expression
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the variable aggregation logic.
|
||||
|
||||
@@ -45,7 +58,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
for variable in self.typed_config.group_variables:
|
||||
var_express = self._get_express(variable)
|
||||
try:
|
||||
value = self.get_variable(var_express, state)
|
||||
value = self.get_variable(var_express, variable_pool)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get variable '{var_express}': {e}")
|
||||
continue
|
||||
@@ -55,7 +68,9 @@ class VariableAggregatorNode(BaseNode):
|
||||
return value
|
||||
|
||||
logger.info("No variable found in non-group mode; returning empty string.")
|
||||
return ""
|
||||
if not self.typed_config.group_type:
|
||||
return ""
|
||||
return DEFAULT_VALUE(self.typed_config.group_type["output"])
|
||||
|
||||
# --------------------------
|
||||
# Group mode
|
||||
@@ -65,7 +80,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
for variable in variables:
|
||||
var_express = self._get_express(variable)
|
||||
try:
|
||||
value = self.get_variable(var_express, state)
|
||||
value = self.get_variable(var_express, variable_pool)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get variable '{var_express}' in group '{group_name}': {e}")
|
||||
continue
|
||||
@@ -74,7 +89,10 @@ class VariableAggregatorNode(BaseNode):
|
||||
result[group_name] = value
|
||||
break
|
||||
else:
|
||||
result[group_name] = ""
|
||||
if not self.typed_config.group_type:
|
||||
result[group_name] = ""
|
||||
else:
|
||||
result[group_name] = DEFAULT_VALUE(self.typed_config.group_type[group_name])
|
||||
logger.info(f"No variable found for group '{group_name}'; set empty string.")
|
||||
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
|
||||
return result
|
||||
|
||||
@@ -43,7 +43,7 @@ class TemplateRenderer:
|
||||
def render(
|
||||
self,
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
@@ -51,7 +51,7 @@ class TemplateRenderer:
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
variables: 用户定义的变量
|
||||
conv_vars: 会话变量
|
||||
node_outputs: 节点输出结果
|
||||
system_vars: 系统变量
|
||||
|
||||
@@ -80,20 +80,11 @@ class TemplateRenderer:
|
||||
'分析结果: 正面情绪'
|
||||
"""
|
||||
# 构建命名空间上下文
|
||||
# variables 的结构:{"sys": {...}, "conv": {...}}
|
||||
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
|
||||
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
|
||||
if self.strict:
|
||||
context = defaultdict(dict)
|
||||
context["conv"] = conv_vars
|
||||
context["node"] = node_outputs
|
||||
context["sys"] = {**(system_vars or {}), **sys_vars}
|
||||
else:
|
||||
context = {
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
|
||||
}
|
||||
context = {
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": system_vars, # 系统变量:{{sys.execution_id}}
|
||||
}
|
||||
|
||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||
# 将所有节点输出添加到顶层上下文
|
||||
@@ -157,9 +148,9 @@ _default_renderer = TemplateRenderer(strict=True)
|
||||
|
||||
def render_template(
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None,
|
||||
system_vars: dict[str, Any],
|
||||
strict: bool = True
|
||||
) -> str:
|
||||
"""渲染模板(便捷函数)
|
||||
@@ -167,7 +158,7 @@ def render_template(
|
||||
Args:
|
||||
strict: 严格模式
|
||||
template: 模板字符串
|
||||
variables: 用户变量
|
||||
conv_vars: 会话变量
|
||||
node_outputs: 节点输出
|
||||
system_vars: 系统变量
|
||||
|
||||
@@ -184,7 +175,7 @@ def render_template(
|
||||
'请分析: 这是一段文本'
|
||||
"""
|
||||
renderer = TemplateRenderer(strict=strict)
|
||||
return renderer.render(template, variables, node_outputs, system_vars)
|
||||
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
||||
|
||||
|
||||
def validate_template(template: str) -> list[str]:
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, TYPE_CHECKING
|
||||
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.workflow_schema import WorkflowConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -64,7 +67,7 @@ class WorkflowValidator:
|
||||
return cycle_nodes, cycle_edges
|
||||
|
||||
@classmethod
|
||||
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
|
||||
def get_subgraph(cls, workflow_config: Union[dict[str, Any], "WorkflowConfig"]) -> list:
|
||||
if not isinstance(workflow_config, dict):
|
||||
workflow_config = {
|
||||
"nodes": workflow_config.nodes,
|
||||
@@ -331,7 +334,7 @@ class WorkflowValidator:
|
||||
|
||||
|
||||
def validate_workflow_config(
|
||||
workflow_config: dict[str, Any],
|
||||
workflow_config: Union[dict[str, Any], 'WorkflowConfig'],
|
||||
for_publish: bool = False
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置(便捷函数)
|
||||
|
||||
170
api/app/core/workflow/variable/base_variable.py
Normal file
170
api/app/core/workflow/variable/base_variable.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from enum import StrEnum
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.schemas import FileType
|
||||
|
||||
|
||||
class VariableType(StrEnum):
|
||||
"""Enumeration of supported variable types in the workflow."""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
OBJECT = "object"
|
||||
FILE = "file"
|
||||
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
|
||||
NESTED_ARRAY = "array_nest"
|
||||
|
||||
ANY = 'any'
|
||||
|
||||
@classmethod
|
||||
def type_map(cls, var: Any) -> "VariableType":
|
||||
"""Maps a Python value to a corresponding VariableType.
|
||||
|
||||
Args:
|
||||
var: The Python value to map.
|
||||
|
||||
Returns:
|
||||
The VariableType corresponding to the input value.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of the input value is not supported.
|
||||
"""
|
||||
var_type = type(var)
|
||||
if isinstance(var_type, str):
|
||||
return cls.STRING
|
||||
elif isinstance(var_type, (int, float)):
|
||||
return cls.NUMBER
|
||||
elif isinstance(var_type, bool):
|
||||
return cls.BOOLEAN
|
||||
elif isinstance(var_type, FileObject) or (isinstance(var, dict) and var.get('__file')):
|
||||
return cls.FILE
|
||||
elif isinstance(var_type, dict):
|
||||
return cls.OBJECT
|
||||
elif isinstance(var_type, list):
|
||||
if len(var) == 0:
|
||||
return cls.ARRAY_STRING
|
||||
else:
|
||||
child_type = type(var[0])
|
||||
if child_type == str:
|
||||
return cls.ARRAY_STRING
|
||||
elif child_type == int or child_type == float:
|
||||
return cls.ARRAY_NUMBER
|
||||
elif child_type == bool:
|
||||
return cls.ARRAY_BOOLEAN
|
||||
elif child_type == dict:
|
||||
return cls.ARRAY_OBJECT
|
||||
elif child_type == list:
|
||||
return cls.NESTED_ARRAY
|
||||
else:
|
||||
raise TypeError(f"Unsupported array child type - {child_type}")
|
||||
raise TypeError(f"Unsupported type - {var_type}")
|
||||
|
||||
|
||||
def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
||||
"""Returns the default value for a given VariableType.
|
||||
|
||||
Args:
|
||||
var_type: The variable type for which to get the default value.
|
||||
|
||||
Returns:
|
||||
The default Python value corresponding to the VariableType.
|
||||
|
||||
Raises:
|
||||
TypeError: If the VariableType is invalid.
|
||||
"""
|
||||
match var_type:
|
||||
case VariableType.STRING:
|
||||
return ""
|
||||
case VariableType.NUMBER:
|
||||
return 0
|
||||
case VariableType.BOOLEAN:
|
||||
return False
|
||||
case VariableType.OBJECT:
|
||||
return {}
|
||||
case VariableType.FILE:
|
||||
return None
|
||||
case VariableType.ARRAY_STRING:
|
||||
return []
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
return []
|
||||
case VariableType.ARRAY_BOOLEAN:
|
||||
return []
|
||||
case VariableType.ARRAY_OBJECT:
|
||||
return []
|
||||
case VariableType.ARRAY_FILE:
|
||||
return []
|
||||
case _:
|
||||
raise TypeError(f"Invalid type - {type}")
|
||||
|
||||
|
||||
class FileObject(BaseModel):
|
||||
type: FileType
|
||||
url: str
|
||||
__file: bool
|
||||
|
||||
|
||||
class BaseVariable(ABC):
|
||||
"""Abstract base class for all workflow variables.
|
||||
|
||||
Subclasses must implement validation and serialization methods.
|
||||
"""
|
||||
type = None
|
||||
|
||||
def __init__(self, value: Any):
|
||||
"""Initializes a variable instance.
|
||||
|
||||
Args:
|
||||
value: The initial value for the variable.
|
||||
|
||||
Attributes:
|
||||
self.value: The validated value stored in the variable.
|
||||
self.literal: A string representation of the variable.
|
||||
"""
|
||||
self.value = self.valid_value(value)
|
||||
self.literal = self.to_literal()
|
||||
|
||||
@abstractmethod
|
||||
def valid_value(self, value) -> Any:
|
||||
"""Validates or converts a value to the correct type for the variable.
|
||||
|
||||
Args:
|
||||
value: The value to validate.
|
||||
|
||||
Returns:
|
||||
The validated or converted value.
|
||||
|
||||
Raises:
|
||||
TypeError: If the value is invalid.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_literal(self) -> str:
|
||||
"""Converts the variable value to a string literal representation.
|
||||
|
||||
Returns:
|
||||
A string representing the variable's value.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_value(self) -> Any:
|
||||
"""Returns the current value of the variable."""
|
||||
return self.value
|
||||
|
||||
def set(self, value):
|
||||
"""Sets the variable to a new value after validation.
|
||||
|
||||
Args:
|
||||
value: The new value to assign to the variable.
|
||||
"""
|
||||
self.value = self.valid_value(value)
|
||||
174
api/app/core/workflow/variable/variable_objects.py
Normal file
174
api/app/core/workflow/variable/variable_objects.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from typing import Any, TypeVar, Type, Generic
|
||||
|
||||
from deprecated import deprecated
|
||||
|
||||
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType
|
||||
|
||||
T = TypeVar("T", bound=BaseVariable)
|
||||
|
||||
|
||||
class StringVariable(BaseVariable):
|
||||
type = 'str'
|
||||
|
||||
def valid_value(self, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(f"Value must be a string - {type(value)}:{value}")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class NumberVariable(BaseVariable):
|
||||
type = 'number'
|
||||
|
||||
def valid_value(self, value) -> int | float:
|
||||
if not isinstance(value, (int, float)):
|
||||
raise TypeError(f"Value must be a number - {type(value)}:{value}")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class BooleanVariable(BaseVariable):
|
||||
type = 'boolean'
|
||||
|
||||
def valid_value(self, value) -> bool:
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError(f"Value must be a boolean - {type(value)}:{value}")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return str(self.value).lower()
|
||||
|
||||
|
||||
class DictVariable(BaseVariable):
|
||||
type = 'object'
|
||||
|
||||
def valid_value(self, value) -> dict:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class FileVariable(BaseVariable):
|
||||
type = 'file'
|
||||
|
||||
def valid_value(self, value) -> FileObject:
|
||||
|
||||
if isinstance(value, dict):
|
||||
if not value.get("__file"):
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
return FileObject(
|
||||
**{
|
||||
"type": str(value.get('type')),
|
||||
"url": value.get('url'),
|
||||
"__file": True
|
||||
}
|
||||
)
|
||||
if isinstance(value, FileObject):
|
||||
return value
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})'
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return self.value.model_dump()
|
||||
|
||||
|
||||
class ArrayObject(BaseVariable, Generic[T]):
|
||||
type = 'array'
|
||||
|
||||
def __init__(self, child_type: Type[T], value: list[Any]):
|
||||
if not issubclass(child_type, BaseVariable):
|
||||
raise TypeError("child_type must be a subclass of BaseVariable")
|
||||
self.child_type = child_type
|
||||
super().__init__(value)
|
||||
|
||||
def valid_value(self, value: list[Any]) -> list[T]:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Value must be a list - {type(value)}:{value}")
|
||||
final_value = []
|
||||
for v in value:
|
||||
try:
|
||||
final_value.append(self.child_type(v))
|
||||
except:
|
||||
raise TypeError(f"All elements must be of type {self.child_type.type}")
|
||||
return final_value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return "\n".join([v.to_literal() for v in self.value])
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return [v.get_value() for v in self.value]
|
||||
|
||||
|
||||
class NestedArrayObject(BaseVariable):
|
||||
type = 'array_nest'
|
||||
|
||||
def valid_value(self, value: list[T]) -> list[T]:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Value must be a list - {type(value)}:{value}")
|
||||
final_value = []
|
||||
for v in value:
|
||||
if not isinstance(v, ArrayObject):
|
||||
raise TypeError("All elements must be of type list")
|
||||
final_value.append(v)
|
||||
return final_value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value])
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return [[item.get_value() for item in row] for row in self.value]
|
||||
|
||||
|
||||
@deprecated(
|
||||
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
|
||||
category=RuntimeWarning
|
||||
)
|
||||
class AnyObject(BaseVariable):
|
||||
type = 'any'
|
||||
|
||||
def valid_value(self, value: Any) -> Any:
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
|
||||
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
|
||||
"""简化 ArrayObject 创建,不需要重复写类型"""
|
||||
|
||||
return ArrayObject(child_type, value)
|
||||
|
||||
|
||||
def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||
match var_type:
|
||||
case VariableType.STRING:
|
||||
return StringVariable(value)
|
||||
case VariableType.NUMBER:
|
||||
return NumberVariable(value)
|
||||
case VariableType.BOOLEAN:
|
||||
return BooleanVariable(value)
|
||||
case VariableType.OBJECT:
|
||||
return DictVariable(value)
|
||||
case VariableType.ARRAY_STRING:
|
||||
return make_array(StringVariable, value)
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
return make_array(NumberVariable, value)
|
||||
case VariableType.ARRAY_BOOLEAN:
|
||||
return make_array(BooleanVariable, value)
|
||||
case VariableType.ARRAY_OBJECT:
|
||||
return make_array(DictVariable, value)
|
||||
case VariableType.ARRAY_FILE:
|
||||
return make_array(FileVariable, value)
|
||||
case VariableType.ANY:
|
||||
return AnyObject(value)
|
||||
case _:
|
||||
raise TypeError(f"Invalid type - {var_type}")
|
||||
@@ -11,10 +11,15 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generic
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,11 +28,6 @@ class VariableSelector:
|
||||
"""变量选择器
|
||||
|
||||
用于引用变量的路径表示。
|
||||
|
||||
Examples:
|
||||
>>> selector = VariableSelector(["sys", "message"])
|
||||
>>> selector = VariableSelector(["node_A", "output"])
|
||||
>>> selector = VariableSelector.from_string("sys.message")
|
||||
"""
|
||||
|
||||
def __init__(self, path: list[str]):
|
||||
@@ -52,10 +52,6 @@ class VariableSelector:
|
||||
|
||||
Returns:
|
||||
VariableSelector 实例
|
||||
|
||||
Examples:
|
||||
>>> selector = VariableSelector.from_string("sys.message")
|
||||
>>> selector = VariableSelector.from_string("llm_qa.output")
|
||||
"""
|
||||
path = selector_str.split(".")
|
||||
return cls(path)
|
||||
@@ -67,160 +63,212 @@ class VariableSelector:
|
||||
return f"VariableSelector({self.path})"
|
||||
|
||||
|
||||
class VariableStruct(BaseModel, Generic[T]):
|
||||
"""A typed variable struct.
|
||||
|
||||
Represents a runtime variable with an associated logical type and
|
||||
a concrete value object.
|
||||
|
||||
This class bridges the static type system (via generics) and the
|
||||
runtime type system (via ``VariableType``).
|
||||
|
||||
Attributes:
|
||||
type:
|
||||
Logical variable type descriptor used for runtime validation,
|
||||
serialization, and workflow type checking.
|
||||
instance:
|
||||
The concrete variable object. The actual Python type is
|
||||
represented by the generic parameter ``T`` (e.g. StringVariable,
|
||||
NumberVariable, ArrayObject[StringVariable]).
|
||||
mut:
|
||||
Whether the variable is mutable.
|
||||
"""
|
||||
type: VariableType
|
||||
instance: T
|
||||
mut: bool
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True
|
||||
}
|
||||
|
||||
|
||||
class VariablePool:
|
||||
"""变量池
|
||||
|
||||
管理工作流执行过程中的所有变量。
|
||||
|
||||
变量命名空间:
|
||||
- sys.*: 系统变量(message, execution_id, workspace_id, user_id, conversation_id)
|
||||
- conv.*: 会话变量(跨多轮对话保持的变量)
|
||||
- <node_id>.*: 节点输出
|
||||
|
||||
Examples:
|
||||
>>> pool = VariablePool(state)
|
||||
>>> pool.get(["sys", "message"])
|
||||
"用户的问题"
|
||||
>>> pool.get(["llm_qa", "output"])
|
||||
"AI 的回答"
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
"""Variable pool.
|
||||
|
||||
Manages all variables during workflow execution, including storage,
|
||||
namespacing, and concurrency control.
|
||||
|
||||
Variable namespace conventions:
|
||||
- ``sys.*``:
|
||||
System variables (e.g. message, execution_id, workspace_id,
|
||||
user_id, conversation_id).
|
||||
- ``conv.*``:
|
||||
Conversation-level variables that persist across multiple turns.
|
||||
- ``<node_id>.*``:
|
||||
Variables produced by workflow nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, state: "WorkflowState"):
|
||||
"""初始化变量池
|
||||
|
||||
Args:
|
||||
state: 工作流状态(LangGraph State)
|
||||
"""
|
||||
self.state = state
|
||||
def __init__(self):
|
||||
"""Initialize the variable pool.
|
||||
|
||||
Attributes:
|
||||
self.locks:
|
||||
A per-key lock table used for fine-grained concurrency control.
|
||||
|
||||
self.variables:
|
||||
Storage for all variables managed by the pool.
|
||||
"""
|
||||
self.locks = defaultdict(Lock)
|
||||
self.variables: dict[str, dict[str, VariableStruct[Any]]] = {}
|
||||
|
||||
@staticmethod
|
||||
def transform_selector(selector):
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
||||
selector = VariableSelector.from_string(variable_literal).path
|
||||
if len(selector) != 2:
|
||||
raise ValueError(f"Selector not valid - {selector}")
|
||||
return selector
|
||||
|
||||
def _get_variable_struct(
|
||||
self,
|
||||
selector: str
|
||||
) -> VariableStruct[T] | None:
|
||||
"""Retrieve a variable struct from the variable pool.
|
||||
|
||||
def get(self, selector: list[str] | str, default: Any = None) -> Any:
|
||||
"""获取变量值
|
||||
|
||||
Args:
|
||||
selector: 变量选择器,可以是列表或字符串
|
||||
default: 默认值(变量不存在时返回)
|
||||
|
||||
selector:
|
||||
Variable selector, either:
|
||||
- A string variable literal (e.g. "{{ sys.message }}")
|
||||
|
||||
Returns:
|
||||
变量值
|
||||
|
||||
Examples:
|
||||
>>> pool.get(["sys", "message"])
|
||||
>>> pool.get("sys.message")
|
||||
>>> pool.get(["llm_qa", "output"])
|
||||
>>> pool.get("llm_qa.output")
|
||||
|
||||
Raises:
|
||||
KeyError: 变量不存在且未提供默认值
|
||||
The variable's struct if it exists; otherwise returns None.
|
||||
"""
|
||||
# 转换为 VariableSelector
|
||||
if isinstance(selector, str):
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
||||
selector = VariableSelector.from_string(variable_literal).path
|
||||
|
||||
if not selector or len(selector) < 1:
|
||||
raise ValueError("变量选择器不能为空")
|
||||
selector = self.transform_selector(selector)
|
||||
|
||||
namespace = selector[0]
|
||||
variable_name = selector[1]
|
||||
|
||||
try:
|
||||
# 系统变量
|
||||
if namespace == "sys":
|
||||
key = selector[1] if len(selector) > 1 else None
|
||||
if not key:
|
||||
return self.state.get("variables", {}).get("sys", {})
|
||||
return self.state.get("variables", {}).get("sys", {}).get(key, default)
|
||||
namespace_variables = self.variables.get(namespace)
|
||||
if namespace_variables is None:
|
||||
return None
|
||||
|
||||
# 会话变量
|
||||
elif namespace == "conv":
|
||||
key = selector[1] if len(selector) > 1 else None
|
||||
if not key:
|
||||
return self.state.get("variables", {}).get("conv", {})
|
||||
return self.state.get("variables", {}).get("conv", {}).get(key, default)
|
||||
var_instance = namespace_variables.get(variable_name)
|
||||
if var_instance is None:
|
||||
return None
|
||||
return var_instance
|
||||
|
||||
# 节点输出(从 runtime_vars 读取)
|
||||
else:
|
||||
node_id = namespace
|
||||
runtime_vars = self.state.get("runtime_vars", {})
|
||||
def get_value(
|
||||
self,
|
||||
selector: str,
|
||||
default: Any = None,
|
||||
strict: bool = True,
|
||||
) -> Any:
|
||||
"""Retrieve a variable value from the variable pool.
|
||||
|
||||
if node_id not in runtime_vars:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"节点 '{node_id}' 的输出不存在")
|
||||
Args:
|
||||
selector:
|
||||
Variable selector, either:
|
||||
- A list of path components (e.g. ["sys", "message"])
|
||||
- A string variable literal (e.g. "{{ sys.message }}")
|
||||
default:
|
||||
The value to return if the variable does not exist.
|
||||
strict:
|
||||
If True, raises KeyError when the variable does not exist.
|
||||
|
||||
node_var = runtime_vars[node_id]
|
||||
Returns:
|
||||
The variable's value if it exists; otherwise returns `default`.
|
||||
|
||||
# 如果只有节点 ID,返回整个变量
|
||||
if len(selector) == 1:
|
||||
return node_var
|
||||
Raises:
|
||||
KeyError: If strict is True and the variable does not exist.
|
||||
"""
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
if strict:
|
||||
raise KeyError(f"{selector} not exist")
|
||||
return default
|
||||
|
||||
# 获取特定字段
|
||||
# 支持嵌套访问,如 node_id.field.subfield
|
||||
result = node_var
|
||||
for k in selector[1:]:
|
||||
if isinstance(result, dict):
|
||||
result = result.get(k)
|
||||
if result is None:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"字段 '{'.'.join(selector)}' 不存在")
|
||||
else:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
|
||||
return variable_struct.instance.get_value()
|
||||
|
||||
return result
|
||||
def get_literal(
|
||||
self,
|
||||
selector: str,
|
||||
default: Any = None,
|
||||
strict: bool = True,
|
||||
) -> Any:
|
||||
"""Retrieve a variable value from the variable pool.
|
||||
|
||||
except KeyError:
|
||||
if default is not None:
|
||||
return default
|
||||
raise
|
||||
Args:
|
||||
selector:
|
||||
Variable selector, either:
|
||||
- A list of path components (e.g. ["sys", "message"])
|
||||
- A string variable literal (e.g. "{{ sys.message }}")
|
||||
default:
|
||||
The value to return if the variable does not exist.
|
||||
strict:
|
||||
If True, raises KeyError when the variable does not exist.
|
||||
|
||||
def set(self, selector: list[str] | str, value: Any):
|
||||
Returns:
|
||||
The variable's value if it exists; otherwise returns `default`.
|
||||
|
||||
Raises:
|
||||
KeyError: If strict is True and the variable does not exist.
|
||||
"""
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
if strict:
|
||||
raise KeyError(f"{selector} not exist")
|
||||
return default
|
||||
|
||||
return variable_struct.instance.to_literal()
|
||||
|
||||
async def set(
|
||||
self,
|
||||
selector: str,
|
||||
value: Any
|
||||
):
|
||||
"""设置变量值
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
value: 变量值
|
||||
|
||||
Examples:
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
>>> pool.set("conv.user_name", "张三")
|
||||
|
||||
|
||||
Note:
|
||||
- 只能设置会话变量 (conv.*)
|
||||
- 系统变量和节点输出是只读的
|
||||
"""
|
||||
# 转换为 VariableSelector
|
||||
if isinstance(selector, str):
|
||||
selector = VariableSelector.from_string(selector).path
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
raise KeyError(f"Variable {selector} is not defined")
|
||||
if not variable_struct.mut:
|
||||
raise KeyError(f"{selector} cannot be modified")
|
||||
async with self.locks[selector]:
|
||||
variable_struct.instance.set(value)
|
||||
|
||||
if not selector or len(selector) < 2:
|
||||
raise ValueError("变量选择器必须包含命名空间和键名")
|
||||
async def new(
|
||||
self,
|
||||
namespace: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
var_type: VariableType,
|
||||
mut: bool
|
||||
):
|
||||
if self.has(f"{namespace}.{key}"):
|
||||
try:
|
||||
await self.set(f"{namespace}.{key}", value)
|
||||
except KeyError:
|
||||
pass
|
||||
instance = create_variable_instance(var_type, value)
|
||||
variable_struct = VariableStruct(type=var_type, instance=instance, mut=mut)
|
||||
namespace_variable = self.variables.get(namespace)
|
||||
if namespace_variable is None:
|
||||
self.variables[namespace] = {
|
||||
key: variable_struct
|
||||
}
|
||||
else:
|
||||
self.variables[namespace][key] = variable_struct
|
||||
|
||||
namespace = selector[0]
|
||||
|
||||
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
|
||||
raise ValueError("Only conversation or cycle variables can be assigned.")
|
||||
|
||||
key = selector[1]
|
||||
|
||||
# 确保 variables 结构存在
|
||||
if "variables" not in self.state:
|
||||
self.state["variables"] = {"sys": {}, "conv": {}}
|
||||
if namespace == "conv":
|
||||
if "conv" not in self.state["variables"]:
|
||||
self.state["variables"]["conv"] = {}
|
||||
|
||||
# 设置值
|
||||
self.state["variables"]["conv"][key] = value
|
||||
elif namespace in self.state["cycle_nodes"]:
|
||||
self.state["runtime_vars"][namespace][key] = value
|
||||
|
||||
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
|
||||
|
||||
def has(self, selector: list[str] | str) -> bool:
|
||||
def has(self, selector: str) -> bool:
|
||||
"""检查变量是否存在
|
||||
|
||||
Args:
|
||||
@@ -228,18 +276,8 @@ class VariablePool:
|
||||
|
||||
Returns:
|
||||
变量是否存在
|
||||
|
||||
Examples:
|
||||
>>> pool.has(["sys", "message"])
|
||||
True
|
||||
>>> pool.has("llm_qa.output")
|
||||
False
|
||||
"""
|
||||
try:
|
||||
self.get(selector)
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
return self._get_variable_struct(selector) is not None
|
||||
|
||||
def get_all_system_vars(self) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
@@ -247,7 +285,8 @@ class VariablePool:
|
||||
Returns:
|
||||
系统变量字典
|
||||
"""
|
||||
return self.state.get("variables", {}).get("sys", {})
|
||||
sys_namespace = self.variables.get("sys", {})
|
||||
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
|
||||
|
||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||
"""获取所有会话变量
|
||||
@@ -255,7 +294,8 @@ class VariablePool:
|
||||
Returns:
|
||||
会话变量字典
|
||||
"""
|
||||
return self.state.get("variables", {}).get("conv", {})
|
||||
conv_namespace = self.variables.get("conv", {})
|
||||
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
|
||||
|
||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||
"""获取所有节点输出(运行时变量)
|
||||
@@ -263,18 +303,37 @@ class VariablePool:
|
||||
Returns:
|
||||
节点输出字典,键为节点 ID
|
||||
"""
|
||||
return self.state.get("runtime_vars", {})
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.get_value()
|
||||
for k, v in vars_dict.items()
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
if namespace not in ("sys", "conv")
|
||||
}
|
||||
return runtime_vars
|
||||
|
||||
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
|
||||
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
||||
"""获取指定节点的输出(运行时变量)
|
||||
|
||||
Args:
|
||||
node_id: 节点 ID
|
||||
defalut: 默认值
|
||||
strict: 是否严格模式
|
||||
|
||||
Returns:
|
||||
节点输出或 None
|
||||
"""
|
||||
return self.state.get("runtime_vars", {}).get(node_id)
|
||||
node_namespace = self.variables.get(node_id)
|
||||
if node_namespace:
|
||||
return {k: v.instance.get_value() for k, v in node_namespace.items()}
|
||||
if strict:
|
||||
raise KeyError(f"node {node_id} output not exist")
|
||||
else:
|
||||
return defalut
|
||||
|
||||
def copy(self, pool: 'VariablePool'):
|
||||
self.variables = deepcopy(pool.variables)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""导出为字典
|
||||
|
||||
@@ -50,13 +50,16 @@ async def lifespan(app: FastAPI):
|
||||
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
|
||||
|
||||
# 加载预定义模型
|
||||
logger.info("开始加载预定义模型...")
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
result = load_models(db, silent=True)
|
||||
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
||||
if settings.LOAD_MODEL:
|
||||
logger.info("开始加载预定义模型...")
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
result = load_models(db, silent=True)
|
||||
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
||||
else:
|
||||
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||||
|
||||
logger.info("应用程序启动完成")
|
||||
yield
|
||||
@@ -77,10 +80,14 @@ default_origins = [
|
||||
]
|
||||
allowed_origins = list({o for o in (default_origins + settings.CORS_ORIGINS) if o})
|
||||
|
||||
# 如果 CORS_ORIGINS 包含 "*",则允许所有来源
|
||||
if "*" in settings.CORS_ORIGINS:
|
||||
allowed_origins = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allowed_origins,
|
||||
allow_credentials=True,
|
||||
allow_credentials=True if "*" not in allowed_origins else False, # 允许所有来源时不能使用 credentials
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ from .tool_model import (
|
||||
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
||||
)
|
||||
from .memory_perceptual_model import MemoryPerceptualModel
|
||||
from .skill_model import Skill
|
||||
from .ontology_scene import OntologyScene
|
||||
from .ontology_class import OntologyClass
|
||||
from .ontology_scene import OntologyScene
|
||||
@@ -84,5 +85,6 @@ __all__ = [
|
||||
"ExecutionStatus",
|
||||
"MemoryPerceptualModel",
|
||||
"ModelBase",
|
||||
"LoadBalanceStrategy"
|
||||
"LoadBalanceStrategy",
|
||||
"Skill"
|
||||
]
|
||||
|
||||
@@ -29,7 +29,8 @@ class AgentConfig(Base):
|
||||
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
|
||||
memory = Column(JSON, nullable=True, comment="记忆配置")
|
||||
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
||||
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
|
||||
tools = Column(JSON, default=list, nullable=True, comment="工具配置")
|
||||
skills = Column(JSON, default=dict, nullable=True, comment="技能配置")
|
||||
|
||||
# 多 Agent 相关字段
|
||||
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
|
||||
|
||||
37
api/app/models/skill_model.py
Normal file
37
api/app/models/skill_model.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Skill 模型定义"""
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class Skill(Base):
|
||||
"""技能模型 - 可以关联工具(内置、MCP、自定义)"""
|
||||
__tablename__ = "skills"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
name = Column(String, nullable=False, comment="技能名称")
|
||||
description = Column(Text, comment="技能描述")
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID")
|
||||
|
||||
# 关联的工具
|
||||
tools = Column(JSON, default=list, comment="关联的工具列表")
|
||||
|
||||
# 技能配置
|
||||
config = Column(JSON, default=dict, comment="技能配置")
|
||||
|
||||
# 专属提示词
|
||||
prompt = Column(Text, comment="技能专属提示词")
|
||||
|
||||
# 状态
|
||||
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||
is_public = Column(Boolean, default=False, nullable=False, comment="是否公开到市场")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Skill(id={self.id}, name={self.name})>"
|
||||
@@ -235,6 +235,8 @@ class MemoryConfigRepository:
|
||||
llm_id=params.llm_id,
|
||||
embedding_id=params.embedding_id,
|
||||
rerank_id=params.rerank_id,
|
||||
reflection_model_id=params.reflection_model_id,
|
||||
emotion_model_id=params.emotion_model_id,
|
||||
)
|
||||
db.add(db_config)
|
||||
db.flush() # 获取自增ID但不提交事务
|
||||
|
||||
@@ -583,7 +583,7 @@ class ModelApiKeyRepository:
|
||||
db_api_key.usage_count = str(current_count + 1)
|
||||
db_api_key.last_used_at = func.now()
|
||||
|
||||
db.commit()
|
||||
db.flush()
|
||||
db_logger.debug(f"API Key使用统计更新成功: api_key_id={api_key_id}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -207,4 +207,4 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
except Exception as e:
|
||||
print(f"Neo4j integration error: {e}")
|
||||
print("Continuing without database storage...")
|
||||
return False
|
||||
return False
|
||||
111
api/app/repositories/skill_repository.py
Normal file
111
api/app/repositories/skill_repository.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Skill Repository"""
|
||||
from typing import List, Optional, Tuple, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
import uuid
|
||||
|
||||
from app.models.skill_model import Skill
|
||||
from app.schemas.skill_schema import SkillCreate, SkillUpdate
|
||||
|
||||
|
||||
class SkillRepository:
|
||||
"""Skill 数据访问层"""
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
|
||||
"""创建技能"""
|
||||
skill = Skill(
|
||||
**data.model_dump(),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
db.add(skill)
|
||||
db.flush()
|
||||
return skill
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, skill_id: uuid.UUID, tenant_id: Optional[uuid.UUID] = None) -> Optional[Skill]:
|
||||
"""根据ID获取技能"""
|
||||
query = db.query(Skill).filter(Skill.id == skill_id)
|
||||
if tenant_id:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Skill.tenant_id == tenant_id,
|
||||
Skill.is_public == True
|
||||
)
|
||||
)
|
||||
return query.first()
|
||||
|
||||
@staticmethod
|
||||
def list_skills(
|
||||
db: Session,
|
||||
tenant_id: uuid.UUID,
|
||||
search: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
is_public: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 10
|
||||
) -> tuple[list[type[Skill]], int]:
|
||||
"""列出技能"""
|
||||
filters = [
|
||||
or_(
|
||||
Skill.tenant_id == tenant_id,
|
||||
Skill.is_public == True
|
||||
)
|
||||
]
|
||||
|
||||
if search:
|
||||
filters.append(
|
||||
or_(
|
||||
Skill.name.ilike(f"%{search}%"),
|
||||
# Skill.description.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
filters.append(Skill.is_active == is_active)
|
||||
|
||||
if is_public is not None:
|
||||
filters.append(Skill.is_public == is_public)
|
||||
|
||||
query = db.query(Skill).filter(and_(*filters))
|
||||
total = query.count()
|
||||
|
||||
skills = query.order_by(Skill.created_at.desc()).offset(
|
||||
(page - 1) * pagesize
|
||||
).limit(pagesize).all()
|
||||
|
||||
return skills, total
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Optional[Skill]:
|
||||
"""更新技能"""
|
||||
skill = db.query(Skill).filter(
|
||||
Skill.id == skill_id,
|
||||
Skill.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
if not skill:
|
||||
return None
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(skill, key, value)
|
||||
|
||||
db.flush()
|
||||
return skill
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
|
||||
"""删除技能"""
|
||||
skill = db.query(Skill).filter(
|
||||
Skill.id == skill_id,
|
||||
Skill.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
if not skill:
|
||||
return False
|
||||
|
||||
# db.delete(skill)
|
||||
skill.is_active = False
|
||||
db.flush()
|
||||
return True
|
||||
@@ -1,14 +1,14 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Optional, Any, List, Dict, Union
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||
|
||||
|
||||
# ---------- Multimodal File Support ----------
|
||||
|
||||
class FileType(str, Enum):
|
||||
class FileType(StrEnum):
|
||||
"""文件类型枚举"""
|
||||
IMAGE = "image"
|
||||
DOCUMENT = "document"
|
||||
@@ -82,6 +82,12 @@ class ToolConfig(BaseModel):
|
||||
tool_id: Optional[str] = Field(default=None, description="工具ID")
|
||||
operation: Optional[str] = Field(default=None, description="工具特定配置")
|
||||
|
||||
class SkillConfig(BaseModel):
|
||||
"""技能配置"""
|
||||
enabled: bool = Field(default=True, description="是否启用该技能")
|
||||
skill_ids: Optional[list[str]] = Field(default=list, description="技能ID列表")
|
||||
all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能")
|
||||
|
||||
|
||||
class ToolOldConfig(BaseModel):
|
||||
"""工具配置"""
|
||||
@@ -92,7 +98,7 @@ class ToolOldConfig(BaseModel):
|
||||
class MemoryConfig(BaseModel):
|
||||
"""记忆配置"""
|
||||
enabled: bool = Field(default=True, description="是否启用对话历史记忆")
|
||||
memory_content: Optional[str] = Field(default=None, description="选择记忆的内容类型")
|
||||
memory_config_id: Optional[str] = Field(default=None, description="选择记忆的内容类型")
|
||||
max_history: int = Field(default=10, ge=0, le=100, description="最大保留的历史对话轮数")
|
||||
|
||||
|
||||
@@ -156,6 +162,9 @@ class AgentConfigCreate(BaseModel):
|
||||
description="Agent 可用的工具列表"
|
||||
)
|
||||
|
||||
# 技能配置
|
||||
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
|
||||
|
||||
|
||||
class AppCreate(BaseModel):
|
||||
name: str
|
||||
@@ -207,6 +216,9 @@ class AgentConfigUpdate(BaseModel):
|
||||
|
||||
# 工具配置
|
||||
tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表")
|
||||
|
||||
# 技能配置
|
||||
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
|
||||
|
||||
|
||||
# ---------- Output Schemas ----------
|
||||
@@ -266,6 +278,8 @@ class AgentConfig(BaseModel):
|
||||
# 工具配置
|
||||
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
|
||||
|
||||
skills: Optional[SkillConfig] = {}
|
||||
|
||||
is_active: bool
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@@ -236,6 +236,8 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
||||
|
||||
|
||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||
|
||||
@@ -21,6 +21,11 @@ class PromptOptMessage(BaseModel):
|
||||
description="currently optimized prompt"
|
||||
)
|
||||
|
||||
skill: bool = Field(
|
||||
default=False,
|
||||
description="Enable variable output"
|
||||
)
|
||||
|
||||
|
||||
class PromptSaveRequest(BaseModel):
|
||||
session_id: UUID = Field(
|
||||
|
||||
64
api/app/schemas/skill_schema.py
Normal file
64
api/app/schemas/skill_schema.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Skill Schema 定义"""
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class SkillBase(BaseModel):
|
||||
"""Skill 基础 Schema"""
|
||||
name: str = Field(..., description="技能名称")
|
||||
description: Optional[str] = Field(None, description="技能描述")
|
||||
tools: List[Dict[str, str]] = Field(default_factory=list, description="工具对象列表: [{\"tool_id\": \"xxx\", \"operation\": \"yyy\"}]")
|
||||
config: Dict[str, Any] = Field(default_factory=dict, description="技能配置")
|
||||
prompt: Optional[str] = Field(None, description="技能专属提示词")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
is_public: bool = Field(False, description="是否公开到市场")
|
||||
|
||||
|
||||
class SkillCreate(SkillBase):
|
||||
"""创建 Skill"""
|
||||
pass
|
||||
|
||||
|
||||
class SkillUpdate(BaseModel):
|
||||
"""更新 Skill"""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
tools: Optional[List[Dict[str, str]]] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
prompt: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_public: Optional[bool] = None
|
||||
|
||||
|
||||
class Skill(BaseModel):
|
||||
"""Skill 响应 Schema"""
|
||||
id: uuid.UUID
|
||||
tenant_id: uuid.UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
tools: Union[List[Dict[str, Any]], List[Dict[str, str]]] = Field(default_factory=list, description="工具列表,可以是简单格式或包含工具详情")
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
prompt: Optional[str] = None
|
||||
is_active: bool = True
|
||||
is_public: bool = False
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@field_serializer('created_at', 'updated_at')
|
||||
def serialize_datetime_to_timestamp(self, value: datetime) -> int:
|
||||
"""(毫秒级)时间戳"""
|
||||
return int(value.timestamp() * 1000)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SkillQuery(BaseModel):
|
||||
"""Skill 查询参数"""
|
||||
search: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_public: Optional[bool] = None
|
||||
page: int = Field(1, ge=1)
|
||||
pagesize: int = Field(10, ge=1, le=100)
|
||||
@@ -9,7 +9,7 @@ from app.schemas.app_schema import (
|
||||
VariableDefinition,
|
||||
ToolConfig,
|
||||
AgentConfigCreate,
|
||||
AgentConfigUpdate, ToolOldConfig,
|
||||
AgentConfigUpdate, ToolOldConfig, SkillConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,6 +48,9 @@ class AgentConfigConverter:
|
||||
# 5. 工具配置
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
result["tools"] = [tool.model_dump() for tool in config.tools]
|
||||
|
||||
if hasattr(config, "skills") and config.skills:
|
||||
result["skills"] = config.skills.model_dump()
|
||||
|
||||
return result
|
||||
|
||||
@@ -58,6 +61,7 @@ class AgentConfigConverter:
|
||||
memory: Optional[Dict[str, Any]],
|
||||
variables: Optional[list],
|
||||
tools: Optional[Union[list, Dict[str, Any]]],
|
||||
skills: Optional[dict]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将数据库存储格式转换为 Pydantic 对象
|
||||
@@ -68,6 +72,7 @@ class AgentConfigConverter:
|
||||
memory: 记忆配置
|
||||
variables: 变量配置
|
||||
tools: 工具配置
|
||||
skills: 技能列表
|
||||
|
||||
Returns:
|
||||
包含 Pydantic 对象的字典
|
||||
@@ -78,6 +83,7 @@ class AgentConfigConverter:
|
||||
"memory": MemoryConfig(enabled=True),
|
||||
"variables": [],
|
||||
"tools": [],
|
||||
"skills": SkillConfig(enabled=False, all_skills=False, skill_ids=[])
|
||||
}
|
||||
|
||||
# 1. 解析模型参数配置
|
||||
@@ -117,5 +123,10 @@ class AgentConfigConverter:
|
||||
name: ToolOldConfig(**tool_data)
|
||||
for name, tool_data in tools.items()
|
||||
}
|
||||
|
||||
if skills:
|
||||
result["skills"] = SkillConfig(**skills)
|
||||
else:
|
||||
result["skills"] = SkillConfig(enabled=False, all_skills=False, skill_ids=[])
|
||||
|
||||
return result
|
||||
|
||||
@@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
||||
memory=agent_cfg.memory,
|
||||
variables=agent_cfg.variables,
|
||||
tools=agent_cfg.tools,
|
||||
skills=agent_cfg.skills
|
||||
)
|
||||
|
||||
# 将解析后的字段添加到对象上(用于序列化)
|
||||
@@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
||||
agent_cfg.memory = parsed["memory"]
|
||||
agent_cfg.variables = parsed["variables"]
|
||||
agent_cfg.tools = parsed["tools"]
|
||||
agent_cfg.skills = parsed["skills"]
|
||||
|
||||
return agent_cfg
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -63,7 +64,7 @@ class AppChatService:
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.system_prompt
|
||||
if variables:
|
||||
@@ -79,21 +80,55 @@ class AppChatService:
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, workspace_id))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(config, 'skills') and config.skills:
|
||||
skills = config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
@@ -113,22 +148,6 @@ class AppChatService:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
|
||||
@@ -192,6 +211,8 @@ class AppChatService:
|
||||
}
|
||||
)
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
@@ -230,7 +251,7 @@ class AppChatService:
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.system_prompt
|
||||
if variables:
|
||||
@@ -246,20 +267,54 @@ class AppChatService:
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, workspace_id))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(config, 'skills') and config.skills:
|
||||
skills = config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
@@ -279,22 +334,6 @@ class AppChatService:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
|
||||
@@ -374,6 +413,8 @@ class AppChatService:
|
||||
}
|
||||
)
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
@@ -618,6 +659,7 @@ class AppChatService:
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
public=False
|
||||
|
||||
) -> AsyncGenerator[dict, None]:
|
||||
"""聊天(流式)"""
|
||||
@@ -634,7 +676,8 @@ class AppChatService:
|
||||
payload=payload,
|
||||
config=config,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release_id
|
||||
release_id=release_id,
|
||||
public=public
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
@@ -313,6 +313,7 @@ class AppService:
|
||||
memory=storage_data.get("memory"),
|
||||
variables=storage_data.get("variables", []),
|
||||
tools=storage_data.get("tools", []),
|
||||
skills=storage_data.get("skills", {}),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -916,6 +917,7 @@ class AppService:
|
||||
agent_cfg.variables = storage_data.get("variables", [])
|
||||
# if data.tools is not None:
|
||||
agent_cfg.tools = storage_data.get("tools", [])
|
||||
agent_cfg.skills = storage_data.get("skills", {})
|
||||
|
||||
agent_cfg.updated_at = now
|
||||
|
||||
@@ -1003,11 +1005,12 @@ class AppService:
|
||||
},
|
||||
memory={
|
||||
"enabled": True,
|
||||
"memory_content": None,
|
||||
"memory_config_id": None,
|
||||
"max_history": 10
|
||||
},
|
||||
variables=[],
|
||||
tools=[],
|
||||
skills=[],
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -1403,6 +1406,7 @@ class AppService:
|
||||
"memory": agent_cfg.memory,
|
||||
"variables": agent_cfg.variables or [],
|
||||
"tools": agent_cfg.tools or [],
|
||||
"skills": agent_cfg.skills or {},
|
||||
}
|
||||
# config = AgentConfigConverter.from_storage_format(agent_cfg)
|
||||
default_model_config_id = agent_cfg.default_model_config_id
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
"""应用统计服务"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any
|
||||
import uuid
|
||||
from sqlalchemy import func, and_, cast, Date
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.conversation_model import Conversation, Message
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.api_key_model import ApiKey, ApiKeyLog
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.models.api_key_model import ApiKey, ApiKeyLog, ApiKeyType
|
||||
|
||||
|
||||
class AppStatisticsService:
|
||||
@@ -146,7 +144,6 @@ class AppStatisticsService:
|
||||
end_dt: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""获取Token消耗统计(从Message的meta_data中提取)"""
|
||||
from sqlalchemy import text
|
||||
|
||||
# 查询所有相关消息的token使用情况
|
||||
# meta_data中可能包含: {"usage": {"total_tokens": 100}} 或 {"tokens": 100}
|
||||
@@ -187,7 +184,80 @@ class AppStatisticsService:
|
||||
daily_tokens[date_str] = 0
|
||||
daily_tokens[date_str] += int(tokens)
|
||||
|
||||
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
total = sum(row["tokens"] for row in daily_data)
|
||||
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
total = sum(row["count"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
|
||||
def get_workspace_api_statistics(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
start_date: int,
|
||||
end_date: int
|
||||
) -> list[Any]:
|
||||
"""获取工作空间API调用统计
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
每日统计数据列表
|
||||
"""
|
||||
# 将毫秒时间戳转换为 datetime
|
||||
start_time = datetime.fromtimestamp(start_date / 1000)
|
||||
end_time = datetime.fromtimestamp(end_date / 1000)
|
||||
|
||||
# 应用类型(agent, multi_agent, workflow)
|
||||
app_types = [ApiKeyType.AGENT, ApiKeyType.CLUSTER, ApiKeyType.WORKFLOW]
|
||||
|
||||
# 每日应用类型调用次数
|
||||
daily_app_calls = self.db.query(
|
||||
cast(ApiKeyLog.created_at, Date).label('date'),
|
||||
func.count(ApiKeyLog.id).label('count')
|
||||
).join(
|
||||
ApiKey, ApiKeyLog.api_key_id == ApiKey.id
|
||||
).filter(
|
||||
and_(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.type.in_(app_types),
|
||||
ApiKeyLog.created_at >= start_time,
|
||||
ApiKeyLog.created_at <= end_time
|
||||
)
|
||||
).group_by(cast(ApiKeyLog.created_at, Date)).all()
|
||||
|
||||
# 每日服务类型调用次数
|
||||
daily_service_calls = self.db.query(
|
||||
cast(ApiKeyLog.created_at, Date).label('date'),
|
||||
func.count(ApiKeyLog.id).label('count')
|
||||
).join(
|
||||
ApiKey, ApiKeyLog.api_key_id == ApiKey.id
|
||||
).filter(
|
||||
and_(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.type == ApiKeyType.SERVICE,
|
||||
ApiKeyLog.created_at >= start_time,
|
||||
ApiKeyLog.created_at <= end_time
|
||||
)
|
||||
).group_by(cast(ApiKeyLog.created_at, Date)).all()
|
||||
|
||||
# 构建每日数据
|
||||
app_calls_dict = {str(row.date): row.count for row in daily_app_calls}
|
||||
service_calls_dict = {str(row.date): row.count for row in daily_service_calls}
|
||||
|
||||
# 合并所有日期
|
||||
all_dates = sorted(set(app_calls_dict.keys()) | set(service_calls_dict.keys()))
|
||||
|
||||
daily_data = []
|
||||
for date in all_dates:
|
||||
app_count = app_calls_dict.get(date, 0)
|
||||
service_count = service_calls_dict.get(date, 0)
|
||||
daily_data.append({
|
||||
"date": date,
|
||||
"total_calls": app_count + service_count,
|
||||
"app_calls": app_count,
|
||||
"service_calls": service_count
|
||||
})
|
||||
|
||||
return daily_data
|
||||
|
||||
@@ -24,6 +24,7 @@ from app.core.error_codes import BizCode
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -357,6 +358,8 @@ class CollaborativeOrchestrator:
|
||||
"usage": response.get("usage", {"total_tokens": 0}),
|
||||
"is_final_answer": True
|
||||
}
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, agent_config.get("api_key_id"))
|
||||
|
||||
# 检查是否有工具调用(handoff)
|
||||
tool_calls = response.get("tool_calls", [])
|
||||
@@ -427,7 +430,7 @@ class CollaborativeOrchestrator:
|
||||
)
|
||||
|
||||
# 获取 API Key
|
||||
api_key_config = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
if not api_key_config:
|
||||
raise BusinessException(
|
||||
f"Agent 模型没有可用的 API Key: {agent_id}",
|
||||
@@ -442,7 +445,8 @@ class CollaborativeOrchestrator:
|
||||
"provider": api_key_config.provider,
|
||||
"api_key": api_key_config.api_key,
|
||||
"api_base": api_key_config.api_base,
|
||||
"model_parameters": config_data.get("model_parameters", {})
|
||||
"model_parameters": config_data.get("model_parameters", {}),
|
||||
"api_key_id": api_key_config.id
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
|
||||
@@ -10,6 +10,11 @@ import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -24,12 +29,11 @@ from app.services import task_service
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
|
||||
|
||||
logger = get_business_logger()
|
||||
class KnowledgeRetrievalInput(BaseModel):
|
||||
@@ -59,7 +63,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
长期记忆工具
|
||||
"""
|
||||
# search_switch = memory_config.get("search_switch", "2")
|
||||
config_id= memory_config.get("memory_content") or memory_config.get("memory_config",None)
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None)
|
||||
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
|
||||
@tool(args_schema=LongTermMemoryInput)
|
||||
def long_term_memory(question: str) -> str:
|
||||
@@ -310,6 +315,7 @@ class DraftRunService:
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
@@ -320,9 +326,7 @@ class DraftRunService:
|
||||
print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
@@ -345,6 +349,25 @@ class DraftRunService:
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
kb_config = agent_config.knowledge_retrieval
|
||||
@@ -433,7 +456,8 @@ class DraftRunService:
|
||||
)
|
||||
|
||||
memory_config_= agent_config.memory
|
||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
|
||||
|
||||
# 8. 调用 Agent(支持多模态)
|
||||
result = await agent.chat(
|
||||
@@ -450,6 +474,8 @@ class DraftRunService:
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||
|
||||
# 9. 保存会话消息
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
@@ -558,6 +584,7 @@ class DraftRunService:
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
@@ -567,9 +594,7 @@ class DraftRunService:
|
||||
# print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
@@ -592,6 +617,25 @@ class DraftRunService:
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
@@ -628,7 +672,6 @@ class DraftRunService:
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 4. 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_config["model_name"],
|
||||
@@ -677,7 +720,8 @@ class DraftRunService:
|
||||
})
|
||||
|
||||
memory_config_ = agent_config.memory
|
||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
|
||||
|
||||
# 9. 流式调用 Agent(支持多模态)
|
||||
full_content = ""
|
||||
@@ -704,6 +748,8 @@ class DraftRunService:
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||
|
||||
if sub_agent:
|
||||
yield self._format_sse_event("sub_usage", {
|
||||
"total_tokens": total_tokens
|
||||
@@ -770,7 +816,7 @@ class DraftRunService:
|
||||
Raises:
|
||||
BusinessException: 当没有可用的 API Key 时
|
||||
"""
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
# stmt = (
|
||||
# select(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
@@ -784,7 +830,8 @@ class DraftRunService:
|
||||
# )
|
||||
#
|
||||
# api_key = self.db.scalars(stmt).first()
|
||||
api_key = api_keys[0] if api_keys else None
|
||||
# api_key = api_keys[0] if api_keys else None
|
||||
api_key = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
|
||||
if not api_key:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
@@ -793,7 +840,8 @@ class DraftRunService:
|
||||
"model_name": api_key.model_name,
|
||||
"provider": api_key.provider,
|
||||
"api_key": api_key.api_key,
|
||||
"api_base": api_key.api_base
|
||||
"api_base": api_key.api_base,
|
||||
"api_key_id": api_key.id
|
||||
}
|
||||
|
||||
async def _ensure_conversation(
|
||||
@@ -1051,7 +1099,7 @@ class DraftRunService:
|
||||
|
||||
except Exception as e:
|
||||
# 对于多 Agent 应用,没有直接的 AgentConfig 是正常的
|
||||
logger.debug("获取配置快照失败(可能是多 Agent 应用)", extra={"error": str(e)})
|
||||
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
|
||||
return {}
|
||||
|
||||
def _replace_variables(
|
||||
|
||||
@@ -537,7 +537,7 @@ def convert_multi_agent_config_to_handoffs(
|
||||
|
||||
# 获取该 Agent 的模型配置
|
||||
if release.default_model_config_id:
|
||||
model_api_key = ModelApiKeyService.get_a_api_key(db, release.default_model_config_id)
|
||||
model_api_key = ModelApiKeyService.get_available_api_key(db, release.default_model_config_id)
|
||||
if model_api_key:
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_api_key.model_name,
|
||||
@@ -551,6 +551,7 @@ def convert_multi_agent_config_to_handoffs(
|
||||
}
|
||||
)
|
||||
logger.debug(f"Agent {agent_name} 使用模型: {model_api_key.model_name}")
|
||||
ModelApiKeyService.record_api_key_usage(db, model_api_key.id)
|
||||
else:
|
||||
logger.warning(f"Agent {agent_name} 模型配置无效: {release.default_model_config_id}")
|
||||
else:
|
||||
|
||||
@@ -382,6 +382,7 @@ class LLMRouter:
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.models import ModelApiKey, ModelType
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
# 获取 API Key 配置(通过关联关系)
|
||||
# api_key_config = self.db.query(ModelApiKey).join(
|
||||
@@ -389,8 +390,9 @@ class LLMRouter:
|
||||
# ).filter(ModelConfig.id == self.routing_model_config.id,
|
||||
# ModelApiKey.is_active == True
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id)
|
||||
# api_key_config = api_keys[0] if api_keys else None
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(self.db, self.routing_model_config.id)
|
||||
|
||||
if not api_key_config:
|
||||
raise Exception("路由模型没有可用的 API Key")
|
||||
@@ -424,7 +426,6 @@ class LLMRouter:
|
||||
# 调用模型
|
||||
response = await llm.ainvoke(prompt)
|
||||
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取响应内容
|
||||
|
||||
@@ -349,7 +349,7 @@ class MasterAgentRouter:
|
||||
from app.models import ModelApiKey, ModelType
|
||||
|
||||
# 获取 API Key 配置
|
||||
api_key_config = ModelApiKeyService.get_a_api_key(self.db, self.master_model_config.id)
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(self.db, self.master_model_config.id)
|
||||
|
||||
if not api_key_config:
|
||||
raise Exception("Master Agent 模型没有可用的 API Key")
|
||||
@@ -400,6 +400,7 @@ class MasterAgentRouter:
|
||||
|
||||
# 调用模型
|
||||
response = await llm.ainvoke(prompt)
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
|
||||
@@ -1194,7 +1194,9 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
workspace_id=app.workspace_id
|
||||
)
|
||||
|
||||
memory_config_id = str(memory_config.config_id) if memory_config else None
|
||||
memory_obj = config.get('memory', {})
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
result = {
|
||||
"end_user_id": str(end_user_id),
|
||||
@@ -1284,7 +1286,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
if release:
|
||||
config = release.config or {}
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
if memory_config_id:
|
||||
# 判断是否为UUID格式
|
||||
if len(str(memory_config_id))>=5:
|
||||
@@ -1330,7 +1333,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 从 config 中提取 memory_config_id
|
||||
config = release.config or {}
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
# 获取配置名称(使用字符串形式的ID进行查找,兼容新旧格式)
|
||||
memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None
|
||||
|
||||
@@ -53,7 +53,10 @@ def get_workspace_end_users(
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> List[EndUser]:
|
||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
|
||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
||||
|
||||
返回结果按 updated_at 从新到旧排序(NULL 值排在最后)
|
||||
"""
|
||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
@@ -68,9 +71,14 @@ def get_workspace_end_users(
|
||||
app_ids = [app.id for app in apps_orm]
|
||||
|
||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||
# 按 updated_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||
from app.models.end_user_model import EndUser as EndUserModel
|
||||
from sqlalchemy import desc, nullslast
|
||||
end_users_orm = db.query(EndUserModel).filter(
|
||||
EndUserModel.app_id.in_(app_ids)
|
||||
).order_by(
|
||||
nullslast(desc(EndUserModel.updated_at)),
|
||||
desc(EndUserModel.id)
|
||||
).all()
|
||||
|
||||
# 转换为 Pydantic 模型(只在需要时转换)
|
||||
|
||||
@@ -108,13 +108,14 @@ class WorkspaceAppService:
|
||||
app_info["releases"].append(release_info)
|
||||
|
||||
def _extract_memory_content(self, config: Any) -> str:
|
||||
"""Extract memory_comtent from config"""
|
||||
"""Extract memory_config_id from config (兼容新旧字段名)"""
|
||||
if not config or not isinstance(config, dict):
|
||||
return None
|
||||
|
||||
memory_obj = config.get('memory')
|
||||
if memory_obj and isinstance(memory_obj, dict):
|
||||
return memory_obj.get('memory_content')
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
return memory_obj.get('memory_config_id') or memory_obj.get('memory_content')
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
if not params.rerank_id:
|
||||
params.rerank_id = configs.get('rerank')
|
||||
|
||||
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
|
||||
if not params.reflection_model_id:
|
||||
params.reflection_model_id = params.llm_id
|
||||
if not params.emotion_model_id:
|
||||
params.emotion_model_id = params.llm_id
|
||||
|
||||
config = MemoryConfigRepository.create(self.db, params)
|
||||
self.db.commit()
|
||||
return {"affected": 1, "config_id": config.config_id}
|
||||
@@ -203,6 +209,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"end_user_id": config.end_user_id,
|
||||
"config_id_old": config_id_old,
|
||||
"apply_id": config.apply_id,
|
||||
"scene_id": config.scene_id,
|
||||
"llm_id": config.llm_id,
|
||||
"embedding_id": config.embedding_id,
|
||||
"rerank_id": config.rerank_id,
|
||||
|
||||
@@ -6,7 +6,7 @@ import math
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
|
||||
from app.schemas import model_schema
|
||||
from app.schemas.model_schema import (
|
||||
@@ -633,19 +633,31 @@ class ModelApiKeyService:
|
||||
|
||||
@staticmethod
|
||||
def get_available_api_key(db: Session, model_config_id: uuid.UUID) -> Optional[ModelApiKey]:
|
||||
"""获取可用的API Key(按优先级和负载均衡)"""
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active=True)
|
||||
"""获取可用的API Key(根据负载均衡策略)"""
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
return None
|
||||
|
||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||
if not api_keys:
|
||||
return None
|
||||
return min(api_keys, key=lambda x: int(x.usage_count or "0"))
|
||||
|
||||
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
|
||||
# 否则返回第一个
|
||||
return api_keys[0]
|
||||
|
||||
@staticmethod
|
||||
def record_api_key_usage(db: Session, api_key_id: uuid.UUID) -> bool:
|
||||
def record_api_key_usage(db: Session, api_key_id: uuid.UUID | None) -> bool:
|
||||
"""记录API Key使用"""
|
||||
success = ModelApiKeyRepository.update_usage(db, api_key_id)
|
||||
if success:
|
||||
db.commit()
|
||||
return success
|
||||
if api_key_id:
|
||||
success = ModelApiKeyRepository.update_usage(db, api_key_id)
|
||||
if success:
|
||||
db.commit()
|
||||
return success
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey:
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -2569,8 +2570,9 @@ class MultiAgentOrchestrator:
|
||||
# ModelConfig.id == default_model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
# api_key_config = api_keys[0] if api_keys else None
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(self.db, default_model_config_id)
|
||||
|
||||
if not api_key_config:
|
||||
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
||||
@@ -2601,6 +2603,8 @@ class MultiAgentOrchestrator:
|
||||
# 调用模型进行整合
|
||||
response = await llm.ainvoke(merge_prompt)
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
merged_response = response.content
|
||||
@@ -2730,8 +2734,9 @@ class MultiAgentOrchestrator:
|
||||
# ModelConfig.id == default_model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
# api_key_config = api_keys[0] if api_keys else None
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(self.db, default_model_config_id)
|
||||
|
||||
if not api_key_config:
|
||||
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
||||
@@ -2790,6 +2795,8 @@ class MultiAgentOrchestrator:
|
||||
logger.debug(f"收到流式 chunk #{chunk_count}: {content[:30]}...")
|
||||
yield self._format_sse_event("message", {"content": content})
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
logger.info(f"Master Agent 流式整合完成,共 {chunk_count} 个 chunks")
|
||||
|
||||
except AttributeError as e:
|
||||
|
||||
@@ -23,7 +23,7 @@ logger = get_business_logger()
|
||||
|
||||
class ImageFormatStrategy(Protocol):
|
||||
"""图片格式策略接口"""
|
||||
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""将图片 URL 转换为特定 provider 的格式"""
|
||||
...
|
||||
@@ -31,7 +31,7 @@ class ImageFormatStrategy(Protocol):
|
||||
|
||||
class DashScopeImageStrategy:
|
||||
"""通义千问图片格式策略"""
|
||||
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""通义千问格式: {"type": "image", "image": "url"}"""
|
||||
return {
|
||||
@@ -42,7 +42,7 @@ class DashScopeImageStrategy:
|
||||
|
||||
class BedrockImageStrategy:
|
||||
"""Bedrock/Anthropic 图片格式策略"""
|
||||
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 格式: base64 编码
|
||||
@@ -51,17 +51,17 @@ class BedrockImageStrategy:
|
||||
import httpx
|
||||
import base64
|
||||
from mimetypes import guess_type
|
||||
|
||||
|
||||
logger.info(f"下载并编码图片: {url}")
|
||||
|
||||
|
||||
# 下载图片
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
# 获取图片数据
|
||||
image_data = response.content
|
||||
|
||||
|
||||
# 确定 media type
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("image/"):
|
||||
@@ -69,12 +69,12 @@ class BedrockImageStrategy:
|
||||
else:
|
||||
guessed_type, _ = guess_type(url)
|
||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||
|
||||
|
||||
# 转换为 base64
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
|
||||
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
@@ -87,7 +87,7 @@ class BedrockImageStrategy:
|
||||
|
||||
class OpenAIImageStrategy:
|
||||
"""OpenAI 图片格式策略"""
|
||||
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||
return {
|
||||
@@ -109,7 +109,7 @@ PROVIDER_STRATEGIES = {
|
||||
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope"):
|
||||
"""
|
||||
初始化多模态服务
|
||||
@@ -120,10 +120,10 @@ class MultimodalService:
|
||||
"""
|
||||
self.db = db
|
||||
self.provider = provider.lower()
|
||||
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
files: Optional[List[FileInput]]
|
||||
self,
|
||||
files: Optional[List[FileInput]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
@@ -136,7 +136,7 @@ class MultimodalService:
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
@@ -168,10 +168,10 @@ class MultimodalService:
|
||||
"type": "text",
|
||||
"text": f"[文件处理失败: {str(e)}]"
|
||||
})
|
||||
|
||||
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
|
||||
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片文件
|
||||
@@ -184,14 +184,10 @@ class MultimodalService:
|
||||
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||
- 通义千问: {"type": "image", "image": "url"}
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
url = file.url
|
||||
else:
|
||||
# 本地文件,获取访问 URL
|
||||
url = await self._get_file_url(file.upload_file_id)
|
||||
|
||||
url = await self.get_file_url(file)
|
||||
|
||||
logger.debug(f"处理图片: {url}, provider={self.provider}")
|
||||
|
||||
|
||||
# 根据 provider 返回不同格式
|
||||
if self.provider in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
|
||||
@@ -223,7 +219,7 @@ class MultimodalService:
|
||||
"type": "image",
|
||||
"image": url
|
||||
}
|
||||
|
||||
|
||||
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
|
||||
"""
|
||||
下载图片并转换为 base64
|
||||
@@ -237,15 +233,15 @@ class MultimodalService:
|
||||
import httpx
|
||||
import base64
|
||||
from mimetypes import guess_type
|
||||
|
||||
|
||||
# 下载图片
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
# 获取图片数据
|
||||
image_data = response.content
|
||||
|
||||
|
||||
# 确定 media type
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("image/"):
|
||||
@@ -254,14 +250,14 @@ class MultimodalService:
|
||||
# 从 URL 推断
|
||||
guessed_type, _ = guess_type(url)
|
||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||
|
||||
|
||||
# 转换为 base64
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
|
||||
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
|
||||
return base64_data, media_type
|
||||
|
||||
|
||||
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
@@ -284,14 +280,14 @@ class MultimodalService:
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file.upload_file_id
|
||||
).first()
|
||||
|
||||
|
||||
file_name = generic_file.file_name if generic_file else "unknown"
|
||||
|
||||
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
|
||||
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理音频文件
|
||||
@@ -307,7 +303,7 @@ class MultimodalService:
|
||||
"type": "text",
|
||||
"text": "[音频文件,暂不支持处理]"
|
||||
}
|
||||
|
||||
|
||||
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理视频文件
|
||||
@@ -323,13 +319,13 @@ class MultimodalService:
|
||||
"type": "text",
|
||||
"text": "[视频文件,暂不支持处理]"
|
||||
}
|
||||
|
||||
async def _get_file_url(self, file_id: uuid.UUID) -> str:
|
||||
|
||||
async def get_file_url(self, file: FileInput) -> str:
|
||||
"""
|
||||
获取文件的访问 URL
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
file: File Input Struct
|
||||
|
||||
Returns:
|
||||
str: 文件访问 URL
|
||||
@@ -337,26 +333,31 @@ class MultimodalService:
|
||||
Raises:
|
||||
BusinessException: 文件不存在
|
||||
"""
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.status == "active"
|
||||
).first()
|
||||
|
||||
if not generic_file:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# 如果有 access_url,直接返回
|
||||
if generic_file.access_url:
|
||||
return generic_file.access_url
|
||||
|
||||
# 否则,根据 storage_path 生成 URL
|
||||
# TODO: 根据实际存储方式生成 URL(本地存储、OSS 等)
|
||||
# 这里暂时返回一个占位 URL
|
||||
return f"/api/files/{file_id}/download"
|
||||
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
return file.url
|
||||
else:
|
||||
# 本地文件,获取访问 URL
|
||||
file_id = file.upload_file_id
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file.upload_file_id,
|
||||
GenericFile.status == "active"
|
||||
).first()
|
||||
|
||||
if not generic_file:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file.upload_file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# 如果有 access_url,直接返回
|
||||
if generic_file.access_url:
|
||||
return generic_file.access_url
|
||||
|
||||
# 否则,根据 storage_path 生成 URL
|
||||
# TODO: 根据实际存储方式生成 URL(本地存储、OSS 等)
|
||||
# 这里暂时返回一个占位 URL
|
||||
return f"/api/files/{file_id}/download"
|
||||
|
||||
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
@@ -371,20 +372,20 @@ class MultimodalService:
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.status == "active"
|
||||
).first()
|
||||
|
||||
|
||||
if not generic_file:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
# TODO: 根据文件类型提取文本
|
||||
# - PDF: 使用 PyPDF2 或 pdfplumber
|
||||
# - Word: 使用 python-docx
|
||||
# - TXT/MD: 直接读取
|
||||
|
||||
|
||||
file_ext = generic_file.file_ext.lower()
|
||||
|
||||
|
||||
if file_ext in ['.txt', '.md', '.markdown']:
|
||||
return await self._read_text_file(generic_file.storage_path)
|
||||
elif file_ext == '.pdf':
|
||||
@@ -393,7 +394,7 @@ class MultimodalService:
|
||||
return await self._extract_word_text(generic_file.storage_path)
|
||||
else:
|
||||
return f"[不支持的文档格式: {file_ext}]"
|
||||
|
||||
|
||||
async def _read_text_file(self, storage_path: str) -> str:
|
||||
"""读取纯文本文件"""
|
||||
try:
|
||||
@@ -402,7 +403,7 @@ class MultimodalService:
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本文件失败: {e}")
|
||||
return f"[文件读取失败: {str(e)}]"
|
||||
|
||||
|
||||
async def _extract_pdf_text(self, storage_path: str) -> str:
|
||||
"""提取 PDF 文本"""
|
||||
try:
|
||||
@@ -412,7 +413,7 @@ class MultimodalService:
|
||||
except Exception as e:
|
||||
logger.error(f"提取 PDF 文本失败: {e}")
|
||||
return f"[PDF 提取失败: {str(e)}]"
|
||||
|
||||
|
||||
async def _extract_word_text(self, storage_path: str) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
{% raw %}
|
||||
Role: AI Prompt Optimization Expert
|
||||
|
||||
Profile
|
||||
@@ -12,11 +11,11 @@ Skills
|
||||
Core Optimization Skills
|
||||
Requirement Analysis: Accurately understand the relationship between the user’s current needs and the original prompt.
|
||||
Structural Reconstruction: Transform vague requirements into clear, block-structured instructions.
|
||||
Variable Handling: Identify and standardize dynamic variables in prompts.
|
||||
{% if skill != true %}Variable Handling: Identify and standardize dynamic variables in prompts.{% endif %}
|
||||
Conflict Resolution: Prioritize current requirements when historical requirements conflict with current needs.
|
||||
|
||||
Auxiliary Generation Skills
|
||||
Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.
|
||||
{% if skill != true %}Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.{% endif %}
|
||||
Language Consistency: Maintain consistency between label language and user input language.
|
||||
Executability Verification: Ensure optimized prompts can be directly used in AI tools.
|
||||
Format Standardization: Strictly adhere to specified output format requirements.
|
||||
@@ -25,30 +24,30 @@ Rules
|
||||
Basic Principles
|
||||
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
|
||||
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
|
||||
Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints
|
||||
{% if skill != true %}Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints{% endif %}
|
||||
Language Rule: All label languages must fully match the user input language.
|
||||
|
||||
Behavior Guidelines
|
||||
Precision Guideline: All instructions must be precise and directly executable, avoiding ambiguity.
|
||||
Readability Guideline: Ensure optimized prompts have good readability and logical flow.
|
||||
Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed.
|
||||
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.
|
||||
{% if skill != true %}{% raw %}Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed.
|
||||
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.{% endraw %}{% endif %}
|
||||
|
||||
Constraints
|
||||
Output Constraint: Must output in JSON format including the fields "prompt" and "desc".
|
||||
Content Constraint: Must not include any explanations, analyses, or additional comments.
|
||||
Language Constraint: Must use clear and concise language.
|
||||
Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).
|
||||
{% if skill != true %}Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).{% endif %}
|
||||
|
||||
Workflows
|
||||
Goal: Optimize or generate AI prompts that can be directly used according to user requirements.
|
||||
Step 1: Receive the user’s current requirement description {{user_require}} and the original prompt {{original_prompt}}.
|
||||
Step 2: Analyze requirements, identify conflicts, and prioritize current requirements.
|
||||
Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined.
|
||||
{% if skill != true %}Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined.
|
||||
Step 4: Generate a JSON output containing the optimized prompt and its description.
|
||||
{% else %}Step 3: Generate a JSON output containing the optimized prompt and its description.{% endif %}
|
||||
|
||||
Expected Outcome: Obtain a clear, directly executable AI prompt accompanied by an optimization description.
|
||||
|
||||
Initialization
|
||||
As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows.
|
||||
{% endraw %}
|
||||
As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows.
|
||||
@@ -23,6 +23,7 @@ from app.repositories.prompt_optimizer_repository import (
|
||||
PromptReleaseRepository
|
||||
)
|
||||
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -128,7 +129,8 @@ class PromptOptimizerService:
|
||||
session_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
current_prompt: str,
|
||||
user_require: str
|
||||
user_require: str,
|
||||
skill: bool = False
|
||||
) -> AsyncGenerator[dict[str, str | Any], Any]:
|
||||
"""
|
||||
Optimize a user-provided prompt using a configured prompt optimizer LLM.
|
||||
@@ -157,6 +159,7 @@ class PromptOptimizerService:
|
||||
user_id (uuid.UUID): Identifier of the user associated with the session.
|
||||
current_prompt (str): Original prompt to optimize.
|
||||
user_require (str): User's requirements or instructions for optimization.
|
||||
skill(bool): Is skill required
|
||||
|
||||
Returns:
|
||||
OptimizePromptResult: An object containing:
|
||||
@@ -174,8 +177,9 @@ class PromptOptimizerService:
|
||||
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
|
||||
|
||||
# Create LLM instance
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
|
||||
api_config: ModelApiKey = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
|
||||
# api_config: ModelApiKey = api_keys[0] if api_keys else None
|
||||
api_config: ModelApiKey = ModelApiKeyService.get_available_api_key(self.db, model_config.id)
|
||||
llm = RedBearLLM(RedBearModelConfig(
|
||||
model_name=api_config.model_name,
|
||||
provider=api_config.provider,
|
||||
@@ -186,7 +190,7 @@ class PromptOptimizerService:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render()
|
||||
rendered_system_message = Template(opt_system_prompt).render(skill=skill)
|
||||
|
||||
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_user_prompt = f.read()
|
||||
@@ -250,6 +254,7 @@ class PromptOptimizerService:
|
||||
optim_result = json_repair.repair_json(buffer, return_objects=True)
|
||||
# prompt = optim_result.get("prompt")
|
||||
desc = optim_result.get("desc")
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_config.id)
|
||||
self.create_message(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.services.memory_konwledges_server import write_rag
|
||||
from app.models import ReleaseShare, AppRelease, Conversation
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -178,8 +179,9 @@ class SharedChatService:
|
||||
# .limit(1)
|
||||
# )
|
||||
# api_key_obj = self.db.scalars(stmt).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
api_key_obj = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
# api_key_obj = api_keys[0] if api_keys else None
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
@@ -309,6 +311,8 @@ class SharedChatService:
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
|
||||
return {
|
||||
"conversation_id": conversation.id,
|
||||
@@ -349,7 +353,8 @@ class SharedChatService:
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
memory_config = {"enabled": memory, "memory_content": "17", "max_history": 10}
|
||||
# 兼容新旧字段名:使用 memory_config_id
|
||||
memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10}
|
||||
|
||||
try:
|
||||
# 获取发布版本和配置
|
||||
@@ -383,8 +388,9 @@ class SharedChatService:
|
||||
# .limit(1)
|
||||
# )
|
||||
# api_key_obj = self.db.scalars(stmt).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
api_key_obj = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
# api_key_obj = api_keys[0] if api_keys else None
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
@@ -513,7 +519,8 @@ class SharedChatService:
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
133
api/app/services/skill_service.py
Normal file
133
api/app/services/skill_service.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Skill Service"""
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.skill_repository import SkillRepository
|
||||
from app.schemas.skill_schema import SkillCreate, SkillUpdate
|
||||
from app.models.skill_model import Skill
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.tool_service import ToolService
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""Skill 业务逻辑层"""
|
||||
|
||||
@staticmethod
|
||||
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
|
||||
"""创建技能"""
|
||||
# 检查同名技能
|
||||
existing = db.query(Skill).filter(
|
||||
Skill.tenant_id == tenant_id,
|
||||
Skill.name == data.name
|
||||
).first()
|
||||
if existing:
|
||||
raise BusinessException(f"技能名称'{data.name}'已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
skill = SkillRepository.create(db, data, tenant_id)
|
||||
db.commit()
|
||||
db.refresh(skill)
|
||||
return skill
|
||||
|
||||
@staticmethod
|
||||
def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill:
|
||||
"""获取技能"""
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
|
||||
if not skill:
|
||||
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
|
||||
|
||||
# 填充工具详情
|
||||
tool_service = ToolService(db)
|
||||
enriched_tools = []
|
||||
for tool_config in skill.tools:
|
||||
tool_id = tool_config.get("tool_id")
|
||||
if tool_id:
|
||||
tool_info = tool_service.get_tool_info(tool_id, tenant_id)
|
||||
if tool_info:
|
||||
enriched_tools.append({
|
||||
"tool_id": tool_id,
|
||||
"operation": tool_config.get("operation"),
|
||||
"tool_info": tool_info
|
||||
})
|
||||
skill.tools = enriched_tools
|
||||
|
||||
return skill
|
||||
except (BusinessException, SQLAlchemyError) as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def list_skills(
|
||||
db: Session,
|
||||
tenant_id: uuid.UUID,
|
||||
search: str = None,
|
||||
is_active: bool = None,
|
||||
is_public: bool = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 10
|
||||
) -> tuple[list[type[Skill]], int]:
|
||||
"""列出技能"""
|
||||
return SkillRepository.list_skills(
|
||||
db, tenant_id, search, is_active, is_public, page, pagesize
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill:
|
||||
"""更新技能"""
|
||||
try:
|
||||
skill = SkillRepository.update(db, skill_id, data, tenant_id)
|
||||
if not skill:
|
||||
raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND)
|
||||
db.commit()
|
||||
db.refresh(skill)
|
||||
return skill
|
||||
except (BusinessException, SQLAlchemyError) as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def delete_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
|
||||
"""删除技能"""
|
||||
try:
|
||||
success = SkillRepository.delete(db, skill_id, tenant_id)
|
||||
if not success:
|
||||
raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND)
|
||||
db.commit()
|
||||
return True
|
||||
except (BusinessException, SQLAlchemyError) as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]:
|
||||
"""加载技能关联的工具
|
||||
|
||||
Returns:
|
||||
(tools, tool_to_skill_map) - 工具列表和工具到技能的映射
|
||||
"""
|
||||
tools = []
|
||||
tool_to_skill_map = {} # {tool_name: skill_id}
|
||||
tool_service = ToolService(db)
|
||||
|
||||
for skill_id in skill_ids:
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
||||
if skill and skill.is_active:
|
||||
# 加载技能关联的工具
|
||||
for tool_config in skill.tools:
|
||||
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool:
|
||||
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
# 建立工具到技能的映射
|
||||
tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool)))
|
||||
tool_to_skill_map[tool_name] = skill_id
|
||||
except Exception as e:
|
||||
print(f"加载技能 {skill_id} 的工具时出错: {e}")
|
||||
continue
|
||||
|
||||
return tools, tool_to_skill_map
|
||||
@@ -4,9 +4,8 @@
|
||||
import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Annotated, AsyncGenerator, Optional
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from deprecated import deprecated
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -23,6 +22,7 @@ from app.repositories.workflow_repository import (
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,6 +36,7 @@ class WorkflowService:
|
||||
self.execution_repo = WorkflowExecutionRepository(db)
|
||||
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
||||
self.conversation_service = ConversationService(db)
|
||||
self.multimodal_service = MultimodalService(db)
|
||||
|
||||
# ==================== 配置管理 ====================
|
||||
|
||||
@@ -445,24 +446,22 @@ class WorkflowService:
|
||||
code=BizCode.CONFIG_MISSING,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id}
|
||||
files = []
|
||||
if payload.files:
|
||||
for file in payload.files:
|
||||
files.append(
|
||||
{
|
||||
"type": file.type,
|
||||
"url": await self.multimodal_service.get_file_url(file),
|
||||
"__file": True
|
||||
}
|
||||
)
|
||||
|
||||
# 转换 user_id 为 UUID
|
||||
triggered_by_uuid = None
|
||||
if payload.user_id:
|
||||
try:
|
||||
triggered_by_uuid = uuid.UUID(payload.user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id, "files": files}
|
||||
|
||||
# 转换 conversation_id 为 UUID
|
||||
conversation_id_uuid = None
|
||||
if payload.conversation_id:
|
||||
try:
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = self.create_execution(
|
||||
@@ -544,10 +543,10 @@ class WorkflowService:
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"variables": result.get("variables"),
|
||||
"messages": result.get("messages"),
|
||||
# "variables": result.get("variables"),
|
||||
# "messages": result.get("messages"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
"error_message": result.get("error"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
@@ -566,6 +565,41 @@ class WorkflowService:
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _map_public_event(event: dict) -> dict | None:
|
||||
event_type = event.get("event")
|
||||
payload = event.get("data")
|
||||
match event_type:
|
||||
case "workflow_start":
|
||||
return {
|
||||
"event": "start",
|
||||
"data": {
|
||||
"conversation_id": payload.get("conversation_id"),
|
||||
}
|
||||
}
|
||||
case "workflow_end":
|
||||
return {
|
||||
"event": "end",
|
||||
"data": {
|
||||
"elapsed_time": payload.get("elapsed_time"),
|
||||
"message_length": len(payload.get("output", ""))
|
||||
}
|
||||
}
|
||||
case "node_start" | "node_end" | "node_error":
|
||||
return None
|
||||
case _:
|
||||
return event
|
||||
|
||||
def _emit(self, public: bool, internal_event: dict):
|
||||
"""
|
||||
decide
|
||||
"""
|
||||
if public:
|
||||
mapped = self._map_public_event(internal_event)
|
||||
else:
|
||||
mapped = internal_event
|
||||
return mapped
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
@@ -573,6 +607,7 @@ class WorkflowService:
|
||||
config: WorkflowConfig,
|
||||
workspace_id: uuid.UUID,
|
||||
release_id: Optional[uuid.UUID] = None,
|
||||
public: bool = False
|
||||
):
|
||||
"""运行工作流(流式)
|
||||
|
||||
@@ -582,6 +617,7 @@ class WorkflowService:
|
||||
app_id: 应用 ID
|
||||
payload: 请求对象(包含 message, variables, conversation_id 等)
|
||||
config: 存储类型(可选)
|
||||
public: 是否发布
|
||||
|
||||
Yields:
|
||||
SSE 格式的流式事件
|
||||
@@ -597,24 +633,23 @@ class WorkflowService:
|
||||
code=BizCode.CONFIG_MISSING,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id}
|
||||
|
||||
# 转换 user_id 为 UUID
|
||||
triggered_by_uuid = None
|
||||
if payload.user_id:
|
||||
try:
|
||||
triggered_by_uuid = uuid.UUID(payload.user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
|
||||
files = []
|
||||
if payload.files:
|
||||
for file in payload.files:
|
||||
files.append(
|
||||
{
|
||||
"type": file.type,
|
||||
"url": await self.multimodal_service.get_file_url(file),
|
||||
"__file": True
|
||||
}
|
||||
)
|
||||
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id, "files": files}
|
||||
|
||||
# 转换 conversation_id 为 UUID
|
||||
conversation_id_uuid = None
|
||||
if payload.conversation_id:
|
||||
try:
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = self.create_execution(
|
||||
@@ -661,7 +696,7 @@ class WorkflowService:
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=payload.user_id
|
||||
user_id=payload.user_id,
|
||||
):
|
||||
if event.get("event") == "workflow_end":
|
||||
|
||||
@@ -692,7 +727,9 @@ class WorkflowService:
|
||||
)
|
||||
else:
|
||||
logger.error(f"unexpect workflow run status, status: {status}")
|
||||
yield event
|
||||
event = self._emit(public, event)
|
||||
if event:
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
@@ -710,134 +747,6 @@ class WorkflowService:
|
||||
}
|
||||
}
|
||||
|
||||
@deprecated(reason="This method is deprecated. "
|
||||
"Please use WorkflowService.run / run_stream instead.")
|
||||
async def run_workflow(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
input_data: dict[str, Any],
|
||||
triggered_by: uuid.UUID,
|
||||
conversation_id: uuid.UUID | None = None,
|
||||
stream: bool = False
|
||||
) -> AsyncGenerator | dict:
|
||||
"""运行工作流
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
input_data: 输入数据(包含 message 和 variables)
|
||||
triggered_by: 触发用户 ID
|
||||
conversation_id: 会话 ID(可选)
|
||||
stream: 是否流式返回
|
||||
|
||||
Returns:
|
||||
执行结果(非流式)或生成器(流式)
|
||||
|
||||
Raises:
|
||||
BusinessException: 配置不存在或执行失败时抛出
|
||||
"""
|
||||
# 1. 获取工作流配置
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
code=BizCode.NOT_FOUND,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = self.create_execution(
|
||||
workflow_config_id=config.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=triggered_by,
|
||||
conversation_id=conversation_id,
|
||||
input_data=input_data
|
||||
)
|
||||
|
||||
# 3. 构建工作流配置字典
|
||||
workflow_config_dict = {
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
from app.models import App
|
||||
app = self.db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
if not app:
|
||||
raise BusinessException(
|
||||
code=BizCode.NOT_FOUND,
|
||||
message=f"应用不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
if stream:
|
||||
# 流式执行
|
||||
return self._run_workflow_stream(
|
||||
workflow_config_dict,
|
||||
input_data,
|
||||
execution.execution_id,
|
||||
str(app.workspace_id),
|
||||
str(triggered_by)
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(app.workspace_id),
|
||||
user_id=str(triggered_by)
|
||||
)
|
||||
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
token_usage = result.get("data").get("token_usage", {}) or {}
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=result.get("node_outputs", {}),
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
else:
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=result.get("error")
|
||||
)
|
||||
|
||||
# 返回增强的响应结构
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"error_message": result.get("error"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"token_usage": result.get("token_usage")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
raise BusinessException(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
|
||||
"""清理事件数据,移除不可序列化的对象
|
||||
|
||||
@@ -869,72 +778,6 @@ class WorkflowService:
|
||||
|
||||
return clean_value(event)
|
||||
|
||||
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
|
||||
async def _run_workflow_stream(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str):
|
||||
"""运行工作流(流式,内部方法)
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
input_data: 输入数据
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Yields:
|
||||
流式事件(格式:{"event": "<type>", "data": {...}})
|
||||
"""
|
||||
from app.core.workflow.executor import execute_workflow_stream
|
||||
|
||||
try:
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config,
|
||||
input_data=input_data,
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
):
|
||||
# 直接转发事件(executor 已经返回正确格式)
|
||||
if event.get("event") == "workflow_end":
|
||||
token_usage = event.get("data").get("token_usage", {}) or {}
|
||||
status = event.get("data", {}).get("status")
|
||||
if status == "completed":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"completed",
|
||||
output_data=event.get("data"),
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
elif status == "failed":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"failed",
|
||||
output_data=event.get("data")
|
||||
)
|
||||
else:
|
||||
logger.error(f"unexpect workflow run status, status: {status}")
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"execution_id": execution_id,
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
|
||||
288
api/app/tasks.py
288
api/app/tasks.py
@@ -1069,6 +1069,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # Rollback failed transaction to allow next query
|
||||
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
@@ -1207,3 +1208,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
return result
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Long-term Memory Storage Tasks (Batched Write Strategies)
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True)
|
||||
def long_term_storage_window_task(
|
||||
self,
|
||||
end_user_id: str,
|
||||
langchain_messages: List[Dict[str, Any]],
|
||||
config_id: str,
|
||||
scope: int = 6
|
||||
) -> Dict[str, Any]:
|
||||
"""Celery task for window-based long-term memory storage.
|
||||
|
||||
Accumulates messages in Redis buffer until window size (scope) is reached,
|
||||
then writes batched messages to Neo4j.
|
||||
|
||||
Args:
|
||||
end_user_id: End user identifier
|
||||
langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||
config_id: Memory configuration ID
|
||||
scope: Window size (number of messages before triggering write)
|
||||
|
||||
Returns:
|
||||
Dict containing task status and metadata
|
||||
"""
|
||||
from app.core.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}")
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
# Save to Redis buffer first
|
||||
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||
|
||||
# Load memory config
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="LongTermStorageTask"
|
||||
)
|
||||
|
||||
# Execute window-based dialogue storage
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
|
||||
return {"status": "SUCCESS", "strategy": "window", "scope": scope}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return {
|
||||
**result,
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"strategy": "window",
|
||||
"error": str(e),
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True)
|
||||
# def long_term_storage_time_task(
|
||||
# self,
|
||||
# end_user_id: str,
|
||||
# config_id: str,
|
||||
# time_window: int = 5
|
||||
# ) -> Dict[str, Any]:
|
||||
# """Celery task for time-based long-term memory storage.
|
||||
|
||||
# Retrieves recent sessions from Redis within time window and writes to Neo4j.
|
||||
|
||||
# Args:
|
||||
# end_user_id: End user identifier
|
||||
# config_id: Memory configuration ID
|
||||
# time_window: Time window in minutes for retrieving recent sessions
|
||||
|
||||
# Returns:
|
||||
# Dict containing task status and metadata
|
||||
# """
|
||||
# from app.core.logging_config import get_logger
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}")
|
||||
# start_time = time.time()
|
||||
|
||||
# async def _run() -> Dict[str, Any]:
|
||||
# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# db = next(get_db())
|
||||
# try:
|
||||
# # Load memory config
|
||||
# config_service = MemoryConfigService(db)
|
||||
# memory_config = config_service.load_memory_config(
|
||||
# config_id=config_id,
|
||||
# service_name="LongTermStorageTask"
|
||||
# )
|
||||
|
||||
# # Execute time-based storage
|
||||
# await memory_long_term_storage(end_user_id, memory_config, time_window)
|
||||
|
||||
# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window}
|
||||
# finally:
|
||||
# db.close()
|
||||
|
||||
# try:
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# except ImportError:
|
||||
# pass
|
||||
|
||||
# try:
|
||||
# loop = asyncio.get_event_loop()
|
||||
# if loop.is_closed():
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# except RuntimeError:
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
|
||||
# try:
|
||||
# result = loop.run_until_complete(_run())
|
||||
# elapsed_time = time.time() - start_time
|
||||
|
||||
# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
# return {
|
||||
# **result,
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
# except Exception as e:
|
||||
# elapsed_time = time.time() - start_time
|
||||
# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
# return {
|
||||
# "status": "FAILURE",
|
||||
# "strategy": "time",
|
||||
# "error": str(e),
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
|
||||
|
||||
# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True)
|
||||
# def long_term_storage_aggregate_task(
|
||||
# self,
|
||||
# end_user_id: str,
|
||||
# langchain_messages: List[Dict[str, Any]],
|
||||
# config_id: str
|
||||
# ) -> Dict[str, Any]:
|
||||
# """Celery task for aggregate-based long-term memory storage.
|
||||
|
||||
# Uses LLM to determine if new messages describe the same event as history.
|
||||
# Only writes to Neo4j if messages represent new information (not duplicates).
|
||||
|
||||
# Args:
|
||||
# end_user_id: End user identifier
|
||||
# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||
# config_id: Memory configuration ID
|
||||
|
||||
# Returns:
|
||||
# Dict containing task status, is_same_event flag, and metadata
|
||||
# """
|
||||
# from app.core.logging_config import get_logger
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}")
|
||||
# start_time = time.time()
|
||||
|
||||
# async def _run() -> Dict[str, Any]:
|
||||
# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment
|
||||
# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||
# from app.core.memory.agent.utils.redis_tool import write_store
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# db = next(get_db())
|
||||
# try:
|
||||
# # Save to Redis buffer first
|
||||
# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||
|
||||
# # Load memory config
|
||||
# config_service = MemoryConfigService(db)
|
||||
# memory_config = config_service.load_memory_config(
|
||||
# config_id=config_id,
|
||||
# service_name="LongTermStorageTask"
|
||||
# )
|
||||
|
||||
# # Execute aggregate judgment
|
||||
# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
# return {
|
||||
# "status": "SUCCESS",
|
||||
# "strategy": "aggregate",
|
||||
# "is_same_event": result.get("is_same_event", False),
|
||||
# "wrote_to_neo4j": not result.get("is_same_event", False)
|
||||
# }
|
||||
# finally:
|
||||
# db.close()
|
||||
|
||||
# try:
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# except ImportError:
|
||||
# pass
|
||||
|
||||
# try:
|
||||
# loop = asyncio.get_event_loop()
|
||||
# if loop.is_closed():
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# except RuntimeError:
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
|
||||
# try:
|
||||
# result = loop.run_until_complete(_run())
|
||||
# elapsed_time = time.time() - start_time
|
||||
|
||||
# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
# return {
|
||||
# **result,
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
# except Exception as e:
|
||||
# elapsed_time = time.time() - start_time
|
||||
# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
# return {
|
||||
# "status": "FAILURE",
|
||||
# "strategy": "aggregate",
|
||||
# "error": str(e),
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
|
||||
@@ -99,7 +99,8 @@ def agent_config_4_app_release(release: AppRelease) -> AgentConfig:
|
||||
knowledge_retrieval=config_dict.get("knowledge_retrieval"),
|
||||
memory=config_dict.get("memory"),
|
||||
variables=config_dict.get("variables", []),
|
||||
tools=config_dict.get("tools", {}),
|
||||
tools=config_dict.get("tools", []),
|
||||
skills=config_dict.get("skills", {})
|
||||
)
|
||||
|
||||
return agent_config
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user