Compare commits

..

1 Commits

Author SHA1 Message Date
zhaoying
8476f3b7a8 feat(web): workflow Safari browser compatibility 2026-04-28 12:12:19 +08:00
80 changed files with 1491 additions and 2824 deletions

View File

@@ -3,9 +3,12 @@ name: Sync to Gitee
on: on:
push: push:
branches: branches:
- '**' # All branchs - main # Production
- develop # Integration
- 'release/*' # Release preparation
- 'hotfix/*' # Urgent fixes
tags: tags:
- '**' # All version tags (v1.0.0, etc.) - '*' # All version tags (v1.0.0, etc.)
jobs: jobs:
sync: sync:

View File

@@ -17,7 +17,6 @@ def _mask_url(url: str) -> str:
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议""" """隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url) return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
# macOS fork() safety - must be set before any Celery initialization # macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
@@ -30,7 +29,7 @@ if platform.system() == 'Darwin':
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md # 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
_broker_url = os.getenv("CELERY_BROKER_URL") or \ _broker_url = os.getenv("CELERY_BROKER_URL") or \
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}" _backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
os.environ["CELERY_BROKER_URL"] = _broker_url os.environ["CELERY_BROKER_URL"] = _broker_url
os.environ["CELERY_RESULT_BACKEND"] = _backend_url os.environ["CELERY_RESULT_BACKEND"] = _backend_url
@@ -67,11 +66,11 @@ celery_app.conf.update(
task_serializer='json', task_serializer='json',
accept_content=['json'], accept_content=['json'],
result_serializer='json', result_serializer='json',
# # 时区 # # 时区
# timezone='Asia/Shanghai', # timezone='Asia/Shanghai',
# enable_utc=False, # enable_utc=False,
# 任务追踪 # 任务追踪
task_track_started=True, task_track_started=True,
task_ignore_result=False, task_ignore_result=False,

View File

@@ -1,500 +0,0 @@
import hashlib
import json
import os
import socket
import threading
import time
import uuid
import redis
from app.core.config import settings
from app.core.logging_config import get_named_logger
from app.celery_app import celery_app
logger = get_named_logger("task_scheduler")
# per-user queue scheduler:uq:{user_id}
USER_QUEUE_PREFIX = "scheduler:uq:"
# User Collection of Pending Messages
ACTIVE_USERS = "scheduler:active_users"
# Set of users that can dispatch (ready signal)
READY_SET = "scheduler:ready_users"
# Metadata of tasks that have been dispatched and are pending completion
PENDING_HASH = "scheduler:pending_tasks"
# Dynamic Sharding: Instance Registry
REGISTRY_KEY = "scheduler:instances"
TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded
HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds)
INSTANCE_TTL = 30 # Instance timeout (seconds)
LUA_ATOMIC_LOCK = """
local dispatch_lock = KEYS[1]
local lock_key = KEYS[2]
local instance_id = ARGV[1]
local dispatch_ttl = tonumber(ARGV[2])
local lock_ttl = tonumber(ARGV[3])
if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then
return 0
end
if redis.call('EXISTS', lock_key) == 1 then
redis.call('DEL', dispatch_lock)
return -1
end
redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl)
return 1
"""
LUA_SAFE_DELETE = """
if redis.call('GET', KEYS[1]) == ARGV[1] then
return redis.call('DEL', KEYS[1])
end
return 0
"""
def stable_hash(value: str) -> int:
return int.from_bytes(
hashlib.md5(value.encode("utf-8")).digest(),
"big"
)
def health_check_server(scheduler_ref):
import uvicorn
from fastapi import FastAPI
health_app = FastAPI()
@health_app.get("/")
def health():
return scheduler_ref.health()
port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001"))
threading.Thread(
target=uvicorn.run,
kwargs={
"app": health_app,
"host": "0.0.0.0",
"port": port,
"log_config": None,
},
daemon=True,
).start()
logger.info("[Health] Server started at http://0.0.0.0:%s", port)
class RedisTaskScheduler:
def __init__(self):
self.redis = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB_CELERY_BACKEND,
password=settings.REDIS_PASSWORD,
decode_responses=True,
)
self.running = False
self.dispatched = 0
self.errors = 0
self.instance_id = f"{socket.gethostname()}-{os.getpid()}"
self._shard_index = 0
self._shard_count = 1
self._last_heartbeat = 0.0
def push_task(self, task_name, user_id, params):
try:
msg_id = str(uuid.uuid4())
msg = json.dumps({
"msg_id": msg_id,
"task_name": task_name,
"user_id": user_id,
"params": json.dumps(params),
})
lock_key = f"{task_name}:{user_id}"
queue_key = f"{USER_QUEUE_PREFIX}{user_id}"
pipe = self.redis.pipeline()
pipe.rpush(queue_key, msg)
pipe.sadd(ACTIVE_USERS, user_id)
pipe.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "QUEUED", "task_id": None}),
ex=86400,
)
pipe.execute()
if not self.redis.exists(lock_key):
self.redis.sadd(READY_SET, user_id)
logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id)
return msg_id
except Exception as e:
logger.error("Push task exception %s", e, exc_info=True)
raise
def get_task_status(self, msg_id: str) -> dict:
raw = self.redis.get(f"task_tracker:{msg_id}")
if raw is None:
return {"status": "NOT_FOUND"}
tracker = json.loads(raw)
status = tracker["status"]
task_id = tracker.get("task_id")
result_content = tracker.get("result") or {}
if status == "DISPATCHED" and task_id:
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
if result_raw:
result_data = json.loads(result_raw)
status = result_data.get("status", status)
result_content = result_data.get("result")
return {"status": status, "task_id": task_id, "result": result_content}
def _cleanup_finished(self):
pending = self.redis.hgetall(PENDING_HASH)
if not pending:
return
now = time.time()
task_ids = list(pending.keys())
pipe = self.redis.pipeline()
for task_id in task_ids:
pipe.get(f"celery-task-meta-{task_id}")
results = pipe.execute()
cleanup_pipe = self.redis.pipeline()
has_cleanup = False
ready_user_ids = set()
for task_id, raw_result in zip(task_ids, results):
try:
meta = json.loads(pending[task_id])
lock_key = meta["lock_key"]
dispatched_at = meta.get("dispatched_at", 0)
age = now - dispatched_at
should_cleanup = False
result_data = {}
if raw_result is not None:
result_data = json.loads(raw_result)
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
should_cleanup = True
logger.info(
"Task finished: %s state=%s", task_id,
result_data.get("status"),
)
elif age > TASK_TIMEOUT:
should_cleanup = True
logger.warning(
"Task expired or lost: %s age=%.0fs, force cleanup",
task_id, age,
)
if should_cleanup:
final_status = (
result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
)
self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id)
cleanup_pipe.hdel(PENDING_HASH, task_id)
tracker_msg_id = meta.get("msg_id")
if tracker_msg_id:
cleanup_pipe.set(
f"task_tracker:{tracker_msg_id}",
json.dumps({
"status": final_status,
"task_id": task_id,
"result": result_data.get("result") or {},
}),
ex=86400,
)
has_cleanup = True
parts = lock_key.split(":", 1)
if len(parts) == 2:
ready_user_ids.add(parts[1])
except Exception as e:
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
self.errors += 1
if has_cleanup:
cleanup_pipe.execute()
if ready_user_ids:
self.redis.sadd(READY_SET, *ready_user_ids)
def _heartbeat(self):
now = time.time()
if now - self._last_heartbeat < HEARTBEAT_INTERVAL:
return
self._last_heartbeat = now
self.redis.hset(REGISTRY_KEY, self.instance_id, str(now))
all_instances = self.redis.hgetall(REGISTRY_KEY)
alive = []
dead = []
for iid, ts in all_instances.items():
if now - float(ts) < INSTANCE_TTL:
alive.append(iid)
else:
dead.append(iid)
if dead:
pipe = self.redis.pipeline()
for iid in dead:
pipe.hdel(REGISTRY_KEY, iid)
pipe.execute()
logger.info("Cleaned dead instances: %s", dead)
alive.sort()
self._shard_count = max(len(alive), 1)
self._shard_index = (
alive.index(self.instance_id) if self.instance_id in alive else 0
)
logger.debug(
"Shard: %s/%s (instance=%s, alive=%d)",
self._shard_index, self._shard_count,
self.instance_id, len(alive),
)
def _is_mine(self, user_id: str) -> bool:
if self._shard_count <= 1:
return True
return stable_hash(user_id) % self._shard_count == self._shard_index
def _dispatch(self, msg_id, msg_data) -> bool:
user_id = msg_data["user_id"]
task_name = msg_data["task_name"]
params = json.loads(msg_data.get("params", "{}"))
lock_key = f"{task_name}:{user_id}"
dispatch_lock = f"dispatch:{msg_id}"
result = self.redis.eval(
LUA_ATOMIC_LOCK, 2,
dispatch_lock, lock_key,
self.instance_id, str(300), str(3600),
)
if result == 0:
return False
if result == -1:
return False
try:
task = celery_app.send_task(task_name, kwargs=params)
except Exception as e:
pipe = self.redis.pipeline()
pipe.delete(dispatch_lock)
pipe.delete(lock_key)
pipe.execute()
self.errors += 1
logger.error(
"send_task failed for %s:%s msg=%s: %s",
task_name, user_id, msg_id, e, exc_info=True,
)
return False
try:
pipe = self.redis.pipeline()
pipe.set(lock_key, task.id, ex=3600)
pipe.hset(PENDING_HASH, task.id, json.dumps({
"lock_key": lock_key,
"dispatched_at": time.time(),
"msg_id": msg_id,
}))
pipe.delete(dispatch_lock)
pipe.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
ex=86400,
)
pipe.execute()
except Exception as e:
logger.error(
"Post-dispatch state update failed for %s: %s",
task.id, e, exc_info=True,
)
self.errors += 1
self.dispatched += 1
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
return True
def _process_batch(self, user_ids):
if not user_ids:
return
pipe = self.redis.pipeline()
for uid in user_ids:
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
heads = pipe.execute()
candidates = [] # (user_id, msg_dict)
empty_users = []
for uid, head in zip(user_ids, heads):
if head is None:
empty_users.append(uid)
else:
try:
candidates.append((uid, json.loads(head)))
except (json.JSONDecodeError, TypeError) as e:
logger.error("Bad message in queue for user %s: %s", uid, e)
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
if empty_users:
pipe = self.redis.pipeline()
for uid in empty_users:
pipe.srem(ACTIVE_USERS, uid)
pipe.execute()
if not candidates:
return
for uid, msg in candidates:
if self._dispatch(msg["msg_id"], msg):
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
def schedule_loop(self):
self._heartbeat()
self._cleanup_finished()
pipe = self.redis.pipeline()
pipe.smembers(READY_SET)
pipe.delete(READY_SET)
results = pipe.execute()
ready_users = results[0] or set()
my_users = [uid for uid in ready_users if self._is_mine(uid)]
if not my_users:
time.sleep(0.5)
return
self._process_batch(my_users)
time.sleep(0.1)
def _full_scan(self):
cursor = 0
ready_batch = []
while True:
cursor, user_ids = self.redis.sscan(
ACTIVE_USERS, cursor=cursor, count=1000,
)
if user_ids:
my_users = [uid for uid in user_ids if self._is_mine(uid)]
if my_users:
pipe = self.redis.pipeline()
for uid in my_users:
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
heads = pipe.execute()
for uid, head in zip(my_users, heads):
if head is None:
continue
try:
msg = json.loads(head)
lock_key = f"{msg['task_name']}:{uid}"
ready_batch.append((uid, lock_key))
except (json.JSONDecodeError, TypeError):
continue
if cursor == 0:
break
if not ready_batch:
return
pipe = self.redis.pipeline()
for _, lock_key in ready_batch:
pipe.exists(lock_key)
lock_exists = pipe.execute()
ready_uids = [
uid for (uid, _), locked in zip(ready_batch, lock_exists)
if not locked
]
if ready_uids:
self.redis.sadd(READY_SET, *ready_uids)
logger.info("Full scan found %d ready users", len(ready_uids))
def run_server(self):
health_check_server(self)
self.running = True
last_full_scan = 0.0
full_scan_interval = 30.0
logger.info(
"Scheduler started: instance=%s", self.instance_id,
)
while True:
try:
self.schedule_loop()
now = time.time()
if now - last_full_scan > full_scan_interval:
self._full_scan()
last_full_scan = now
except Exception as e:
logger.error("Scheduler exception %s", e, exc_info=True)
self.errors += 1
time.sleep(5)
def health(self) -> dict:
return {
"running": self.running,
"active_users": self.redis.scard(ACTIVE_USERS),
"ready_users": self.redis.scard(READY_SET),
"pending_tasks": self.redis.hlen(PENDING_HASH),
"dispatched": self.dispatched,
"errors": self.errors,
"shard": f"{self._shard_index}/{self._shard_count}",
"instance": self.instance_id,
}
def shutdown(self):
logger.info("Scheduler shutting down: instance=%s", self.instance_id)
self.running = False
try:
self.redis.hdel(REGISTRY_KEY, self.instance_id)
except Exception as e:
logger.error("Shutdown cleanup error: %s", e)
scheduler: RedisTaskScheduler | None = None
if scheduler is None:
scheduler = RedisTaskScheduler()
if __name__ == "__main__":
import signal
import sys
def _signal_handler(signum, frame):
scheduler.shutdown()
sys.exit(0)
signal.signal(signal.SIGTERM, _signal_handler)
signal.signal(signal.SIGINT, _signal_handler)
scheduler.run_server()

View File

@@ -9,7 +9,7 @@ from app.core.logging_config import get_business_logger
from app.core.response_utils import success from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard from app.dependencies import get_current_user, cur_workspace_access_guard
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
from app.schemas.response_schema import PageData, PageMeta from app.schemas.response_schema import PageData, PageMeta
from app.services.app_service import AppService from app.services.app_service import AppService
from app.services.app_log_service import AppLogService from app.services.app_log_service import AppLogService
@@ -41,7 +41,7 @@ def list_app_logs(
# 验证应用访问权限 # 验证应用访问权限
app_service = AppService(db) app_service = AppService(db)
app = app_service.get_app(app_id, workspace_id) app_service.get_app(app_id, workspace_id)
# 使用 Service 层查询 # 使用 Service 层查询
log_service = AppLogService(db) log_service = AppLogService(db)
@@ -51,8 +51,7 @@ def list_app_logs(
page=page, page=page,
pagesize=pagesize, pagesize=pagesize,
is_draft=is_draft, is_draft=is_draft,
keyword=keyword, keyword=keyword
app_type=app.type,
) )
items = [AppLogConversation.model_validate(c) for c in conversations] items = [AppLogConversation.model_validate(c) for c in conversations]
@@ -79,32 +78,17 @@ def get_app_log_detail(
# 验证应用访问权限 # 验证应用访问权限
app_service = AppService(db) app_service = AppService(db)
app = app_service.get_app(app_id, workspace_id) app_service.get_app(app_id, workspace_id)
# 使用 Service 层查询 # 使用 Service 层查询
log_service = AppLogService(db) log_service = AppLogService(db)
conversation, messages, node_executions_map = log_service.get_conversation_detail( conversation, node_executions_map = log_service.get_conversation_detail(
app_id=app_id, app_id=app_id,
conversation_id=conversation_id, conversation_id=conversation_id,
workspace_id=workspace_id, workspace_id=workspace_id
app_type=app.type
) )
# 构建基础会话信息(不经过 ORM relationship detail = AppLogConversationDetail.model_validate(conversation)
base = AppLogConversation.model_validate(conversation) detail.node_executions_map = node_executions_map
# 单独处理 messages避免触发 SQLAlchemy relationship 校验
if messages and isinstance(messages[0], AppLogMessage):
# 工作流:已经是 AppLogMessage 实例
msg_list = messages
else:
# AgentORM Message 对象逐个转换
msg_list = [AppLogMessage.model_validate(m) for m in messages]
detail = AppLogConversationDetail(
**base.model_dump(),
messages=msg_list,
node_executions_map=node_executions_map,
)
return success(data=detail) return success(data=detail)

View File

@@ -4,9 +4,7 @@
处理显性记忆相关的API接口包括情景记忆和语义记忆的查询。 处理显性记忆相关的API接口包括情景记忆和语义记忆的查询。
""" """
from typing import Optional from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Query
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail from app.core.response_utils import success, fail
@@ -71,140 +69,6 @@ async def get_explicit_memory_overview_api(
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
@router.get("/episodics", response_model=ApiResponse)
async def get_episodic_memory_list_api(
end_user_id: str = Query(..., description="end user ID"),
page: int = Query(1, gt=0, description="page number, starting from 1"),
pagesize: int = Query(10, gt=0, le=100, description="number of items per page, max 100"),
start_date: Optional[int] = Query(None, description="start timestamp (ms)"),
end_date: Optional[int] = Query(None, description="end timestamp (ms)"),
episodic_type: str = Query("all", description="episodic type all/conversation/project_work/learning/decision/important_event"),
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取情景记忆分页列表
返回指定用户的情景记忆列表,支持分页、时间范围筛选和情景类型筛选。
Args:
end_user_id: 终端用户ID必填
page: 页码从1开始默认1
pagesize: 每页数量默认10最大100
start_date: 开始时间戳(可选,毫秒),自动扩展到当天 00:00:00
end_date: 结束时间戳(可选,毫秒),自动扩展到当天 23:59:59
episodic_type: 情景类型筛选可选默认all
current_user: 当前用户
Returns:
ApiResponse: 包含情景记忆分页列表
Examples:
- 基础分页查询GET /episodics?end_user_id=xxx&page=1&pagesize=5
返回第1页每页5条数据
- 按时间范围筛选GET /episodics?end_user_id=xxx&page=1&pagesize=5&start_date=1738684800000&end_date=1738771199000
返回指定时间范围内的数据
- 按情景类型筛选GET /episodics?end_user_id=xxx&page=1&pagesize=5&episodic_type=important_event
返回类型为"重要事件"的数据
Notes:
- start_date 和 end_date 必须同时提供或同时不提供
- start_date 不能大于 end_date
- episodic_type 可选值all, conversation, project_work, learning, decision, important_event
- total 为该用户情景记忆总数(不受筛选条件影响)
- page.total 为筛选后的总条数
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆列表但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"情景记忆分页查询: end_user_id={end_user_id}, "
f"start_date={start_date}, end_date={end_date}, episodic_type={episodic_type}, "
f"page={page}, pagesize={pagesize}, username={current_user.username}"
)
# 1. 参数校验
if page < 1 or pagesize < 1:
api_logger.warning(f"分页参数错误: page={page}, pagesize={pagesize}")
return fail(BizCode.INVALID_PARAMETER, "分页参数必须大于0")
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
if episodic_type not in valid_episodic_types:
api_logger.warning(f"无效的情景类型参数: {episodic_type}")
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
# 时间戳参数校验
if (start_date is not None and end_date is None) or (end_date is not None and start_date is None):
return fail(BizCode.INVALID_PARAMETER, "start_date和end_date必须同时提供")
if start_date is not None and end_date is not None and start_date > end_date:
return fail(BizCode.INVALID_PARAMETER, "start_date不能大于end_date")
# 2. 执行查询
try:
result = await memory_explicit_service.get_episodic_memory_list(
end_user_id=end_user_id,
page=page,
pagesize=pagesize,
start_date=start_date,
end_date=end_date,
episodic_type=episodic_type,
)
api_logger.info(
f"情景记忆分页查询成功: end_user_id={end_user_id}, "
f"total={result['total']}, 返回={len(result['items'])}"
)
except Exception as e:
api_logger.error(f"情景记忆分页查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "情景记忆分页查询失败", str(e))
# 3. 返回结构化响应
return success(data=result, msg="查询成功")
@router.get("/semantics", response_model=ApiResponse)
async def get_semantic_memory_list_api(
end_user_id: str = Query(..., description="终端用户ID"),
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取语义记忆列表
返回指定用户的全量语义记忆列表。
Args:
end_user_id: 终端用户ID必填
current_user: 当前用户
Returns:
ApiResponse: 包含语义记忆全量列表
"""
workspace_id = current_user.current_workspace_id
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询语义记忆列表但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"语义记忆列表查询: end_user_id={end_user_id}, username={current_user.username}"
)
try:
result = await memory_explicit_service.get_semantic_memory_list(
end_user_id=end_user_id
)
api_logger.info(
f"语义记忆列表查询成功: end_user_id={end_user_id}, total={len(result)}"
)
except Exception as e:
api_logger.error(f"语义记忆列表查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "语义记忆列表查询失败", str(e))
return success(data=result, msg="查询成功")
@router.post("/details", response_model=ApiResponse) @router.post("/details", response_model=ApiResponse)
async def get_explicit_memory_details_api( async def get_explicit_memory_details_api(
request: ExplicitMemoryDetailsRequest, request: ExplicitMemoryDetailsRequest,

View File

@@ -14,7 +14,6 @@ from . import (
rag_api_document_controller, rag_api_document_controller,
rag_api_file_controller, rag_api_file_controller,
rag_api_knowledge_controller, rag_api_knowledge_controller,
user_memory_api_controller,
) )
# 创建 V1 API 路由器 # 创建 V1 API 路由器
@@ -29,6 +28,5 @@ service_router.include_router(rag_api_chunk_controller.router)
service_router.include_router(memory_api_controller.router) service_router.include_router(memory_api_controller.router)
service_router.include_router(end_user_api_controller.router) service_router.include_router(end_user_api_controller.router)
service_router.include_router(memory_config_api_controller.router) service_router.include_router(memory_config_api_controller.router)
service_router.include_router(user_memory_api_controller.router)
__all__ = ["service_router"] __all__ = ["service_router"]

View File

@@ -296,7 +296,7 @@ async def chat(
} }
) )
# workflow 非流式返回 # 多 Agent 非流式返回
result = await app_chat_service.workflow_chat( result = await app_chat_service.workflow_chat(
message=payload.message, message=payload.message,

View File

@@ -3,7 +3,6 @@
from fastapi import APIRouter, Body, Depends, Query, Request from fastapi import APIRouter, Body, Depends, Query, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.celery_task_scheduler import scheduler
from app.core.api_key_auth import require_api_key from app.core.api_key_auth import require_api_key
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.quota_stub import check_end_user_quota from app.core.quota_stub import check_end_user_quota
@@ -87,7 +86,7 @@ async def write_memory(
user_rag_memory_id=payload.user_rag_memory_id, user_rag_memory_id=payload.user_rag_memory_id,
) )
logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}") logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted") return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
@@ -106,7 +105,8 @@ async def get_write_task_status(
""" """
logger.info(f"Write task status check - task_id: {task_id}") logger.info(f"Write task status check - task_id: {task_id}")
result = scheduler.get_task_status(task_id) from app.services.task_service import get_task_memory_write_result
result = get_task_memory_write_result(task_id)
return success(data=_sanitize_task_result(result), msg="Task status retrieved") return success(data=_sanitize_task_result(result), msg="Task status retrieved")

View File

@@ -1,230 +0,0 @@
"""User Memory 服务接口 — 基于 API Key 认证
包装 user_memory_controllers.py 和 memory_agent_controller.py 中的内部接口,
提供基于 API Key 认证的对外服务:
1./analytics/graph_data - 知识图谱数据接口
2./analytics/community_graph - 社区图谱接口
3./analytics/node_statistics - 记忆节点统计接口
4./analytics/user_summary - 用户摘要接口
5./analytics/memory_insight - 记忆洞察接口
6./analytics/interest_distribution - 兴趣分布接口
7./analytics/end_user_info - 终端用户信息接口
8./analytics/generate_cache - 缓存生成接口
路由前缀: /memory
子路径: /analytics/...
最终路径: /v1/memory/analytics/...
认证方式: API Key (@require_api_key)
"""
from typing import Optional
from fastapi import APIRouter, Depends, Header, Query, Request, Body
from sqlalchemy.orm import Session
from app.core.api_key_auth import require_api_key
from app.core.api_key_utils import get_current_user_from_api_key, validate_end_user_in_workspace
from app.core.logging_config import get_business_logger
from app.db import get_db
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_storage_schema import GenerateCacheRequest
# 包装内部服务 controller
from app.controllers import user_memory_controllers, memory_agent_controller
router = APIRouter(prefix="/memory", tags=["V1 - User Memory API"])
logger = get_business_logger()
# ==================== 知识图谱 ====================
@router.get("/analytics/graph_data")
@require_api_key(scopes=["memory"])
async def get_graph_data(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
node_types: Optional[str] = Query(None, description="Comma-separated node types filter"),
limit: int = Query(100, description="Max nodes to return (auto-capped at 1000 in service layer)"),
depth: int = Query(1, description="Graph traversal depth (auto-capped at 3 in service layer)"),
center_node_id: Optional[str] = Query(None, description="Center node for subgraph"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get knowledge graph data (nodes + edges) for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_graph_data_api(
end_user_id=end_user_id,
node_types=node_types,
limit=limit,
depth=depth,
center_node_id=center_node_id,
current_user=current_user,
db=db,
)
@router.get("/analytics/community_graph")
@require_api_key(scopes=["memory"])
async def get_community_graph(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get community clustering graph for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_community_graph_data_api(
end_user_id=end_user_id,
current_user=current_user,
db=db,
)
# ==================== 节点统计 ====================
@router.get("/analytics/node_statistics")
@require_api_key(scopes=["memory"])
async def get_node_statistics(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get memory node type statistics for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_node_statistics_api(
end_user_id=end_user_id,
current_user=current_user,
db=db,
)
# ==================== 用户摘要 & 洞察 ====================
@router.get("/analytics/user_summary")
@require_api_key(scopes=["memory"])
async def get_user_summary(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
language_type: str = Header(default=None, alias="X-Language-Type"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get cached user summary for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_user_summary_api(
end_user_id=end_user_id,
language_type=language_type,
current_user=current_user,
db=db,
)
@router.get("/analytics/memory_insight")
@require_api_key(scopes=["memory"])
async def get_memory_insight(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get cached memory insight report for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_memory_insight_report_api(
end_user_id=end_user_id,
current_user=current_user,
db=db,
)
# ==================== 兴趣分布 ====================
@router.get("/analytics/interest_distribution")
@require_api_key(scopes=["memory"])
async def get_interest_distribution(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
limit: int = Query(5, le=5, description="Max interest tags to return"),
language_type: str = Header(default=None, alias="X-Language-Type"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get interest distribution tags for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await memory_agent_controller.get_interest_distribution_by_user_api(
end_user_id=end_user_id,
limit=limit,
language_type=language_type,
current_user=current_user,
db=db,
)
# ==================== 终端用户信息 ====================
@router.get("/analytics/end_user_info")
@require_api_key(scopes=["memory"])
async def get_end_user_info(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get end user basic information (name, aliases, metadata)."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_end_user_info(
end_user_id=end_user_id,
current_user=current_user,
db=db,
)
# ==================== 缓存生成 ====================
@router.post("/analytics/generate_cache")
@require_api_key(scopes=["memory"])
async def generate_cache(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(None, description="Request body"),
language_type: str = Header(default=None, alias="X-Language-Type"),
):
"""Trigger cache generation (user summary + memory insight) for an end user or all workspace users."""
body = await request.json()
cache_request = GenerateCacheRequest(**body)
current_user = get_current_user_from_api_key(db, api_key_auth)
if cache_request.end_user_id:
validate_end_user_in_workspace(db, cache_request.end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.generate_cache_api(
request=cache_request,
language_type=language_type,
current_user=current_user,
db=db,
)

View File

@@ -173,8 +173,6 @@ async def delete_tool(
return success(msg="工具删除成功") return success(msg="工具删除成功")
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -251,8 +249,6 @@ async def parse_openapi_schema(
if result["success"] is False: if result["success"] is False:
raise HTTPException(status_code=400, detail=result["message"]) raise HTTPException(status_code=400, detail=result["message"])
return success(data=result, msg="Schema解析完成") return success(data=result, msg="Schema解析完成")
except HTTPException:
raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@@ -221,7 +221,7 @@ def update_workspace_members(
@router.delete("/members/{member_id}", response_model=ApiResponse) @router.delete("/members/{member_id}", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def delete_workspace_member( def delete_workspace_member(
member_id: uuid.UUID, member_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
@@ -230,7 +230,7 @@ async def delete_workspace_member(
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}") api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
await workspace_service.delete_workspace_member( workspace_service.delete_workspace_member(
db=db, db=db,
workspace_id=workspace_id, workspace_id=workspace_id,
member_id=member_id, member_id=member_id,

View File

@@ -1,15 +1,8 @@
"""API Key 工具函数""" """API Key 工具函数"""
import secrets import secrets
import uuid as _uuid
from typing import Optional, Union from typing import Optional, Union
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session as _Session
from app.core.error_codes import BizCode as _BizCode
from app.core.exceptions import BusinessException as _BusinessException
from app.models.end_user_model import EndUser as _EndUser
from app.repositories.end_user_repository import EndUserRepository as _EndUserRepository
from app.models.api_key_model import ApiKeyType from app.models.api_key_model import ApiKeyType
from fastapi import Response from fastapi import Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@@ -72,72 +65,3 @@ def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
return None return None
return int(dt.timestamp() * 1000) return int(dt.timestamp() * 1000)
def get_current_user_from_api_key(db: _Session, api_key_auth):
"""通过 API Key 构造 current_user 对象。
从 API Key 反查创建者(管理员用户),并设置其 workspace 上下文。
与内部接口的 Depends(get_current_user) (JWT) 等价。
Args:
db: 数据库会话
api_key_auth: API Key 认证信息ApiKeyAuth
Returns:
User ORM 对象,已设置 current_workspace_id
"""
from app.services import api_key_service
api_key = api_key_service.ApiKeyService.get_api_key(
db, api_key_auth.api_key_id, api_key_auth.workspace_id
)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return current_user
def validate_end_user_in_workspace(
db: _Session,
end_user_id: str,
workspace_id,
) -> _EndUser:
"""校验 end_user 是否存在且属于指定 workspace。
Args:
db: 数据库会话
end_user_id: 终端用户 ID
workspace_id: 工作空间 IDUUID 或字符串均可)
Returns:
EndUser ORM 对象(校验通过时)
Raises:
BusinessException(INVALID_PARAMETER): end_user_id 格式无效
BusinessException(USER_NOT_FOUND): end_user 不存在
BusinessException(PERMISSION_DENIED): end_user 不属于该 workspace
"""
try:
_uuid.UUID(end_user_id)
except (ValueError, AttributeError):
raise _BusinessException(
f"Invalid end_user_id format: {end_user_id}",
_BizCode.INVALID_PARAMETER,
)
end_user_repo = _EndUserRepository(db)
end_user = end_user_repo.get_end_user_by_id(end_user_id)
if end_user is None:
raise _BusinessException(
"End user not found",
_BizCode.USER_NOT_FOUND,
)
if str(end_user.workspace_id) != str(workspace_id):
raise _BusinessException(
"End user does not belong to this workspace",
_BizCode.PERMISSION_DENIED,
)
return end_user

View File

@@ -241,8 +241,6 @@ class Settings:
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587")) SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
SMTP_USER: str = os.getenv("SMTP_USER", "") SMTP_USER: str = os.getenv("SMTP_USER", "")
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "") SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
SANDBOX_URL: str = os.getenv("SANDBOX_URL", "")
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300")) REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))

View File

@@ -1,7 +1,6 @@
import json import json
import os import os
from app.celery_task_scheduler import scheduler
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
@@ -13,6 +12,8 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -85,28 +86,16 @@ async def write(
logger.info( logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
# write_id = write_message_task.delay( write_id = write_message_task.delay(
# actual_end_user_id, # end_user_id: User ID actual_end_user_id, # end_user_id: User ID
# structured_messages, # message: JSON string format message list structured_messages, # message: JSON string format message list
# str(actual_config_id), # config_id: Configuration ID string str(actual_config_id), # config_id: Configuration ID string
# storage_type, # storage_type: "neo4j" storage_type, # storage_type: "neo4j"
# user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
# )
scheduler.push_task(
"app.core.memory.agent.write_message",
str(actual_end_user_id),
{
"end_user_id": str(actual_end_user_id),
"message": structured_messages,
"config_id": str(actual_config_id),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id or ""
}
) )
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
# logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") write_status = get_task_memory_write_result(str(write_id))
# write_status = get_task_memory_write_result(str(write_id)) logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
# logger.info(f'[WRITE] Task result - user={actual_end_user_id}')
async def term_memory_save(end_user_id, strategy_type, scope): async def term_memory_save(end_user_id, strategy_type, scope):
@@ -175,24 +164,13 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
else: else:
config_id = memory_config config_id = memory_config
scheduler.push_task( write_message_task.delay(
"app.core.memory.agent.write_message", end_user_id, # end_user_id: User ID
str(end_user_id), redis_messages, # message: JSON string format message list
{ config_id, # config_id: Configuration ID string
"end_user_id": str(end_user_id), AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
"message": redis_messages, "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
"config_id": str(config_id),
"storage_type": AgentMemory_Long_Term.STORAGE_NEO4J,
"user_rag_memory_id": ""
}
) )
# write_message_task.delay(
# end_user_id, # end_user_id: User ID
# redis_messages, # message: JSON string format message list
# config_id, # config_id: Configuration ID string
# AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
# "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
# )
count_store.update_sessions_count(end_user_id, 0, []) count_store.update_sessions_count(end_user_id, 0, [])

View File

@@ -1,8 +1,8 @@
from app.core.memory.enums import SearchStrategy, StorageType from app.core.memory.enums import SearchStrategy, StorageType
from app.core.memory.models.service_models import MemorySearchResult from app.core.memory.models.service_models import MemorySearchResult
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService from app.core.memory.read_services.content_search import Neo4jSearchService, RAGSearchService
from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor from app.core.memory.read_services.query_preprocessor import QueryPreprocessor
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline): class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):

View File

@@ -8,7 +8,7 @@ from neo4j import Session
from app.core.memory.enums import Neo4jNodeType from app.core.memory.enums import Neo4jNodeType
from app.core.memory.memory_service import MemoryContext from app.core.memory.memory_service import MemoryContext
from app.core.memory.models.service_models import Memory, MemorySearchResult from app.core.memory.models.service_models import Memory, MemorySearchResult
from app.core.memory.read_services.search_engine.result_builder import data_builder_factory from app.core.memory.read_services.result_builder import data_builder_factory
from app.core.models import RedBearEmbeddings from app.core.models import RedBearEmbeddings
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.repositories import knowledge_repository from app.repositories import knowledge_repository

View File

@@ -8,4 +8,4 @@ class RetrievalSummaryProcessor:
@staticmethod @staticmethod
def verify(content: str, llm_client: RedBearLLM): def verify(content: str, llm_client: RedBearLLM):
return return

View File

@@ -216,7 +216,7 @@ class RedBearModelFactory:
# 深度思考模式Claude 3.7 Sonnet 等支持思考的模型 # 深度思考模式Claude 3.7 Sonnet 等支持思考的模型
# 通过 additional_model_request_fields 传递 thinking 块关闭时不传Bedrock 无 disabled 选项) # 通过 additional_model_request_fields 传递 thinking 块关闭时不传Bedrock 无 disabled 选项)
if config.deep_thinking: if config.deep_thinking:
budget = config.thinking_budget_tokens or 1024 budget = config.thinking_budget_tokens or 10000
params["additional_model_request_fields"] = { params["additional_model_request_fields"] = {
"thinking": {"type": "enabled", "budget_tokens": budget} "thinking": {"type": "enabled", "budget_tokens": budget}
} }

View File

@@ -73,7 +73,6 @@ class CustomTool(BaseTool):
# 添加通用参数(基于第一个操作的参数) # 添加通用参数(基于第一个操作的参数)
if self._parsed_operations: if self._parsed_operations:
first_operation = next(iter(self._parsed_operations.values())) first_operation = next(iter(self._parsed_operations.values()))
# path/query 参数
for param_name, param_info in first_operation.get("parameters", {}).items(): for param_name, param_info in first_operation.get("parameters", {}).items():
params.append(ToolParameter( params.append(ToolParameter(
name=param_name, name=param_name,
@@ -86,23 +85,6 @@ class CustomTool(BaseTool):
maximum=param_info.get("maximum"), maximum=param_info.get("maximum"),
pattern=param_info.get("pattern") pattern=param_info.get("pattern")
)) ))
# requestBody 参数 — 将 body 字段平铺为独立参数暴露给模型
request_body = first_operation.get("request_body")
if request_body:
body_schema = request_body.get("properties", {})
required_fields = request_body.get("required", [])
for prop_name, prop_schema in body_schema.items():
params.append(ToolParameter(
name=prop_name,
type=self._convert_openapi_type(prop_schema.get("type", "string")),
description=prop_schema.get("description", ""),
required=prop_name in required_fields,
default=prop_schema.get("default"),
enum=prop_schema.get("enum"),
minimum=prop_schema.get("minimum"),
maximum=prop_schema.get("maximum"),
pattern=prop_schema.get("pattern")
))
return params return params

View File

@@ -16,7 +16,6 @@ from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.engine.state_manager import WorkflowStateManager from app.core.workflow.engine.state_manager import WorkflowStateManager
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
from app.core.workflow.nodes.base_node import NodeExecutionError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -327,43 +326,10 @@ class WorkflowExecutor:
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}", logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
exc_info=True) exc_info=True)
# 1) 尝试从 checkpoint 回补已成功节点的 node_outputs
recovered: dict[str, Any] = {}
try:
if self.graph is not None:
recovered = self.graph.get_state(
self.execution_context.checkpoint_config
).values or {}
except Exception as recover_err:
logger.warning(
f"Recover state on failure failed: {recover_err}, "
f"execution_id={self.execution_context.execution_id}"
)
if result is None: if result is None:
result = dict(recovered) if recovered else {} result = {"error": str(e)}
else: else:
# 已有 result 与 recovered 合并node_outputs 深度合并 result["error"] = str(e)
for k, v in recovered.items():
if k == "node_outputs" and isinstance(v, dict):
existing = result.get("node_outputs") or {}
result["node_outputs"] = {**v, **existing}
else:
result.setdefault(k, v)
# 2) 如果是节点抛出的 NodeExecutionError把失败节点的 node_output 注入 node_outputs
failed_node_id: str | None = None
if isinstance(e, NodeExecutionError):
failed_node_id = e.node_id
node_outputs = result.setdefault("node_outputs", {})
# 不覆盖已有(理论上不会有),保底写入失败节点记录
node_outputs.setdefault(e.node_id, e.node_output)
result["error"] = str(e)
if failed_node_id:
result["error_node"] = failed_node_id
yield { yield {
"event": "workflow_end", "event": "workflow_end",
"data": self.result_builder.build_final_output( "data": self.result_builder.build_final_output(

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import logging import logging
import time
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
@@ -23,20 +22,6 @@ from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NodeExecutionError(Exception):
"""节点执行失败异常。
携带失败节点的完整 node_output供 executor 兜底注入 node_outputs
保证 workflow_executions.output_data 里能看到失败节点的日志记录。
"""
def __init__(self, node_id: str, node_output: dict[str, Any], error_message: str):
super().__init__(f"Node {node_id} execution failed: {error_message}")
self.node_id = node_id
self.node_output = node_output
self.error_message = error_message
class BaseNode(ABC): class BaseNode(ABC):
"""Base class for workflow nodes. """Base class for workflow nodes.
@@ -411,8 +396,6 @@ class BaseNode(ABC):
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"token_usage": token_usage, "token_usage": token_usage,
"error": None, "error": None,
# 单调递增序号用于日志按执行顺序排序JSONB 不保证 key 顺序)
"execution_order": time.monotonic_ns(),
**self._extract_extra_fields(business_result), **self._extract_extra_fields(business_result),
} }
final_output = { final_output = {
@@ -461,9 +444,7 @@ class BaseNode(ABC):
"output": None, "output": None,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"token_usage": None, "token_usage": None,
"error": error_message, "error": error_message
# 单调递增序号,用于日志按执行顺序排序
"execution_order": time.monotonic_ns(),
} }
# if error_edge: # if error_edge:
@@ -485,12 +466,7 @@ class BaseNode(ABC):
**node_output **node_output
}) })
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
# 抛出自定义异常,把 node_output 带给 executor供其写入 node_outputs raise Exception(f"Node {self.node_id} execution failed: {error_message}")
raise NodeExecutionError(
node_id=self.node_id,
node_output=node_output,
error_message=error_message,
)
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Extracts the input data for this node (used for logging or audit). """Extracts the input data for this node (used for logging or audit).

View File

@@ -14,7 +14,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import BaseNode from app.core.workflow.nodes import BaseNode
from app.core.workflow.nodes.code.config import CodeNodeConfig from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -132,7 +131,7 @@ class CodeNode(BaseNode):
async with httpx.AsyncClient(timeout=60) as client: async with httpx.AsyncClient(timeout=60) as client:
response = await client.post( response = await client.post(
f"{settings.SANDBOX_URL}:8194/v1/sandbox/run", "http://sandbox:8194/v1/sandbox/run",
headers={ headers={
"x-api-key": 'redbear-sandbox' "x-api-key": 'redbear-sandbox'
}, },

View File

@@ -174,18 +174,12 @@ class IterationRuntime:
continue continue
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type") node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
node_cfg = next(
(n for n in self.cycle_nodes if n.get("id") == node_name), None
)
self.event_write({ self.event_write({
"type": "cycle_item", "type": "cycle_item",
"data": { "data": {
"cycle_id": self.node_id, "cycle_id": self.node_id,
"cycle_idx": idx, "cycle_idx": idx,
"node_id": node_name, "node_id": node_name,
"node_type": node_type,
"node_name": node_cfg.get("data", {}).get("label") if node_cfg else node_name,
"status": result.get("node_outputs", {}).get(node_name, {}).get("status", "completed"),
"input": result.get("node_outputs", {}).get(node_name, {}).get("input") "input": result.get("node_outputs", {}).get(node_name, {}).get("input")
if not cycle_variable else cycle_variable, if not cycle_variable else cycle_variable,
"output": result.get("node_outputs", {}).get(node_name, {}).get("output") "output": result.get("node_outputs", {}).get(node_name, {}).get("output")

View File

@@ -210,9 +210,6 @@ class LoopRuntime:
"cycle_id": self.node_id, "cycle_id": self.node_id,
"cycle_idx": idx, "cycle_idx": idx,
"node_id": node_name, "node_id": node_name,
"node_type": node_type,
"node_name": node_name,
"status": result.get("node_outputs", {}).get(node_name, {}).get("status", "completed"),
"input": result.get("node_outputs", {}).get(node_name, {}).get("input") "input": result.get("node_outputs", {}).get(node_name, {}).get("input")
if not cycle_variable else cycle_variable, if not cycle_variable else cycle_variable,
"output": result.get("node_outputs", {}).get(node_name, {}).get("output") "output": result.get("node_outputs", {}).get(node_name, {}).get("output")

View File

@@ -182,7 +182,7 @@ class DocExtractorNode(BaseNode):
mime_type=f"image/{ext}", mime_type=f"image/{ext}",
is_file=True, is_file=True,
).model_dump()) ).model_dump())
text = text + f"\n{placeholder}: <img src=\"{url}\" data-url=\"{url}\">" text = text + f"\n{placeholder}: {url}"
except Exception as e: except Exception as e:
logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}") logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}")

View File

@@ -272,11 +272,6 @@ class HttpRequestNodeOutput(BaseModel):
description="HTTP response body", description="HTTP response body",
) )
process_data: dict = Field(
default_factory=dict,
description="Raw HTTP request details for debugging",
)
# files: list[File] = Field( # files: list[File] = Field(
# ... # ...
# ) # )

View File

@@ -160,6 +160,7 @@ class HttpRequestNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: HttpRequestNodeConfig | None = None self.typed_config: HttpRequestNodeConfig | None = None
self.last_request: str = ""
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return { return {
@@ -170,6 +171,47 @@ class HttpRequestNode(BaseNode):
"output": VariableType.STRING "output": VariableType.STRING
} }
def _extract_output(self, business_result: Any) -> Any:
if isinstance(business_result, dict):
result = {k: v for k, v in business_result.items() if k != "request"}
return result
return business_result
def _extract_extra_fields(self, business_result: Any) -> dict[str, Any]:
if isinstance(business_result, dict) and "request" in business_result:
return {
"process": {
"request": business_result.get("request", "")
}
}
return {}
def _wrap_error(
self,
error_message: str,
elapsed_time: float,
state: WorkflowState,
variable_pool: VariablePool
) -> dict[str, Any]:
input_data = self._extract_input(state, variable_pool)
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
"node_name": self.node_name,
"status": "failed",
"input": input_data,
"output": None,
"process": {"request": self.last_request} if self.last_request else None,
"elapsed_time": elapsed_time,
"token_usage": None,
"error": error_message
}
return {
"node_outputs": {self.node_id: node_output},
"error": error_message,
"error_node": self.node_id
}
def _build_timeout(self) -> Timeout: def _build_timeout(self) -> Timeout:
""" """
Build httpx Timeout configuration. Build httpx Timeout configuration.
@@ -255,18 +297,13 @@ class HttpRequestNode(BaseNode):
case HttpContentType.NONE: case HttpContentType.NONE:
return {} return {}
case HttpContentType.JSON: case HttpContentType.JSON:
rendered = self._render_template( rendered_body = self._render_template(
self.typed_config.body.data, variable_pool self.typed_config.body.data, variable_pool
) ).strip()
if not rendered or not rendered.strip(): if not rendered_body:
# 第三方导入的工作流可能出现 content_type=json 但 data 为空的情况,视为无 body content["json"] = {}
return {} else:
try: content["json"] = json.loads(rendered_body)
content["json"] = json.loads(rendered)
except json.JSONDecodeError as e:
raise RuntimeError(
f"Invalid JSON body for HTTP request node: {e.msg} (data={rendered!r})"
)
case HttpContentType.FROM_DATA: case HttpContentType.FROM_DATA:
data = {} data = {}
files = [] files = []
@@ -334,15 +371,61 @@ class HttpRequestNode(BaseNode):
case _: case _:
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}") raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
def _extract_output(self, business_result: Any) -> Any: def _generate_raw_request(
if isinstance(business_result, dict): self,
return {k: v for k, v in business_result.items() if k != "process_data"} variable_pool: VariablePool,
return business_result url: str,
headers: dict[str, str],
params: dict[str, str],
content: dict[str, Any]
) -> str:
"""
Generate raw HTTP request format for debugging.
def _extract_extra_fields(self, business_result: Any) -> dict: Args:
if isinstance(business_result, dict) and "process_data" in business_result: variable_pool: Variable Pool
return {"process": business_result["process_data"]} url: Rendered URL
return {} headers: Request headers
params: Query parameters
content: Request body content
Returns:
Raw HTTP request string
"""
method = self.typed_config.method.value
if params:
param_str = "&".join([f"{k}={v}" for k, v in params.items()])
full_url = f"{url}?{param_str}" if "?" not in url else f"{url}&{param_str}"
else:
full_url = url
lines = [f"{method} {full_url} HTTP/1.1"]
for key, value in headers.items():
lines.append(f"{key}: {value}")
if "json" in content and content["json"]:
json_body = json.dumps(content["json"], ensure_ascii=False)
lines.append(f"Content-Length: {len(json_body)}")
lines.append("")
lines.append(json_body)
elif "data" in content and "files" not in content:
if isinstance(content["data"], dict):
body_str = "&".join([f"{k}={v}" for k, v in content["data"].items()])
lines.append(f"Content-Length: {len(body_str)}")
lines.append("")
lines.append(body_str)
elif "content" in content:
lines.append(f"Content-Length: {len(content['content'])}")
lines.append("")
lines.append(content["content"])
elif "files" in content:
lines.append("Content-Length: 0")
lines.append("")
lines.append("# Note: This request includes file uploads")
return "\r\n".join(lines)
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
""" """
@@ -362,42 +445,47 @@ class HttpRequestNode(BaseNode):
- str: Branch identifier (e.g. "ERROR") when branching is enabled - str: Branch identifier (e.g. "ERROR") when branching is enabled
""" """
self.typed_config = HttpRequestNodeConfig(**self.config) self.typed_config = HttpRequestNodeConfig(**self.config)
rendered_url = self._render_template(self.typed_config.url, variable_pool)
built_headers = self._build_header(variable_pool) | self._build_auth(variable_pool) # Build request components
built_params = self._build_params(variable_pool) headers = self._build_header(variable_pool) | self._build_auth(variable_pool)
params = self._build_params(variable_pool)
content = await self._build_content(variable_pool)
url = self._render_template(self.typed_config.url, variable_pool)
logger.info(f"Node {self.node_id}: headers={headers}, params={params}, content keys={list(content.keys())}")
# Generate raw HTTP request for debugging
raw_request = self._generate_raw_request(variable_pool, url, headers, params, content)
self.last_request = raw_request
logger.info(f"Node {self.node_id}: Generated HTTP request:\n{raw_request}")
async with httpx.AsyncClient( async with httpx.AsyncClient(
verify=self.typed_config.verify_ssl, verify=self.typed_config.verify_ssl,
timeout=self._build_timeout(), timeout=self._build_timeout(),
headers=built_headers, headers=headers,
params=built_params, params=params,
follow_redirects=True follow_redirects=True
) as client: ) as client:
retries = self.typed_config.retry.max_attempts retries = self.typed_config.retry.max_attempts
while retries > 0: while retries > 0:
try: try:
request_func = self._get_client_method(client) request_func = self._get_client_method(client)
built_content = await self._build_content(variable_pool)
resp = await request_func( resp = await request_func(
url=rendered_url, url=url,
**built_content **content
) )
resp.raise_for_status() resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded") logger.info(f"Node {self.node_id}: HTTP request succeeded")
response = HttpResponse(resp) response = HttpResponse(resp)
# Build raw request summary for process_data return {
raw_request = ( **HttpRequestNodeOutput(
f"{self.typed_config.method.upper()} {resp.request.url} HTTP/1.1\r\n" body=response.body,
+ "".join(f"{k}: {v}\r\n" for k, v in resp.request.headers.items()) status_code=resp.status_code,
+ "\r\n" headers=resp.headers,
+ (resp.request.content.decode(errors="replace") if resp.request.content else "") files=response.files
) ).model_dump(),
return HttpRequestNodeOutput( "request": raw_request
body=response.body, }
status_code=resp.status_code,
headers=resp.headers,
files=response.files,
process_data={"request": raw_request},
).model_dump()
except (httpx.HTTPStatusError, httpx.RequestError) as e: except (httpx.HTTPStatusError, httpx.RequestError) as e:
logger.error(f"HTTP request node exception: {e}") logger.error(f"HTTP request node exception: {e}")
retries -= 1 retries -= 1
@@ -413,10 +501,19 @@ class HttpRequestNode(BaseNode):
logger.warning( logger.warning(
f"Node {self.node_id}: HTTP request failed, returning default result" f"Node {self.node_id}: HTTP request failed, returning default result"
) )
return self.typed_config.error_handle.default.model_dump() error_result = self.typed_config.error_handle.default.model_dump()
error_result["request"] = raw_request
return error_result
case HttpErrorHandle.BRANCH: case HttpErrorHandle.BRANCH:
logger.warning( logger.warning(
f"Node {self.node_id}: HTTP request failed, switching to error handling branch" f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
) )
return {"output": "ERROR"} return {
"output": "ERROR",
"body": "",
"status_code": 500,
"headers": {},
"files": [],
"request": raw_request
}
raise RuntimeError("http request failed") raise RuntimeError("http request failed")

View File

@@ -334,8 +334,7 @@ class KnowledgeRetrievalNode(BaseNode):
for kb_config in knowledge_bases: for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1): if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
logger.warning("The knowledge base does not exist or access is denied.") raise RuntimeError("The knowledge base does not exist or access is denied.")
continue
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config)) tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
if tasks: if tasks:
result = await asyncio.gather(*tasks) result = await asyncio.gather(*tasks)

View File

@@ -1,7 +1,6 @@
import re import re
from typing import Any from typing import Any
from app.celery_task_scheduler import scheduler
from app.core.memory.enums import SearchStrategy from app.core.memory.enums import SearchStrategy
from app.core.memory.memory_service import MemoryService from app.core.memory.memory_service import MemoryService
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
@@ -12,6 +11,7 @@ from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.db import get_db_read from app.db import get_db_read
from app.schemas import FileInput from app.schemas import FileInput
from app.tasks import write_message_task
class MemoryReadNode(BaseNode): class MemoryReadNode(BaseNode):
@@ -126,23 +126,12 @@ class MemoryWriteNode(BaseNode):
"files": file_info "files": file_info
}) })
scheduler.push_task( write_message_task.delay(
"app.core.memory.agent.write_message", end_user_id=end_user_id,
end_user_id, message=messages,
{ config_id=str(self.typed_config.config_id),
"end_user_id": end_user_id, storage_type=state["memory_storage_type"],
"message": messages, user_rag_memory_id=state["user_rag_memory_id"]
"config_id": str(self.typed_config.config_id),
"storage_type": state["memory_storage_type"],
"user_rag_memory_id": state["user_rag_memory_id"]
}
) )
# write_message_task.delay(
# end_user_id=end_user_id,
# message=messages,
# config_id=str(self.typed_config.config_id),
# storage_type=state["memory_storage_type"],
# user_rag_memory_id=state["user_rag_memory_id"]
# )
return "success" return "success"

View File

@@ -1,15 +1,13 @@
import uuid import uuid
from typing import Optional from typing import Optional
from sqlalchemy import select, desc, func, or_, cast, Text from sqlalchemy import select, desc, func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.exceptions import ResourceNotFoundException from app.core.exceptions import ResourceNotFoundException
from app.core.logging_config import get_db_logger from app.core.logging_config import get_db_logger
from app.models import Conversation, Message from app.models import Conversation, Message
from app.models.app_model import AppType
from app.models.conversation_model import ConversationDetail from app.models.conversation_model import ConversationDetail
from app.models.workflow_model import WorkflowExecution
logger = get_db_logger() logger = get_db_logger()
@@ -208,8 +206,7 @@ class ConversationRepository:
is_draft: Optional[bool] = None, is_draft: Optional[bool] = None,
keyword: Optional[str] = None, keyword: Optional[str] = None,
page: int = 1, page: int = 1,
pagesize: int = 20, pagesize: int = 20
app_type: Optional[str] = None,
) -> tuple[list[Conversation], int]: ) -> tuple[list[Conversation], int]:
""" """
查询应用日志会话列表(带分页和过滤) 查询应用日志会话列表(带分页和过滤)
@@ -221,9 +218,6 @@ class ConversationRepository:
keyword: 搜索关键词(匹配消息内容) keyword: 搜索关键词(匹配消息内容)
page: 页码(从 1 开始) page: 页码(从 1 开始)
pagesize: 每页数量 pagesize: 每页数量
app_type: 应用类型。WORKFLOW 类型改用 workflow_executions 的
input_data/output_data 做关键词过滤(因为失败的工作流不会写入 messages 表);
其他类型仍走 messages 表。
Returns: Returns:
Tuple[List[Conversation], int]: (会话列表,总数) Tuple[List[Conversation], int]: (会话列表,总数)
@@ -240,28 +234,12 @@ class ConversationRepository:
# 如果有关键词搜索,通过子查询过滤包含该关键词的 conversation # 如果有关键词搜索,通过子查询过滤包含该关键词的 conversation
if keyword: if keyword:
kw_pattern = f"%{keyword}%" # 查找包含关键词的 conversation_id 列表
if app_type == AppType.WORKFLOW: keyword_stmt = (
# 工作流:从 workflow_executions 的 input_data / output_data 匹配 select(Message.conversation_id)
# messages 表只存开场白 assistant 消息,失败的工作流也不会写入) .where(Message.content.ilike(f"%{keyword}%"))
keyword_stmt = ( .distinct()
select(WorkflowExecution.conversation_id) )
.where(
WorkflowExecution.conversation_id.is_not(None),
or_(
cast(WorkflowExecution.input_data, Text).ilike(kw_pattern),
cast(WorkflowExecution.output_data, Text).ilike(kw_pattern),
),
)
.distinct()
)
else:
# Agent 等其他类型:仍走 messages 表user + assistant 内容)
keyword_stmt = (
select(Message.conversation_id)
.where(Message.content.ilike(kw_pattern))
.distinct()
)
base_stmt = base_stmt.where(Conversation.id.in_(keyword_stmt)) base_stmt = base_stmt.where(Conversation.id.in_(keyword_stmt))
# Calculate total number of records # Calculate total number of records

View File

@@ -14,7 +14,6 @@ class AppLogMessage(BaseModel):
conversation_id: uuid.UUID conversation_id: uuid.UUID
role: str = Field(description="角色: user / assistant / system") role: str = Field(description="角色: user / assistant / system")
content: str content: str
status: Optional[str] = Field(default=None, description="执行状态(工作流专用): completed / failed")
meta_data: Optional[Dict[str, Any]] = None meta_data: Optional[Dict[str, Any]] = None
created_at: datetime.datetime created_at: datetime.datetime
@@ -59,7 +58,6 @@ class AppLogNodeExecution(BaseModel):
input: Optional[Any] = None input: Optional[Any] = None
process: Optional[Any] = None process: Optional[Any] = None
output: Optional[Any] = None output: Optional[Any] = None
cycle_items: Optional[List[Any]] = None
elapsed_time: Optional[float] = None elapsed_time: Optional[float] = None
token_usage: Optional[Dict[str, Any]] = None token_usage: Optional[Dict[str, Any]] = None

View File

@@ -3,7 +3,7 @@ import uuid
from typing import Optional, Any, List, Dict, Union from typing import Optional, Any, List, Dict, Union
from enum import Enum, StrEnum from enum import Enum, StrEnum
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator, model_serializer from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
from app.schemas.workflow_schema import WorkflowConfigCreate from app.schemas.workflow_schema import WorkflowConfigCreate
@@ -250,7 +250,7 @@ class ModelParameters(BaseModel):
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量") n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
stop: Optional[List[str]] = Field(default=None, description="停止序列") stop: Optional[List[str]] = Field(default=None, description="停止序列")
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)") deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1, le=131072, description="深度思考 token 预算(仅部分模型支持)") thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)")
json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)") json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)")
@@ -661,11 +661,9 @@ class DraftRunResponse(BaseModel):
suggested_questions: List[str] = Field(default_factory=list, description="下一步建议问题") suggested_questions: List[str] = Field(default_factory=list, description="下一步建议问题")
citations: List[Dict[str, Any]] = Field(default_factory=list, description="引用来源") citations: List[Dict[str, Any]] = Field(default_factory=list, description="引用来源")
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL") audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
audio_status: Optional[str] = Field(default=None, description="TTS 语音状态")
@model_serializer(mode="wrap") def model_dump(self, **kwargs):
def _serialize(self, handler): data = super().model_dump(**kwargs)
data = handler(self)
if not data.get("reasoning_content"): if not data.get("reasoning_content"):
data.pop("reasoning_content", None) data.pop("reasoning_content", None)
return data return data

View File

@@ -2,7 +2,7 @@
import uuid import uuid
import datetime import datetime
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from pydantic import BaseModel, Field, ConfigDict, field_serializer, model_serializer from pydantic import BaseModel, Field, ConfigDict, field_serializer
# 导入 FileInput用于体验运行 # 导入 FileInput用于体验运行
from app.schemas.app_schema import FileInput from app.schemas.app_schema import FileInput
@@ -94,18 +94,6 @@ class ChatResponse(BaseModel):
message_id: str message_id: str
usage: Optional[Dict[str, Any]] = None usage: Optional[Dict[str, Any]] = None
elapsed_time: Optional[float] = None elapsed_time: Optional[float] = None
reasoning_content: Optional[str] = None
suggested_questions: Optional[List[str]] = None
citations: Optional[List[Dict[str, Any]]] = None
audio_url: Optional[str] = None
audio_status: Optional[str] = None
@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if not data.get("reasoning_content"):
data.pop("reasoning_content", None)
return data
# ---------- Conversation Summary Schemas ---------- # ---------- Conversation Summary Schemas ----------

View File

@@ -112,12 +112,12 @@ class MemoryWriteResponse(BaseModel):
"""Response schema for memory write operation. """Response schema for memory write operation.
Attributes: Attributes:
task_id: task ID for status polling task_id: Celery task ID for status polling
status: Initial task status (QUEUED) status: Initial task status (PENDING)
end_user_id: End user ID the write was submitted for end_user_id: End user ID the write was submitted for
""" """
task_id: str = Field(..., description="task ID for polling") task_id: str = Field(..., description="Celery task ID for polling")
status: str = Field(..., description="Task status: QUEUED") status: str = Field(..., description="Task status: PENDING")
end_user_id: str = Field(..., description="End user ID") end_user_id: str = Field(..., description="End user ID")

View File

@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from sqlalchemy import select from sqlalchemy import select
from app.aioRedis import aio_redis from app.aioRedis import aio_redis
from app.models.api_key_model import ApiKey, ApiKeyType from app.models.api_key_model import ApiKey
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
from app.schemas import api_key_schema from app.schemas import api_key_schema
from app.schemas.response_schema import PageData, PageMeta from app.schemas.response_schema import PageData, PageMeta
@@ -65,12 +65,6 @@ class ApiKeyService:
BizCode.BAD_REQUEST BizCode.BAD_REQUEST
) )
# SERVICE 类型的 resource_id 指向 workspace非应用跳过应用发布校验
if data.resource_id and data.type != ApiKeyType.SERVICE.value:
app = db.get(App, data.resource_id)
if not app or not app.current_release_id:
raise BusinessException("该应用未发布", BizCode.APP_NOT_PUBLISHED)
# 生成 API Key # 生成 API Key
api_key = generate_api_key(data.type) api_key = generate_api_key(data.type)
@@ -453,12 +447,9 @@ class ApiKeyAuthService:
def check_app_published(db: Session, api_key_obj: ApiKey) -> None: def check_app_published(db: Session, api_key_obj: ApiKey) -> None:
""" """
检查应用是否已发布,未发布则抛出异常 检查应用是否已发布,未发布则抛出异常
SERVICE 类型的 api_key 不绑定应用resource_id 指向 workspace跳过校验
""" """
if not api_key_obj.resource_id: if not api_key_obj.resource_id:
return return
if api_key_obj.type == ApiKeyType.SERVICE.value:
return
app = db.get(App, api_key_obj.resource_id) app = db.get(App, api_key_obj.resource_id)
if not app or not app.current_release_id: if not app or not app.current_release_id:
raise BusinessException("应用未发布,不可用", BizCode.APP_NOT_PUBLISHED) raise BusinessException("应用未发布,不可用", BizCode.APP_NOT_PUBLISHED)

View File

@@ -107,6 +107,23 @@ class AppChatService:
# 获取模型参数 # 获取模型参数
model_parameters = config.model_parameters model_parameters = config.model_parameters
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
json_output=model_parameters.get("json_output", False),
capability=api_key_obj.capability or [],
)
model_info = ModelInfo( model_info = ModelInfo(
model_name=api_key_obj.model_name, model_name=api_key_obj.model_name,
provider=api_key_obj.provider, provider=api_key_obj.provider,
@@ -160,30 +177,16 @@ class AppChatService:
if doc_img_recognition and "vision" in (api_key_obj.capability or []) and any( if doc_img_recognition and "vision" in (api_key_obj.capability or []) and any(
f.type == FileType.DOCUMENT for f in files f.type == FileType.DOCUMENT for f in files
): ):
system_prompt += ( from langchain.agents import create_agent
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>" agent.system_prompt += (
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片" "\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" "请在回答中用 Markdown 格式 ![描述](URL) 展示相关图片,做到图文并茂。"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。" )
agent.agent = create_agent(
model=agent.llm,
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
system_prompt=agent.system_prompt
) )
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
json_output=model_parameters.get("json_output", False),
capability=api_key_obj.capability or [],
)
# 为需要运行时上下文的工具注入上下文 # 为需要运行时上下文的工具注入上下文
for t in tools: for t in tools:
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
@@ -320,7 +323,7 @@ class AppChatService:
"suggested_questions": suggested_questions, "suggested_questions": suggested_questions,
"citations": filtered_citations, "citations": filtered_citations,
"audio_url": audio_url, "audio_url": audio_url,
"audio_status": "pending" if audio_url else None "audio_status": "pending"
} }
async def agnet_chat_stream( async def agnet_chat_stream(
@@ -396,6 +399,24 @@ class AppChatService:
# 获取模型参数 # 获取模型参数
model_parameters = config.model_parameters model_parameters = config.model_parameters
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
json_output=model_parameters.get("json_output", False),
capability=api_key_obj.capability or [],
)
model_info = ModelInfo( model_info = ModelInfo(
model_name=api_key_obj.model_name, model_name=api_key_obj.model_name,
provider=api_key_obj.provider, provider=api_key_obj.provider,
@@ -450,30 +471,15 @@ class AppChatService:
f.type == FileType.DOCUMENT for f in files f.type == FileType.DOCUMENT for f in files
): ):
from langchain.agents import create_agent from langchain.agents import create_agent
system_prompt += ( agent.system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>" "\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片" "请在回答中用 Markdown 格式 ![描述](URL) 展示相关图片,做到图文并茂"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" )
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。" agent.agent = create_agent(
model=agent.llm,
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
system_prompt=agent.system_prompt
) )
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
json_output=model_parameters.get("json_output", False),
capability=api_key_obj.capability or [],
)
# 为需要运行时上下文的工具注入上下文 # 为需要运行时上下文的工具注入上下文
for t in tools: for t in tools:

View File

@@ -1,17 +1,16 @@
"""应用日志服务层""" """应用日志服务层"""
import uuid import uuid
import datetime as dt
from typing import Optional, Tuple from typing import Optional, Tuple
from datetime import datetime
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.models.app_model import AppType
from app.models.conversation_model import Conversation, Message from app.models.conversation_model import Conversation, Message
from app.models.workflow_model import WorkflowExecution from app.models.workflow_model import WorkflowExecution
from app.repositories.conversation_repository import ConversationRepository, MessageRepository from app.repositories.conversation_repository import ConversationRepository, MessageRepository
from app.schemas.app_log_schema import AppLogMessage, AppLogNodeExecution from app.schemas.app_log_schema import AppLogNodeExecution
logger = get_business_logger() logger = get_business_logger()
@@ -32,7 +31,6 @@ class AppLogService:
pagesize: int = 20, pagesize: int = 20,
is_draft: Optional[bool] = None, is_draft: Optional[bool] = None,
keyword: Optional[str] = None, keyword: Optional[str] = None,
app_type: Optional[str] = None,
) -> Tuple[list[Conversation], int]: ) -> Tuple[list[Conversation], int]:
""" """
查询应用日志会话列表 查询应用日志会话列表
@@ -44,7 +42,6 @@ class AppLogService:
pagesize: 每页数量 pagesize: 每页数量
is_draft: 是否草稿会话None表示返回全部 is_draft: 是否草稿会话None表示返回全部
keyword: 搜索关键词(匹配消息内容) keyword: 搜索关键词(匹配消息内容)
app_type: 应用类型WORKFLOW 时关键词将从 workflow_executions 搜索)
Returns: Returns:
Tuple[list[Conversation], int]: (会话列表,总数) Tuple[list[Conversation], int]: (会话列表,总数)
@@ -57,8 +54,7 @@ class AppLogService:
"page": page, "page": page,
"pagesize": pagesize, "pagesize": pagesize,
"is_draft": is_draft, "is_draft": is_draft,
"keyword": keyword, "keyword": keyword
"app_type": app_type,
} }
) )
@@ -69,8 +65,7 @@ class AppLogService:
is_draft=is_draft, is_draft=is_draft,
keyword=keyword, keyword=keyword,
page=page, page=page,
pagesize=pagesize, pagesize=pagesize
app_type=app_type,
) )
logger.info( logger.info(
@@ -88,40 +83,51 @@ class AppLogService:
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
workspace_id: uuid.UUID, workspace_id: uuid.UUID
app_type: str = AppType.AGENT ) -> Tuple[Conversation, dict[str, list[AppLogNodeExecution]]]:
) -> Tuple[Conversation, list, dict[str, list[AppLogNodeExecution]]]:
""" """
查询会话详情 查询会话详情(包含消息和工作流节点执行记录)
Args:
app_id: 应用 ID
conversation_id: 会话 ID
workspace_id: 工作空间 ID
Returns: Returns:
Tuple[Conversation, list[AppLogMessage|Message], dict[str, list[AppLogNodeExecution]]] Tuple[Conversation, dict[str, list[AppLogNodeExecution]]]:
(包含消息的会话对象, 按消息ID分组的节点执行记录)
Raises:
ResourceNotFoundException: 当会话不存在时
""" """
logger.info( logger.info(
"查询应用日志会话详情", "查询应用日志会话详情",
extra={ extra={
"app_id": str(app_id), "app_id": str(app_id),
"conversation_id": str(conversation_id), "conversation_id": str(conversation_id),
"workspace_id": str(workspace_id), "workspace_id": str(workspace_id)
"app_type": app_type
} }
) )
# 查询会话
conversation = self.conversation_repository.get_conversation_for_app_log( conversation = self.conversation_repository.get_conversation_for_app_log(
conversation_id=conversation_id, conversation_id=conversation_id,
app_id=app_id, app_id=app_id,
workspace_id=workspace_id workspace_id=workspace_id
) )
if app_type == AppType.WORKFLOW: # 查询消息(按时间正序)
messages, node_executions_map = self._get_workflow_messages_and_nodes(conversation_id) messages = self.message_repository.get_messages_by_conversation(
else: conversation_id=conversation_id
messages = self.message_repository.get_messages_by_conversation( )
conversation_id=conversation_id
) # 将消息附加到会话对象
node_executions_map = self._get_workflow_node_executions_with_map( conversation.messages = messages
conversation_id, messages
) # 查询工作流节点执行记录(按消息分组)
_, node_executions_map = self._get_workflow_node_executions_with_map(
conversation_id, messages
)
logger.info( logger.info(
"查询应用日志会话详情成功", "查询应用日志会话详情成功",
@@ -133,129 +139,13 @@ class AppLogService:
} }
) )
return conversation, messages, node_executions_map return conversation, node_executions_map
def _get_workflow_messages_and_nodes(
self,
conversation_id: uuid.UUID,
) -> Tuple[list[AppLogMessage], dict[str, list[AppLogNodeExecution]]]:
"""
工作流应用专用:从 workflow_executions 构建 messages 和节点日志。
每条 WorkflowExecution 对应一轮对话:
- user message来自 execution.input_datacontent 取 message 字段files 放 meta_data
- assistant message来自 execution.output_data失败时内容为错误信息
开场白的 suggested_questions 合并到第一条 assistant message 的 meta_data 里。
Returns:
(messages 列表, node_executions_map)
"""
stmt = (
select(WorkflowExecution)
.where(
WorkflowExecution.conversation_id == conversation_id,
WorkflowExecution.status.in_(["completed", "failed"])
)
.order_by(WorkflowExecution.started_at.asc())
)
executions = list(self.db.scalars(stmt).all())
# 查开场白Message 表里 meta_data 含 suggested_questions 的第一条 assistant 消息
opening_stmt = (
select(Message)
.where(
Message.conversation_id == conversation_id,
Message.role == "assistant",
)
.order_by(Message.created_at.asc())
.limit(10)
)
early_messages = list(self.db.scalars(opening_stmt).all())
suggested_questions: list = []
for m in early_messages:
if isinstance(m.meta_data, dict) and "suggested_questions" in m.meta_data:
suggested_questions = m.meta_data.get("suggested_questions") or []
break
messages: list[AppLogMessage] = []
node_executions_map: dict[str, list[AppLogNodeExecution]] = {}
# 如果有开场白,作为第一条 assistant 消息插入
if suggested_questions or early_messages:
opening_msg = next(
(m for m in early_messages
if isinstance(m.meta_data, dict) and "suggested_questions" in m.meta_data),
None
)
if opening_msg:
messages.append(AppLogMessage(
id=opening_msg.id,
conversation_id=conversation_id,
role="assistant",
content=opening_msg.content,
status=None,
meta_data={"suggested_questions": suggested_questions},
created_at=opening_msg.created_at,
))
for execution in executions:
started_at = execution.started_at or dt.datetime.now()
completed_at = execution.completed_at or started_at
# assistant message 的 id同时作为 node_executions_map 的 key
assistant_msg_id = uuid.uuid5(execution.id, "assistant")
# --- user message输入---
input_data = execution.input_data or {}
input_content = input_data.get("message") or _extract_text(input_data)
# 跳过没有用户输入的 execution如开场白触发的记录
if not input_content or not input_content.strip():
continue
files = input_data.get("files") or []
user_msg = AppLogMessage(
id=uuid.uuid5(execution.id, "user"),
conversation_id=conversation_id,
role="user",
content=input_content,
meta_data={"files": files} if files else None,
created_at=started_at,
)
messages.append(user_msg)
# --- assistant message输出---
if execution.status == "completed":
output_content = _extract_text(execution.output_data)
meta = {"usage": execution.token_usage or {}, "elapsed_time": execution.elapsed_time}
else:
output_content = _extract_text(execution.output_data) or ""
meta = {"error": execution.error_message, "error_node_id": execution.error_node_id}
assistant_msg = AppLogMessage(
id=assistant_msg_id,
conversation_id=conversation_id,
role="assistant",
content=output_content,
status=execution.status,
meta_data=meta,
created_at=completed_at,
)
messages.append(assistant_msg)
# --- 节点执行记录,从 workflow_executions.output_data["node_outputs"] 读取 ---
execution_nodes = _build_nodes_from_output_data(execution.output_data)
if execution_nodes:
node_executions_map[str(assistant_msg_id)] = execution_nodes
return messages, node_executions_map
def _get_workflow_node_executions_with_map( def _get_workflow_node_executions_with_map(
self, self,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
messages: list[Message] messages: list[Message]
) -> dict[str, list[AppLogNodeExecution]]: ) -> Tuple[list[AppLogNodeExecution], dict[str, list[AppLogNodeExecution]]]:
""" """
从 workflow_executions 表中提取节点执行记录,并按 assistant message 分组 从 workflow_executions 表中提取节点执行记录,并按 assistant message 分组
@@ -267,12 +157,13 @@ class AppLogService:
Tuple[list[AppLogNodeExecution], dict[str, list[AppLogNodeExecution]]]: Tuple[list[AppLogNodeExecution], dict[str, list[AppLogNodeExecution]]]:
(所有节点执行记录列表, 按 message_id 分组的节点执行记录字典) (所有节点执行记录列表, 按 message_id 分组的节点执行记录字典)
""" """
node_executions = []
node_executions_map: dict[str, list[AppLogNodeExecution]] = {} node_executions_map: dict[str, list[AppLogNodeExecution]] = {}
# 查询该会话关联的所有工作流执行记录(按时间正序) # 查询该会话关联的所有工作流执行记录(按时间正序)
stmt = select(WorkflowExecution).where( stmt = select(WorkflowExecution).where(
WorkflowExecution.conversation_id == conversation_id, WorkflowExecution.conversation_id == conversation_id,
WorkflowExecution.status.in_(["completed", "failed"]) WorkflowExecution.status == "completed"
).order_by(WorkflowExecution.started_at.asc()) ).order_by(WorkflowExecution.started_at.asc())
executions = self.db.scalars(stmt).all() executions = self.db.scalars(stmt).all()
@@ -297,18 +188,10 @@ class AppLogService:
used_message_ids: set[str] = set() used_message_ids: set[str] = set()
for execution in executions: for execution in executions:
# 构建节点执行记录列表,从 workflow_executions.output_data["node_outputs"] 读取 if not execution.output_data:
execution_nodes = _build_nodes_from_output_data(execution.output_data)
if not execution_nodes:
continue continue
# 失败的执行没有 assistant message,直接用 execution id 作为 key # 找到该 execution 对应的 assistant message
if execution.status == "failed":
node_executions_map[f"execution_{str(execution.id)}"] = execution_nodes
continue
# completed通过时序匹配关联到对应的 assistant message
# 逻辑:找 execution.started_at 之后最近的、未使用的 assistant message # 逻辑:找 execution.started_at 之后最近的、未使用的 assistant message
best_msg = None best_msg = None
best_dt = None best_dt = None
@@ -317,9 +200,9 @@ class AppLogService:
if msg_id_str in used_message_ids: if msg_id_str in used_message_ids:
continue continue
if msg.created_at and msg.created_at >= execution.started_at: if msg.created_at and msg.created_at >= execution.started_at:
delta = (msg.created_at - execution.started_at).total_seconds() dt = (msg.created_at - execution.started_at).total_seconds()
if best_dt is None or delta < best_dt: if best_dt is None or dt < best_dt:
best_dt = delta best_dt = dt
best_msg = msg best_msg = msg
if not best_msg: if not best_msg:
@@ -327,86 +210,31 @@ class AppLogService:
msg_id_str = str(best_msg.id) msg_id_str = str(best_msg.id)
used_message_ids.add(msg_id_str) used_message_ids.add(msg_id_str)
node_executions_map[msg_id_str] = execution_nodes
return node_executions_map # 提取节点输出
output_data = execution.output_data
if isinstance(output_data, dict):
node_outputs = output_data.get("node_outputs", {})
execution_nodes = []
for node_id, node_data in node_outputs.items():
if not isinstance(node_data, dict):
continue
node_execution = AppLogNodeExecution(
node_id=node_data.get("node_id", node_id),
node_type=node_data.get("node_type", "unknown"),
node_name=node_data.get("node_name"),
status=node_data.get("status", "unknown"),
error=node_data.get("error"),
input=node_data.get("input"),
process=node_data.get("process"),
output=node_data.get("output"),
elapsed_time=node_data.get("elapsed_time"),
token_usage=node_data.get("token_usage"),
)
node_executions.append(node_execution)
execution_nodes.append(node_execution)
# 将节点记录关联到 message_id
node_executions_map[msg_id_str] = execution_nodes
def _extract_text(data: Optional[dict]) -> str: return node_executions, node_executions_map
"""从 workflow execution 的 input_data / output_data 中提取可读文本。
优先取 'text''content''output' 字段;若都没有则 JSON 序列化整个 dict。
"""
if not data:
return ""
for key in ("message", "text", "content", "output", "result", "answer"):
if key in data and isinstance(data[key], str):
return data[key]
import json
return json.dumps(data, ensure_ascii=False)
def _build_nodes_from_output_data(output_data: Optional[dict]) -> list[AppLogNodeExecution]:
"""从 workflow_executions.output_data["node_outputs"] 构建节点执行记录列表。
output_data 结构:
{
"node_outputs": {
"<node_id>": {
"node_type": ...,
"node_name": ...,
"status": ...,
"input": ...,
"output": ...,
"elapsed_time": ...,
"token_usage": ...,
"error": ...,
"cycle_items": [...],
...
}
},
"error": ...,
...
}
"""
if not output_data:
return []
node_outputs: dict = output_data.get("node_outputs") or {}
# 按 execution_order节点执行时写入的单调递增序号排序。
# PostgreSQL JSONB 不保证 key 顺序,不能依赖 dict 插入顺序;
# 缺失 execution_order 的历史数据退化到 0保持在最前。
ordered_items = sorted(
node_outputs.items(),
key=lambda kv: (kv[1] or {}).get("execution_order", 0)
if isinstance(kv[1], dict) else 0
)
result = []
for node_id, node_data in ordered_items:
if not isinstance(node_data, dict):
continue
output = dict(node_data)
cycle_items = output.pop("cycle_items", None)
# 把已知的顶层字段剥离,剩余的作为 output
node_type = output.pop("node_type", "unknown")
node_name = output.pop("node_name", None)
status = output.pop("status", "completed")
error = output.pop("error", None)
inp = output.pop("input", None)
elapsed_time = output.pop("elapsed_time", None)
token_usage = output.pop("token_usage", None)
# execution_order 仅用于排序,不返回给前端
output.pop("execution_order", None)
result.append(AppLogNodeExecution(
node_id=node_id,
node_type=node_type,
node_name=node_name,
status=status,
error=error,
input=inp,
process=None,
output=output if output else None,
cycle_items=cycle_items,
elapsed_time=elapsed_time,
token_usage=token_usage,
))
return result

View File

@@ -595,6 +595,23 @@ class AgentRunService:
) )
tools.extend(memory_tools) tools.extend(memory_tools)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
json_output=effective_params.get("json_output", False),
capability=api_key_config.get("capability", []),
)
# 5. 处理会话ID创建或验证新会话时写入开场白 # 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id is_new_conversation = not conversation_id
opening, suggested_questions = None, None opening, suggested_questions = None, None
@@ -649,29 +666,16 @@ class AgentRunService:
and any(f.type == FileType.DOCUMENT for f in files) and any(f.type == FileType.DOCUMENT for f in files)
) )
if has_doc_with_images: if has_doc_with_images:
system_prompt += ( agent.system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>" "\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片" "请在回答中用 Markdown 格式 ![描述](URL) 展示相关图片,做到图文并茂"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" )
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。" # 重建 agent graph 以使新 system_prompt 生效
agent.agent = create_agent(
model=agent.llm,
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
system_prompt=agent.system_prompt
) )
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
json_output=effective_params.get("json_output", False),
capability=api_key_config.get("capability", []),
)
# 为需要运行时上下文的工具注入上下文 # 为需要运行时上下文的工具注入上下文
for t in tools: for t in tools:
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
@@ -757,7 +761,7 @@ class AgentRunService:
) if not sub_agent else [], ) if not sub_agent else [],
"citations": filtered_citations, "citations": filtered_citations,
"audio_url": audio_url, "audio_url": audio_url,
"audio_status": "pending" if audio_url else None "audio_status": "pending"
} }
logger.info( logger.info(
@@ -871,6 +875,24 @@ class AgentRunService:
user_rag_memory_id) user_rag_memory_id)
tools.extend(memory_tools) tools.extend(memory_tools)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
json_output=effective_params.get("json_output", False),
capability=api_key_config.get("capability", []),
)
# 5. 处理会话ID创建或验证新会话时写入开场白 # 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id is_new_conversation = not conversation_id
opening, suggested_questions = None, None opening, suggested_questions = None, None
@@ -926,31 +948,18 @@ class AgentRunService:
and any(f.type == FileType.DOCUMENT for f in files) and any(f.type == FileType.DOCUMENT for f in files)
) )
if has_doc_with_images: if has_doc_with_images:
system_prompt += ( agent.system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>" "\n\n文档中包含图片,图片位置已在文本中以 [图片 第N页 第M张图片]: URL 标记。"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片" "请在回答中用 Markdown 格式 ![描述](URL) 展示相关图片,做到图文并茂"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" "**规则1图片URL必须原封不动、一字不差地复制,禁止修改、禁止省略任何字符**"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。" "**规则2禁止修改URL中UUID里的任何数字和字母**"
"**规则3直接使用 ![描述](完整URL) 格式输出**"
)
agent.agent = create_agent(
model=agent.llm,
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
system_prompt=agent.system_prompt
) )
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
json_output=effective_params.get("json_output", False),
capability=api_key_config.get("capability", []),
)
# 为需要运行时上下文的工具注入上下文 # 为需要运行时上下文的工具注入上下文
for t in tools: for t in tools:
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):

View File

@@ -10,7 +10,6 @@ from typing import Any, Dict, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.celery_task_scheduler import scheduler
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, ResourceNotFoundException from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.logging_config import get_logger from app.core.logging_config import get_logger
@@ -167,31 +166,20 @@ class MemoryAPIService:
# Convert to message list format expected by write_message_task # Convert to message list format expected by write_message_task
messages = message if isinstance(message, list) else [{"role": "user", "content": message}] messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
# from app.tasks import write_message_task from app.tasks import write_message_task
# task = write_message_task.delay( task = write_message_task.delay(
# end_user_id,
# messages,
# config_id,
# storage_type,
# user_rag_memory_id or "",
# )
task_id = scheduler.push_task(
"app.core.memory.agent.write_message",
end_user_id, end_user_id,
{ messages,
"end_user_id": end_user_id, config_id,
"message": messages, storage_type,
"config_id": config_id, user_rag_memory_id or "",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id or ""
}
) )
logger.info(f"Memory write task submitted, task_id={task_id} end_user_id={end_user_id}") logger.info(f"Memory write task submitted: task_id={task.id}, end_user_id={end_user_id}")
return { return {
"task_id": task_id, "task_id": task.id,
"status": "QUEUED", "status": "PENDING",
"end_user_id": end_user_id, "end_user_id": end_user_id,
} }

View File

@@ -4,7 +4,7 @@
处理显性记忆相关的业务逻辑,包括情景记忆和语义记忆的查询。 处理显性记忆相关的业务逻辑,包括情景记忆和语义记忆的查询。
""" """
from typing import Any, Dict, Optional from typing import Any, Dict
from app.core.logging_config import get_logger from app.core.logging_config import get_logger
from app.services.memory_base_service import MemoryBaseService from app.services.memory_base_service import MemoryBaseService
@@ -104,7 +104,7 @@ class MemoryExplicitService(MemoryBaseService):
e.description AS core_definition e.description AS core_definition
ORDER BY e.name ASC ORDER BY e.name ASC
""" """
semantic_result = await self.neo4j_connector.execute_query( semantic_result = await self.neo4j_connector.execute_query(
semantic_query, semantic_query,
end_user_id=end_user_id end_user_id=end_user_id
@@ -146,209 +146,6 @@ class MemoryExplicitService(MemoryBaseService):
logger.error(f"获取显性记忆总览时出错: {str(e)}", exc_info=True) logger.error(f"获取显性记忆总览时出错: {str(e)}", exc_info=True)
raise raise
async def get_episodic_memory_list(
self,
end_user_id: str,
page: int,
pagesize: int,
start_date: Optional[int] = None,
end_date: Optional[int] = None,
episodic_type: str = "all",
) -> Dict[str, Any]:
"""
获取情景记忆分页列表
Args:
end_user_id: 终端用户ID
page: 页码
pagesize: 每页数量
start_date: 开始时间戳(毫秒),可选
end_date: 结束时间戳(毫秒),可选
episodic_type: 情景类型筛选
Returns:
{
"total": int, # 该用户情景记忆总数(不受筛选影响)
"items": [...], # 当前页数据
"page": {
"page": int,
"pagesize": int,
"total": int, # 筛选后总数
"hasnext": bool
}
}
"""
try:
logger.info(
f"情景记忆分页查询: end_user_id={end_user_id}, "
f"start_date={start_date}, end_date={end_date}, "
f"episodic_type={episodic_type}, page={page}, pagesize={pagesize}"
)
# 1. 查询情景记忆总数(不受筛选条件限制)
total_all_query = """
MATCH (s:MemorySummary)
WHERE s.end_user_id = $end_user_id
RETURN count(s) AS total
"""
total_all_result = await self.neo4j_connector.execute_query(
total_all_query, end_user_id=end_user_id
)
total_all = total_all_result[0]["total"] if total_all_result else 0
# 2. 构建筛选条件
where_clauses = ["s.end_user_id = $end_user_id"]
params = {"end_user_id": end_user_id}
# 时间戳筛选(毫秒时间戳转为 UTC ISO 字符串,使用 Neo4j datetime() 精确比较)
if start_date is not None and end_date is not None:
from datetime import datetime, timezone
start_dt = datetime.fromtimestamp(start_date / 1000, tz=timezone.utc)
end_dt = datetime.fromtimestamp(end_date / 1000, tz=timezone.utc)
# 开始时间取当天 UTC 00:00:00结束时间取当天 UTC 23:59:59.999999
start_iso = start_dt.strftime("%Y-%m-%dT") + "00:00:00.000000"
end_iso = end_dt.strftime("%Y-%m-%dT") + "23:59:59.999999"
where_clauses.append("datetime(s.created_at) >= datetime($start_iso) AND datetime(s.created_at) <= datetime($end_iso)")
params["start_iso"] = start_iso
params["end_iso"] = end_iso
# 类型筛选下推到 Cypher兼容中英文
if episodic_type != "all":
type_mapping = {
"conversation": "对话",
"project_work": "项目/工作",
"learning": "学习",
"decision": "决策",
"important_event": "重要事件"
}
chinese_type = type_mapping.get(episodic_type)
if chinese_type:
where_clauses.append(
"(s.memory_type = $episodic_type OR s.memory_type = $chinese_type)"
)
params["episodic_type"] = episodic_type
params["chinese_type"] = chinese_type
else:
where_clauses.append("s.memory_type = $episodic_type")
params["episodic_type"] = episodic_type
where_str = " AND ".join(where_clauses)
# 3. 查询筛选后的总数
count_query = f"""
MATCH (s:MemorySummary)
WHERE {where_str}
RETURN count(s) AS total
"""
count_result = await self.neo4j_connector.execute_query(count_query, **params)
filtered_total = count_result[0]["total"] if count_result else 0
# 4. 查询分页数据
skip = (page - 1) * pagesize
data_query = f"""
MATCH (s:MemorySummary)
WHERE {where_str}
RETURN elementId(s) AS id,
s.name AS title,
s.memory_type AS memory_type,
s.content AS content,
s.created_at AS created_at
ORDER BY s.created_at DESC
SKIP $skip LIMIT $limit
"""
params["skip"] = skip
params["limit"] = pagesize
result = await self.neo4j_connector.execute_query(data_query, **params)
# 5. 处理结果
items = []
if result:
for record in result:
raw_created_at = record.get("created_at")
created_at_timestamp = self.parse_timestamp(raw_created_at)
items.append({
"id": record["id"],
"title": record.get("title") or "未命名",
"memory_type": record.get("memory_type") or "其他",
"content": record.get("content") or "",
"created_at": created_at_timestamp
})
# 6. 构建返回结果
return {
"total": total_all,
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": filtered_total,
"hasnext": (page * pagesize) < filtered_total
}
}
except Exception as e:
logger.error(f"情景记忆分页查询出错: {str(e)}", exc_info=True)
raise
async def get_semantic_memory_list(
self,
end_user_id: str
) -> list:
"""
获取语义记忆全量列表
Args:
end_user_id: 终端用户ID
Returns:
[
{
"id": str,
"name": str,
"entity_type": str,
"core_definition": str
}
]
"""
try:
logger.info(f"语义记忆列表查询: end_user_id={end_user_id}")
semantic_query = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id
AND e.is_explicit_memory = true
RETURN elementId(e) AS id,
e.name AS name,
e.entity_type AS entity_type,
e.description AS core_definition
ORDER BY e.name ASC
"""
result = await self.neo4j_connector.execute_query(
semantic_query, end_user_id=end_user_id
)
items = []
if result:
for record in result:
items.append({
"id": record["id"],
"name": record.get("name") or "未命名",
"entity_type": record.get("entity_type") or "未分类",
"core_definition": record.get("core_definition") or ""
})
logger.info(f"语义记忆列表查询成功: end_user_id={end_user_id}, total={len(items)}")
return items
except Exception as e:
logger.error(f"语义记忆列表查询出错: {str(e)}", exc_info=True)
raise
async def get_explicit_memory_details( async def get_explicit_memory_details(
self, self,
end_user_id: str, end_user_id: str,

View File

@@ -95,7 +95,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
"""通义千问文档格式""" """通义千问文档格式"""
return True, { return True, {
"type": "text", "type": "text",
"text": f"<document name=\"{file_name}\">\n文档内容:\n{text}\n</document>" "text": f"<document name=\"{file_name}\">\n{text}\n</document>"
} }
async def format_audio( async def format_audio(
@@ -167,7 +167,6 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]: async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]:
"""Bedrock/Anthropic 文档格式(需要 base64 编码)""" """Bedrock/Anthropic 文档格式(需要 base64 编码)"""
# Bedrock 文档需要 base64 编码 # Bedrock 文档需要 base64 编码
text = f"文档内容:\n{text}\n"
text_bytes = text.encode('utf-8') text_bytes = text.encode('utf-8')
base64_text = base64.b64encode(text_bytes).decode('utf-8') base64_text = base64.b64encode(text_bytes).decode('utf-8')
@@ -224,7 +223,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
"""OpenAI 文档格式""" """OpenAI 文档格式"""
return True, { return True, {
"type": "text", "type": "text",
"text": f"<document name=\"{file_name}\">\n文档内容:\n{text}\n</document>" "text": f"<document name=\"{file_name}\">\n{text}\n</document>"
} }
async def format_audio( async def format_audio(
@@ -389,18 +388,17 @@ class MultimodalService:
from app.models.workspace_model import Workspace as WorkspaceModel from app.models.workspace_model import Workspace as WorkspaceModel
ws = self.db.query(WorkspaceModel).filter(WorkspaceModel.id == workspace_id).first() ws = self.db.query(WorkspaceModel).filter(WorkspaceModel.id == workspace_id).first()
tenant_id = ws.tenant_id if ws else None tenant_id = ws.tenant_id if ws else None
img_result = []
for img_info in img_infos: for img_info in img_infos:
page = img_info["page"] page = img_info["page"]
index = img_info["index"] index = img_info["index"]
ext = img_info.get("ext", "png") ext = img_info.get("ext", "png")
try: try:
_, img_url = await self._save_doc_image_to_storage(img_info["bytes"], ext, tenant_id, workspace_id) _, img_url = await self._save_doc_image_to_storage(img_info["bytes"], ext, tenant_id, workspace_id)
placeholder = f"{page}页 第{index + 1}" if page > 0 else f"{index + 1}" placeholder = f"{page}页 第{index + 1}图片" if page > 0 else f"{index + 1}图片"
# 在文本内容中追加图片位置标记 # 在文本内容中追加图片位置标记
if result and result[-1].get("type") in ("text", "document"): if result and result[-1].get("type") in ("text", "document"):
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1] key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: <img src=\"{img_url}\" data-url=\"{img_url}\">" result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: {img_url}"
# 将图片以视觉格式追加到消息内容中 # 将图片以视觉格式追加到消息内容中
img_file = FileInput( img_file = FileInput(
type=FileType.IMAGE, type=FileType.IMAGE,
@@ -409,10 +407,9 @@ class MultimodalService:
file_type="image/png", file_type="image/png",
) )
_, img_content = await self._process_image(img_file, strategy_class(img_file)) _, img_content = await self._process_image(img_file, strategy_class(img_file))
img_result.append(img_content) result.append(img_content)
except Exception as img_err: except Exception as img_err:
logger.warning(f"文档图片处理失败: {img_err}") logger.warning(f"文档图片处理失败: {img_err}")
result.extend(img_result)
elif file.type == FileType.AUDIO and "audio" in self.capability: elif file.type == FileType.AUDIO and "audio" in self.capability:
is_support, content = await self._process_audio(file, strategy) is_support, content = await self._process_audio(file, strategy)
result.append(content) result.append(content)

View File

@@ -815,12 +815,11 @@ class ToolService:
"default": param_info.get("default") "default": param_info.get("default")
}) })
# 请求体参数 — _extract_request_body 返回 {"schema": {...}, "required": bool, ...} # 请求体参数
request_body = operation.get("request_body") request_body = operation.get("request_body")
if request_body: if request_body:
body_schema = request_body.get("schema", {}) schema_props = request_body.get("schema", {}).get("properties", {})
schema_props = body_schema.get("properties", {}) required_props = request_body.get("schema", {}).get("required", [])
required_props = body_schema.get("required", [])
for prop_name, prop_schema in schema_props.items(): for prop_name, prop_schema in schema_props.items():
parameters.append({ parameters.append({

View File

@@ -17,9 +17,8 @@ from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config from app.core.workflow.validator import validate_workflow_config
from app.db import get_db from app.db import get_db
from sqlalchemy import select
from app.models import App from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories import knowledge_repository from app.repositories import knowledge_repository
from app.repositories.workflow_repository import ( from app.repositories.workflow_repository import (
WorkflowConfigRepository, WorkflowConfigRepository,
@@ -554,16 +553,13 @@ class WorkflowService:
} }
} }
case "workflow_end": case "workflow_end":
data = {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
}
if "citations" in payload and payload["citations"]:
data["citations"] = payload["citations"]
return { return {
"event": "end", "event": "end",
"data": data "data": {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
}
} }
case "node_start" | "node_end" | "node_error" | "cycle_item": case "node_start" | "node_end" | "node_error" | "cycle_item":
return None return None
@@ -922,7 +918,6 @@ class WorkflowService:
input_data["conv_messages"] = conv_messages input_data["conv_messages"] = conv_messages
init_message_length = len(input_data.get("conv_messages", [])) init_message_length = len(input_data.get("conv_messages", []))
message_id = uuid.uuid4() message_id = uuid.uuid4()
_cycle_items: dict[str, list] = {}
# 新会话时写入开场白 # 新会话时写入开场白
is_new_conversation = init_message_length == 0 is_new_conversation = init_message_length == 0
@@ -953,15 +948,6 @@ class WorkflowService:
memory_storage_type=storage_type, memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
): ):
event_type = event.get("event")
event_data = event.get("data", {})
if event_type == "cycle_item":
cycle_id = event_data.get("cycle_id")
if cycle_id not in _cycle_items:
_cycle_items[cycle_id] = []
_cycle_items[cycle_id].append(event_data)
if event.get("event") == "workflow_end": if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status") status = event.get("data", {}).get("status")
token_usage = event.get("data", {}).get("token_usage", {}) or {} token_usage = event.get("data", {}).get("token_usage", {}) or {}
@@ -1033,18 +1019,6 @@ class WorkflowService:
) )
else: else:
logger.error(f"unexpect workflow run status, status: {status}") logger.error(f"unexpect workflow run status, status: {status}")
# 把积累的 cycle_item 写入 workflow_executions.output_data["node_outputs"]
if _cycle_items and execution.output_data:
import copy
new_output_data = copy.deepcopy(execution.output_data)
node_outputs = new_output_data.setdefault("node_outputs", {})
for cycle_node_id, items in _cycle_items.items():
if cycle_node_id in node_outputs:
node_outputs[cycle_node_id]["cycle_items"] = items
else:
node_outputs[cycle_node_id] = {"cycle_items": items}
execution.output_data = new_output_data
self.db.commit()
elif event.get("event") == "workflow_start": elif event.get("event") == "workflow_start":
event["data"]["message_id"] = str(message_id) event["data"]["message_id"] = str(message_id)
event = self._emit(public, event) event = self._emit(public, event)

View File

@@ -20,7 +20,6 @@ from app.models.workspace_model import (
) )
from app.repositories import workspace_repository from app.repositories import workspace_repository
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
from app.services.session_service import SessionService
from app.schemas.workspace_schema import ( from app.schemas.workspace_schema import (
InviteAcceptRequest, InviteAcceptRequest,
InviteValidateResponse, InviteValidateResponse,
@@ -59,7 +58,7 @@ def switch_workspace(
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR) raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
async def delete_workspace_member( def delete_workspace_member(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
member_id: uuid.UUID, member_id: uuid.UUID,
@@ -77,29 +76,10 @@ async def delete_workspace_member(
BizCode.WORKSPACE_NOT_FOUND) BizCode.WORKSPACE_NOT_FOUND)
try: try:
deleted_user = workspace_member.user
workspace_member.is_active = False workspace_member.is_active = False
deleted_user.current_workspace_id = None workspace_member.user.current_workspace_id = None
# 若被删除成员不是超级管理员且没有其他可用工作空间,则禁用该用户
if not deleted_user.is_superuser:
remaining = (
db.query(WorkspaceMember)
.filter(
WorkspaceMember.user_id == deleted_user.id,
WorkspaceMember.workspace_id != workspace_id,
WorkspaceMember.is_active.is_(True),
)
.count()
)
if remaining == 0:
deleted_user.is_active = False
db.commit() db.commit()
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}") business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
# 使被删除成员的所有 token 立即失效
await SessionService.invalidate_all_user_tokens(str(workspace_member.user_id))
except Exception as e: except Exception as e:
db.rollback() db.rollback()
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}") business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")

View File

@@ -34,7 +34,7 @@ from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory, ElasticSearchVectorFactory,
) )
from app.db import get_db_context from app.db import get_db, get_db_context
from app.models import Document, File, Knowledge from app.models import Document, File, Knowledge
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from app.schemas import document_schema, file_schema from app.schemas import document_schema, file_schema
@@ -2025,7 +2025,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
end_users = db.query(EndUser).all() end_users = db.query(EndUser).all()
if not end_users: if not end_users:
logger.info("没有终端用户,跳过遗忘周期") logger.info("没有终端用户,跳过遗忘周期")
return {"status": "SUCCESS", "message": "没有终端用户", return {"status": "SUCCESS", "message": "没有终端用户",
"report": {"merged_count": 0, "failed_count": 0, "processed_users": 0}, "report": {"merged_count": 0, "failed_count": 0, "processed_users": 0},
"duration_seconds": time.time() - start_time} "duration_seconds": time.time() - start_time}
@@ -2039,7 +2039,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
# 获取用户配置(自动回退到工作空间默认配置) # 获取用户配置(自动回退到工作空间默认配置)
connected_config = get_end_user_connected_config(str(end_user.id), db) connected_config = get_end_user_connected_config(str(end_user.id), db)
user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db) user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db)
if not user_config_id: if not user_config_id:
failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"}) failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"})
continue continue
@@ -2048,13 +2048,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
report = await forget_service.trigger_forgetting_cycle( report = await forget_service.trigger_forgetting_cycle(
db=db, end_user_id=str(end_user.id), config_id=user_config_id db=db, end_user_id=str(end_user.id), config_id=user_config_id
) )
total_merged += report.get('merged_count', 0) total_merged += report.get('merged_count', 0)
total_failed += report.get('failed_count', 0) total_failed += report.get('failed_count', 0)
processed_users += 1 processed_users += 1
logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点") logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点")
except Exception as e: except Exception as e:
logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True) logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True)
failed_users.append({"end_user_id": str(end_user.id), "error": str(e)}) failed_users.append({"end_user_id": str(end_user.id), "error": str(e)})
@@ -2801,18 +2801,18 @@ def run_incremental_clustering(
包含任务执行结果的字典 包含任务执行结果的字典
""" """
start_time = time.time() start_time = time.time()
async def _run() -> Dict[str, Any]: async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger from app.core.logging_config import get_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
logger = get_logger(__name__) logger = get_logger(__name__)
logger.info( logger.info(
f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, " f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, "
f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}" f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}"
) )
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
engine = LabelPropagationEngine( engine = LabelPropagationEngine(
@@ -2820,12 +2820,12 @@ def run_incremental_clustering(
llm_model_id=llm_model_id, llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id, embedding_model_id=embedding_model_id,
) )
# 执行增量聚类 # 执行增量聚类
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}") logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}")
return { return {
"status": "SUCCESS", "status": "SUCCESS",
"end_user_id": end_user_id, "end_user_id": end_user_id,
@@ -2836,18 +2836,18 @@ def run_incremental_clustering(
raise raise
finally: finally:
await connector.close() await connector.close()
try: try:
loop = set_asyncio_event_loop() loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time result["elapsed_time"] = time.time() - start_time
result["task_id"] = self.request.id result["task_id"] = self.request.id
logger.info( logger.info(
f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, " f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, "
f"elapsed_time={result['elapsed_time']:.2f}s" f"elapsed_time={result['elapsed_time']:.2f}s"
) )
return result return result
except Exception as e: except Exception as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time

View File

@@ -63,23 +63,6 @@ services:
networks: networks:
- celery - celery
celery-task-scheduler:
image: redbear-mem-open:latest
container_name: celery-task-scheduler
env_file:
- .env
volumes:
- /etc/localtime:/etc/localtime:ro
command: python -m app.celery_task_scheduler
restart: unless-stopped
healthcheck:
test: CMD curl -f 127.0.0.1:8001 || exit 1
interval: 30s
timeout: 5s
retries: 3
networks:
- celery
# Celery Beat - scheduler # Celery Beat - scheduler
beat: beat:
image: redbear-mem-open:latest image: redbear-mem-open:latest

View File

@@ -62,6 +62,7 @@
"remark-gfm": "^4.0.1", "remark-gfm": "^4.0.1",
"remark-math": "^6.0.0", "remark-math": "^6.0.0",
"tailwindcss": "^4.1.14", "tailwindcss": "^4.1.14",
"x6-html-shape": "0.4.9",
"xlsx": "^0.18.5", "xlsx": "^0.18.5",
"zustand": "^5.0.8" "zustand": "^5.0.8"
}, },

View File

@@ -8,11 +8,12 @@ import { type FC, useRef, useEffect, useState } from 'react'
import clsx from 'clsx' import clsx from 'clsx'
import Markdown from '@/components/Markdown' import Markdown from '@/components/Markdown'
import type { ChatContentProps } from './types' import type { ChatContentProps } from './types'
import { Spin, Flex, Button } from 'antd' import { Spin, Image, Flex, Button } from 'antd'
import { SoundOutlined } from '@ant-design/icons' import { SoundOutlined } from '@ant-design/icons'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import MessageFiles from './MessageFiles' import AudioPlayer from './AudioPlayer'
import VideoPlayer from './VideoPlayer'
const getFileUrl = (file: any) => { const getFileUrl = (file: any) => {
return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined) return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
@@ -148,7 +149,72 @@ const ChatContent: FC<ChatContentProps> = ({
{labelFormat(item)} {labelFormat(item)}
</div> </div>
} }
<MessageFiles files={item.meta_data?.files ?? []} contentClassNames={contentClassNames} onDownload={handleDownload} /> {item?.meta_data?.files && item.meta_data?.files.length > 0 && <Flex gap={8} vertical align="end" className="rb:mb-2!">
{item.meta_data?.files?.map((file) => {
if (file.type.includes('image')) {
return (
<div key={file.url || file.uid} className={`rb:inline-block rb:group rb:relative rb:rounded-lg ${contentClassNames}`}>
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
</div>
)
}
if (file.type.includes('video')) {
return (
<div key={file.url || file.uid} className="rb:w-50">
{/* <video src={getFileUrl(file)} controls className="rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" /> */}
<VideoPlayer key={file.url || file.uid} src={getFileUrl(file)} />
</div>
)
}
if (file.type.includes('audio')) {
return (
<div key={file.url || file.uid} className="rb:w-50">
<AudioPlayer key={file.url || file.uid} src={getFileUrl(file)} />
</div>
)
}
const documentType = (file.file_type || file.type)?.split('/')
return (
<Flex
key={file.url || file.uid}
align="center"
gap={10}
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
onClick={() => handleDownload(file)}
>
<div
className={clsx(
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
file.type?.includes('pdf')
? "rb:bg-[url('@/assets/images/file/pdf.svg')]"
: (file.type?.includes('excel') || file.type?.includes('spreadsheetml.sheet')) || file.type?.includes('xls') || file.type?.includes('xlsx')
? "rb:bg-[url('@/assets/images/file/excel.svg')]"
: file.type?.includes('csv')
? "rb:bg-[url('@/assets/images/file/csv.svg')]"
: file.type?.includes('html')
? "rb:bg-[url('@/assets/images/file/html.svg')]"
: file.type?.includes('json')
? "rb:bg-[url('@/assets/images/file/json.svg')]"
: file.type?.includes('ppt')
? "rb:bg-[url('@/assets/images/file/ppt.svg')]"
: file.type?.includes('markdown')
? "rb:bg-[url('@/assets/images/file/md.svg')]"
: file.type?.includes('text')
? "rb:bg-[url('@/assets/images/file/txt.svg')]"
: (file.type?.includes('doc') || file.type?.includes('docx') || file.type?.includes('word') || file.type?.includes('wordprocessingml.document'))
? "rb:bg-[url('@/assets/images/file/word.svg')]"
: "rb:bg-[url('@/assets/images/file/txt.svg')]"
)}
></div>
<div className="rb:flex-1 rb:w-32.5">
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{documentType?.[documentType.length - 1]} · {file.size}</div>
</div>
</Flex>
)
})}
</Flex>}
{/* Message bubble */} {/* Message bubble */}
<div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', { <div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', {
// Error message style (content is null and not assistant message) // Error message style (content is null and not assistant message)

View File

@@ -1,87 +0,0 @@
import { Image, Flex } from 'antd'
import clsx from 'clsx'
import AudioPlayer from './AudioPlayer'
import VideoPlayer from './VideoPlayer'
const getFileUrl = (file: any) =>
file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
const DOC_ICONS: [string[], string][] = [
[['pdf'], "rb:bg-[url('@/assets/images/file/pdf.svg')]"],
[['excel', 'spreadsheetml.sheet', 'xls', 'xlsx'], "rb:bg-[url('@/assets/images/file/excel.svg')]"],
[['csv'], "rb:bg-[url('@/assets/images/file/csv.svg')]"],
[['html'], "rb:bg-[url('@/assets/images/file/html.svg')]"],
[['json'], "rb:bg-[url('@/assets/images/file/json.svg')]"],
[['ppt'], "rb:bg-[url('@/assets/images/file/ppt.svg')]"],
[['markdown'], "rb:bg-[url('@/assets/images/file/md.svg')]"],
[['text'], "rb:bg-[url('@/assets/images/file/txt.svg')]"],
[['doc', 'docx', 'word', 'wordprocessingml.document'], "rb:bg-[url('@/assets/images/file/word.svg')]"],
]
const getDocIcon = (parts: string[]) => {
const match = DOC_ICONS.find(([keys]) => keys.some(k => parts.includes(k)))
return match ? match[1] : "rb:bg-[url('@/assets/images/file/txt.svg')]"
}
interface MessageFilesProps {
files: any[]
contentClassNames?: string | Record<string, boolean>
onDownload: (file: any) => void
}
const MessageFiles = ({ files, contentClassNames, onDownload }: MessageFilesProps) => {
if (!files?.length) return null
return (
<Flex gap={8} vertical align="end" className="rb:mb-2!">
{files.map((file) => {
const key = file.url || file.uid
if (file.type.includes('image')) {
return (
<div key={key} className={clsx('rb:inline-block rb:group rb:relative rb:rounded-lg', contentClassNames)}>
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
</div>
)
}
if (file.type.includes('video')) {
return (
<div key={key} className="rb:w-50">
<VideoPlayer src={getFileUrl(file)} />
</div>
)
}
if (file.type.includes('audio')) {
return (
<div key={key} className="rb:w-50">
<AudioPlayer src={getFileUrl(file)} />
</div>
)
}
const documentType = (file.file_type || file.type)?.split('/') ?? []
return (
<Flex
key={key}
align="center"
gap={10}
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
onClick={() => onDownload(file)}
>
<div
className={clsx(
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
getDocIcon(documentType)
)}
/>
<div className="rb:flex-1 rb:w-32.5">
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">
{documentType?.[documentType.length - 1]} · {file.size}
</div>
</div>
</Flex>
)
})}
</Flex>
)
}
export default MessageFiles

View File

@@ -3,14 +3,14 @@ import { Popover, type PopoverProps } from 'antd'
import Tag, { type TagProps } from '@/components/Tag' import Tag, { type TagProps } from '@/components/Tag'
interface OverflowTagsProps { interface OverflowTagsProps {
items?: ReactNode[]; items: ReactNode[];
gap?: number; gap?: number;
numTagColor?: TagProps['color']; numTagColor?: TagProps['color'];
numTag?: (num?: number) => ReactNode; numTag?: (num?: number) => ReactNode;
popoverProps?: PopoverProps | false; popoverProps?: PopoverProps | false;
} }
const OverflowTags = ({ items = [], gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => { const OverflowTags = ({ items, gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => {
const containerRef = useRef<HTMLDivElement>(null) const containerRef = useRef<HTMLDivElement>(null)
const measureRef = useRef<HTMLDivElement>(null) const measureRef = useRef<HTMLDivElement>(null)
const [visibleCount, setVisibleCount] = useState(items.length) const [visibleCount, setVisibleCount] = useState(items.length)
@@ -20,7 +20,7 @@ const OverflowTags = ({ items = [], gap = 8, numTagColor = 'default', numTag, po
if (!measure || containerWidth === 0) return if (!measure || containerWidth === 0) return
const children = Array.from(measure.children) as HTMLElement[] const children = Array.from(measure.children) as HTMLElement[]
if (!children.length) { setVisibleCount(0); return } if (!children.length) return
// last child is the sample +N tag // last child is the sample +N tag
const extraTagWidth = (children[children.length - 1] as HTMLElement).offsetWidth const extraTagWidth = (children[children.length - 1] as HTMLElement).offsetWidth

View File

@@ -399,7 +399,7 @@ const Menu: FC<{
className="rb:overflow-y-auto rb:flex-1!" className="rb:overflow-y-auto rb:flex-1!"
/> />
{/* Return to space button for superusers */} {/* Return to space button for superusers */}
{source === 'space' && {user?.is_superuser && source === 'space' &&
<Flex gap={4} vertical className="rb:my-3! rb:mx-3!"> <Flex gap={4} vertical className="rb:my-3! rb:mx-3!">
<Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" /> <Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" />
<Flex <Flex
@@ -412,18 +412,16 @@ const Menu: FC<{
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div> <div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div>
{collapsed ? null : t('common.switchSpace')} {collapsed ? null : t('common.switchSpace')}
</Flex> </Flex>
{user?.is_superuser && <Flex
<Flex gap={8}
gap={8} align="center"
align="center" justify="start"
justify="start" onClick={goToSpace}
onClick={goToSpace} className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer" >
> <div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div> {collapsed ? null : t('common.returnToSpace')}
{collapsed ? null : t('common.returnToSpace')} </Flex>
</Flex>
}
</Flex> </Flex>
} }
{source === 'manage' && subscription && !collapsed && {source === 'manage' && subscription && !collapsed &&

View File

@@ -1538,7 +1538,6 @@ export const en = {
json_output: 'Support JSON formatted output', json_output: 'Support JSON formatted output',
thinking_budget_tokens: 'thinking budget tokens', thinking_budget_tokens: 'thinking budget tokens',
thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})", thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})",
thinking_budget_tokens_min_error: "Cannot be less than {{min}}",
logSearchPlaceholder: 'Search log content', logSearchPlaceholder: 'Search log content',
}, },
userMemory: { userMemory: {

View File

@@ -868,7 +868,6 @@ export const zh = {
json_output: '支持JSON格式化输出', json_output: '支持JSON格式化输出',
thinking_budget_tokens: '深度思考预算Token数', thinking_budget_tokens: '深度思考预算Token数',
thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})", thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})",
thinking_budget_tokens_min_error: "不能小于 {{min}}",
logSearchPlaceholder: '搜索日志内容', logSearchPlaceholder: '搜索日志内容',
}, },
table: { table: {

176
web/src/vendor/x6-html-shape/index.js vendored Normal file
View File

@@ -0,0 +1,176 @@
// Patched x6-html-shape: replaces View.createElement (removed in X6 3.x) with document.createElement
import { Node as p, NodeView as l, Graph as C, Dom as s } from "@antv/x6";
import { getConfig as w, clickable as x, isInputElement as y, forwardEvent as S } from "./utils.js";
const u = "html-shape", h = "html-shape-view", T = p.define(w(h)), m = {};
export function register(i) {
const { shape: e, render: n, inherit: t = u, ...o } = i;
if (!e) throw new Error("should specify shape in config");
m[e] = n;
C.registerNode(e, { inherit: t, ...o }, true);
}
const a = "html";
// Determine which HTML layer a node belongs to.
// Parent (loop/iteration) nodes go behind the SVG layer so edges render above them.
// All other nodes go in front of the SVG layer so they render above edges.
function isBackNode(cell) {
const type = cell.getData?.()?.type;
return type === 'loop' || type === 'iteration';
}
// Ensure the two HTML container layers exist and are correctly positioned.
function ensureHtmlLayers(graph) {
if (!graph._htmlBack) {
const back = graph._htmlBack = document.createElement('div');
s.css(back, {
position: 'absolute', width: '100%', height: '100%',
'touch-action': 'none', 'user-select': 'none', 'pointer-events': 'none',
'z-index': 0, 'transform-origin': 'left top',
});
back.classList.add('x6-html-shape-container', 'x6-html-shape-back');
const svg = graph.container.querySelector('svg');
// back layer: before SVG → visually behind edges
graph.container.insertBefore(back, svg || null);
}
if (!graph._htmlFront) {
const front = graph._htmlFront = document.createElement('div');
s.css(front, {
position: 'absolute', width: '100%', height: '100%',
'touch-action': 'none', 'user-select': 'none', 'pointer-events': 'none',
'z-index': 0, 'transform-origin': 'left top',
});
front.classList.add('x6-html-shape-container', 'x6-html-shape-front');
// front layer: after SVG → visually above edges
graph.container.append(front);
}
// Keep legacy alias so updateHtmlContainerSize can iterate both
graph.htmlContainers = [graph._htmlBack, graph._htmlFront];
}
class BaseHTMLShapeView extends l {
confirmUpdate(e) {
const n = super.confirmUpdate(e);
return this.handleAction(n, a, () => {
if (!this.mounted) {
const t = m[this.cell.shape], o = this.ensureComponentContainer();
t && o && (this.mounted = t(this.cell, this.graph, o) || true,
this.onMounted(),
o.addEventListener("mousedown", this.prevEvent, true),
o.addEventListener("mouseup", this.prevEvent, true));
}
});
}
prevEvent(e) {
(x(e.target) || y(e.target)) && (e.preventDefault(), e.stopPropagation());
}
ensureComponentContainer() {}
onMounted() {}
onUnMount() {
if (this.onZIndexChange) {
this.cell.off("change:zIndex", this.onZIndexChange);
}
if (this.onNodeMoving) {
this.graph.off("node:moving", this.onNodeMoving);
}
}
unmount() {
typeof this.mounted == "function" && this.mounted();
this.componentContainer && this.componentContainer.remove();
this.onUnMount();
return super.unmount(), this;
}
}
BaseHTMLShapeView.config({ bootstrap: [a], actions: { component: a } });
class HTMLShapeView extends BaseHTMLShapeView {
constructor(...e) {
super(...e);
this.cell.on("change:visible", ({ cell: n }) => {
if (n.view === h) {
const t = this.graph.findViewByCell(n.id);
t && Promise.resolve().then(() => {
t.componentContainer.style.display = t.container.style.display;
});
}
});
}
onMounted() {
const listeners = this.graph.listeners;
// Always register per-cell zIndex listener regardless of shared transform events
this.onZIndexChange = () => this.updateContainerStyle();
this.cell.on("change:zIndex", this.onZIndexChange);
if (listeners?.hasTransformEvent?.length) return;
this.onTranslate = this.updateHtmlContainerSize.bind(this);
this.graph.on("translate", this.onTranslate);
this.graph.on("scale", this.onTranslate);
this.graph.on("node:change:position", this.onTranslate);
this.graph.on("hasTransformEvent", this.onTranslate);
// While dragging, lift this node's componentContainer to the top of its
// layer so its ports are never obscured by a sibling node underneath.
this.onNodeMoving = ({ node }) => {
if (node === this.cell && this.componentContainer) {
const layer = isBackNode(this.cell) ? this.graph._htmlBack : this.graph._htmlFront;
layer.append(this.componentContainer);
}
};
this.graph.on("node:moving", this.onNodeMoving);
this.updateHtmlContainerSize();
}
ensureComponentContainer() {
ensureHtmlLayers(this.graph);
const layer = isBackNode(this.cell) ? this.graph._htmlBack : this.graph._htmlFront;
if (!this.componentContainer) {
const e = this.componentContainer = document.createElement("div");
s.css(e, {
"pointer-events": "auto", "touch-action": "none", "user-select": "none",
"transform-origin": "center", position: "absolute"
});
e.classList.add("x6-html-shape-node");
"click,dblclick,contextmenu,mousedown,mousemove,mouseup,mouseover,mouseout,mouseenter,mouseleave"
.split(",").forEach(t => S(t, e, this.container));
layer.append(e);
}
return this.componentContainer;
}
resize() { super.resize(); this.updateContainerStyle(); }
updateTransform() { super.updateTransform(); this.updateContainerStyle(); }
updateContainerStyle() {
const e = this.ensureComponentContainer();
const { x: n, y: t } = this.cell.getBBox();
const { width: o, height: r } = this.cell.getSize();
const g = getComputedStyle(this.container).cursor;
const f = this.cell.getZIndex() ?? 0;
// Shrink the interactive width by the port hover radius (6px) so the right
// port circle is fully outside the componentContainer and never blocked by it.
// overflow:visible keeps the visual rendering intact.
const PORT_RADIUS = 6;
s.css(e, {
cursor: g, height: r + "px", width: (o - PORT_RADIUS) + "px",
overflow: "visible",
"z-index": f,
transform: `translate(${n}px, ${t}px) rotate(${this.cell.getAngle()}deg)`
});
}
updateHtmlContainerSize() {
const { graph: e } = this;
const t = e.transform.getMatrix();
const { offsetHeight: o, offsetWidth: r } = e.container;
const n = e.transform.getZoom();
const style = {
transform: `matrix(${t.a}, ${t.b}, ${t.c}, ${t.d}, ${t.e}, ${t.f})`,
width: r / n + "px",
height: o / n + "px",
};
// Update both layers
(e.htmlContainers || [e._htmlBack, e._htmlFront].filter(Boolean)).forEach(c => s.css(c, style));
}
}
l.registry.register(h, HTMLShapeView, true);
p.registry.register(u, T, true);
export { BaseHTMLShapeView, T as HTMLShape, u as HTMLShapeName, HTMLShapeView, h as HTMLView, a as action };

1
web/src/vendor/x6-html-shape/react.js vendored Normal file
View File

@@ -0,0 +1 @@
export { default } from "x6-html-shape/dist/react.js";

98
web/src/vendor/x6-html-shape/utils.js vendored Normal file
View File

@@ -0,0 +1,98 @@
import { Dom as u, ObjectExt as l, Markup as c } from "@antv/x6";
const o = "fo-shape-view";
function p(t, e, r) {
e.addEventListener(t, function(n) {
r.dispatchEvent(new n.constructor(n.type, n)), n.preventDefault(), n.stopPropagation();
});
}
function s(t, e = 3) {
return !t || !u.isHTMLElement(t) || e <= 0 ? !1 : ["a", "button"].includes(u.tagName(t)) || t.getAttribute("role") === "button" || t.getAttribute("type") === "button" ? !0 : s(t.parentNode, e - 1);
}
function g(t) {
if (u.tagName(t) === "input") {
const r = t.getAttribute("type");
if (r == null || ["text", "password", "number", "email", "search", "tel", "url"].includes(
r
))
return !0;
}
return !1;
}
function f(t = "rect", e = !0) {
return [
{
tagName: t,
selector: "body"
},
e ? c.getForeignObjectMarkup() : null,
{
tagName: "text",
selector: "label"
}
].filter((r) => r);
}
function b(t) {
return {
view: t,
markup: f("rect", t === o),
attrs: {
body: {
// fill: "none",
// 这里很奇怪none的时候不能触发节点移动改成transparent可以触发
fill: "transparent",
stroke: "none",
refWidth: "100%",
refHeight: "100%"
},
label: {
fontSize: 14,
fill: "#333",
refX: "50%",
refY: "50%",
textAnchor: "middle",
textVerticalAnchor: "middle"
},
fo: {
refWidth: "100%",
refHeight: "100%"
}
},
propHooks(e) {
if (e.markup == null) {
const { primer: r, view: n } = e;
if (r && r !== "rect") {
e.markup = f(r, n === o);
let i = {};
r === "circle" ? i = {
refCx: "50%",
refCy: "50%",
refR: "50%"
} : r === "ellipse" && (i = {
refCx: "50%",
refCy: "50%",
refRx: "50%",
refRy: "50%"
}), e.attrs = l.merge(
{},
{
body: {
refWidth: null,
refHeight: null,
...i
}
},
e.attrs || {}
);
}
}
return e;
}
};
}
export {
o as FOView,
s as clickable,
p as forwardEvent,
b as getConfig,
g as isInputElement
};

View File

@@ -7,7 +7,7 @@
import { type FC, useRef } from 'react'; import { type FC, useRef } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useParams } from 'react-router-dom'; import { useParams } from 'react-router-dom';
import { Flex, Button, Form } from 'antd'; import { Flex, Button } from 'antd';
import type { ColumnsType } from 'antd/es/table'; import type { ColumnsType } from 'antd/es/table';
import { getAppLogsUrl } from '@/api/application'; import { getAppLogsUrl } from '@/api/application';
@@ -15,14 +15,11 @@ import Table from '@/components/Table'
import { formatDateTime } from '@/utils/format'; import { formatDateTime } from '@/utils/format';
import type { LogItem, LogDetailModalRef } from './types' import type { LogItem, LogDetailModalRef } from './types'
import LogDetailModal from './components/LogDetailModal' import LogDetailModal from './components/LogDetailModal'
import SearchInput from '@/components/SearchInput'
const Statistics: FC = () => { const Statistics: FC = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { id } = useParams(); const { id } = useParams();
const logDetailRef = useRef<LogDetailModalRef>(null); const logDetailRef = useRef<LogDetailModalRef>(null);
const [form] = Form.useForm();
const values = Form.useWatch([], form);
const handleViewDetail = (item: LogItem) => { const handleViewDetail = (item: LogItem) => {
logDetailRef.current?.handleOpen(item); logDetailRef.current?.handleOpen(item);
@@ -65,26 +62,15 @@ const Statistics: FC = () => {
]; ];
return ( return (
<div className="rb:bg-white rb:rounded-lg rb:pt-3 rb:px-3"> <div className="rb:bg-white rb:rounded-lg rb:pt-3 rb:px-3">
<Flex justify="flex-end" className="rb:mb-3!">
<Form form={form}>
<Form.Item name="keyword" noStyle>
<SearchInput
placeholder={t('application.logSearchPlaceholder')}
variant="outlined"
/>
</Form.Item>
</Form>
</Flex>
<Table<LogItem> <Table<LogItem>
apiUrl={getAppLogsUrl(id || '')} apiUrl={getAppLogsUrl(id || '')}
apiParams={{ apiParams={{
is_draft: false, is_draft: false,
...(values ?? {})
}} }}
columns={columns} columns={columns}
rowKey="id" rowKey="id"
isScroll={true} isScroll={true}
scrollY="calc(100vh - 242px)" scrollY="calc(100vh - 214px)"
/> />
<LogDetailModal ref={logDetailRef} /> <LogDetailModal ref={logDetailRef} />
</div> </div>

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-03-13 17:27:52 * @Date: 2026-03-13 17:27:52
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-24 18:14:25 * @Last Modified time: 2026-04-07 21:48:30
*/ */
import { type FC, useState, useRef, useEffect } from 'react' import { type FC, useState, useRef, useEffect } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
@@ -59,7 +59,6 @@ interface NodeData {
node_type?: string; node_type?: string;
input?: any; input?: any;
output?: any; output?: any;
process?: any;
elapsed_time?: string; elapsed_time?: string;
error?: any; error?: any;
state: Record<string, any>; state: Record<string, any>;
@@ -486,7 +485,7 @@ const TestChat: FC<TestChatProps> = ({
} }
const updateWorkflowNodeEndMessage = (data: NodeData) => { const updateWorkflowNodeEndMessage = (data: NodeData) => {
const { node_id, input, output, process, error, elapsed_time, status } = data; const { node_id, input, output, error, elapsed_time, status } = data;
setChatList(prev => { setChatList(prev => {
const newList = [...prev] const newList = [...prev]
const lastIndex = newList.length - 1 const lastIndex = newList.length - 1
@@ -499,7 +498,6 @@ const TestChat: FC<TestChatProps> = ({
content: { content: {
input, input,
output, output,
process,
error, error,
}, },
status: status || 'completed', status: status || 'completed',
@@ -516,7 +514,7 @@ const TestChat: FC<TestChatProps> = ({
} }
const updateWorkflowCycleMessage = (data: NodeData) => { const updateWorkflowCycleMessage = (data: NodeData) => {
const { node_id, cycle_id, cycle_idx, input, output, process, error, elapsed_time, status } = data; const { node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = data;
const { nodes } = config as WorkflowConfig const { nodes } = config as WorkflowConfig
const node = nodes.find(n => n.id === node_id); const node = nodes.find(n => n.id === node_id);
const { name, type } = node || {} const { name, type } = node || {}
@@ -540,7 +538,6 @@ const TestChat: FC<TestChatProps> = ({
cycle_idx, cycle_idx,
input, input,
output, output,
process,
error, error,
}, },
status: status || 'completed', status: status || 'completed',

View File

@@ -1,8 +1,8 @@
/* /*
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-03-24 16:31:24 * @Date: 2026-03-24 16:31:24
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-24 17:49:58 * @Last Modified time: 2026-03-24 16:31:24
*/ */
import { forwardRef, useImperativeHandle, useState, useEffect } from 'react'; import { forwardRef, useImperativeHandle, useState, useEffect } from 'react';
import { Flex, Button, Empty, Skeleton } from 'antd'; import { Flex, Button, Empty, Skeleton } from 'antd';
@@ -14,12 +14,6 @@ import { getAppLogDetail } from '@/api/application'
import ChatContent from '@/components/Chat/ChatContent' import ChatContent from '@/components/Chat/ChatContent'
import { formatDateTime } from '@/utils/format' import { formatDateTime } from '@/utils/format'
import type { ChatItem } from '@/components/Chat/types' import type { ChatItem } from '@/components/Chat/types'
import Runtime from '@/views/Workflow/components/Chat/Runtime'
import { nodeLibrary } from '@/views/Workflow/constant'
const nodeIconMap = Object.fromEntries(
nodeLibrary.flatMap(c => c.nodes.map(n => [n.type, n.icon]))
)
/** Log detail data with conversation messages */ /** Log detail data with conversation messages */
type Data = LogItem & { type Data = LogItem & {
@@ -60,30 +54,7 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
if (!vo) return if (!vo) return
setLoading(true) setLoading(true)
getAppLogDetail(vo.app_id, vo.id).then(res => { getAppLogDetail(vo.app_id, vo.id).then(res => {
const { node_executions_map, messages, ...rest } = res as Data; setData(res as Data)
let hasSubContentMessages = messages
if (messages && messages.length > 0 && node_executions_map && Object.keys(node_executions_map).length > 0) {
hasSubContentMessages = messages.map(item => {
if (item.id && node_executions_map[item.id]) {
item.subContent = node_executions_map[item.id]?.map(({ input, output, cycle_items = [], error, process, ...node }: any) => {
const converted: any = { ...node, icon: nodeIconMap[node.node_type], content: { input, output, process, error } }
if (node.node_type === 'loop' && Array.isArray(cycle_items) && cycle_items.length > 0) {
converted.subContent = cycle_items.map(({ input: cInput, output: cOutput, error: cError, process: cProcess, ...cNode }: any) => ({
...cNode,
icon: nodeIconMap[cNode.node_type],
content: { input: cInput, output: cOutput, process: cProcess, error: cError }
}))
}
return converted
})
}
return { ...item }
})
}
setData({
...rest,
messages: hasSubContentMessages
})
}) })
.finally(() => { .finally(() => {
setLoading(false) setLoading(false)
@@ -95,8 +66,6 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
handleClose handleClose
})); }));
console.log('data', data)
return ( return (
<RbModal <RbModal
title={<> title={<>
@@ -123,7 +92,6 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
data={data.messages || []} data={data.messages || []}
streamLoading={false} streamLoading={false}
labelFormat={(item) => formatDateTime(item.created_at)} labelFormat={(item) => formatDateTime(item.created_at)}
renderRuntime={(item, index) => <Runtime item={item} index={index} />}
/> />
) )
} }

View File

@@ -49,8 +49,6 @@ const configFields = [
{ key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 }, { key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 },
] ]
const minThinkingBudgetTokens = 128;
const defaultThinkingBudgetTokens = 1000;
const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({
refresh, refresh,
data, data,
@@ -110,7 +108,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
const newValues: ModelConfig = { const newValues: ModelConfig = {
capability: (option as Model).capability, capability: (option as Model).capability,
deep_thinking: false, deep_thinking: false,
thinking_budget_tokens: defaultThinkingBudgetTokens, thinking_budget_tokens: undefined,
json_output: false, json_output: false,
} }
if (source === 'chat') { if (source === 'chat') {
@@ -130,12 +128,6 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
form.setFieldsValue({ ...rest }) form.setFieldsValue({ ...rest })
}, [data?.default_model_config_id]) }, [data?.default_model_config_id])
useEffect(() => {
if (values?.deep_thinking && !values?.thinking_budget_tokens) {
form.setFieldValue('thinking_budget_tokens', defaultThinkingBudgetTokens)
}
}, [values?.deep_thinking])
const handleReset = () => { const handleReset = () => {
if (!id) return if (!id) return
resetAppModelConfig(id).then((res) => { resetAppModelConfig(id).then((res) => {
@@ -186,20 +178,15 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
name="thinking_budget_tokens" name="thinking_budget_tokens"
label={t('application.thinking_budget_tokens')} label={t('application.thinking_budget_tokens')}
hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))} hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))}
extra={<>{t('application.range')}: [{minThinkingBudgetTokens}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>} extra={<>{t('application.range')}: [{0}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
rules={[ rules={[
{ required: values?.deep_thinking, message: t('common.pleaseEnter') }, { required: values?.deep_thinking, message: t('common.pleaseEnter') },
{ {
validator: (_, value) => { validator: (_, value) => {
const maxTokens = values?.max_tokens const maxTokens = values?.max_tokens
const deep_thinking = values?.deep_thinking; const deep_thinking = values?.deep_thinking;
if (deep_thinking && value !== undefined) { if (deep_thinking && value !== undefined && maxTokens !== undefined && value > maxTokens) {
if (value < minThinkingBudgetTokens) { return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
return Promise.reject(t('application.thinking_budget_tokens_min_error', { min: minThinkingBudgetTokens }))
}
if (maxTokens !== undefined && value > maxTokens) {
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
}
} }
return Promise.resolve() return Promise.resolve()
} }
@@ -208,7 +195,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
> >
<RbSlider <RbSlider
step={1} step={1}
min={minThinkingBudgetTokens} min={0}
max={32000} max={32000}
isInput={true} isInput={true}
disabled={!values?.deep_thinking} disabled={!values?.deep_thinking}

View File

@@ -166,10 +166,10 @@ const Ontology: FC = () => {
<div className="rb:h-10 rb:wrap-break-word rb:line-clamp-2 rb:leading-5">{item.scene_description}</div> <div className="rb:h-10 rb:wrap-break-word rb:line-clamp-2 rb:leading-5">{item.scene_description}</div>
</Tooltip> </Tooltip>
<div className="rb:mt-2 rb:h-5.5"> <div className="rb:mt-2">
<OverflowTags <OverflowTags
popoverProps={false} popoverProps={false}
items={item.entity_type ? [...item.entity_type.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>] : []} items={[...item.entity_type?.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>]}
numTag={(num?: number) => <Tag variant="borderless" color="dark">{`+${item.type_num - 3 + (num ? num - 1 : 0)}`}</Tag>} numTag={(num?: number) => <Tag variant="borderless" color="dark">{`+${item.type_num - 3 + (num ? num - 1 : 0)}`}</Tag>}
/> />
</div> </div>

View File

@@ -101,7 +101,6 @@ const CustomToolModal = forwardRef<CustomToolModalRef, CustomToolModalProps>(({
}); });
}; };
const formatSchema = (value: string) => { const formatSchema = (value: string) => {
if (!value || value.trim() === '') return
setParseSchemaData({} as ParseSchemaData) setParseSchemaData({} as ParseSchemaData)
parseSchema({ schema_content: value }) parseSchema({ schema_content: value })
.then(res => { .then(res => {

View File

@@ -57,6 +57,7 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
} }
}} }}
labelRender={(props) => { labelRender={(props) => {
console.log('props', props)
return `${props.value}%` return `${props.value}%`
}} }}
className="rb:w-20 rb:h-4!" className="rb:w-20 rb:h-4!"

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-06 21:10:56 * @Date: 2026-02-06 21:10:56
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-24 18:13:22 * @Last Modified time: 2026-04-21 14:59:13
*/ */
/** /**
* Workflow Chat Component * Workflow Chat Component
@@ -66,6 +66,8 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
const [fileList, setFileList] = useState<any[]>([]) const [fileList, setFileList] = useState<any[]>([])
const [message, setMessage] = useState<string | undefined>(undefined) const [message, setMessage] = useState<string | undefined>(undefined)
console.log('abortRef', abortRef)
/** /**
* Opens the chat drawer and loads workflow variables from the start node * Opens the chat drawer and loads workflow variables from the start node
*/ */
@@ -183,7 +185,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
*/ */
const handleStreamMessage = (data: SSEMessage[]) => { const handleStreamMessage = (data: SSEMessage[]) => {
data.forEach(item => { data.forEach(item => {
const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, process, error, elapsed_time, status, citations } = item.data as { const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status, citations } = item.data as {
content: string; content: string;
conversation_id: string | null; conversation_id: string | null;
cycle_id: string; cycle_id: string;
@@ -191,7 +193,6 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
node_id: string; node_id: string;
node_name?: string; node_name?: string;
node_type?: string; node_type?: string;
process?: any;
input?: any; input?: any;
output?: any; output?: any;
elapsed_time?: string; elapsed_time?: string;
@@ -276,7 +277,6 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
content: { content: {
input, input,
output, output,
process,
error, error,
}, },
status: status || 'completed', status: status || 'completed',
@@ -305,14 +305,13 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
cycle_id, cycle_id,
cycle_idx, cycle_idx,
node_id, node_id,
node_name: type === 'cycle-start' ? t('workflow.cycle-start') : name, node_name: name,
node_type: type, node_type: type,
icon, icon,
content: { content: {
cycle_idx, cycle_idx,
input, input,
output, output,
process,
error, error,
}, },
status: status || 'completed', status: status || 'completed',

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-24 17:57:08 * @Date: 2026-02-24 17:57:08
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-24 18:04:31 * @Last Modified time: 2026-04-20 15:33:48
*/ */
/* /*
* Runtime Component * Runtime Component
@@ -184,30 +184,27 @@ const Runtime: FC<{ item: ChatItem; index: number;}> = ({
</Flex> </Flex>
)} )}
{/* Display input and output data as JSON code blocks */} {/* Display input and output data as JSON code blocks */}
{['input', 'process', 'output'].map(key => { {['input', 'output'].map(key => (
if (vo.node_type !== 'http-request' && key === 'process') return null <div key={key} className="rb:bg-[#EBEBEB] rb:rounded-lg">
return ( <div className="rb:py-2 rb:px-3 rb:flex rb:justify-between rb:items-center rb:text-[12px]">
<div key={key} className="rb:bg-[#EBEBEB] rb:rounded-lg"> {isLoop ? t(`workflow.runtime.${key}_cycle_vars`) : t(`workflow.${key}_result`)}
<div className="rb:py-2 rb:px-3 rb:flex rb:justify-between rb:items-center rb:text-[12px]"> <Button
{isLoop ? t(`workflow.runtime.${key}_cycle_vars`) : t(`workflow.${key}_result`)} className="rb:py-0! rb:px-1! rb:text-[12px]!"
<Button size="small"
className="rb:py-0! rb:px-1! rb:text-[12px]!" onClick={() => handleCopy(typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}')}
size="small" >{t('common.copy')}</Button>
onClick={() => handleCopy(typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}')}
>{t('common.copy')}</Button>
</div>
<div className="rb:max-h-40 rb:overflow-auto">
<CodeBlock
size="small"
value={typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}'}
needCopy={false}
showLineNumbers={true}
background="#EBEBEB"
/>
</div>
</div> </div>
) <div className="rb:max-h-40 rb:overflow-auto">
})} <CodeBlock
size="small"
value={typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}'}
needCopy={false}
showLineNumbers={true}
background="#EBEBEB"
/>
</div>
</div>
))}
</Flex> </Flex>
) )
}]} }]}

View File

@@ -2,183 +2,29 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-09 18:31:30 * @Date: 2026-02-09 18:31:30
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-30 11:55:10 * @Last Modified time: 2026-04-28 10:24:58
*/ */
import { useState } from 'react'; import { Flex } from 'antd';
import { Popover, Flex } from 'antd';
import clsx from 'clsx'; import clsx from 'clsx';
import type { ReactShapeConfig } from '@antv/x6-react-shape'; import type { ReactShapeConfig } from '@antv/x6-react-shape';
import { nodeLibrary, graphNodeLibrary, edgeAttrs, nodeWidth } from '../../constant';
import { useTranslation } from 'react-i18next';
const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => { const AddNode: ReactShapeConfig['component'] = ({ node }) => {
const data = node?.getData() || {}; const data = node?.getData() || {};
const { t } = useTranslation();
const [open, setOpen] = useState(false);
// Handle node selection from popover and create new node replacing the add-node placeholder
const handleNodeSelect = (selectedNodeType: any) => {
graph.startBatch('add-node');
const parentBBox = node.getBBox();
const cycleId = data.cycle;
const horizontalSpacing = 0;
const id = `${selectedNodeType.type.replace(/-/g, '_') }_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const newNode = graph.addNode({
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
x: parentBBox.x + horizontalSpacing,
y: parentBBox.y - 12,
id,
data: {
id,
type: selectedNodeType.type,
icon: selectedNodeType.icon,
name: t(`workflow.${selectedNodeType.type}`),
cycle: cycleId,
parentId: data.parentId,
config: selectedNodeType.config || {}
},
});
// Add new node as child of parent node
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) {
parentNode.addChild(newNode, { silent: true });
}
}
const incomingEdges = graph.getIncomingEdges(node);
const outgoingEdges = graph.getOutgoingEdges(node);
const addedEdges: any[] = [];
incomingEdges?.forEach((edge: any) => {
addedEdges.push(graph.addEdge({
source: { cell: edge.getSourceCellId(), port: edge.getSourcePortId() },
target: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'left')?.id || 'left' },
...edgeAttrs
}));
});
outgoingEdges?.forEach((edge: any) => {
const targetCell = graph.getCellById(edge.getTargetCellId()) as any;
const targetPortId = targetCell?.getPorts?.()?.find((port: any) => port.group === 'left')?.id || edge.getTargetPortId();
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'right')?.id || 'right' },
target: { cell: edge.getTargetCellId(), port: targetPortId },
...edgeAttrs
}));
});
// Remove all add-node type nodes
graph.getNodes().forEach((n: any) => {
if (n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId) {
n.remove();
}
});
// Automatically adjust loop node size
const loopNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (loopNode) {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc, child) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
loopNode.prop('size', { width: newWidth, height: newHeight });
loopNode.getPorts().forEach(port => {
if (port.group === 'right' && port.args) {
loopNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
}
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
graph.stopBatch('add-node');
setOpen(false);
};
const content = (
<div style={{ maxHeight: '300px', overflowY: 'auto', minWidth: `${nodeWidth}px'` }}>
{nodeLibrary.map((category, categoryIndex) => {
const filteredNodes = category.nodes.filter(nodeType =>
nodeType.type !== 'start' && nodeType.type !== 'end' && nodeType.type !== 'iteration' && nodeType.type !== 'loop' && nodeType.type !== 'cycle-start'
);
if (filteredNodes.length === 0) return null;
return (
<div key={category.category}>
{categoryIndex > 0 && <div style={{ height: '1px', background: '#f0f0f0', margin: '4px 0' }} />}
<div style={{ padding: '4px 12px', fontSize: '12px', color: '#999', fontWeight: 'bold' }}>
{t(`workflow.${category.category}`)}
</div>
{filteredNodes.map((nodeType) => (
<div
key={nodeType.type}
style={{
padding: '8px 12px',
cursor: 'pointer',
display: 'flex',
alignItems: 'center',
gap: '8px',
}}
onClick={() => handleNodeSelect(nodeType)}
onMouseEnter={(e) => {
e.currentTarget.style.background = '#f0f8ff';
}}
onMouseLeave={(e) => {
e.currentTarget.style.background = 'white';
}}
>
<div className={`rb:size-4 rb:bg-cover ${nodeType.icon}`} />
<span style={{ fontSize: '14px' }}>{t(`workflow.${nodeType.type}`)}</span>
</div>
))}
</div>
);
})}
</div>
);
return ( return (
<Popover <Flex
content={content} align="center"
trigger="click" justify="center"
open={open} gap={4}
onOpenChange={setOpen} className={clsx('rb:text-[#212332] rb:font-medium rb:text-[12px] rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:border rb:rounded-lg rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:border-[#FCFCFD] rb:flex rb:items-center rb:justify-center', {
placement="bottomLeft" 'rb:border-orange-500 rb:border-[3px] rb:bg-[#FCFCFD] rb:text-[#475467]': data.isSelected,
'rb:border-[#d1d5db] rb:bg-[#FCFCFD] rb:text-[#374151]': !data.isSelected
})}
> >
<Flex <div className="rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/workflow/node_plus.png')]"></div>
align="center" {data.label}
justify="center" </Flex>
gap={4}
className={clsx('rb:text-[#212332] rb:font-medium rb:text-[12px] rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:border rb:rounded-lg rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:border-[#FCFCFD] rb:flex rb:items-center rb:justify-center', {
'rb:border-orange-500 rb:border-[3px] rb:bg-[#FCFCFD] rb:text-[#475467]': data.isSelected,
'rb:border-[#d1d5db] rb:bg-[#FCFCFD] rb:text-[#374151]': !data.isSelected
})}
>
<div className="rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/workflow/node_plus.png')]"></div>
{data.label}
</Flex>
</Popover>
); );
}; };
export default AddNode; export default AddNode;

View File

@@ -65,8 +65,8 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
return ( return (
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', { <div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
'rb:border-[#171719]!': data.isSelected && !data.executionStatus, 'rb:border-[#171719]!': data.isSelected,
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus, 'rb:border-[#FCFCFD]': !data.isSelected,
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed', 'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed', 'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
})}> })}>
@@ -99,7 +99,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
{data.type === 'if-else' && {data.type === 'if-else' &&
<Flex vertical gap={4} className="rb:mt-3!"> <Flex vertical gap={4} className="rb:mt-3!">
{data.config?.cases?.defaultValue.map((item: any, index: number) => ( {data.config?.cases?.defaultValue.map((item: any, index: number) => (
<div key={index}> <div key={index} className={item.expressions.length > 0 ? '' : 'rb:mb-1'}>
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4"> <Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>} {item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span> <span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>

View File

@@ -1,19 +1,138 @@
import { useEffect } from 'react';
import { useTranslation } from 'react-i18next'
import clsx from 'clsx'; import clsx from 'clsx';
import type { ReactShapeConfig } from '@antv/x6-react-shape'; import type { ReactShapeConfig } from '@antv/x6-react-shape';
import { Flex } from 'antd'; import { Flex } from 'antd';
import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons'; import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons';
import { useTranslation } from 'react-i18next'
import { graphNodeLibrary, edgeAttrs } from '../../constant';
import NodeTools from './NodeTools' import NodeTools from './NodeTools'
const LoopNode: ReactShapeConfig['component'] = ({ node }) => { const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
const data = node.getData() || {}; const data = node.getData() || {};
const { t } = useTranslation() const { t } = useTranslation()
useEffect(() => {
// 使用setTimeout确保在所有节点都添加完成后再创建连线
const timer = setTimeout(() => {
initNodes()
checkAndAddAddNode()
}, 50)
return () => clearTimeout(timer)
}, [graph])
const checkAndAddAddNode = () => {
if (!graph) return;
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === data.id);
const cycleStartNodes = childNodes.filter((n: any) => n.getData()?.type === 'cycle-start');
// 如果只有一个cycle-start节点且没有其他类型的子节点则添加add-node
if (cycleStartNodes.length === 1 && childNodes.length === 1) {
const cycleStartNode = cycleStartNodes[0];
const cycleStartBBox = cycleStartNode.getBBox();
const addNode = graph.addNode({
...graphNodeLibrary.addStart,
x: cycleStartBBox.x + 84,
y: cycleStartBBox.y + 4,
data: {
type: 'add-node',
label: t('workflow.addNode'),
icon: '+',
parentId: node.id,
cycle: data.id,
},
});
node.addChild(addNode);
// 连接cycle-start和add-node
const sourcePorts = cycleStartNode.getPorts();
const targetPorts = addNode.getPorts();
const sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
const targetPort = targetPorts.find((port: any) => port.group === 'left')?.id || 'left';
// 然后创建连线
graph.addEdge({
source: { cell: cycleStartNode.id, port: sourcePort },
target: { cell: addNode.id, port: targetPort },
...edgeAttrs,
});
cycleStartNode.toFront()
addNode.toFront()
}
}
const initNodes = () => {
// 检查是否存在cycle为当前节点ID的子节点若存在则不调用initNodes避免重复创建
const existingCycleNodes = graph.getNodes().filter((n: any) =>
n.getData()?.cycle === data.id
);
if (existingCycleNodes.length > 0) return;
// 添加默认子节点
const parentBBox = node.getBBox();
const centerX = parentBBox.x + 24;
const centerY = parentBBox.y + 70;
const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const cycleStartNode = graph.addNode({
...graphNodeLibrary.cycleStart,
x: centerX,
y: centerY,
id: cycleStartNodeId,
data: {
id: cycleStartNodeId,
type: 'cycle-start',
parentId: node.id,
isDefault: true, // 标记为默认节点,不可删除
cycle: data.id,
},
});
const addNode = graph.addNode({
...graphNodeLibrary.addStart,
x: centerX + 84,
y: centerY + 4,
data: {
type: 'add-node',
label: t('workflow.addNode'),
icon: '+',
parentId: node.id,
cycle: data.id,
},
});
node.addChild(cycleStartNode)
node.addChild(addNode)
const sourcePorts = cycleStartNode.getPorts()
const targetPorts = addNode.getPorts()
let sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
const edgeConfig = {
source: {
cell: cycleStartNode.id,
port: sourcePort
},
target: {
cell: addNode.id,
port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left'
},
...edgeAttrs
}
graph.addEdge(edgeConfig)
setTimeout(() => {
cycleStartNode.toFront()
addNode.toFront()
}, 0)
}
return ( return (
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', { <div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
'rb:border-[#171719]!': data.isSelected && !data.executionStatus, 'rb:border-[#171719]': data.isSelected,
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus, 'rb:border-[#FCFCFD]': !data.isSelected,
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed', 'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed', 'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
})}> })}>

View File

@@ -12,8 +12,8 @@ const NormalNode: ReactShapeConfig['component'] = ({ node }) => {
return ( return (
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', { <div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
'rb:border-[#171719]!': data.isSelected && !data.executionStatus, 'rb:border-[#171719]!': data.isSelected,
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus, 'rb:border-[#FCFCFD]': !data.isSelected,
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed', 'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed', 'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
})}> })}>

View File

@@ -2,13 +2,44 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-09 18:30:28 * @Date: 2026-02-09 18:30:28
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-30 15:14:02 * @Last Modified time: 2026-04-28 11:41:17
*/ */
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { createPortal } from 'react-dom';
import { Flex, Popover } from 'antd'; import { Flex, Popover } from 'antd';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { nodeLibrary, graphNodeLibrary, edgeAttrs, nodeWidth } from '../constant'; import { nodeLibrary, graphNodeLibrary, edgeAttrs, nodeWidth } from '../constant';
// Shared helper: adjust loop/iteration container size to fit child nodes
export const adjustCycleContainerSize = (graph: any, cycleId: string) => {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (!parentNode) return;
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length === 0) return;
const bounds = childNodes.reduce((acc: any, child: any) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height),
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
parentNode.prop('size', { width: newWidth, height: newHeight });
parentNode.getPorts().forEach((port: any) => {
if (port.group === 'right' && port.args) {
parentNode.portProp(port.id!, 'args/x', newWidth);
}
});
childNodes.forEach((childNode: any) => {
childNode.off('change:position');
childNode.on('change:position', () => adjustCycleContainerSize(graph, cycleId));
});
};
interface PortClickHandlerProps { interface PortClickHandlerProps {
graph: any; graph: any;
} }
@@ -16,7 +47,6 @@ interface PortClickHandlerProps {
const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => { const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
const { t } = useTranslation(); const { t } = useTranslation();
const [popoverVisible, setPopoverVisible] = useState(false); const [popoverVisible, setPopoverVisible] = useState(false);
const [popoverPosition, setPopoverPosition] = useState({ x: 0, y: 0 });
const [sourceNode, setSourceNode] = useState<any>(null); const [sourceNode, setSourceNode] = useState<any>(null);
const [sourcePort, setSourcePort] = useState<string>(''); const [sourcePort, setSourcePort] = useState<string>('');
const [tempElement, setTempElement] = useState<HTMLElement | null>(null); const [tempElement, setTempElement] = useState<HTMLElement | null>(null);
@@ -24,12 +54,11 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
useEffect(() => { useEffect(() => {
const handlePortClick = (event: CustomEvent) => { const handlePortClick = (event: CustomEvent) => {
const { node, port, element, rect, edgeInsertion } = event.detail; const { node, port, element, edgeInsertion } = event.detail;
setSourceNode(node); setSourceNode(node);
setSourcePort(port); setSourcePort(port);
setTempElement(element); setTempElement(element);
setEdgeInsertion(edgeInsertion || null); setEdgeInsertion(edgeInsertion || null);
setPopoverPosition({ x: rect.left, y: rect.top });
setPopoverVisible(true); setPopoverVisible(true);
}; };
@@ -43,52 +72,130 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
}; };
}, []); }, []);
// Handle node selection from popover menu and create new node with edge connection
const handleNodeSelect = (selectedNodeType: any) => { const handleNodeSelect = (selectedNodeType: any) => {
if (!sourceNode || !graph) return; if (!sourceNode || !graph) return;
const sourceNodeData = sourceNode.getData(); const sourceNodeData = sourceNode.getData();
const sourceNodeType = sourceNodeData?.type; const sourceNodeType = sourceNodeData?.type;
const isCycleSubNode = !!sourceNodeData.cycle;
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
const newNodeType = selectedNodeType.type;
// Save add-node placeholder position before disabling history // AddNode placeholder mode: replace the add-node placeholder with the selected node
if (sourceNodeType === 'add-node') {
const placeholderBBox = sourceNode.getBBox();
const cycleId = sourceNodeData.cycle;
const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const newNode = graph.addNode({
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
x: placeholderBBox.x,
y: placeholderBBox.y - 12,
id,
data: {
id,
type: selectedNodeType.type,
icon: selectedNodeType.icon,
name: t(`workflow.${selectedNodeType.type}`),
cycle: cycleId,
parentId: sourceNodeData.parentId,
config: selectedNodeType.config || {},
},
});
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) parentNode.addChild(newNode);
}
const incomingEdges = graph.getIncomingEdges(sourceNode);
const outgoingEdges = graph.getOutgoingEdges(sourceNode);
const addedEdges: any[] = [];
incomingEdges?.forEach((edge: any) => {
addedEdges.push(graph.addEdge({
source: { cell: edge.getSourceCellId(), port: edge.getSourcePortId() },
target: { cell: newNode.id, port: newNode.getPorts().find((p: any) => p.group === 'left')?.id || 'left' },
...edgeAttrs,
}));
});
outgoingEdges?.forEach((edge: any) => {
const targetCell = graph.getCellById(edge.getTargetCellId()) as any;
const targetPortId = targetCell?.getPorts?.()?.find((p: any) => p.group === 'left')?.id || edge.getTargetPortId();
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: newNode.getPorts().find((p: any) => p.group === 'right')?.id || 'right' },
target: { cell: edge.getTargetCellId(), port: targetPortId },
...edgeAttrs,
}));
});
graph.getNodes().forEach((n: any) => {
if (n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId) n.remove();
});
setTimeout(() => {
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
}, 50);
if (cycleId) adjustCycleContainerSize(graph, cycleId);
if (tempElement) { document.body.removeChild(tempElement); setTempElement(null); }
setPopoverVisible(false);
return;
}
// If it's a cycle-start node, handle the add-node placeholder
let addNodePosition = null; let addNodePosition = null;
const isCycleSubNode = sourceNodeData.cycle
if (isCycleSubNode && sourceNodeType === 'cycle-start') { if (isCycleSubNode && sourceNodeType === 'cycle-start') {
const cycleId = sourceNodeData.cycle; const cycleId = sourceNodeData.cycle;
const addNodes = graph.getNodes().filter((n: any) => const addNodes = graph.getNodes().filter((n: any) =>
n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId
); );
if (addNodes.length > 0) addNodePosition = addNodes[0].getBBox();
if (addNodes.length > 0) {
const addNode = addNodes[0];
addNodePosition = addNode.getBBox();
addNode.remove();
}
} }
// Calculate position // Calculate new node position to avoid overlapping
const sourceBBox = sourceNode.getBBox(); const sourceBBox = sourceNode.getBBox();
const nw = graphNodeLibrary[newNodeType]?.width || 120; const nodeWidth = graphNodeLibrary[selectedNodeType.type]?.width || 120;
const nh = graphNodeLibrary[newNodeType]?.height || 88; const nodeHeight = graphNodeLibrary[selectedNodeType.type]?.height || 88;
const hSpacing = isCycleSubNode ? 48 : 80; const horizontalSpacing = isCycleSubNode ? 48 : 80;
const vSpacing = 10; const verticalSpacing = 10;
// Get source port group information
const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort); const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort);
const sourcePortGroup = sourcePortInfo?.group || sourcePort; const sourcePortGroup = sourcePortInfo?.group || sourcePort;
let newX: number, newY: number; // Calculate new node position
let newX, newY;
if (edgeInsertion) { if (edgeInsertion) {
// Edge insertion: place new node on the same row as target, between source and target
const targetBBox = edgeInsertion.targetCell.getBBox(); const targetBBox = edgeInsertion.targetCell.getBBox();
const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width); const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width);
const requiredSpace = nw + hSpacing * 4; const requiredSpace = nodeWidth + horizontalSpacing * 4;
newX = sourceBBox.x + sourceBBox.width + hSpacing;
newY = targetBBox.y + (targetBBox.height - nh) / 2; // New node x: right after source + spacing
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
// Same row as target node
newY = targetBBox.y + (targetBBox.height - nodeHeight) / 2;
// If not enough space, shift target and all downstream nodes to the right
if (gap < requiredSpace) { if (gap < requiredSpace) {
const shiftX = requiredSpace - gap; const shiftX = requiredSpace - gap;
const visited = new Set<string>(); const visited = new Set<string>();
const shiftDownstream = (cell: any) => { const shiftDownstream = (cell: any) => {
if (visited.has(cell.id)) return; const cellId = cell.id;
visited.add(cell.id); if (visited.has(cellId)) return;
visited.add(cellId);
const pos = cell.getPosition(); const pos = cell.getPosition();
cell.setPosition(pos.x + shiftX, pos.y); cell.setPosition(pos.x + shiftX, pos.y);
// Recursively shift nodes connected from right ports
graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => { graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => {
const tCell = graph.getCellById(e.getTargetCellId()); const tId = e.getTargetCellId();
if (tCell?.isNode()) shiftDownstream(tCell); if (tId && !visited.has(tId)) {
const tCell = graph.getCellById(tId);
if (tCell?.isNode()) shiftDownstream(tCell);
}
}); });
}; };
shiftDownstream(edgeInsertion.targetCell); shiftDownstream(edgeInsertion.targetCell);
@@ -96,170 +203,167 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
} else if (addNodePosition) { } else if (addNodePosition) {
newX = addNodePosition.x; newX = addNodePosition.x;
newY = addNodePosition.y; newY = addNodePosition.y;
} else if (sourcePortGroup === 'left') {
newX = sourceBBox.x - nw * 2 - hSpacing;
newY = sourceBBox.y;
} else { } else {
newX = sourceBBox.x + sourceBBox.width + hSpacing; // Determine node placement direction based on port position
newY = sourceBBox.y; if (sourcePortGroup === 'left') {
const connectedNodes = new Set<string>(); // Left port: add node to the left
graph.getConnectedEdges(sourceNode).forEach((e: any) => { newX = sourceBBox.x - nodeWidth*2 - horizontalSpacing;
[e.getSourceCellId(), e.getTargetCellId()].forEach((cid: string) => { newY = sourceBBox.y;
if (cid !== sourceNode.id) connectedNodes.add(cid); } else {
// Right port: add node to the right
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
newY = sourceBBox.y;
}
// Check if position overlaps with existing nodes (only consider connected nodes)
const checkOverlap = (x: number, y: number) => {
// Get nodes connected to the source node
const connectedNodes = new Set();
graph.getConnectedEdges(sourceNode).forEach((edge: any) => {
const sourceId = edge.getSourceCellId();
const targetId = edge.getTargetCellId();
if (sourceId !== sourceNode.id) connectedNodes.add(sourceId);
if (targetId !== sourceNode.id) connectedNodes.add(targetId);
}); });
});
const checkOverlap = (x: number, y: number) => return graph.getNodes().some((node: any) => {
graph.getNodes().some((n: any) => { if (node.id === sourceNode.id) return false;
if (n.id === sourceNode.id || !connectedNodes.has(n.id)) return false; if (!connectedNodes.has(node.id)) return false; // Only consider connected nodes
const b = n.getBBox(); const bbox = node.getBBox();
return !(x + nw < b.x || x > b.x + b.width || y + nh < b.y || y > b.y + b.height); return !(x + nodeWidth < bbox.x || x > bbox.x + bbox.width ||
y + nodeHeight < bbox.y || y > bbox.y + bbox.height);
}); });
while (checkOverlap(newX, newY)) newY += nh + vSpacing; };
// If position is occupied, search downward for empty space
while (checkOverlap(newX, newY)) {
newY += nodeHeight + verticalSpacing;
}
} }
// Disable history for all graph mutations // Create new node
graph.disableHistory(); const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
// Remove add-node placeholder
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
const cycleId = sourceNodeData.cycle;
graph.getNodes()
.filter((n: any) => n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId)
.forEach((n: any) => n.remove());
}
const id = `${newNodeType.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const newNode = graph.addNode({ const newNode = graph.addNode({
...(graphNodeLibrary[newNodeType] || graphNodeLibrary.default), ...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
x: newX, x: newX,
y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0), y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0),
id, id,
data: { data: {
id, id,
type: newNodeType, type: selectedNodeType.type,
icon: selectedNodeType.icon, icon: selectedNodeType.icon,
name: t(`workflow.${newNodeType}`), name: t(`workflow.${selectedNodeType.type}`),
cycle: sourceNodeData.cycle, cycle: sourceNodeData.cycle, // Inherit cycle from source node
config: selectedNodeType.config || {} config: selectedNodeType.config || {}
}, },
}); });
// Add new node as child of parent node
if (sourceNodeData.cycle) { if (sourceNodeData.cycle) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle); const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle);
if (parentNode) parentNode.addChild(newNode, { silent: true });
}
if (edgeInsertion) {
const { edge: oldEdge } = edgeInsertion;
if (oldEdge.id && graph.getCellById(oldEdge.id)) graph.removeCell(oldEdge.id);
else graph.removeEdge(oldEdge);
}
const newPorts = newNode.getPorts();
const addedCells: any[] = [newNode];
if (edgeInsertion) {
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: newLeftPort }, ...edgeAttrs }));
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: newRightPort }, target: { cell: targetCell.id, port: origTargetPort }, ...edgeAttrs }));
setEdgeInsertion(null);
} else if (sourcePortGroup === 'left') {
const tp = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: tp }, target: { cell: sourceNode.id, port: sourcePort }, ...edgeAttrs }));
} else {
const tp = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: tp }, ...edgeAttrs }));
}
// If adding a loop/iteration node, create cycle-start, add-node and inner edge regardless of source type
if (isCycleContainer(newNodeType)) {
const parentBBox = newNode.getBBox();
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const cycleStartNode = graph.addNode({
...graphNodeLibrary.cycleStart,
x: parentBBox.x + 24,
y: parentBBox.y + 70,
id: cycleStartId,
data: { id: cycleStartId, type: 'cycle-start', parentId: id, isDefault: true, cycle: id },
});
const addNodePlaceholder = graph.addNode({
...graphNodeLibrary.addStart,
x: parentBBox.x + 24 + 84,
y: parentBBox.y + 70 + 4,
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: id, cycle: id },
});
newNode.addChild(cycleStartNode, { silent: true });
newNode.addChild(addNodePlaceholder, { silent: true });
const innerEdge = graph.addEdge({
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find((p: any) => p.group === 'right')?.id || 'right' },
target: { cell: addNodePlaceholder.id, port: addNodePlaceholder.getPorts().find((p: any) => p.group === 'left')?.id || 'left' },
...edgeAttrs,
});
addedCells.push(cycleStartNode, addNodePlaceholder, innerEdge);
}
// Adjust parent size if adding inside a cycle container
const cycleId = sourceNodeData.cycle;
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) { if (parentNode) {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId); parentNode.addChild(newNode);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc: any, child: any) => {
const b = child.getBBox();
return { minX: Math.min(acc.minX, b.x), minY: Math.min(acc.minY, b.y), maxX: Math.max(acc.maxX, b.x + b.width), maxY: Math.max(acc.maxY, b.y + b.height) };
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
parentNode.prop('size', { width: newWidth, height: newHeight });
parentNode.getPorts().forEach((port: any) => {
if (port.group === 'right' && port.args) parentNode.portProp(port.id!, 'args/x', newWidth);
});
}
} }
} }
// toFront // Edge insertion: remove old edge immediately before creating new edges
const bringCycleChildrenToFront = (cycleContainerId: string) => { if (edgeInsertion) {
graph.getEdges().forEach((e: any) => { const { edge: oldEdge } = edgeInsertion;
const src = graph.getCellById(e.getSourceCellId()); if (oldEdge.id && graph.getCellById(oldEdge.id)) {
const tgt = graph.getCellById(e.getTargetCellId()); graph.removeCell(oldEdge.id);
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront(); } else {
}); graph.removeEdge(oldEdge);
graph.getNodes().forEach((n: any) => { if (n.getData()?.cycle === cycleContainerId) n.toFront(); }); }
};
if (isCycleContainer(sourceNodeType)) {
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(sourceNodeData.id);
if (isCycleContainer(newNodeType)) bringCycleChildrenToFront(id);
} else if (isCycleContainer(newNodeType)) {
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(id);
} else {
addedCells.forEach(c => { if (c.isNode?.()) c.toFront(); });
} }
// Re-enable history and manually push one batch frame for all added cells // Create edge connection
graph.enableHistory(); setTimeout(() => {
const history = graph.getPlugin('history') as any; const newPorts = newNode.getPorts();
if (history) {
const batchFrame = addedCells.map((cell: any) => ({
batch: true,
event: 'cell:added',
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
options: {},
}));
history.undoStack.push(batchFrame);
history.redoStack = [];
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-node' } });
}
const addedEdges: any[] = [];
if (edgeInsertion) {
// Edge insertion: create source→new and new→target edges
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedEdges.push(graph.addEdge({
source: { cell: sourceNode.id, port: sourcePort },
target: { cell: newNode.id, port: newLeftPort },
...edgeAttrs
}));
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: newRightPort },
target: { cell: targetCell.id, port: origTargetPort },
...edgeAttrs
}));
setEdgeInsertion(null);
} else if (sourcePortGroup === 'left') {
// Connect from left port to new node's right side
const targetPort = newPorts.find((port: any) => port.group === 'right')?.id || 'right';
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: targetPort },
target: { cell: sourceNode.id, port: sourcePort },
...edgeAttrs
}));
} else {
// Connect from right port to new node's left side
const targetPort = newPorts.find((port: any) => port.group === 'left')?.id || 'left';
addedEdges.push(graph.addEdge({
source: { cell: sourceNode.id, port: sourcePort },
target: { cell: newNode.id, port: targetPort },
...edgeAttrs
}));
}
// Adjust loop node size when child node is added via port within loop node
const cycleId = sourceNodeData.cycle;
if (cycleId) adjustCycleContainerSize(graph, cycleId);
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
const newNodeType = selectedNodeType.type;
// Helper: bring all child nodes and their edges of a cycle container to front
const bringCycleChildrenToFront = (cycleContainerId: string) => {
graph.getEdges().forEach((e: any) => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
});
graph.getNodes().forEach((n: any) => {
if (n.getData()?.cycle === cycleContainerId) n.toFront();
});
};
if (isCycleContainer(sourceNodeType)) {
console.log('isCycleContainer(sourceNodeType)')
// Case 4: source is a loop/iteration node — bring new node to front, then its children
newNode.toFront();
sourceNode.toFront();
bringCycleChildrenToFront(sourceNodeData.id);
} else if (isCycleContainer(newNodeType)) {
console.log('isCycleContainer(newNodeType)')
// Case 3: adding a loop/iteration node from a normal node — bring new node to front, then its children
newNode.toFront();
sourceNode.toFront()
bringCycleChildrenToFront(id);
} else {
// Case 2: normal node → normal node
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
}
}, 50);
// Clean up temporary element
if (tempElement) { if (tempElement) {
document.body.removeChild(tempElement); document.body.removeChild(tempElement);
setTempElement(null); setTempElement(null);
} }
setPopoverVisible(false); setPopoverVisible(false);
}; };
@@ -316,23 +420,19 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
if (!tempElement) return null; if (!tempElement) return null;
return ( return createPortal(
<Popover <Popover
content={content} content={content}
open={popoverVisible} open={popoverVisible}
onOpenChange={(visible) => { onOpenChange={(visible) => { if (!visible) handlePopoverClose(); }}
if (!visible) handlePopoverClose();
}}
placement="right" placement="right"
overlayStyle={{ autoAdjustOverflow
position: 'fixed', getPopupContainer={() => document.body}
left: popoverPosition.x + 10,
top: popoverPosition.y - 10,
}}
> >
<div /> <div style={{ width: '1px', height: '1px' }} />
</Popover> </Popover>,
tempElement
); );
}; };
export default PortClickHandler; export default PortClickHandler;

View File

@@ -242,11 +242,10 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''} className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''}
> >
{parameter.type === 'string' && parameter.enum && parameter.enum.length > 0 {parameter.type === 'string' && parameter.enum && parameter.enum.length > 0
? <Select key={values.tool_id} size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} /> ? <Select size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
: parameter.type === 'boolean' : parameter.type === 'boolean'
? <Switch key={values.tool_id} size="small" /> ? <Switch size="small" />
: <Editor : <Editor
key={values.tool_id}
variant="outlined" variant="outlined"
type="input" type="input"
size="small" size="small"

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 15:06:18 * @Date: 2026-02-03 15:06:18
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-27 14:07:14 * @Last Modified time: 2026-04-21 18:23:31
*/ */
import type { ReactShapeConfig } from '@antv/x6-react-shape'; import type { ReactShapeConfig } from '@antv/x6-react-shape';
import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port'; import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port';
@@ -948,15 +948,6 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
width: nodeWidth, width: nodeWidth,
height: 120, height: 120,
shape: 'notes-node', shape: 'notes-node',
},
output: {
width: nodeWidth,
height: 76,
shape: 'normal-node',
ports: {
groups: { left: defaultPortGroup },
items: [defaultPortItems[0]],
},
} }
} }

View File

@@ -2,13 +2,16 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 15:17:48 * @Date: 2026-02-03 15:17:48
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-28 13:49:11 * @Last Modified time: 2026-04-28 12:07:33
*/ */
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6'; import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
import { register } from '@antv/x6-react-shape'; import type { HistoryCommand as Command } from '@antv/x6/lib/plugin/history/type';
import { register as registerReactShape } from '@antv/x6-react-shape';
import type { PortMetadata } from '@antv/x6/lib/model/port'; import type { PortMetadata } from '@antv/x6/lib/model/port';
import { App } from 'antd'; import { App } from 'antd';
import { useEffect, useRef, useState } from 'react'; import { useEffect, useRef, useState, createElement } from 'react';
import type { RefObject, Dispatch, SetStateAction, MutableRefObject, DragEvent } from 'react';
import { createRoot } from 'react-dom/client';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useParams } from 'react-router-dom'; import { useParams } from 'react-router-dom';
@@ -16,18 +19,20 @@ import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application';
import { useUser } from '@/store/user'; import { useUser } from '@/store/user';
import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types'; import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types';
import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant'; import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant';
import type { ChatVariable, HistoryRecord, NodeProperties, WorkflowConfig } from '../types'; import type { ChatVariable, NodeProperties, WorkflowConfig } from '../types';
import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils'; import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils';
import { useWorkflowStore } from '@/store/workflow'; import { useWorkflowStore } from '@/store/workflow';
const isSafari = /^((?!chrome|android).)*safari/i.test(navigator.userAgent);
/** /**
* Props for useWorkflowGraph hook * Props for useWorkflowGraph hook
*/ */
export interface UseWorkflowGraphProps { export interface UseWorkflowGraphProps {
/** Reference to the main graph container element */ /** Reference to the main graph container element */
containerRef: React.RefObject<HTMLDivElement>; containerRef: RefObject<HTMLDivElement>;
/** Reference to the minimap container element */ /** Reference to the minimap container element */
miniMapRef: React.RefObject<HTMLDivElement>; miniMapRef: RefObject<HTMLDivElement>;
/** Callback when features config is loaded */ /** Callback when features config is loaded */
onFeaturesLoad?: (features: FeaturesConfigForm | undefined) => void; onFeaturesLoad?: (features: FeaturesConfigForm | undefined) => void;
} }
@@ -39,23 +44,23 @@ export interface UseWorkflowGraphReturn {
/** Current workflow configuration */ /** Current workflow configuration */
config: WorkflowConfig | null; config: WorkflowConfig | null;
/** Function to update workflow configuration */ /** Function to update workflow configuration */
setConfig: React.Dispatch<React.SetStateAction<WorkflowConfig | null>>; setConfig: Dispatch<SetStateAction<WorkflowConfig | null>>;
/** Reference to the X6 graph instance */ /** Reference to the X6 graph instance */
graphRef: React.MutableRefObject<Graph | undefined>; graphRef: MutableRefObject<Graph | undefined>;
/** Currently selected node */ /** Currently selected node */
selectedNode: Node | null; selectedNode: Node | null;
/** Function to update selected node */ /** Function to update selected node */
setSelectedNode: React.Dispatch<React.SetStateAction<Node | null>>; setSelectedNode: Dispatch<SetStateAction<Node | null>>;
/** Current zoom level of the graph */ /** Current zoom level of the graph */
zoomLevel: number; zoomLevel: number;
/** Function to update zoom level */ /** Function to update zoom level */
setZoomLevel: React.Dispatch<React.SetStateAction<number>>; setZoomLevel: Dispatch<SetStateAction<number>>;
/** Whether hand/pan mode is enabled */ /** Whether hand/pan mode is enabled */
isHandMode: boolean; isHandMode: boolean;
/** Function to toggle hand mode */ /** Function to toggle hand mode */
setIsHandMode: React.Dispatch<React.SetStateAction<boolean>>; setIsHandMode: Dispatch<SetStateAction<boolean>>;
/** Handler for dropping nodes onto canvas */ /** Handler for dropping nodes onto canvas */
onDrop: (event: React.DragEvent) => void; onDrop: (event: DragEvent) => void;
/** Handler for clicking blank canvas area */ /** Handler for clicking blank canvas area */
blankClick: () => void; blankClick: () => void;
/** Handler for delete keyboard event */ /** Handler for delete keyboard event */
@@ -77,7 +82,7 @@ export interface UseWorkflowGraphReturn {
/** Chat variables for workflow */ /** Chat variables for workflow */
chatVariables: ChatVariable[]; chatVariables: ChatVariable[];
/** Function to update chat variables */ /** Function to update chat variables */
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>; setChatVariables: Dispatch<SetStateAction<ChatVariable[]>>;
handleAddNotes: () => void; handleAddNotes: () => void;
handleSaveFeaturesConfig: (value: FeaturesConfigForm) => void; handleSaveFeaturesConfig: (value: FeaturesConfigForm) => void;
@@ -85,10 +90,6 @@ export interface UseWorkflowGraphReturn {
/** Get start node output variable list (user-defined + system variables) */ /** Get start node output variable list (user-defined + system variables) */
getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>; getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>;
nodeClick: ({ node }: { node: Node }) => void; nodeClick: ({ node }: { node: Node }) => void;
/** All recorded history operations */
historyRecords: HistoryRecord[];
/** Clear history records */
clearHistoryRecords: () => void;
} }
/** /**
@@ -122,19 +123,14 @@ export const useWorkflowGraph = ({
const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined) const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined)
const [canUndo, setCanUndo] = useState(false) const [canUndo, setCanUndo] = useState(false)
const [canRedo, setCanRedo] = useState(false) const [canRedo, setCanRedo] = useState(false)
const [historyRecords, setHistoryRecords] = useState<HistoryRecord[]>([])
const lastHistoryRef = useRef<{ cellIds: string[]; timestamp: number; type: string } | null>(null)
const undoRef = useRef<() => void>(() => {})
const redoRef = useRef<() => void>(() => {})
const syncChildRelationshipsRef = useRef<() => void>(() => {})
const isSyncingRef = useRef(false)
useEffect(() => { useEffect(() => {
if (!graphRef.current) return if (!graphRef.current) return
graphRef.current.getNodes().forEach(node => { graphRef.current.getNodes().forEach(node => {
const data = node.getData() const data = node.getData()
if (data?.type === 'if-else' || data?.type === 'question-classifier') { if (data?.type === 'if-else' || data?.type === 'question-classifier') {
console.log('chatVariables', chatVariables) console.log('chatVariables', chatVariables)
node.setData({ ...data, chatVariables }) node.setData({ ...data, chatVariables }, { silent: true })
} }
}) })
}, [chatVariables]) }, [chatVariables])
@@ -168,6 +164,21 @@ export const useWorkflowGraph = ({
initWorkflow() initWorkflow()
}, [config, graphRef.current]) }, [config, graphRef.current])
/**
* Assign explicit zIndex values to enforce layer order:
* parent nodes (loop/iteration) → child edges → child nodes
* Ports live inside each node's SVG container and are always above
* edges once the node zIndex is higher than the edge zIndex.
*/
const reorderCells = (graph: Graph) => {
// Safari uses x6-html-shape (dual HTML layer architecture).
// zIndex controls order within each HTML layer and SVG layer.
graph.getEdges().forEach(edge => edge.setZIndex(0));
graph.getNodes().forEach(node => {
node.setZIndex(node.getData()?.cycle ? 2 : 1);
});
};
/** /**
* Initialize workflow graph with nodes and edges from configuration * Initialize workflow graph with nodes and edges from configuration
*/ */
@@ -351,7 +362,7 @@ export const useWorkflowGraph = ({
if (parentNode) { if (parentNode) {
const addedChild = graphRef.current?.addNode(childNode) const addedChild = graphRef.current?.addNode(childNode)
if (addedChild) { if (addedChild) {
parentNode.addChild(addedChild, { silent: true }) parentNode.addChild(addedChild)
} }
} }
} }
@@ -382,6 +393,8 @@ export const useWorkflowGraph = ({
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2) const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight) const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
console.log('newWidth', newHeight, newWidth)
parentNode.prop('size', { width: newWidth, height: newHeight }) parentNode.prop('size', { width: newWidth, height: newHeight })
// Update x position of right group ports // Update x position of right group ports
@@ -476,95 +489,30 @@ export const useWorkflowGraph = ({
if (nodes.length > 0 || edges.length > 0) { if (nodes.length > 0 || edges.length > 0) {
setTimeout(() => { setTimeout(() => {
if (graphRef.current) { if (graphRef.current) {
graphRef.current.getNodes().forEach(node => { if (isSafari) {
if (!node.getData()?.cycle) node.toFront(); reorderCells(graphRef.current)
}); } else {
// Bring edges to front first, then child nodes above edges; parent nodes stay behind graphRef.current.getNodes().forEach(node => {
graphRef.current.getEdges().forEach(edge => { if (!node.getData()?.cycle) node.toFront();
const sourceCell = graphRef.current?.getCellById(edge.getSourceCellId()); });
const targetCell = graphRef.current?.getCellById(edge.getTargetCellId()); // Bring edges to front first, then child nodes above edges; parent nodes stay behind
if (sourceCell?.getData()?.cycle || targetCell?.getData()?.cycle) { graphRef.current.getEdges().forEach(edge => {
edge.toFront(); const sourceCell = graphRef.current?.getCellById(edge.getSourceCellId());
} const targetCell = graphRef.current?.getCellById(edge.getTargetCellId());
}); if (sourceCell?.getData()?.cycle || targetCell?.getData()?.cycle) {
graphRef.current.getNodes().forEach(node => { edge.toFront();
if (node.getData()?.cycle) node.toFront(); }
}); });
graphRef.current.getNodes().forEach(node => {
if (node.getData()?.cycle) node.toFront();
});
}
graphRef.current.enableHistory() graphRef.current.enableHistory()
graphRef.current.cleanHistory() graphRef.current.cleanHistory()
} }
}, 200) }, isSafari ? 0 : 200)
} else {
graphRef.current.enableHistory()
graphRef.current.cleanHistory()
} }
} }
const resizeGroupNodes = (graph: Graph) => {
graph.getNodes().forEach(parentNode => {
const parentType = parentNode.getData()?.type
if (parentType !== 'loop' && parentType !== 'iteration') return
const children = graph.getNodes().filter(
n => n.getData()?.cycle === parentNode.getData()?.id && n.getData()?.type !== 'add-node'
)
if (!children.length) return
const padding = 24
const headerHeight = 50
const childBounds = children.map(c => c.getBBox())
const minX = Math.min(...childBounds.map(b => b.x))
const minY = Math.min(...childBounds.map(b => b.y))
const maxX = Math.max(...childBounds.map(b => b.x + b.width))
const maxY = Math.max(...childBounds.map(b => b.y + b.height))
const parentBBox = parentNode.getBBox()
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
parentNode.prop('size', { width: newWidth, height: newHeight })
parentNode.getPorts().forEach(port => {
if (port.group === 'right' && port.args) {
parentNode.portProp(port.id!, 'args/x', newWidth)
}
})
})
}
const syncChildRelationships = () => {
if (!graphRef.current) return
const graph = graphRef.current
graph.disableHistory()
graph.getNodes().forEach(node => {
const cycleId = node.getData()?.cycle
if (!cycleId) return
const parentNode = graph.getCellById(cycleId) as Node | null
if (!parentNode) return
if (!parentNode.getChildren()?.some(c => c.id === node.id)) {
parentNode.addChild(node, { silent: true })
}
})
graph.getNodes().forEach(node => {
const children = node.getChildren()
if (!children?.length) return
children.forEach(child => {
if (!child.isNode()) return
const childCycleId = (child as Node).getData?.()?.cycle
if (childCycleId !== node.id && childCycleId !== node.getData?.()?.id) {
node.removeChild(child, { silent: true })
}
})
})
resizeGroupNodes(graph)
graph.getEdges().forEach(edge => {
const src = graph.getCellById(edge.getSourceCellId())
const tgt = graph.getCellById(edge.getTargetCellId())
if (src?.getData()?.cycle || tgt?.getData()?.cycle) {
edge.toFront()
}
})
graph.getNodes().forEach(node => {
if (node.getData()?.cycle) node.toFront()
})
graph.enableHistory()
}
syncChildRelationshipsRef.current = syncChildRelationships
/** /**
* Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard) * Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard)
*/ */
@@ -600,44 +548,18 @@ export const useWorkflowGraph = ({
new History({ new History({
enabled: false, enabled: false,
beforeAddCommand(_event, args: any) { beforeAddCommand(_event, args: any) {
const key = args?.key const event = args?.key ? `cell:change:${args.key}` : _event;
if (key === 'attrs' || key === 'tools') return false if (event.startsWith('cell:change:') &&
event !== 'cell:change:position' &&
event !== 'cell:change:source' &&
event !== 'cell:change:target') return false;
}, },
}), }),
); );
const MERGE_INTERVAL = 1000 graphRef.current.on('history:change', ({ cmds }: { cmds: Command[] }) => {
graphRef.current.on('history:change', ({ cmds, options }: { cmds: any[]; options: any }) => {
setCanUndo(graphRef.current?.canUndo() ?? false) setCanUndo(graphRef.current?.canUndo() ?? false)
setCanRedo(graphRef.current?.canRedo() ?? false) setCanRedo(graphRef.current?.canRedo() ?? false)
console.log('history:change', cmds, options)
const batchName: string | undefined = options?.name
const actionType = batchName === 'undo' ? 'undo' : batchName === 'redo' ? 'redo' : batchName ? 'batch' : 'change'
const cellIds = [...new Set(cmds?.map((cmd: any) => cmd.data?.id).filter(Boolean))]
const now = Date.now()
const last = lastHistoryRef.current
const canMerge =
actionType === 'change' &&
last?.type === 'change' &&
now - last.timestamp < MERGE_INTERVAL &&
cellIds.length > 0 &&
cellIds.length === last.cellIds.length &&
cellIds.every((id, i) => id === last.cellIds[i])
if (canMerge) {
lastHistoryRef.current!.timestamp = now
setHistoryRecords(prev => {
const next = [...prev]
next[next.length - 1] = { ...next[next.length - 1], timestamp: now }
return next
})
} else {
const record: HistoryRecord = { type: actionType, timestamp: now, batchName, cellIds }
lastHistoryRef.current = { cellIds, timestamp: now, type: actionType }
setHistoryRecords(prev => [...prev, record])
}
}) })
graphRef.current.on('history:undo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
graphRef.current.on('history:redo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
}; };
// 显示/隐藏连接桩 // 显示/隐藏连接桩
// const showPorts = (show: boolean) => { // const showPorts = (show: boolean) => {
@@ -652,12 +574,33 @@ export const useWorkflowGraph = ({
* @param node - Clicked node * @param node - Clicked node
*/ */
const nodeClick = ({ node }: { node: Node }) => { const nodeClick = ({ node }: { node: Node }) => {
// add-node type: dispatch port:click to open node selection popover
// Must handle before blankClick() to avoid blank:click closing the popover immediately
const nodeData = node.getData()
if (nodeData?.type === 'add-node') {
const bbox = node.getBBox();
const screenPos = graphRef.current!.localToClient(bbox.x + bbox.width, bbox.y + bbox.height / 2);
const tempDiv = document.createElement('div');
tempDiv.style.cssText = `position:fixed;left:${screenPos.x}px;top:${screenPos.y}px;width:1px;height:1px;z-index:9999;`;
document.body.appendChild(tempDiv);
window.dispatchEvent(new CustomEvent('port:click', {
detail: {
node,
port: 'right',
element: tempDiv,
rect: { left: screenPos.x, top: screenPos.y },
edgeInsertion: null,
},
}));
return;
}
blankClick() blankClick()
setTimeout(() => { setTimeout(() => {
// Ignore add-node type node clicks // Ignore add-node type node clicks
const nodeData = node.getData() const nodeData = node.getData()
if (nodeData?.type === 'add-node' || nodeData.type === 'break' || nodeData.type === 'cycle-start') { if (nodeData.type === 'break' || nodeData.type === 'cycle-start') {
setSelectedNode(null) setSelectedNode(null)
return; return;
} }
@@ -670,13 +613,13 @@ export const useWorkflowGraph = ({
vo.setData({ vo.setData({
...data, ...data,
isSelected: false, isSelected: false,
}, { silent: true }); });
} }
}); });
node.setData({ node.setData({
...nodeData, ...nodeData,
isSelected: true, isSelected: true,
}, { silent: true }); });
clearEdgeSelect() clearEdgeSelect()
if (nodeData.type !== 'notes') { if (nodeData.type !== 'notes') {
setSelectedNode(node); setSelectedNode(node);
@@ -690,7 +633,7 @@ export const useWorkflowGraph = ({
const edgeClick = ({ edge }: { edge: Edge }) => { const edgeClick = ({ edge }: { edge: Edge }) => {
clearEdgeSelect(); clearEdgeSelect();
edge.setAttrByPath('line/stroke', edge_selected_color); edge.setAttrByPath('line/stroke', edge_selected_color);
edge.setData({ ...edge.getData(), isSelected: true }, { silent: true }); edge.setData({ ...edge.getData(), isSelected: true });
clearNodeSelect(); clearNodeSelect();
}; };
/** /**
@@ -705,7 +648,7 @@ export const useWorkflowGraph = ({
node.setData({ node.setData({
...data, ...data,
isSelected: false, isSelected: false,
}, { silent: true }); });
} }
}); });
setSelectedNode(null); setSelectedNode(null);
@@ -715,7 +658,7 @@ export const useWorkflowGraph = ({
*/ */
const clearEdgeSelect = () => { const clearEdgeSelect = () => {
graphRef.current?.getEdges().forEach(e => { graphRef.current?.getEdges().forEach(e => {
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false }, { silent: true }); e.setData({ ...e.getData(), isSelected: false, isNodeHover: false });
e.setAttrByPath('line/stroke', edge_color); e.setAttrByPath('line/stroke', edge_color);
e.setAttrByPath('line/strokeWidth', edge_width); e.setAttrByPath('line/strokeWidth', edge_width);
}); });
@@ -745,7 +688,8 @@ export const useWorkflowGraph = ({
const cycle = node.getData()?.cycle; const cycle = node.getData()?.cycle;
if (cycle) { if (cycle) {
const parentNode = graphRef.current!.getNodes().find(n => n.id === cycle); const parentNode = graphRef.current!.getNodes().find(n => n.id === cycle);
if (parentNode?.getData()?.isGroup) { const parentType = parentNode?.getData()?.type;
if (parentNode && (parentType === 'loop' || parentType === 'iteration')) {
// Get parent node and child node bounding boxes // Get parent node and child node bounding boxes
const parentBBox = parentNode.getBBox(); const parentBBox = parentNode.getBBox();
const childBBox = node.getBBox(); const childBBox = node.getBBox();
@@ -854,6 +798,8 @@ export const useWorkflowGraph = ({
// Find corresponding parent node // Find corresponding parent node
const parentNode = nodes?.find(n => n.id === nodeData.cycle); const parentNode = nodes?.find(n => n.id === nodeData.cycle);
if (parentNode) { if (parentNode) {
// Use removeChild method to delete child node
parentNode.removeChild(nodeToDelete);
parentNodesToUpdate.push(parentNode); parentNodesToUpdate.push(parentNode);
} }
// Add child node to deletion list // Add child node to deletion list
@@ -881,51 +827,42 @@ export const useWorkflowGraph = ({
// Delete all collected nodes and edges // Delete all collected nodes and edges
if (cells.length > 0) { if (cells.length > 0) {
// Pre-calculate which parents need an add-node restored (before removal changes the graph)
const parentsNeedingAddNode = parentNodesToUpdate
.filter(parentNode => {
const parentShape = parentNode.shape;
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return false;
const parentData = parentNode.getData();
const allChildren = graphRef.current!.getNodes().filter(n => n.getData()?.cycle === parentData.id);
const cycleStartNodes = allChildren.filter(n => n.getData()?.type === 'cycle-start');
// After deletion, only cycle-start will remain
const nonCycleStartToDelete = cells.filter(c =>
c.isNode() &&
(c as Node).getData()?.cycle === parentData.id &&
(c as Node).getData()?.type !== 'cycle-start'
);
return cycleStartNodes.length === 1 && (allChildren.length - nonCycleStartToDelete.length) === 1;
})
.map(parentNode => ({
parentNode,
cycleStartNode: graphRef.current!.getNodes().find(
n => n.getData()?.cycle === parentNode.getData().id && n.getData()?.type === 'cycle-start'
)!
}))
.filter(({ cycleStartNode }) => !!cycleStartNode);
graphRef.current?.startBatch('delete');
graphRef.current?.removeCells(cells); graphRef.current?.removeCells(cells);
parentsNeedingAddNode.forEach(({ parentNode, cycleStartNode }) => { // If parent is iteration/loop and only cycle-start remains, add add-node connected to it
parentNodesToUpdate.forEach(parentNode => {
const parentShape = parentNode.shape;
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return;
const parentData = parentNode.getData(); const parentData = parentNode.getData();
const bbox = cycleStartNode.getBBox(); const remainingChildren = graphRef.current!.getNodes().filter(
const addNode = graphRef.current!.addNode({ n => n.getData()?.cycle === parentData.id
...graphNodeLibrary.addStart, );
x: bbox.x + 84, const cycleStartNodes = remainingChildren.filter(n => n.getData()?.type === 'cycle-start');
y: bbox.y + 4, if (cycleStartNodes.length === 1 && remainingChildren.length === 1) {
data: { type: 'add-node', parentId: parentNode.id, cycle: parentData.id, label: t('workflow.addNode'), icon: '+' }, const cycleStartNode = cycleStartNodes[0];
}); const bbox = cycleStartNode.getBBox();
parentNode.addChild(addNode, { silent: true }); const addNode = graphRef.current!.addNode({
graphRef.current!.addEdge({ ...graphNodeLibrary.addStart,
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' }, x: bbox.x + 84,
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' }, y: bbox.y + 4,
...edgeAttrs, data: {
}); type: 'add-node',
parentId: parentNode.id,
cycle: parentData.id,
label: t('workflow.addNode'),
icon: '+',
},
});
parentNode.addChild(addNode);
const sourcePort = cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right';
const targetPort = addNode.getPorts().find(p => p.group === 'left')?.id || 'left';
graphRef.current!.addEdge({
source: { cell: cycleStartNode.id, port: sourcePort },
target: { cell: addNode.id, port: targetPort },
...edgeAttrs,
});
}
}); });
graphRef.current?.stopBatch('delete');
} }
return false; return false;
}; };
@@ -965,13 +902,35 @@ export const useWorkflowGraph = ({
/** /**
* Initialize X6 graph with configuration and event listeners * Initialize X6 graph with configuration and event listeners
*/ */
const init = () => { const init = async () => {
if (!containerRef.current || !miniMapRef.current) return; if (!containerRef.current || !miniMapRef.current) return;
// Register React shapes // Register React shapes
nodeRegisterLibrary.forEach((item) => { // Safari: use x6-html-shape to avoid foreignObject rendering issues
register(item); if (isSafari) {
}); const { register: registerHtmlShape } = await import('x6-html-shape');
nodeRegisterLibrary.forEach(({ shape, width, height, component }) => {
registerHtmlShape({
shape,
width,
height,
render(node: Node, _graph: unknown, container: HTMLElement) {
const root = createRoot(container);
const doRender = () => {
root.render(createElement(component as any, { node, graph: node.model?.graph, data: node.getData() }));
};
doRender();
node.on('change:data', doRender);
return () => {
node.off('change:data', doRender);
root.unmount();
};
},
});
});
} else {
nodeRegisterLibrary.forEach((item) => registerReactShape(item));
}
const container = containerRef.current; const container = containerRef.current;
graphRef.current = new Graph({ graphRef.current = new Graph({
@@ -1144,7 +1103,7 @@ export const useWorkflowGraph = ({
graphRef.current?.getConnectedEdges(node).forEach(edge => { graphRef.current?.getConnectedEdges(node).forEach(edge => {
if (!edge.getData()?.isSelected) { if (!edge.getData()?.isSelected) {
edge.setAttrByPath('line/stroke', edge_selected_color); edge.setAttrByPath('line/stroke', edge_selected_color);
edge.setData({ ...edge.getData(), isNodeHover: true }, { silent: true }); edge.setData({ ...edge.getData(), isNodeHover: true });
} }
}); });
}); });
@@ -1152,7 +1111,7 @@ export const useWorkflowGraph = ({
graphRef.current?.getConnectedEdges(node).forEach(edge => { graphRef.current?.getConnectedEdges(node).forEach(edge => {
if (!edge.getData()?.isSelected) { if (!edge.getData()?.isSelected) {
edge.setAttrByPath('line/stroke', edge_color); edge.setAttrByPath('line/stroke', edge_color);
edge.setData({ ...edge.getData(), isNodeHover: false }, { silent: true }); edge.setData({ ...edge.getData(), isNodeHover: false });
} }
}); });
}); });
@@ -1161,10 +1120,71 @@ export const useWorkflowGraph = ({
// Listen to node move event // Listen to node move event
graphRef.current.on('node:moved', nodeMoved); graphRef.current.on('node:moved', nodeMoved);
if (isSafari) {
// When a parent (loop/iteration) node moves, keep child nodes in sync.
// Store each child's offset relative to the parent at drag start, then
// reapply it every frame to avoid cumulative delta errors.
const dragOffsets = new Map<string, { dx: number; dy: number }>();
graphRef.current.on('node:moving', ({ node }: { node: Node }) => {
const data = node.getData();
if (data?.type !== 'loop' && data?.type !== 'iteration') return;
const pos = node.getPosition();
const PORT_RADIUS = 6;
// Update parent componentContainer directly
const parentView = graphRef.current?.findViewByCell(node) as any;
if (parentView?.componentContainer) {
parentView.componentContainer.style.transform =
`translate(${pos.x + PORT_RADIUS}px, ${pos.y}px)`;
}
const children = graphRef.current?.getNodes().filter(child => {
const cycle = child.getData()?.cycle;
return cycle === data.id || cycle === node.id;
}) ?? [];
// First event for this drag: record offsets
if (!dragOffsets.has(node.id)) {
children.forEach(child => {
const cp = child.getPosition();
dragOffsets.set(child.id, { dx: cp.x - pos.x, dy: cp.y - pos.y });
});
}
// Apply stored offsets to keep children in place relative to parent
children.forEach(child => {
const off = dragOffsets.get(child.id);
if (!off) return;
const nx = pos.x + off.dx;
const ny = pos.y + off.dy;
child.setPosition(nx, ny);
const childView = graphRef.current?.findViewByCell(child) as any;
if (childView?.componentContainer) {
childView.componentContainer.style.transform =
`translate(${nx + PORT_RADIUS}px, ${ny}px)`;
}
});
});
graphRef.current.on('node:moved', ({ node }: { node: Node }) => {
// Clear offsets for this parent and all its children
const data = node.getData();
graphRef.current?.getNodes().forEach(child => {
const cycle = child.getData()?.cycle;
if (cycle === data?.id || cycle === node.id) dragOffsets.delete(child.id);
});
dragOffsets.delete(node.id);
nodeMoved({ node });
});
}
graphRef.current.on('node:removed', blankClick) graphRef.current.on('node:removed', blankClick)
// When edge connected, bring connected nodes' ports to front // When edge connected, reorder all cells to maintain correct layer order
graphRef.current.on('edge:connected', ({ isNew, edge }) => { graphRef.current.on('edge:connected', ({ isNew, edge }) => {
if (isNew) { if (isSafari && isNew && graphRef.current) {
reorderCells(graphRef.current);
} else if (!isSafari && isNew) {
const sourceCellId = edge.getSourceCellId() const sourceCellId = edge.getSourceCellId()
const targetCellId = edge.getTargetCellId() const targetCellId = edge.getTargetCellId()
const sourceCell = graphRef.current?.getCellById(sourceCellId); const sourceCell = graphRef.current?.getCellById(sourceCellId);
@@ -1234,8 +1254,8 @@ export const useWorkflowGraph = ({
// Delete selected nodes and edges // Delete selected nodes and edges
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent); graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
// Undo / Redo // Undo / Redo
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { undo(); return false; }); graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { graphRef.current?.undo(); return false; });
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { redo(); return false; }); graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { graphRef.current?.redo(); return false; });
}; };
@@ -1278,7 +1298,7 @@ export const useWorkflowGraph = ({
* Creates new node at drop position * Creates new node at drop position
* @param event - React drag event * @param event - React drag event
*/ */
const onDrop = (event: React.DragEvent) => { const onDrop = (event: DragEvent) => {
if (!graphRef.current) return; if (!graphRef.current) return;
event.preventDefault(); event.preventDefault();
const dragData = JSON.parse(event.dataTransfer.getData('application/json')); const dragData = JSON.parse(event.dataTransfer.getData('application/json'));
@@ -1301,51 +1321,13 @@ export const useWorkflowGraph = ({
}; };
if (dragData.type === 'loop' || dragData.type === 'iteration') { if (dragData.type === 'loop' || dragData.type === 'iteration') {
graph.disableHistory() graphRef.current.addNode({
const parentNode = graphRef.current.addNode({
...graphNodeLibrary[dragData.type], ...graphNodeLibrary[dragData.type],
x: point.x - 150, x: point.x - 150,
y: point.y - 100, y: point.y - 100,
id: cleanNodeData.id, id: cleanNodeData.id,
data: { ...cleanNodeData, isGroup: true }, data: { ...cleanNodeData, isGroup: true },
}) });
const parentBBox = parentNode.getBBox()
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const cycleStartNode = graphRef.current.addNode({
...graphNodeLibrary.cycleStart,
x: parentBBox.x + 24,
y: parentBBox.y + 70,
id: cycleStartId,
data: { id: cycleStartId, type: 'cycle-start', parentId: cleanNodeData.id, isDefault: true, cycle: cleanNodeData.id },
})
const addNode = graphRef.current.addNode({
...graphNodeLibrary.addStart,
x: parentBBox.x + 24 + 84,
y: parentBBox.y + 70 + 4,
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: cleanNodeData.id, cycle: cleanNodeData.id },
})
parentNode.addChild(cycleStartNode, { silent: true })
parentNode.addChild(addNode, { silent: true })
const newEdge = graphRef.current.addEdge({
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
...edgeAttrs,
})
cycleStartNode.toFront()
addNode.toFront()
graph.enableHistory()
// Manually push a single batch frame covering all 4 cells into undoStack
const history = graph.getPlugin('history') as History
const makeBatchCmd = (cell: any) => ({
batch: true,
event: 'cell:added',
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
options: {},
})
const batchFrame = [parentNode, cycleStartNode, addNode, newEdge].map(makeBatchCmd)
;(history as any).undoStack.push(batchFrame)
;(history as any).redoStack = []
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-group' } })
} else if (dragData.type === 'if-else') { } else if (dragData.type === 'if-else') {
// Create condition node // Create condition node
graphRef.current.addNode({ graphRef.current.addNode({
@@ -1592,80 +1574,8 @@ export const useWorkflowGraph = ({
return userVars return userVars
} }
const clearHistoryRecords = () => { const undo = () => graphRef.current?.undo()
setHistoryRecords([]) const redo = () => graphRef.current?.redo()
lastHistoryRef.current = null
}
const getStackCellIds = (cmds: any): string[] => {
const arr = Array.isArray(cmds) ? cmds : [cmds]
return [...new Set(arr.map((c: any) => c.data?.id).filter(Boolean))]
}
const isSkippableFrame = (frame: any): boolean => {
const arr = Array.isArray(frame) ? frame : [frame]
return arr.every((c: any) => ['zIndex', 'attrs', 'tools'].includes(c.data?.key))
}
const undo = () => {
const history = graphRef.current?.getPlugin('history') as History | undefined
if (!history || history.getUndoSize() === 0) return
const undoStack = (history as any).undoStack as any[]
isSyncingRef.current = true
while (undoStack.length > 0 && isSkippableFrame(undoStack[undoStack.length - 1])) {
graphRef.current!.undo()
}
if (undoStack.length === 0) {
isSyncingRef.current = false
return
}
const topIds = getStackCellIds(undoStack[undoStack.length - 1])
graphRef.current!.undo()
while (undoStack.length > 0) {
if (isSkippableFrame(undoStack[undoStack.length - 1])) {
graphRef.current!.undo()
continue
}
const nextIds = getStackCellIds(undoStack[undoStack.length - 1])
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
graphRef.current!.undo()
} else {
break
}
}
isSyncingRef.current = false
syncChildRelationships()
}
const redo = () => {
const history = graphRef.current?.getPlugin('history') as History | undefined
if (!history || history.getRedoSize() === 0) return
const redoStack = (history as any).redoStack as any[]
isSyncingRef.current = true
while (redoStack.length > 0 && isSkippableFrame(redoStack[redoStack.length - 1])) {
graphRef.current!.redo()
}
if (redoStack.length === 0) {
isSyncingRef.current = false
return
}
const topIds = getStackCellIds(redoStack[redoStack.length - 1])
graphRef.current!.redo()
while (redoStack.length > 0) {
if (isSkippableFrame(redoStack[redoStack.length - 1])) {
graphRef.current!.redo()
continue
}
const nextIds = getStackCellIds(redoStack[redoStack.length - 1])
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
graphRef.current!.redo()
} else {
break
}
}
isSyncingRef.current = false
syncChildRelationships()
}
const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => { const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => {
const { statement = '' } = value?.opening_statement || {} const { statement = '' } = value?.opening_statement || {}
@@ -1706,16 +1616,20 @@ export const useWorkflowGraph = ({
if (!graphRef.current) return; if (!graphRef.current) return;
const nodes = graphRef.current.getNodes(); const nodes = graphRef.current.getNodes();
// Reset all node execution status on every chatHistory change const lastWithSub = [...chatHistory].reverse().find(item => item.subContent?.length);
// Reset all node execution status first
nodes.forEach(node => { nodes.forEach(node => {
const data = node.getData(); const data = node.getData();
node.setData({ ...data, executionStatus: '' }); if (typeof data.status === 'string') {
node.setData({ ...data, executionStatus: undefined });
}
}); });
if (!lastWithSub?.subContent) return;
const lastAssistant = [...chatHistory].reverse().find(item => item.role === 'assistant'); // Build a nodeId -> status map first
if (!lastAssistant?.subContent?.length) return; const statusMap: Record<string, string> = {};
lastAssistant.subContent.forEach(sub => { lastWithSub.subContent.forEach(sub => {
if (typeof sub.status === 'string') { if (typeof sub.status === 'string') {
statusMap[sub.node_id] = sub.status;
const node = nodes.find(n => n.getData()?.id === sub.node_id); const node = nodes.find(n => n.getData()?.id === sub.node_id);
if (node) { if (node) {
node.setData({ ...node.getData(), executionStatus: sub.status }); node.setData({ ...node.getData(), executionStatus: sub.status });
@@ -1751,7 +1665,5 @@ export const useWorkflowGraph = ({
canRedo, canRedo,
undo, undo,
redo, redo,
historyRecords,
clearHistoryRecords,
}; };
}; };

View File

@@ -113,13 +113,4 @@ export interface ChatVariable {
} }
export interface AddChatVariableRef { export interface AddChatVariableRef {
handleOpen: (value?: ChatVariable) => void; handleOpen: (value?: ChatVariable) => void;
}
export type HistoryActionType = 'add' | 'remove' | 'change' | 'undo' | 'redo' | 'batch'
export interface HistoryRecord {
type: HistoryActionType;
timestamp: number;
batchName?: string;
cellIds?: string[];
} }

View File

@@ -17,7 +17,6 @@ export const isSubExprSet = (sub: any) => {
* Uses the same per-expression height logic as getConditionNodeCasePortY. * Uses the same per-expression height logic as getConditionNodeCasePortY.
*/ */
export const calcConditionNodeTotalHeight = (cases: any[]) => { export const calcConditionNodeTotalHeight = (cases: any[]) => {
if (!cases?.length) return conditionNodeHeight;
const casesHeight = cases.reduce((acc: number, c: any) => { const casesHeight = cases.reduce((acc: number, c: any) => {
const exprs = c?.expressions ?? []; const exprs = c?.expressions ?? [];
const n = exprs.length; const n = exprs.length;

View File

@@ -44,6 +44,9 @@ export default defineConfig({
resolve: { resolve: {
alias: { alias: {
'@': resolve(__dirname, 'src'), '@': resolve(__dirname, 'src'),
'x6-html-shape': resolve(__dirname, 'src/vendor/x6-html-shape/index.js'),
'x6-html-shape/dist/react': resolve(__dirname, 'src/vendor/x6-html-shape/react.js'),
'x6-html-shape/dist/utils.js': resolve(__dirname, 'src/vendor/x6-html-shape/utils.js'),
}, },
}, },
base: './', // 使用相对路径,确保资源能正确加载 base: './', // 使用相对路径,确保资源能正确加载