Compare commits

..

1 Commits

Author SHA1 Message Date
zhaoying
8476f3b7a8 feat(web): workflow Safari browser compatibility 2026-04-28 12:12:19 +08:00
54 changed files with 1006 additions and 2202 deletions

View File

@@ -17,7 +17,6 @@ def _mask_url(url: str) -> str:
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
# macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')

View File

@@ -1,500 +0,0 @@
import hashlib
import json
import os
import socket
import threading
import time
import uuid
import redis
from app.core.config import settings
from app.core.logging_config import get_named_logger
from app.celery_app import celery_app
logger = get_named_logger("task_scheduler")
# per-user queue scheduler:uq:{user_id}
USER_QUEUE_PREFIX = "scheduler:uq:"
# User Collection of Pending Messages
ACTIVE_USERS = "scheduler:active_users"
# Set of users that can dispatch (ready signal)
READY_SET = "scheduler:ready_users"
# Metadata of tasks that have been dispatched and are pending completion
PENDING_HASH = "scheduler:pending_tasks"
# Dynamic Sharding: Instance Registry
REGISTRY_KEY = "scheduler:instances"
TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded
HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds)
INSTANCE_TTL = 30 # Instance timeout (seconds)
LUA_ATOMIC_LOCK = """
local dispatch_lock = KEYS[1]
local lock_key = KEYS[2]
local instance_id = ARGV[1]
local dispatch_ttl = tonumber(ARGV[2])
local lock_ttl = tonumber(ARGV[3])
if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then
return 0
end
if redis.call('EXISTS', lock_key) == 1 then
redis.call('DEL', dispatch_lock)
return -1
end
redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl)
return 1
"""
LUA_SAFE_DELETE = """
if redis.call('GET', KEYS[1]) == ARGV[1] then
return redis.call('DEL', KEYS[1])
end
return 0
"""
def stable_hash(value: str) -> int:
return int.from_bytes(
hashlib.md5(value.encode("utf-8")).digest(),
"big"
)
def health_check_server(scheduler_ref):
import uvicorn
from fastapi import FastAPI
health_app = FastAPI()
@health_app.get("/")
def health():
return scheduler_ref.health()
port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001"))
threading.Thread(
target=uvicorn.run,
kwargs={
"app": health_app,
"host": "0.0.0.0",
"port": port,
"log_config": None,
},
daemon=True,
).start()
logger.info("[Health] Server started at http://0.0.0.0:%s", port)
class RedisTaskScheduler:
def __init__(self):
self.redis = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB_CELERY_BACKEND,
password=settings.REDIS_PASSWORD,
decode_responses=True,
)
self.running = False
self.dispatched = 0
self.errors = 0
self.instance_id = f"{socket.gethostname()}-{os.getpid()}"
self._shard_index = 0
self._shard_count = 1
self._last_heartbeat = 0.0
def push_task(self, task_name, user_id, params):
try:
msg_id = str(uuid.uuid4())
msg = json.dumps({
"msg_id": msg_id,
"task_name": task_name,
"user_id": user_id,
"params": json.dumps(params),
})
lock_key = f"{task_name}:{user_id}"
queue_key = f"{USER_QUEUE_PREFIX}{user_id}"
pipe = self.redis.pipeline()
pipe.rpush(queue_key, msg)
pipe.sadd(ACTIVE_USERS, user_id)
pipe.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "QUEUED", "task_id": None}),
ex=86400,
)
pipe.execute()
if not self.redis.exists(lock_key):
self.redis.sadd(READY_SET, user_id)
logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id)
return msg_id
except Exception as e:
logger.error("Push task exception %s", e, exc_info=True)
raise
def get_task_status(self, msg_id: str) -> dict:
raw = self.redis.get(f"task_tracker:{msg_id}")
if raw is None:
return {"status": "NOT_FOUND"}
tracker = json.loads(raw)
status = tracker["status"]
task_id = tracker.get("task_id")
result_content = tracker.get("result") or {}
if status == "DISPATCHED" and task_id:
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
if result_raw:
result_data = json.loads(result_raw)
status = result_data.get("status", status)
result_content = result_data.get("result")
return {"status": status, "task_id": task_id, "result": result_content}
def _cleanup_finished(self):
pending = self.redis.hgetall(PENDING_HASH)
if not pending:
return
now = time.time()
task_ids = list(pending.keys())
pipe = self.redis.pipeline()
for task_id in task_ids:
pipe.get(f"celery-task-meta-{task_id}")
results = pipe.execute()
cleanup_pipe = self.redis.pipeline()
has_cleanup = False
ready_user_ids = set()
for task_id, raw_result in zip(task_ids, results):
try:
meta = json.loads(pending[task_id])
lock_key = meta["lock_key"]
dispatched_at = meta.get("dispatched_at", 0)
age = now - dispatched_at
should_cleanup = False
result_data = {}
if raw_result is not None:
result_data = json.loads(raw_result)
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
should_cleanup = True
logger.info(
"Task finished: %s state=%s", task_id,
result_data.get("status"),
)
elif age > TASK_TIMEOUT:
should_cleanup = True
logger.warning(
"Task expired or lost: %s age=%.0fs, force cleanup",
task_id, age,
)
if should_cleanup:
final_status = (
result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
)
self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id)
cleanup_pipe.hdel(PENDING_HASH, task_id)
tracker_msg_id = meta.get("msg_id")
if tracker_msg_id:
cleanup_pipe.set(
f"task_tracker:{tracker_msg_id}",
json.dumps({
"status": final_status,
"task_id": task_id,
"result": result_data.get("result") or {},
}),
ex=86400,
)
has_cleanup = True
parts = lock_key.split(":", 1)
if len(parts) == 2:
ready_user_ids.add(parts[1])
except Exception as e:
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
self.errors += 1
if has_cleanup:
cleanup_pipe.execute()
if ready_user_ids:
self.redis.sadd(READY_SET, *ready_user_ids)
def _heartbeat(self):
now = time.time()
if now - self._last_heartbeat < HEARTBEAT_INTERVAL:
return
self._last_heartbeat = now
self.redis.hset(REGISTRY_KEY, self.instance_id, str(now))
all_instances = self.redis.hgetall(REGISTRY_KEY)
alive = []
dead = []
for iid, ts in all_instances.items():
if now - float(ts) < INSTANCE_TTL:
alive.append(iid)
else:
dead.append(iid)
if dead:
pipe = self.redis.pipeline()
for iid in dead:
pipe.hdel(REGISTRY_KEY, iid)
pipe.execute()
logger.info("Cleaned dead instances: %s", dead)
alive.sort()
self._shard_count = max(len(alive), 1)
self._shard_index = (
alive.index(self.instance_id) if self.instance_id in alive else 0
)
logger.debug(
"Shard: %s/%s (instance=%s, alive=%d)",
self._shard_index, self._shard_count,
self.instance_id, len(alive),
)
def _is_mine(self, user_id: str) -> bool:
if self._shard_count <= 1:
return True
return stable_hash(user_id) % self._shard_count == self._shard_index
def _dispatch(self, msg_id, msg_data) -> bool:
user_id = msg_data["user_id"]
task_name = msg_data["task_name"]
params = json.loads(msg_data.get("params", "{}"))
lock_key = f"{task_name}:{user_id}"
dispatch_lock = f"dispatch:{msg_id}"
result = self.redis.eval(
LUA_ATOMIC_LOCK, 2,
dispatch_lock, lock_key,
self.instance_id, str(300), str(3600),
)
if result == 0:
return False
if result == -1:
return False
try:
task = celery_app.send_task(task_name, kwargs=params)
except Exception as e:
pipe = self.redis.pipeline()
pipe.delete(dispatch_lock)
pipe.delete(lock_key)
pipe.execute()
self.errors += 1
logger.error(
"send_task failed for %s:%s msg=%s: %s",
task_name, user_id, msg_id, e, exc_info=True,
)
return False
try:
pipe = self.redis.pipeline()
pipe.set(lock_key, task.id, ex=3600)
pipe.hset(PENDING_HASH, task.id, json.dumps({
"lock_key": lock_key,
"dispatched_at": time.time(),
"msg_id": msg_id,
}))
pipe.delete(dispatch_lock)
pipe.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
ex=86400,
)
pipe.execute()
except Exception as e:
logger.error(
"Post-dispatch state update failed for %s: %s",
task.id, e, exc_info=True,
)
self.errors += 1
self.dispatched += 1
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
return True
def _process_batch(self, user_ids):
if not user_ids:
return
pipe = self.redis.pipeline()
for uid in user_ids:
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
heads = pipe.execute()
candidates = [] # (user_id, msg_dict)
empty_users = []
for uid, head in zip(user_ids, heads):
if head is None:
empty_users.append(uid)
else:
try:
candidates.append((uid, json.loads(head)))
except (json.JSONDecodeError, TypeError) as e:
logger.error("Bad message in queue for user %s: %s", uid, e)
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
if empty_users:
pipe = self.redis.pipeline()
for uid in empty_users:
pipe.srem(ACTIVE_USERS, uid)
pipe.execute()
if not candidates:
return
for uid, msg in candidates:
if self._dispatch(msg["msg_id"], msg):
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
def schedule_loop(self):
self._heartbeat()
self._cleanup_finished()
pipe = self.redis.pipeline()
pipe.smembers(READY_SET)
pipe.delete(READY_SET)
results = pipe.execute()
ready_users = results[0] or set()
my_users = [uid for uid in ready_users if self._is_mine(uid)]
if not my_users:
time.sleep(0.5)
return
self._process_batch(my_users)
time.sleep(0.1)
def _full_scan(self):
cursor = 0
ready_batch = []
while True:
cursor, user_ids = self.redis.sscan(
ACTIVE_USERS, cursor=cursor, count=1000,
)
if user_ids:
my_users = [uid for uid in user_ids if self._is_mine(uid)]
if my_users:
pipe = self.redis.pipeline()
for uid in my_users:
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
heads = pipe.execute()
for uid, head in zip(my_users, heads):
if head is None:
continue
try:
msg = json.loads(head)
lock_key = f"{msg['task_name']}:{uid}"
ready_batch.append((uid, lock_key))
except (json.JSONDecodeError, TypeError):
continue
if cursor == 0:
break
if not ready_batch:
return
pipe = self.redis.pipeline()
for _, lock_key in ready_batch:
pipe.exists(lock_key)
lock_exists = pipe.execute()
ready_uids = [
uid for (uid, _), locked in zip(ready_batch, lock_exists)
if not locked
]
if ready_uids:
self.redis.sadd(READY_SET, *ready_uids)
logger.info("Full scan found %d ready users", len(ready_uids))
def run_server(self):
health_check_server(self)
self.running = True
last_full_scan = 0.0
full_scan_interval = 30.0
logger.info(
"Scheduler started: instance=%s", self.instance_id,
)
while True:
try:
self.schedule_loop()
now = time.time()
if now - last_full_scan > full_scan_interval:
self._full_scan()
last_full_scan = now
except Exception as e:
logger.error("Scheduler exception %s", e, exc_info=True)
self.errors += 1
time.sleep(5)
def health(self) -> dict:
return {
"running": self.running,
"active_users": self.redis.scard(ACTIVE_USERS),
"ready_users": self.redis.scard(READY_SET),
"pending_tasks": self.redis.hlen(PENDING_HASH),
"dispatched": self.dispatched,
"errors": self.errors,
"shard": f"{self._shard_index}/{self._shard_count}",
"instance": self.instance_id,
}
def shutdown(self):
logger.info("Scheduler shutting down: instance=%s", self.instance_id)
self.running = False
try:
self.redis.hdel(REGISTRY_KEY, self.instance_id)
except Exception as e:
logger.error("Shutdown cleanup error: %s", e)
scheduler: RedisTaskScheduler | None = None
if scheduler is None:
scheduler = RedisTaskScheduler()
if __name__ == "__main__":
import signal
import sys
def _signal_handler(signum, frame):
scheduler.shutdown()
sys.exit(0)
signal.signal(signal.SIGTERM, _signal_handler)
signal.signal(signal.SIGINT, _signal_handler)
scheduler.run_server()

View File

@@ -9,7 +9,7 @@ from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
from app.schemas.response_schema import PageData, PageMeta
from app.services.app_service import AppService
from app.services.app_log_service import AppLogService
@@ -41,7 +41,7 @@ def list_app_logs(
# 验证应用访问权限
app_service = AppService(db)
app = app_service.get_app(app_id, workspace_id)
app_service.get_app(app_id, workspace_id)
# 使用 Service 层查询
log_service = AppLogService(db)
@@ -51,8 +51,7 @@ def list_app_logs(
page=page,
pagesize=pagesize,
is_draft=is_draft,
keyword=keyword,
app_type=app.type,
keyword=keyword
)
items = [AppLogConversation.model_validate(c) for c in conversations]
@@ -79,32 +78,17 @@ def get_app_log_detail(
# 验证应用访问权限
app_service = AppService(db)
app = app_service.get_app(app_id, workspace_id)
app_service.get_app(app_id, workspace_id)
# 使用 Service 层查询
log_service = AppLogService(db)
conversation, messages, node_executions_map = log_service.get_conversation_detail(
conversation, node_executions_map = log_service.get_conversation_detail(
app_id=app_id,
conversation_id=conversation_id,
workspace_id=workspace_id,
app_type=app.type
workspace_id=workspace_id
)
# 构建基础会话信息(不经过 ORM relationship
base = AppLogConversation.model_validate(conversation)
# 单独处理 messages避免触发 SQLAlchemy relationship 校验
if messages and isinstance(messages[0], AppLogMessage):
# 工作流:已经是 AppLogMessage 实例
msg_list = messages
else:
# AgentORM Message 对象逐个转换
msg_list = [AppLogMessage.model_validate(m) for m in messages]
detail = AppLogConversationDetail(
**base.model_dump(),
messages=msg_list,
node_executions_map=node_executions_map,
)
detail = AppLogConversationDetail.model_validate(conversation)
detail.node_executions_map = node_executions_map
return success(data=detail)

View File

@@ -4,9 +4,7 @@
处理显性记忆相关的API接口包括情景记忆和语义记忆的查询。
"""
from typing import Optional
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
@@ -71,140 +69,6 @@ async def get_explicit_memory_overview_api(
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
@router.get("/episodics", response_model=ApiResponse)
async def get_episodic_memory_list_api(
end_user_id: str = Query(..., description="end user ID"),
page: int = Query(1, gt=0, description="page number, starting from 1"),
pagesize: int = Query(10, gt=0, le=100, description="number of items per page, max 100"),
start_date: Optional[int] = Query(None, description="start timestamp (ms)"),
end_date: Optional[int] = Query(None, description="end timestamp (ms)"),
episodic_type: str = Query("all", description="episodic type all/conversation/project_work/learning/decision/important_event"),
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取情景记忆分页列表
返回指定用户的情景记忆列表,支持分页、时间范围筛选和情景类型筛选。
Args:
end_user_id: 终端用户ID必填
page: 页码从1开始默认1
pagesize: 每页数量默认10最大100
start_date: 开始时间戳(可选,毫秒),自动扩展到当天 00:00:00
end_date: 结束时间戳(可选,毫秒),自动扩展到当天 23:59:59
episodic_type: 情景类型筛选可选默认all
current_user: 当前用户
Returns:
ApiResponse: 包含情景记忆分页列表
Examples:
- 基础分页查询GET /episodics?end_user_id=xxx&page=1&pagesize=5
返回第1页每页5条数据
- 按时间范围筛选GET /episodics?end_user_id=xxx&page=1&pagesize=5&start_date=1738684800000&end_date=1738771199000
返回指定时间范围内的数据
- 按情景类型筛选GET /episodics?end_user_id=xxx&page=1&pagesize=5&episodic_type=important_event
返回类型为"重要事件"的数据
Notes:
- start_date 和 end_date 必须同时提供或同时不提供
- start_date 不能大于 end_date
- episodic_type 可选值all, conversation, project_work, learning, decision, important_event
- total 为该用户情景记忆总数(不受筛选条件影响)
- page.total 为筛选后的总条数
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆列表但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"情景记忆分页查询: end_user_id={end_user_id}, "
f"start_date={start_date}, end_date={end_date}, episodic_type={episodic_type}, "
f"page={page}, pagesize={pagesize}, username={current_user.username}"
)
# 1. 参数校验
if page < 1 or pagesize < 1:
api_logger.warning(f"分页参数错误: page={page}, pagesize={pagesize}")
return fail(BizCode.INVALID_PARAMETER, "分页参数必须大于0")
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
if episodic_type not in valid_episodic_types:
api_logger.warning(f"无效的情景类型参数: {episodic_type}")
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
# 时间戳参数校验
if (start_date is not None and end_date is None) or (end_date is not None and start_date is None):
return fail(BizCode.INVALID_PARAMETER, "start_date和end_date必须同时提供")
if start_date is not None and end_date is not None and start_date > end_date:
return fail(BizCode.INVALID_PARAMETER, "start_date不能大于end_date")
# 2. 执行查询
try:
result = await memory_explicit_service.get_episodic_memory_list(
end_user_id=end_user_id,
page=page,
pagesize=pagesize,
start_date=start_date,
end_date=end_date,
episodic_type=episodic_type,
)
api_logger.info(
f"情景记忆分页查询成功: end_user_id={end_user_id}, "
f"total={result['total']}, 返回={len(result['items'])}"
)
except Exception as e:
api_logger.error(f"情景记忆分页查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "情景记忆分页查询失败", str(e))
# 3. 返回结构化响应
return success(data=result, msg="查询成功")
@router.get("/semantics", response_model=ApiResponse)
async def get_semantic_memory_list_api(
end_user_id: str = Query(..., description="终端用户ID"),
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取语义记忆列表
返回指定用户的全量语义记忆列表。
Args:
end_user_id: 终端用户ID必填
current_user: 当前用户
Returns:
ApiResponse: 包含语义记忆全量列表
"""
workspace_id = current_user.current_workspace_id
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询语义记忆列表但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"语义记忆列表查询: end_user_id={end_user_id}, username={current_user.username}"
)
try:
result = await memory_explicit_service.get_semantic_memory_list(
end_user_id=end_user_id
)
api_logger.info(
f"语义记忆列表查询成功: end_user_id={end_user_id}, total={len(result)}"
)
except Exception as e:
api_logger.error(f"语义记忆列表查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "语义记忆列表查询失败", str(e))
return success(data=result, msg="查询成功")
@router.post("/details", response_model=ApiResponse)
async def get_explicit_memory_details_api(
request: ExplicitMemoryDetailsRequest,

View File

@@ -14,7 +14,6 @@ from . import (
rag_api_document_controller,
rag_api_file_controller,
rag_api_knowledge_controller,
user_memory_api_controller,
)
# 创建 V1 API 路由器
@@ -29,6 +28,5 @@ service_router.include_router(rag_api_chunk_controller.router)
service_router.include_router(memory_api_controller.router)
service_router.include_router(end_user_api_controller.router)
service_router.include_router(memory_config_api_controller.router)
service_router.include_router(user_memory_api_controller.router)
__all__ = ["service_router"]

View File

@@ -3,7 +3,6 @@
from fastapi import APIRouter, Body, Depends, Query, Request
from sqlalchemy.orm import Session
from app.celery_task_scheduler import scheduler
from app.core.api_key_auth import require_api_key
from app.core.logging_config import get_business_logger
from app.core.quota_stub import check_end_user_quota
@@ -87,7 +86,7 @@ async def write_memory(
user_rag_memory_id=payload.user_rag_memory_id,
)
logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}")
logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
@@ -106,7 +105,8 @@ async def get_write_task_status(
"""
logger.info(f"Write task status check - task_id: {task_id}")
result = scheduler.get_task_status(task_id)
from app.services.task_service import get_task_memory_write_result
result = get_task_memory_write_result(task_id)
return success(data=_sanitize_task_result(result), msg="Task status retrieved")

View File

@@ -1,230 +0,0 @@
"""User Memory 服务接口 — 基于 API Key 认证
包装 user_memory_controllers.py 和 memory_agent_controller.py 中的内部接口,
提供基于 API Key 认证的对外服务:
1./analytics/graph_data - 知识图谱数据接口
2./analytics/community_graph - 社区图谱接口
3./analytics/node_statistics - 记忆节点统计接口
4./analytics/user_summary - 用户摘要接口
5./analytics/memory_insight - 记忆洞察接口
6./analytics/interest_distribution - 兴趣分布接口
7./analytics/end_user_info - 终端用户信息接口
8./analytics/generate_cache - 缓存生成接口
路由前缀: /memory
子路径: /analytics/...
最终路径: /v1/memory/analytics/...
认证方式: API Key (@require_api_key)
"""
from typing import Optional
from fastapi import APIRouter, Depends, Header, Query, Request, Body
from sqlalchemy.orm import Session
from app.core.api_key_auth import require_api_key
from app.core.api_key_utils import get_current_user_from_api_key, validate_end_user_in_workspace
from app.core.logging_config import get_business_logger
from app.db import get_db
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_storage_schema import GenerateCacheRequest
# 包装内部服务 controller
from app.controllers import user_memory_controllers, memory_agent_controller
router = APIRouter(prefix="/memory", tags=["V1 - User Memory API"])
logger = get_business_logger()
# ==================== 知识图谱 ====================
@router.get("/analytics/graph_data")
@require_api_key(scopes=["memory"])
async def get_graph_data(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
node_types: Optional[str] = Query(None, description="Comma-separated node types filter"),
limit: int = Query(100, description="Max nodes to return (auto-capped at 1000 in service layer)"),
depth: int = Query(1, description="Graph traversal depth (auto-capped at 3 in service layer)"),
center_node_id: Optional[str] = Query(None, description="Center node for subgraph"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get knowledge graph data (nodes + edges) for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_graph_data_api(
end_user_id=end_user_id,
node_types=node_types,
limit=limit,
depth=depth,
center_node_id=center_node_id,
current_user=current_user,
db=db,
)
@router.get("/analytics/community_graph")
@require_api_key(scopes=["memory"])
async def get_community_graph(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get community clustering graph for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_community_graph_data_api(
end_user_id=end_user_id,
current_user=current_user,
db=db,
)
# ==================== 节点统计 ====================
@router.get("/analytics/node_statistics")
@require_api_key(scopes=["memory"])
async def get_node_statistics(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get memory node type statistics for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_node_statistics_api(
end_user_id=end_user_id,
current_user=current_user,
db=db,
)
# ==================== 用户摘要 & 洞察 ====================
@router.get("/analytics/user_summary")
@require_api_key(scopes=["memory"])
async def get_user_summary(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
language_type: str = Header(default=None, alias="X-Language-Type"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get cached user summary for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_user_summary_api(
end_user_id=end_user_id,
language_type=language_type,
current_user=current_user,
db=db,
)
@router.get("/analytics/memory_insight")
@require_api_key(scopes=["memory"])
async def get_memory_insight(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get cached memory insight report for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_memory_insight_report_api(
end_user_id=end_user_id,
current_user=current_user,
db=db,
)
# ==================== 兴趣分布 ====================
@router.get("/analytics/interest_distribution")
@require_api_key(scopes=["memory"])
async def get_interest_distribution(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
limit: int = Query(5, le=5, description="Max interest tags to return"),
language_type: str = Header(default=None, alias="X-Language-Type"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get interest distribution tags for an end user."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await memory_agent_controller.get_interest_distribution_by_user_api(
end_user_id=end_user_id,
limit=limit,
language_type=language_type,
current_user=current_user,
db=db,
)
# ==================== 终端用户信息 ====================
@router.get("/analytics/end_user_info")
@require_api_key(scopes=["memory"])
async def get_end_user_info(
request: Request,
end_user_id: str = Query(..., description="End user ID"),
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""Get end user basic information (name, aliases, metadata)."""
current_user = get_current_user_from_api_key(db, api_key_auth)
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.get_end_user_info(
end_user_id=end_user_id,
current_user=current_user,
db=db,
)
# ==================== 缓存生成 ====================
@router.post("/analytics/generate_cache")
@require_api_key(scopes=["memory"])
async def generate_cache(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(None, description="Request body"),
language_type: str = Header(default=None, alias="X-Language-Type"),
):
"""Trigger cache generation (user summary + memory insight) for an end user or all workspace users."""
body = await request.json()
cache_request = GenerateCacheRequest(**body)
current_user = get_current_user_from_api_key(db, api_key_auth)
if cache_request.end_user_id:
validate_end_user_in_workspace(db, cache_request.end_user_id, api_key_auth.workspace_id)
return await user_memory_controllers.generate_cache_api(
request=cache_request,
language_type=language_type,
current_user=current_user,
db=db,
)

View File

@@ -1,15 +1,8 @@
"""API Key 工具函数"""
import secrets
import uuid as _uuid
from typing import Optional, Union
from datetime import datetime
from sqlalchemy.orm import Session as _Session
from app.core.error_codes import BizCode as _BizCode
from app.core.exceptions import BusinessException as _BusinessException
from app.models.end_user_model import EndUser as _EndUser
from app.repositories.end_user_repository import EndUserRepository as _EndUserRepository
from app.models.api_key_model import ApiKeyType
from fastapi import Response
from fastapi.responses import JSONResponse
@@ -72,72 +65,3 @@ def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
return None
return int(dt.timestamp() * 1000)
def get_current_user_from_api_key(db: _Session, api_key_auth):
"""通过 API Key 构造 current_user 对象。
从 API Key 反查创建者(管理员用户),并设置其 workspace 上下文。
与内部接口的 Depends(get_current_user) (JWT) 等价。
Args:
db: 数据库会话
api_key_auth: API Key 认证信息ApiKeyAuth
Returns:
User ORM 对象,已设置 current_workspace_id
"""
from app.services import api_key_service
api_key = api_key_service.ApiKeyService.get_api_key(
db, api_key_auth.api_key_id, api_key_auth.workspace_id
)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return current_user
def validate_end_user_in_workspace(
db: _Session,
end_user_id: str,
workspace_id,
) -> _EndUser:
"""校验 end_user 是否存在且属于指定 workspace。
Args:
db: 数据库会话
end_user_id: 终端用户 ID
workspace_id: 工作空间 IDUUID 或字符串均可)
Returns:
EndUser ORM 对象(校验通过时)
Raises:
BusinessException(INVALID_PARAMETER): end_user_id 格式无效
BusinessException(USER_NOT_FOUND): end_user 不存在
BusinessException(PERMISSION_DENIED): end_user 不属于该 workspace
"""
try:
_uuid.UUID(end_user_id)
except (ValueError, AttributeError):
raise _BusinessException(
f"Invalid end_user_id format: {end_user_id}",
_BizCode.INVALID_PARAMETER,
)
end_user_repo = _EndUserRepository(db)
end_user = end_user_repo.get_end_user_by_id(end_user_id)
if end_user is None:
raise _BusinessException(
"End user not found",
_BizCode.USER_NOT_FOUND,
)
if str(end_user.workspace_id) != str(workspace_id):
raise _BusinessException(
"End user does not belong to this workspace",
_BizCode.PERMISSION_DENIED,
)
return end_user

View File

@@ -1,7 +1,6 @@
import json
import os
from app.celery_task_scheduler import scheduler
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
@@ -13,6 +12,8 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__)
@@ -85,28 +86,16 @@ async def write(
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
# write_id = write_message_task.delay(
# actual_end_user_id, # end_user_id: User ID
# structured_messages, # message: JSON string format message list
# str(actual_config_id), # config_id: Configuration ID string
# storage_type, # storage_type: "neo4j"
# user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
# )
scheduler.push_task(
"app.core.memory.agent.write_message",
str(actual_end_user_id),
{
"end_user_id": str(actual_end_user_id),
"message": structured_messages,
"config_id": str(actual_config_id),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id or ""
}
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: User ID
structured_messages, # message: JSON string format message list
str(actual_config_id), # config_id: Configuration ID string
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
)
# logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
# write_status = get_task_memory_write_result(str(write_id))
# logger.info(f'[WRITE] Task result - user={actual_end_user_id}')
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def term_memory_save(end_user_id, strategy_type, scope):
@@ -175,24 +164,13 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
else:
config_id = memory_config
scheduler.push_task(
"app.core.memory.agent.write_message",
str(end_user_id),
{
"end_user_id": str(end_user_id),
"message": redis_messages,
"config_id": str(config_id),
"storage_type": AgentMemory_Long_Term.STORAGE_NEO4J,
"user_rag_memory_id": ""
}
write_message_task.delay(
end_user_id, # end_user_id: User ID
redis_messages, # message: JSON string format message list
config_id, # config_id: Configuration ID string
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
)
# write_message_task.delay(
# end_user_id, # end_user_id: User ID
# redis_messages, # message: JSON string format message list
# config_id, # config_id: Configuration ID string
# AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
# "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
# )
count_store.update_sessions_count(end_user_id, 0, [])

View File

@@ -1,8 +1,8 @@
from app.core.memory.enums import SearchStrategy, StorageType
from app.core.memory.models.service_models import MemorySearchResult
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor
from app.core.memory.read_services.content_search import Neo4jSearchService, RAGSearchService
from app.core.memory.read_services.query_preprocessor import QueryPreprocessor
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):

View File

@@ -8,7 +8,7 @@ from neo4j import Session
from app.core.memory.enums import Neo4jNodeType
from app.core.memory.memory_service import MemoryContext
from app.core.memory.models.service_models import Memory, MemorySearchResult
from app.core.memory.read_services.search_engine.result_builder import data_builder_factory
from app.core.memory.read_services.result_builder import data_builder_factory
from app.core.models import RedBearEmbeddings
from app.core.rag.nlp.search import knowledge_retrieval
from app.repositories import knowledge_repository

View File

@@ -73,7 +73,6 @@ class CustomTool(BaseTool):
# 添加通用参数(基于第一个操作的参数)
if self._parsed_operations:
first_operation = next(iter(self._parsed_operations.values()))
# path/query 参数
for param_name, param_info in first_operation.get("parameters", {}).items():
params.append(ToolParameter(
name=param_name,
@@ -86,23 +85,6 @@ class CustomTool(BaseTool):
maximum=param_info.get("maximum"),
pattern=param_info.get("pattern")
))
# requestBody 参数 — 将 body 字段平铺为独立参数暴露给模型
request_body = first_operation.get("request_body")
if request_body:
body_schema = request_body.get("properties", {})
required_fields = request_body.get("required", [])
for prop_name, prop_schema in body_schema.items():
params.append(ToolParameter(
name=prop_name,
type=self._convert_openapi_type(prop_schema.get("type", "string")),
description=prop_schema.get("description", ""),
required=prop_name in required_fields,
default=prop_schema.get("default"),
enum=prop_schema.get("enum"),
minimum=prop_schema.get("minimum"),
maximum=prop_schema.get("maximum"),
pattern=prop_schema.get("pattern")
))
return params

View File

@@ -16,7 +16,6 @@ from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.engine.state_manager import WorkflowStateManager
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
from app.core.workflow.nodes.base_node import NodeExecutionError
logger = logging.getLogger(__name__)
@@ -327,43 +326,10 @@ class WorkflowExecutor:
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
exc_info=True)
# 1) 尝试从 checkpoint 回补已成功节点的 node_outputs
recovered: dict[str, Any] = {}
try:
if self.graph is not None:
recovered = self.graph.get_state(
self.execution_context.checkpoint_config
).values or {}
except Exception as recover_err:
logger.warning(
f"Recover state on failure failed: {recover_err}, "
f"execution_id={self.execution_context.execution_id}"
)
if result is None:
result = dict(recovered) if recovered else {}
result = {"error": str(e)}
else:
# 已有 result 与 recovered 合并node_outputs 深度合并
for k, v in recovered.items():
if k == "node_outputs" and isinstance(v, dict):
existing = result.get("node_outputs") or {}
result["node_outputs"] = {**v, **existing}
else:
result.setdefault(k, v)
# 2) 如果是节点抛出的 NodeExecutionError把失败节点的 node_output 注入 node_outputs
failed_node_id: str | None = None
if isinstance(e, NodeExecutionError):
failed_node_id = e.node_id
node_outputs = result.setdefault("node_outputs", {})
# 不覆盖已有(理论上不会有),保底写入失败节点记录
node_outputs.setdefault(e.node_id, e.node_output)
result["error"] = str(e)
if failed_node_id:
result["error_node"] = failed_node_id
yield {
"event": "workflow_end",
"data": self.result_builder.build_final_output(

View File

@@ -1,6 +1,5 @@
import asyncio
import logging
import time
import uuid
from abc import ABC, abstractmethod
from datetime import datetime
@@ -23,20 +22,6 @@ from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__)
class NodeExecutionError(Exception):
"""节点执行失败异常。
携带失败节点的完整 node_output供 executor 兜底注入 node_outputs
保证 workflow_executions.output_data 里能看到失败节点的日志记录。
"""
def __init__(self, node_id: str, node_output: dict[str, Any], error_message: str):
super().__init__(f"Node {node_id} execution failed: {error_message}")
self.node_id = node_id
self.node_output = node_output
self.error_message = error_message
class BaseNode(ABC):
"""Base class for workflow nodes.
@@ -411,8 +396,6 @@ class BaseNode(ABC):
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": None,
# 单调递增序号用于日志按执行顺序排序JSONB 不保证 key 顺序)
"execution_order": time.monotonic_ns(),
**self._extract_extra_fields(business_result),
}
final_output = {
@@ -461,9 +444,7 @@ class BaseNode(ABC):
"output": None,
"elapsed_time": elapsed_time,
"token_usage": None,
"error": error_message,
# 单调递增序号,用于日志按执行顺序排序
"execution_order": time.monotonic_ns(),
"error": error_message
}
# if error_edge:
@@ -485,12 +466,7 @@ class BaseNode(ABC):
**node_output
})
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
# 抛出自定义异常,把 node_output 带给 executor供其写入 node_outputs
raise NodeExecutionError(
node_id=self.node_id,
node_output=node_output,
error_message=error_message,
)
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Extracts the input data for this node (used for logging or audit).

View File

@@ -174,18 +174,12 @@ class IterationRuntime:
continue
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
node_cfg = next(
(n for n in self.cycle_nodes if n.get("id") == node_name), None
)
self.event_write({
"type": "cycle_item",
"data": {
"cycle_id": self.node_id,
"cycle_idx": idx,
"node_id": node_name,
"node_type": node_type,
"node_name": node_cfg.get("data", {}).get("label") if node_cfg else node_name,
"status": result.get("node_outputs", {}).get(node_name, {}).get("status", "completed"),
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
if not cycle_variable else cycle_variable,
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")

View File

@@ -210,9 +210,6 @@ class LoopRuntime:
"cycle_id": self.node_id,
"cycle_idx": idx,
"node_id": node_name,
"node_type": node_type,
"node_name": node_name,
"status": result.get("node_outputs", {}).get(node_name, {}).get("status", "completed"),
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
if not cycle_variable else cycle_variable,
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")

View File

@@ -272,11 +272,6 @@ class HttpRequestNodeOutput(BaseModel):
description="HTTP response body",
)
process_data: dict = Field(
default_factory=dict,
description="Raw HTTP request details for debugging",
)
# files: list[File] = Field(
# ...
# )

View File

@@ -160,6 +160,7 @@ class HttpRequestNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: HttpRequestNodeConfig | None = None
self.last_request: str = ""
def _output_types(self) -> dict[str, VariableType]:
return {
@@ -170,6 +171,47 @@ class HttpRequestNode(BaseNode):
"output": VariableType.STRING
}
def _extract_output(self, business_result: Any) -> Any:
if isinstance(business_result, dict):
result = {k: v for k, v in business_result.items() if k != "request"}
return result
return business_result
def _extract_extra_fields(self, business_result: Any) -> dict[str, Any]:
if isinstance(business_result, dict) and "request" in business_result:
return {
"process": {
"request": business_result.get("request", "")
}
}
return {}
def _wrap_error(
self,
error_message: str,
elapsed_time: float,
state: WorkflowState,
variable_pool: VariablePool
) -> dict[str, Any]:
input_data = self._extract_input(state, variable_pool)
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
"node_name": self.node_name,
"status": "failed",
"input": input_data,
"output": None,
"process": {"request": self.last_request} if self.last_request else None,
"elapsed_time": elapsed_time,
"token_usage": None,
"error": error_message
}
return {
"node_outputs": {self.node_id: node_output},
"error": error_message,
"error_node": self.node_id
}
def _build_timeout(self) -> Timeout:
"""
Build httpx Timeout configuration.
@@ -255,18 +297,13 @@ class HttpRequestNode(BaseNode):
case HttpContentType.NONE:
return {}
case HttpContentType.JSON:
rendered = self._render_template(
rendered_body = self._render_template(
self.typed_config.body.data, variable_pool
)
if not rendered or not rendered.strip():
# 第三方导入的工作流可能出现 content_type=json 但 data 为空的情况,视为无 body
return {}
try:
content["json"] = json.loads(rendered)
except json.JSONDecodeError as e:
raise RuntimeError(
f"Invalid JSON body for HTTP request node: {e.msg} (data={rendered!r})"
)
).strip()
if not rendered_body:
content["json"] = {}
else:
content["json"] = json.loads(rendered_body)
case HttpContentType.FROM_DATA:
data = {}
files = []
@@ -334,15 +371,61 @@ class HttpRequestNode(BaseNode):
case _:
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
def _extract_output(self, business_result: Any) -> Any:
if isinstance(business_result, dict):
return {k: v for k, v in business_result.items() if k != "process_data"}
return business_result
def _generate_raw_request(
self,
variable_pool: VariablePool,
url: str,
headers: dict[str, str],
params: dict[str, str],
content: dict[str, Any]
) -> str:
"""
Generate raw HTTP request format for debugging.
def _extract_extra_fields(self, business_result: Any) -> dict:
if isinstance(business_result, dict) and "process_data" in business_result:
return {"process": business_result["process_data"]}
return {}
Args:
variable_pool: Variable Pool
url: Rendered URL
headers: Request headers
params: Query parameters
content: Request body content
Returns:
Raw HTTP request string
"""
method = self.typed_config.method.value
if params:
param_str = "&".join([f"{k}={v}" for k, v in params.items()])
full_url = f"{url}?{param_str}" if "?" not in url else f"{url}&{param_str}"
else:
full_url = url
lines = [f"{method} {full_url} HTTP/1.1"]
for key, value in headers.items():
lines.append(f"{key}: {value}")
if "json" in content and content["json"]:
json_body = json.dumps(content["json"], ensure_ascii=False)
lines.append(f"Content-Length: {len(json_body)}")
lines.append("")
lines.append(json_body)
elif "data" in content and "files" not in content:
if isinstance(content["data"], dict):
body_str = "&".join([f"{k}={v}" for k, v in content["data"].items()])
lines.append(f"Content-Length: {len(body_str)}")
lines.append("")
lines.append(body_str)
elif "content" in content:
lines.append(f"Content-Length: {len(content['content'])}")
lines.append("")
lines.append(content["content"])
elif "files" in content:
lines.append("Content-Length: 0")
lines.append("")
lines.append("# Note: This request includes file uploads")
return "\r\n".join(lines)
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
"""
@@ -362,42 +445,47 @@ class HttpRequestNode(BaseNode):
- str: Branch identifier (e.g. "ERROR") when branching is enabled
"""
self.typed_config = HttpRequestNodeConfig(**self.config)
rendered_url = self._render_template(self.typed_config.url, variable_pool)
built_headers = self._build_header(variable_pool) | self._build_auth(variable_pool)
built_params = self._build_params(variable_pool)
# Build request components
headers = self._build_header(variable_pool) | self._build_auth(variable_pool)
params = self._build_params(variable_pool)
content = await self._build_content(variable_pool)
url = self._render_template(self.typed_config.url, variable_pool)
logger.info(f"Node {self.node_id}: headers={headers}, params={params}, content keys={list(content.keys())}")
# Generate raw HTTP request for debugging
raw_request = self._generate_raw_request(variable_pool, url, headers, params, content)
self.last_request = raw_request
logger.info(f"Node {self.node_id}: Generated HTTP request:\n{raw_request}")
async with httpx.AsyncClient(
verify=self.typed_config.verify_ssl,
timeout=self._build_timeout(),
headers=built_headers,
params=built_params,
headers=headers,
params=params,
follow_redirects=True
) as client:
retries = self.typed_config.retry.max_attempts
while retries > 0:
try:
request_func = self._get_client_method(client)
built_content = await self._build_content(variable_pool)
resp = await request_func(
url=rendered_url,
**built_content
url=url,
**content
)
resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded")
response = HttpResponse(resp)
# Build raw request summary for process_data
raw_request = (
f"{self.typed_config.method.upper()} {resp.request.url} HTTP/1.1\r\n"
+ "".join(f"{k}: {v}\r\n" for k, v in resp.request.headers.items())
+ "\r\n"
+ (resp.request.content.decode(errors="replace") if resp.request.content else "")
)
return HttpRequestNodeOutput(
return {
**HttpRequestNodeOutput(
body=response.body,
status_code=resp.status_code,
headers=resp.headers,
files=response.files,
process_data={"request": raw_request},
).model_dump()
files=response.files
).model_dump(),
"request": raw_request
}
except (httpx.HTTPStatusError, httpx.RequestError) as e:
logger.error(f"HTTP request node exception: {e}")
retries -= 1
@@ -413,10 +501,19 @@ class HttpRequestNode(BaseNode):
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning default result"
)
return self.typed_config.error_handle.default.model_dump()
error_result = self.typed_config.error_handle.default.model_dump()
error_result["request"] = raw_request
return error_result
case HttpErrorHandle.BRANCH:
logger.warning(
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
)
return {"output": "ERROR"}
return {
"output": "ERROR",
"body": "",
"status_code": 500,
"headers": {},
"files": [],
"request": raw_request
}
raise RuntimeError("http request failed")

View File

@@ -1,7 +1,6 @@
import re
from typing import Any
from app.celery_task_scheduler import scheduler
from app.core.memory.enums import SearchStrategy
from app.core.memory.memory_service import MemoryService
from app.core.workflow.engine.state_manager import WorkflowState
@@ -12,6 +11,7 @@ from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.db import get_db_read
from app.schemas import FileInput
from app.tasks import write_message_task
class MemoryReadNode(BaseNode):
@@ -126,23 +126,12 @@ class MemoryWriteNode(BaseNode):
"files": file_info
})
scheduler.push_task(
"app.core.memory.agent.write_message",
end_user_id,
{
"end_user_id": end_user_id,
"message": messages,
"config_id": str(self.typed_config.config_id),
"storage_type": state["memory_storage_type"],
"user_rag_memory_id": state["user_rag_memory_id"]
}
write_message_task.delay(
end_user_id=end_user_id,
message=messages,
config_id=str(self.typed_config.config_id),
storage_type=state["memory_storage_type"],
user_rag_memory_id=state["user_rag_memory_id"]
)
# write_message_task.delay(
# end_user_id=end_user_id,
# message=messages,
# config_id=str(self.typed_config.config_id),
# storage_type=state["memory_storage_type"],
# user_rag_memory_id=state["user_rag_memory_id"]
# )
return "success"

View File

@@ -1,15 +1,13 @@
import uuid
from typing import Optional
from sqlalchemy import select, desc, func, or_, cast, Text
from sqlalchemy import select, desc, func
from sqlalchemy.orm import Session
from app.core.exceptions import ResourceNotFoundException
from app.core.logging_config import get_db_logger
from app.models import Conversation, Message
from app.models.app_model import AppType
from app.models.conversation_model import ConversationDetail
from app.models.workflow_model import WorkflowExecution
logger = get_db_logger()
@@ -208,8 +206,7 @@ class ConversationRepository:
is_draft: Optional[bool] = None,
keyword: Optional[str] = None,
page: int = 1,
pagesize: int = 20,
app_type: Optional[str] = None,
pagesize: int = 20
) -> tuple[list[Conversation], int]:
"""
查询应用日志会话列表(带分页和过滤)
@@ -221,9 +218,6 @@ class ConversationRepository:
keyword: 搜索关键词(匹配消息内容)
page: 页码(从 1 开始)
pagesize: 每页数量
app_type: 应用类型。WORKFLOW 类型改用 workflow_executions 的
input_data/output_data 做关键词过滤(因为失败的工作流不会写入 messages 表);
其他类型仍走 messages 表。
Returns:
Tuple[List[Conversation], int]: (会话列表,总数)
@@ -240,26 +234,10 @@ class ConversationRepository:
# 如果有关键词搜索,通过子查询过滤包含该关键词的 conversation
if keyword:
kw_pattern = f"%{keyword}%"
if app_type == AppType.WORKFLOW:
# 工作流:从 workflow_executions 的 input_data / output_data 匹配
# messages 表只存开场白 assistant 消息,失败的工作流也不会写入)
keyword_stmt = (
select(WorkflowExecution.conversation_id)
.where(
WorkflowExecution.conversation_id.is_not(None),
or_(
cast(WorkflowExecution.input_data, Text).ilike(kw_pattern),
cast(WorkflowExecution.output_data, Text).ilike(kw_pattern),
),
)
.distinct()
)
else:
# Agent 等其他类型:仍走 messages 表user + assistant 内容)
# 查找包含关键词的 conversation_id 列表
keyword_stmt = (
select(Message.conversation_id)
.where(Message.content.ilike(kw_pattern))
.where(Message.content.ilike(f"%{keyword}%"))
.distinct()
)
base_stmt = base_stmt.where(Conversation.id.in_(keyword_stmt))

View File

@@ -14,7 +14,6 @@ class AppLogMessage(BaseModel):
conversation_id: uuid.UUID
role: str = Field(description="角色: user / assistant / system")
content: str
status: Optional[str] = Field(default=None, description="执行状态(工作流专用): completed / failed")
meta_data: Optional[Dict[str, Any]] = None
created_at: datetime.datetime
@@ -59,7 +58,6 @@ class AppLogNodeExecution(BaseModel):
input: Optional[Any] = None
process: Optional[Any] = None
output: Optional[Any] = None
cycle_items: Optional[List[Any]] = None
elapsed_time: Optional[float] = None
token_usage: Optional[Dict[str, Any]] = None

View File

@@ -112,12 +112,12 @@ class MemoryWriteResponse(BaseModel):
"""Response schema for memory write operation.
Attributes:
task_id: task ID for status polling
status: Initial task status (QUEUED)
task_id: Celery task ID for status polling
status: Initial task status (PENDING)
end_user_id: End user ID the write was submitted for
"""
task_id: str = Field(..., description="task ID for polling")
status: str = Field(..., description="Task status: QUEUED")
task_id: str = Field(..., description="Celery task ID for polling")
status: str = Field(..., description="Task status: PENDING")
end_user_id: str = Field(..., description="End user ID")

View File

@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from sqlalchemy import select
from app.aioRedis import aio_redis
from app.models.api_key_model import ApiKey, ApiKeyType
from app.models.api_key_model import ApiKey
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
from app.schemas import api_key_schema
from app.schemas.response_schema import PageData, PageMeta
@@ -65,12 +65,6 @@ class ApiKeyService:
BizCode.BAD_REQUEST
)
# SERVICE 类型的 resource_id 指向 workspace非应用跳过应用发布校验
if data.resource_id and data.type != ApiKeyType.SERVICE.value:
app = db.get(App, data.resource_id)
if not app or not app.current_release_id:
raise BusinessException("该应用未发布", BizCode.APP_NOT_PUBLISHED)
# 生成 API Key
api_key = generate_api_key(data.type)
@@ -453,12 +447,9 @@ class ApiKeyAuthService:
def check_app_published(db: Session, api_key_obj: ApiKey) -> None:
"""
检查应用是否已发布,未发布则抛出异常
SERVICE 类型的 api_key 不绑定应用resource_id 指向 workspace跳过校验
"""
if not api_key_obj.resource_id:
return
if api_key_obj.type == ApiKeyType.SERVICE.value:
return
app = db.get(App, api_key_obj.resource_id)
if not app or not app.current_release_id:
raise BusinessException("应用未发布,不可用", BizCode.APP_NOT_PUBLISHED)

View File

@@ -107,6 +107,23 @@ class AppChatService:
# 获取模型参数
model_parameters = config.model_parameters
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
json_output=model_parameters.get("json_output", False),
capability=api_key_obj.capability or [],
)
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
@@ -160,27 +177,16 @@ class AppChatService:
if doc_img_recognition and "vision" in (api_key_obj.capability or []) and any(
f.type == FileType.DOCUMENT for f in files
):
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
from langchain.agents import create_agent
agent.system_prompt += (
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
"请在回答中用 Markdown 格式 ![描述](URL) 展示相关图片,做到图文并茂。"
)
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
json_output=model_parameters.get("json_output", False),
capability=api_key_obj.capability or [],
agent.agent = create_agent(
model=agent.llm,
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
system_prompt=agent.system_prompt
)
# 为需要运行时上下文的工具注入上下文
for t in tools:
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
@@ -393,6 +399,24 @@ class AppChatService:
# 获取模型参数
model_parameters = config.model_parameters
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
json_output=model_parameters.get("json_output", False),
capability=api_key_obj.capability or [],
)
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
@@ -447,26 +471,14 @@ class AppChatService:
f.type == FileType.DOCUMENT for f in files
):
from langchain.agents import create_agent
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片"
agent.system_prompt += (
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记"
"请在回答中用 Markdown 格式 ![描述](URL) 展示相关图片,做到图文并茂。"
)
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
json_output=model_parameters.get("json_output", False),
capability=api_key_obj.capability or [],
agent.agent = create_agent(
model=agent.llm,
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
system_prompt=agent.system_prompt
)
# 为需要运行时上下文的工具注入上下文

View File

@@ -1,17 +1,16 @@
"""应用日志服务层"""
import uuid
import datetime as dt
from typing import Optional, Tuple
from datetime import datetime
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger
from app.models.app_model import AppType
from app.models.conversation_model import Conversation, Message
from app.models.workflow_model import WorkflowExecution
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
from app.schemas.app_log_schema import AppLogMessage, AppLogNodeExecution
from app.schemas.app_log_schema import AppLogNodeExecution
logger = get_business_logger()
@@ -32,7 +31,6 @@ class AppLogService:
pagesize: int = 20,
is_draft: Optional[bool] = None,
keyword: Optional[str] = None,
app_type: Optional[str] = None,
) -> Tuple[list[Conversation], int]:
"""
查询应用日志会话列表
@@ -44,7 +42,6 @@ class AppLogService:
pagesize: 每页数量
is_draft: 是否草稿会话None表示返回全部
keyword: 搜索关键词(匹配消息内容)
app_type: 应用类型WORKFLOW 时关键词将从 workflow_executions 搜索)
Returns:
Tuple[list[Conversation], int]: (会话列表,总数)
@@ -57,8 +54,7 @@ class AppLogService:
"page": page,
"pagesize": pagesize,
"is_draft": is_draft,
"keyword": keyword,
"app_type": app_type,
"keyword": keyword
}
)
@@ -69,8 +65,7 @@ class AppLogService:
is_draft=is_draft,
keyword=keyword,
page=page,
pagesize=pagesize,
app_type=app_type,
pagesize=pagesize
)
logger.info(
@@ -88,38 +83,49 @@ class AppLogService:
self,
app_id: uuid.UUID,
conversation_id: uuid.UUID,
workspace_id: uuid.UUID,
app_type: str = AppType.AGENT
) -> Tuple[Conversation, list, dict[str, list[AppLogNodeExecution]]]:
workspace_id: uuid.UUID
) -> Tuple[Conversation, dict[str, list[AppLogNodeExecution]]]:
"""
查询会话详情
查询会话详情(包含消息和工作流节点执行记录)
Args:
app_id: 应用 ID
conversation_id: 会话 ID
workspace_id: 工作空间 ID
Returns:
Tuple[Conversation, list[AppLogMessage|Message], dict[str, list[AppLogNodeExecution]]]
Tuple[Conversation, dict[str, list[AppLogNodeExecution]]]:
(包含消息的会话对象, 按消息ID分组的节点执行记录)
Raises:
ResourceNotFoundException: 当会话不存在时
"""
logger.info(
"查询应用日志会话详情",
extra={
"app_id": str(app_id),
"conversation_id": str(conversation_id),
"workspace_id": str(workspace_id),
"app_type": app_type
"workspace_id": str(workspace_id)
}
)
# 查询会话
conversation = self.conversation_repository.get_conversation_for_app_log(
conversation_id=conversation_id,
app_id=app_id,
workspace_id=workspace_id
)
if app_type == AppType.WORKFLOW:
messages, node_executions_map = self._get_workflow_messages_and_nodes(conversation_id)
else:
# 查询消息(按时间正序)
messages = self.message_repository.get_messages_by_conversation(
conversation_id=conversation_id
)
node_executions_map = self._get_workflow_node_executions_with_map(
# 将消息附加到会话对象
conversation.messages = messages
# 查询工作流节点执行记录(按消息分组)
_, node_executions_map = self._get_workflow_node_executions_with_map(
conversation_id, messages
)
@@ -133,129 +139,13 @@ class AppLogService:
}
)
return conversation, messages, node_executions_map
def _get_workflow_messages_and_nodes(
self,
conversation_id: uuid.UUID,
) -> Tuple[list[AppLogMessage], dict[str, list[AppLogNodeExecution]]]:
"""
工作流应用专用:从 workflow_executions 构建 messages 和节点日志。
每条 WorkflowExecution 对应一轮对话:
- user message来自 execution.input_datacontent 取 message 字段files 放 meta_data
- assistant message来自 execution.output_data失败时内容为错误信息
开场白的 suggested_questions 合并到第一条 assistant message 的 meta_data 里。
Returns:
(messages 列表, node_executions_map)
"""
stmt = (
select(WorkflowExecution)
.where(
WorkflowExecution.conversation_id == conversation_id,
WorkflowExecution.status.in_(["completed", "failed"])
)
.order_by(WorkflowExecution.started_at.asc())
)
executions = list(self.db.scalars(stmt).all())
# 查开场白Message 表里 meta_data 含 suggested_questions 的第一条 assistant 消息
opening_stmt = (
select(Message)
.where(
Message.conversation_id == conversation_id,
Message.role == "assistant",
)
.order_by(Message.created_at.asc())
.limit(10)
)
early_messages = list(self.db.scalars(opening_stmt).all())
suggested_questions: list = []
for m in early_messages:
if isinstance(m.meta_data, dict) and "suggested_questions" in m.meta_data:
suggested_questions = m.meta_data.get("suggested_questions") or []
break
messages: list[AppLogMessage] = []
node_executions_map: dict[str, list[AppLogNodeExecution]] = {}
# 如果有开场白,作为第一条 assistant 消息插入
if suggested_questions or early_messages:
opening_msg = next(
(m for m in early_messages
if isinstance(m.meta_data, dict) and "suggested_questions" in m.meta_data),
None
)
if opening_msg:
messages.append(AppLogMessage(
id=opening_msg.id,
conversation_id=conversation_id,
role="assistant",
content=opening_msg.content,
status=None,
meta_data={"suggested_questions": suggested_questions},
created_at=opening_msg.created_at,
))
for execution in executions:
started_at = execution.started_at or dt.datetime.now()
completed_at = execution.completed_at or started_at
# assistant message 的 id同时作为 node_executions_map 的 key
assistant_msg_id = uuid.uuid5(execution.id, "assistant")
# --- user message输入---
input_data = execution.input_data or {}
input_content = input_data.get("message") or _extract_text(input_data)
# 跳过没有用户输入的 execution如开场白触发的记录
if not input_content or not input_content.strip():
continue
files = input_data.get("files") or []
user_msg = AppLogMessage(
id=uuid.uuid5(execution.id, "user"),
conversation_id=conversation_id,
role="user",
content=input_content,
meta_data={"files": files} if files else None,
created_at=started_at,
)
messages.append(user_msg)
# --- assistant message输出---
if execution.status == "completed":
output_content = _extract_text(execution.output_data)
meta = {"usage": execution.token_usage or {}, "elapsed_time": execution.elapsed_time}
else:
output_content = _extract_text(execution.output_data) or ""
meta = {"error": execution.error_message, "error_node_id": execution.error_node_id}
assistant_msg = AppLogMessage(
id=assistant_msg_id,
conversation_id=conversation_id,
role="assistant",
content=output_content,
status=execution.status,
meta_data=meta,
created_at=completed_at,
)
messages.append(assistant_msg)
# --- 节点执行记录,从 workflow_executions.output_data["node_outputs"] 读取 ---
execution_nodes = _build_nodes_from_output_data(execution.output_data)
if execution_nodes:
node_executions_map[str(assistant_msg_id)] = execution_nodes
return messages, node_executions_map
return conversation, node_executions_map
def _get_workflow_node_executions_with_map(
self,
conversation_id: uuid.UUID,
messages: list[Message]
) -> dict[str, list[AppLogNodeExecution]]:
) -> Tuple[list[AppLogNodeExecution], dict[str, list[AppLogNodeExecution]]]:
"""
从 workflow_executions 表中提取节点执行记录,并按 assistant message 分组
@@ -267,12 +157,13 @@ class AppLogService:
Tuple[list[AppLogNodeExecution], dict[str, list[AppLogNodeExecution]]]:
(所有节点执行记录列表, 按 message_id 分组的节点执行记录字典)
"""
node_executions = []
node_executions_map: dict[str, list[AppLogNodeExecution]] = {}
# 查询该会话关联的所有工作流执行记录(按时间正序)
stmt = select(WorkflowExecution).where(
WorkflowExecution.conversation_id == conversation_id,
WorkflowExecution.status.in_(["completed", "failed"])
WorkflowExecution.status == "completed"
).order_by(WorkflowExecution.started_at.asc())
executions = self.db.scalars(stmt).all()
@@ -297,18 +188,10 @@ class AppLogService:
used_message_ids: set[str] = set()
for execution in executions:
# 构建节点执行记录列表,从 workflow_executions.output_data["node_outputs"] 读取
execution_nodes = _build_nodes_from_output_data(execution.output_data)
if not execution_nodes:
if not execution.output_data:
continue
# 失败的执行没有 assistant message,直接用 execution id 作为 key
if execution.status == "failed":
node_executions_map[f"execution_{str(execution.id)}"] = execution_nodes
continue
# completed通过时序匹配关联到对应的 assistant message
# 找到该 execution 对应的 assistant message
# 逻辑:找 execution.started_at 之后最近的、未使用的 assistant message
best_msg = None
best_dt = None
@@ -317,9 +200,9 @@ class AppLogService:
if msg_id_str in used_message_ids:
continue
if msg.created_at and msg.created_at >= execution.started_at:
delta = (msg.created_at - execution.started_at).total_seconds()
if best_dt is None or delta < best_dt:
best_dt = delta
dt = (msg.created_at - execution.started_at).total_seconds()
if best_dt is None or dt < best_dt:
best_dt = dt
best_msg = msg
if not best_msg:
@@ -327,86 +210,31 @@ class AppLogService:
msg_id_str = str(best_msg.id)
used_message_ids.add(msg_id_str)
node_executions_map[msg_id_str] = execution_nodes
return node_executions_map
def _extract_text(data: Optional[dict]) -> str:
"""从 workflow execution 的 input_data / output_data 中提取可读文本。
优先取 'text''content''output' 字段;若都没有则 JSON 序列化整个 dict。
"""
if not data:
return ""
for key in ("message", "text", "content", "output", "result", "answer"):
if key in data and isinstance(data[key], str):
return data[key]
import json
return json.dumps(data, ensure_ascii=False)
def _build_nodes_from_output_data(output_data: Optional[dict]) -> list[AppLogNodeExecution]:
"""从 workflow_executions.output_data["node_outputs"] 构建节点执行记录列表。
output_data 结构:
{
"node_outputs": {
"<node_id>": {
"node_type": ...,
"node_name": ...,
"status": ...,
"input": ...,
"output": ...,
"elapsed_time": ...,
"token_usage": ...,
"error": ...,
"cycle_items": [...],
...
}
},
"error": ...,
...
}
"""
if not output_data:
return []
node_outputs: dict = output_data.get("node_outputs") or {}
# 按 execution_order节点执行时写入的单调递增序号排序。
# PostgreSQL JSONB 不保证 key 顺序,不能依赖 dict 插入顺序;
# 缺失 execution_order 的历史数据退化到 0保持在最前。
ordered_items = sorted(
node_outputs.items(),
key=lambda kv: (kv[1] or {}).get("execution_order", 0)
if isinstance(kv[1], dict) else 0
)
result = []
for node_id, node_data in ordered_items:
# 提取节点输出
output_data = execution.output_data
if isinstance(output_data, dict):
node_outputs = output_data.get("node_outputs", {})
execution_nodes = []
for node_id, node_data in node_outputs.items():
if not isinstance(node_data, dict):
continue
output = dict(node_data)
cycle_items = output.pop("cycle_items", None)
# 把已知的顶层字段剥离,剩余的作为 output
node_type = output.pop("node_type", "unknown")
node_name = output.pop("node_name", None)
status = output.pop("status", "completed")
error = output.pop("error", None)
inp = output.pop("input", None)
elapsed_time = output.pop("elapsed_time", None)
token_usage = output.pop("token_usage", None)
# execution_order 仅用于排序,不返回给前端
output.pop("execution_order", None)
result.append(AppLogNodeExecution(
node_id=node_id,
node_type=node_type,
node_name=node_name,
status=status,
error=error,
input=inp,
process=None,
output=output if output else None,
cycle_items=cycle_items,
elapsed_time=elapsed_time,
token_usage=token_usage,
))
return result
node_execution = AppLogNodeExecution(
node_id=node_data.get("node_id", node_id),
node_type=node_data.get("node_type", "unknown"),
node_name=node_data.get("node_name"),
status=node_data.get("status", "unknown"),
error=node_data.get("error"),
input=node_data.get("input"),
process=node_data.get("process"),
output=node_data.get("output"),
elapsed_time=node_data.get("elapsed_time"),
token_usage=node_data.get("token_usage"),
)
node_executions.append(node_execution)
execution_nodes.append(node_execution)
# 将节点记录关联到 message_id
node_executions_map[msg_id_str] = execution_nodes
return node_executions, node_executions_map

View File

@@ -595,6 +595,23 @@ class AgentRunService:
)
tools.extend(memory_tools)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
json_output=effective_params.get("json_output", False),
capability=api_key_config.get("capability", []),
)
# 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id
opening, suggested_questions = None, None
@@ -649,26 +666,16 @@ class AgentRunService:
and any(f.type == FileType.DOCUMENT for f in files)
)
if has_doc_with_images:
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片"
agent.system_prompt += (
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记"
"请在回答中用 Markdown 格式 ![描述](URL) 展示相关图片,做到图文并茂。"
)
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
json_output=effective_params.get("json_output", False),
capability=api_key_config.get("capability", []),
# 重建 agent graph 以使新 system_prompt 生效
agent.agent = create_agent(
model=agent.llm,
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
system_prompt=agent.system_prompt
)
# 为需要运行时上下文的工具注入上下文
for t in tools:
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
@@ -868,6 +875,24 @@ class AgentRunService:
user_rag_memory_id)
tools.extend(memory_tools)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
json_output=effective_params.get("json_output", False),
capability=api_key_config.get("capability", []),
)
# 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id
opening, suggested_questions = None, None
@@ -923,28 +948,18 @@ class AgentRunService:
and any(f.type == FileType.DOCUMENT for f in files)
)
if has_doc_with_images:
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片"
agent.system_prompt += (
"\n\n文档中包含图片,图片位置已在文本中以 [图片 第N页 第M张图片]: URL 标记"
"请在回答中用 Markdown 格式 ![描述](URL) 展示相关图片,做到图文并茂。"
"**规则1图片URL必须原封不动、一字不差地复制禁止修改、禁止省略任何字符**"
"**规则2禁止修改URL中UUID里的任何数字和字母**"
"**规则3直接使用 ![描述](完整URL) 格式输出**"
)
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
json_output=effective_params.get("json_output", False),
capability=api_key_config.get("capability", []),
agent.agent = create_agent(
model=agent.llm,
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
system_prompt=agent.system_prompt
)
# 为需要运行时上下文的工具注入上下文
for t in tools:
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):

View File

@@ -10,7 +10,6 @@ from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from app.celery_task_scheduler import scheduler
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.logging_config import get_logger
@@ -167,31 +166,20 @@ class MemoryAPIService:
# Convert to message list format expected by write_message_task
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
# from app.tasks import write_message_task
# task = write_message_task.delay(
# end_user_id,
# messages,
# config_id,
# storage_type,
# user_rag_memory_id or "",
# )
task_id = scheduler.push_task(
"app.core.memory.agent.write_message",
from app.tasks import write_message_task
task = write_message_task.delay(
end_user_id,
{
"end_user_id": end_user_id,
"message": messages,
"config_id": config_id,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id or ""
}
messages,
config_id,
storage_type,
user_rag_memory_id or "",
)
logger.info(f"Memory write task submitted, task_id={task_id} end_user_id={end_user_id}")
logger.info(f"Memory write task submitted: task_id={task.id}, end_user_id={end_user_id}")
return {
"task_id": task_id,
"status": "QUEUED",
"task_id": task.id,
"status": "PENDING",
"end_user_id": end_user_id,
}

View File

@@ -4,7 +4,7 @@
处理显性记忆相关的业务逻辑,包括情景记忆和语义记忆的查询。
"""
from typing import Any, Dict, Optional
from typing import Any, Dict
from app.core.logging_config import get_logger
from app.services.memory_base_service import MemoryBaseService
@@ -146,209 +146,6 @@ class MemoryExplicitService(MemoryBaseService):
logger.error(f"获取显性记忆总览时出错: {str(e)}", exc_info=True)
raise
async def get_episodic_memory_list(
self,
end_user_id: str,
page: int,
pagesize: int,
start_date: Optional[int] = None,
end_date: Optional[int] = None,
episodic_type: str = "all",
) -> Dict[str, Any]:
"""
获取情景记忆分页列表
Args:
end_user_id: 终端用户ID
page: 页码
pagesize: 每页数量
start_date: 开始时间戳(毫秒),可选
end_date: 结束时间戳(毫秒),可选
episodic_type: 情景类型筛选
Returns:
{
"total": int, # 该用户情景记忆总数(不受筛选影响)
"items": [...], # 当前页数据
"page": {
"page": int,
"pagesize": int,
"total": int, # 筛选后总数
"hasnext": bool
}
}
"""
try:
logger.info(
f"情景记忆分页查询: end_user_id={end_user_id}, "
f"start_date={start_date}, end_date={end_date}, "
f"episodic_type={episodic_type}, page={page}, pagesize={pagesize}"
)
# 1. 查询情景记忆总数(不受筛选条件限制)
total_all_query = """
MATCH (s:MemorySummary)
WHERE s.end_user_id = $end_user_id
RETURN count(s) AS total
"""
total_all_result = await self.neo4j_connector.execute_query(
total_all_query, end_user_id=end_user_id
)
total_all = total_all_result[0]["total"] if total_all_result else 0
# 2. 构建筛选条件
where_clauses = ["s.end_user_id = $end_user_id"]
params = {"end_user_id": end_user_id}
# 时间戳筛选(毫秒时间戳转为 UTC ISO 字符串,使用 Neo4j datetime() 精确比较)
if start_date is not None and end_date is not None:
from datetime import datetime, timezone
start_dt = datetime.fromtimestamp(start_date / 1000, tz=timezone.utc)
end_dt = datetime.fromtimestamp(end_date / 1000, tz=timezone.utc)
# 开始时间取当天 UTC 00:00:00结束时间取当天 UTC 23:59:59.999999
start_iso = start_dt.strftime("%Y-%m-%dT") + "00:00:00.000000"
end_iso = end_dt.strftime("%Y-%m-%dT") + "23:59:59.999999"
where_clauses.append("datetime(s.created_at) >= datetime($start_iso) AND datetime(s.created_at) <= datetime($end_iso)")
params["start_iso"] = start_iso
params["end_iso"] = end_iso
# 类型筛选下推到 Cypher兼容中英文
if episodic_type != "all":
type_mapping = {
"conversation": "对话",
"project_work": "项目/工作",
"learning": "学习",
"decision": "决策",
"important_event": "重要事件"
}
chinese_type = type_mapping.get(episodic_type)
if chinese_type:
where_clauses.append(
"(s.memory_type = $episodic_type OR s.memory_type = $chinese_type)"
)
params["episodic_type"] = episodic_type
params["chinese_type"] = chinese_type
else:
where_clauses.append("s.memory_type = $episodic_type")
params["episodic_type"] = episodic_type
where_str = " AND ".join(where_clauses)
# 3. 查询筛选后的总数
count_query = f"""
MATCH (s:MemorySummary)
WHERE {where_str}
RETURN count(s) AS total
"""
count_result = await self.neo4j_connector.execute_query(count_query, **params)
filtered_total = count_result[0]["total"] if count_result else 0
# 4. 查询分页数据
skip = (page - 1) * pagesize
data_query = f"""
MATCH (s:MemorySummary)
WHERE {where_str}
RETURN elementId(s) AS id,
s.name AS title,
s.memory_type AS memory_type,
s.content AS content,
s.created_at AS created_at
ORDER BY s.created_at DESC
SKIP $skip LIMIT $limit
"""
params["skip"] = skip
params["limit"] = pagesize
result = await self.neo4j_connector.execute_query(data_query, **params)
# 5. 处理结果
items = []
if result:
for record in result:
raw_created_at = record.get("created_at")
created_at_timestamp = self.parse_timestamp(raw_created_at)
items.append({
"id": record["id"],
"title": record.get("title") or "未命名",
"memory_type": record.get("memory_type") or "其他",
"content": record.get("content") or "",
"created_at": created_at_timestamp
})
# 6. 构建返回结果
return {
"total": total_all,
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": filtered_total,
"hasnext": (page * pagesize) < filtered_total
}
}
except Exception as e:
logger.error(f"情景记忆分页查询出错: {str(e)}", exc_info=True)
raise
async def get_semantic_memory_list(
self,
end_user_id: str
) -> list:
"""
获取语义记忆全量列表
Args:
end_user_id: 终端用户ID
Returns:
[
{
"id": str,
"name": str,
"entity_type": str,
"core_definition": str
}
]
"""
try:
logger.info(f"语义记忆列表查询: end_user_id={end_user_id}")
semantic_query = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id
AND e.is_explicit_memory = true
RETURN elementId(e) AS id,
e.name AS name,
e.entity_type AS entity_type,
e.description AS core_definition
ORDER BY e.name ASC
"""
result = await self.neo4j_connector.execute_query(
semantic_query, end_user_id=end_user_id
)
items = []
if result:
for record in result:
items.append({
"id": record["id"],
"name": record.get("name") or "未命名",
"entity_type": record.get("entity_type") or "未分类",
"core_definition": record.get("core_definition") or ""
})
logger.info(f"语义记忆列表查询成功: end_user_id={end_user_id}, total={len(items)}")
return items
except Exception as e:
logger.error(f"语义记忆列表查询出错: {str(e)}", exc_info=True)
raise
async def get_explicit_memory_details(
self,
end_user_id: str,

View File

@@ -95,7 +95,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
"""通义千问文档格式"""
return True, {
"type": "text",
"text": f"<document name=\"{file_name}\">\n文档内容:\n{text}\n</document>"
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
async def format_audio(
@@ -167,7 +167,6 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]:
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
# Bedrock 文档需要 base64 编码
text = f"文档内容:\n{text}\n"
text_bytes = text.encode('utf-8')
base64_text = base64.b64encode(text_bytes).decode('utf-8')
@@ -224,7 +223,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
"""OpenAI 文档格式"""
return True, {
"type": "text",
"text": f"<document name=\"{file_name}\">\n文档内容:\n{text}\n</document>"
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
async def format_audio(
@@ -389,14 +388,13 @@ class MultimodalService:
from app.models.workspace_model import Workspace as WorkspaceModel
ws = self.db.query(WorkspaceModel).filter(WorkspaceModel.id == workspace_id).first()
tenant_id = ws.tenant_id if ws else None
img_result = []
for img_info in img_infos:
page = img_info["page"]
index = img_info["index"]
ext = img_info.get("ext", "png")
try:
_, img_url = await self._save_doc_image_to_storage(img_info["bytes"], ext, tenant_id, workspace_id)
placeholder = f"{page}页 第{index + 1}" if page > 0 else f"{index + 1}"
placeholder = f"{page}页 第{index + 1}图片" if page > 0 else f"{index + 1}图片"
# 在文本内容中追加图片位置标记
if result and result[-1].get("type") in ("text", "document"):
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
@@ -409,10 +407,9 @@ class MultimodalService:
file_type="image/png",
)
_, img_content = await self._process_image(img_file, strategy_class(img_file))
img_result.append(img_content)
result.append(img_content)
except Exception as img_err:
logger.warning(f"文档图片处理失败: {img_err}")
result.extend(img_result)
elif file.type == FileType.AUDIO and "audio" in self.capability:
is_support, content = await self._process_audio(file, strategy)
result.append(content)

View File

@@ -815,12 +815,11 @@ class ToolService:
"default": param_info.get("default")
})
# 请求体参数 — _extract_request_body 返回 {"schema": {...}, "required": bool, ...}
# 请求体参数
request_body = operation.get("request_body")
if request_body:
body_schema = request_body.get("schema", {})
schema_props = body_schema.get("properties", {})
required_props = body_schema.get("required", [])
schema_props = request_body.get("schema", {}).get("properties", {})
required_props = request_body.get("schema", {}).get("required", [])
for prop_name, prop_schema in schema_props.items():
parameters.append({

View File

@@ -17,9 +17,8 @@ from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config
from app.db import get_db
from sqlalchemy import select
from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories import knowledge_repository
from app.repositories.workflow_repository import (
WorkflowConfigRepository,
@@ -919,7 +918,6 @@ class WorkflowService:
input_data["conv_messages"] = conv_messages
init_message_length = len(input_data.get("conv_messages", []))
message_id = uuid.uuid4()
_cycle_items: dict[str, list] = {}
# 新会话时写入开场白
is_new_conversation = init_message_length == 0
@@ -950,15 +948,6 @@ class WorkflowService:
memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
event_type = event.get("event")
event_data = event.get("data", {})
if event_type == "cycle_item":
cycle_id = event_data.get("cycle_id")
if cycle_id not in _cycle_items:
_cycle_items[cycle_id] = []
_cycle_items[cycle_id].append(event_data)
if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status")
token_usage = event.get("data", {}).get("token_usage", {}) or {}
@@ -1030,18 +1019,6 @@ class WorkflowService:
)
else:
logger.error(f"unexpect workflow run status, status: {status}")
# 把积累的 cycle_item 写入 workflow_executions.output_data["node_outputs"]
if _cycle_items and execution.output_data:
import copy
new_output_data = copy.deepcopy(execution.output_data)
node_outputs = new_output_data.setdefault("node_outputs", {})
for cycle_node_id, items in _cycle_items.items():
if cycle_node_id in node_outputs:
node_outputs[cycle_node_id]["cycle_items"] = items
else:
node_outputs[cycle_node_id] = {"cycle_items": items}
execution.output_data = new_output_data
self.db.commit()
elif event.get("event") == "workflow_start":
event["data"]["message_id"] = str(message_id)
event = self._emit(public, event)

View File

@@ -34,7 +34,7 @@ from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory,
)
from app.db import get_db_context
from app.db import get_db, get_db_context
from app.models import Document, File, Knowledge
from app.models.end_user_model import EndUser
from app.schemas import document_schema, file_schema

View File

@@ -63,23 +63,6 @@ services:
networks:
- celery
celery-task-scheduler:
image: redbear-mem-open:latest
container_name: celery-task-scheduler
env_file:
- .env
volumes:
- /etc/localtime:/etc/localtime:ro
command: python -m app.celery_task_scheduler
restart: unless-stopped
healthcheck:
test: CMD curl -f 127.0.0.1:8001 || exit 1
interval: 30s
timeout: 5s
retries: 3
networks:
- celery
# Celery Beat - scheduler
beat:
image: redbear-mem-open:latest

View File

@@ -62,6 +62,7 @@
"remark-gfm": "^4.0.1",
"remark-math": "^6.0.0",
"tailwindcss": "^4.1.14",
"x6-html-shape": "0.4.9",
"xlsx": "^0.18.5",
"zustand": "^5.0.8"
},

176
web/src/vendor/x6-html-shape/index.js vendored Normal file
View File

@@ -0,0 +1,176 @@
// Patched x6-html-shape: replaces View.createElement (removed in X6 3.x) with document.createElement
import { Node as p, NodeView as l, Graph as C, Dom as s } from "@antv/x6";
import { getConfig as w, clickable as x, isInputElement as y, forwardEvent as S } from "./utils.js";
const u = "html-shape", h = "html-shape-view", T = p.define(w(h)), m = {};
export function register(i) {
const { shape: e, render: n, inherit: t = u, ...o } = i;
if (!e) throw new Error("should specify shape in config");
m[e] = n;
C.registerNode(e, { inherit: t, ...o }, true);
}
const a = "html";
// Determine which HTML layer a node belongs to.
// Parent (loop/iteration) nodes go behind the SVG layer so edges render above them.
// All other nodes go in front of the SVG layer so they render above edges.
function isBackNode(cell) {
const type = cell.getData?.()?.type;
return type === 'loop' || type === 'iteration';
}
// Ensure the two HTML container layers exist and are correctly positioned.
function ensureHtmlLayers(graph) {
if (!graph._htmlBack) {
const back = graph._htmlBack = document.createElement('div');
s.css(back, {
position: 'absolute', width: '100%', height: '100%',
'touch-action': 'none', 'user-select': 'none', 'pointer-events': 'none',
'z-index': 0, 'transform-origin': 'left top',
});
back.classList.add('x6-html-shape-container', 'x6-html-shape-back');
const svg = graph.container.querySelector('svg');
// back layer: before SVG → visually behind edges
graph.container.insertBefore(back, svg || null);
}
if (!graph._htmlFront) {
const front = graph._htmlFront = document.createElement('div');
s.css(front, {
position: 'absolute', width: '100%', height: '100%',
'touch-action': 'none', 'user-select': 'none', 'pointer-events': 'none',
'z-index': 0, 'transform-origin': 'left top',
});
front.classList.add('x6-html-shape-container', 'x6-html-shape-front');
// front layer: after SVG → visually above edges
graph.container.append(front);
}
// Keep legacy alias so updateHtmlContainerSize can iterate both
graph.htmlContainers = [graph._htmlBack, graph._htmlFront];
}
class BaseHTMLShapeView extends l {
confirmUpdate(e) {
const n = super.confirmUpdate(e);
return this.handleAction(n, a, () => {
if (!this.mounted) {
const t = m[this.cell.shape], o = this.ensureComponentContainer();
t && o && (this.mounted = t(this.cell, this.graph, o) || true,
this.onMounted(),
o.addEventListener("mousedown", this.prevEvent, true),
o.addEventListener("mouseup", this.prevEvent, true));
}
});
}
prevEvent(e) {
(x(e.target) || y(e.target)) && (e.preventDefault(), e.stopPropagation());
}
ensureComponentContainer() {}
onMounted() {}
onUnMount() {
if (this.onZIndexChange) {
this.cell.off("change:zIndex", this.onZIndexChange);
}
if (this.onNodeMoving) {
this.graph.off("node:moving", this.onNodeMoving);
}
}
unmount() {
typeof this.mounted == "function" && this.mounted();
this.componentContainer && this.componentContainer.remove();
this.onUnMount();
return super.unmount(), this;
}
}
BaseHTMLShapeView.config({ bootstrap: [a], actions: { component: a } });
class HTMLShapeView extends BaseHTMLShapeView {
constructor(...e) {
super(...e);
this.cell.on("change:visible", ({ cell: n }) => {
if (n.view === h) {
const t = this.graph.findViewByCell(n.id);
t && Promise.resolve().then(() => {
t.componentContainer.style.display = t.container.style.display;
});
}
});
}
onMounted() {
const listeners = this.graph.listeners;
// Always register per-cell zIndex listener regardless of shared transform events
this.onZIndexChange = () => this.updateContainerStyle();
this.cell.on("change:zIndex", this.onZIndexChange);
if (listeners?.hasTransformEvent?.length) return;
this.onTranslate = this.updateHtmlContainerSize.bind(this);
this.graph.on("translate", this.onTranslate);
this.graph.on("scale", this.onTranslate);
this.graph.on("node:change:position", this.onTranslate);
this.graph.on("hasTransformEvent", this.onTranslate);
// While dragging, lift this node's componentContainer to the top of its
// layer so its ports are never obscured by a sibling node underneath.
this.onNodeMoving = ({ node }) => {
if (node === this.cell && this.componentContainer) {
const layer = isBackNode(this.cell) ? this.graph._htmlBack : this.graph._htmlFront;
layer.append(this.componentContainer);
}
};
this.graph.on("node:moving", this.onNodeMoving);
this.updateHtmlContainerSize();
}
ensureComponentContainer() {
ensureHtmlLayers(this.graph);
const layer = isBackNode(this.cell) ? this.graph._htmlBack : this.graph._htmlFront;
if (!this.componentContainer) {
const e = this.componentContainer = document.createElement("div");
s.css(e, {
"pointer-events": "auto", "touch-action": "none", "user-select": "none",
"transform-origin": "center", position: "absolute"
});
e.classList.add("x6-html-shape-node");
"click,dblclick,contextmenu,mousedown,mousemove,mouseup,mouseover,mouseout,mouseenter,mouseleave"
.split(",").forEach(t => S(t, e, this.container));
layer.append(e);
}
return this.componentContainer;
}
resize() { super.resize(); this.updateContainerStyle(); }
updateTransform() { super.updateTransform(); this.updateContainerStyle(); }
updateContainerStyle() {
const e = this.ensureComponentContainer();
const { x: n, y: t } = this.cell.getBBox();
const { width: o, height: r } = this.cell.getSize();
const g = getComputedStyle(this.container).cursor;
const f = this.cell.getZIndex() ?? 0;
// Shrink the interactive width by the port hover radius (6px) so the right
// port circle is fully outside the componentContainer and never blocked by it.
// overflow:visible keeps the visual rendering intact.
const PORT_RADIUS = 6;
s.css(e, {
cursor: g, height: r + "px", width: (o - PORT_RADIUS) + "px",
overflow: "visible",
"z-index": f,
transform: `translate(${n}px, ${t}px) rotate(${this.cell.getAngle()}deg)`
});
}
updateHtmlContainerSize() {
const { graph: e } = this;
const t = e.transform.getMatrix();
const { offsetHeight: o, offsetWidth: r } = e.container;
const n = e.transform.getZoom();
const style = {
transform: `matrix(${t.a}, ${t.b}, ${t.c}, ${t.d}, ${t.e}, ${t.f})`,
width: r / n + "px",
height: o / n + "px",
};
// Update both layers
(e.htmlContainers || [e._htmlBack, e._htmlFront].filter(Boolean)).forEach(c => s.css(c, style));
}
}
l.registry.register(h, HTMLShapeView, true);
p.registry.register(u, T, true);
export { BaseHTMLShapeView, T as HTMLShape, u as HTMLShapeName, HTMLShapeView, h as HTMLView, a as action };

1
web/src/vendor/x6-html-shape/react.js vendored Normal file
View File

@@ -0,0 +1 @@
export { default } from "x6-html-shape/dist/react.js";

98
web/src/vendor/x6-html-shape/utils.js vendored Normal file
View File

@@ -0,0 +1,98 @@
import { Dom as u, ObjectExt as l, Markup as c } from "@antv/x6";
const o = "fo-shape-view";
function p(t, e, r) {
e.addEventListener(t, function(n) {
r.dispatchEvent(new n.constructor(n.type, n)), n.preventDefault(), n.stopPropagation();
});
}
function s(t, e = 3) {
return !t || !u.isHTMLElement(t) || e <= 0 ? !1 : ["a", "button"].includes(u.tagName(t)) || t.getAttribute("role") === "button" || t.getAttribute("type") === "button" ? !0 : s(t.parentNode, e - 1);
}
function g(t) {
if (u.tagName(t) === "input") {
const r = t.getAttribute("type");
if (r == null || ["text", "password", "number", "email", "search", "tel", "url"].includes(
r
))
return !0;
}
return !1;
}
function f(t = "rect", e = !0) {
return [
{
tagName: t,
selector: "body"
},
e ? c.getForeignObjectMarkup() : null,
{
tagName: "text",
selector: "label"
}
].filter((r) => r);
}
function b(t) {
return {
view: t,
markup: f("rect", t === o),
attrs: {
body: {
// fill: "none",
// 这里很奇怪none的时候不能触发节点移动改成transparent可以触发
fill: "transparent",
stroke: "none",
refWidth: "100%",
refHeight: "100%"
},
label: {
fontSize: 14,
fill: "#333",
refX: "50%",
refY: "50%",
textAnchor: "middle",
textVerticalAnchor: "middle"
},
fo: {
refWidth: "100%",
refHeight: "100%"
}
},
propHooks(e) {
if (e.markup == null) {
const { primer: r, view: n } = e;
if (r && r !== "rect") {
e.markup = f(r, n === o);
let i = {};
r === "circle" ? i = {
refCx: "50%",
refCy: "50%",
refR: "50%"
} : r === "ellipse" && (i = {
refCx: "50%",
refCy: "50%",
refRx: "50%",
refRy: "50%"
}), e.attrs = l.merge(
{},
{
body: {
refWidth: null,
refHeight: null,
...i
}
},
e.attrs || {}
);
}
}
return e;
}
};
}
export {
o as FOView,
s as clickable,
p as forwardEvent,
b as getConfig,
g as isInputElement
};

View File

@@ -7,7 +7,7 @@
import { type FC, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { useParams } from 'react-router-dom';
import { Flex, Button, Form } from 'antd';
import { Flex, Button } from 'antd';
import type { ColumnsType } from 'antd/es/table';
import { getAppLogsUrl } from '@/api/application';
@@ -15,14 +15,11 @@ import Table from '@/components/Table'
import { formatDateTime } from '@/utils/format';
import type { LogItem, LogDetailModalRef } from './types'
import LogDetailModal from './components/LogDetailModal'
import SearchInput from '@/components/SearchInput'
const Statistics: FC = () => {
const { t } = useTranslation();
const { id } = useParams();
const logDetailRef = useRef<LogDetailModalRef>(null);
const [form] = Form.useForm();
const values = Form.useWatch([], form);
const handleViewDetail = (item: LogItem) => {
logDetailRef.current?.handleOpen(item);
@@ -65,26 +62,15 @@ const Statistics: FC = () => {
];
return (
<div className="rb:bg-white rb:rounded-lg rb:pt-3 rb:px-3">
<Flex justify="flex-end" className="rb:mb-3!">
<Form form={form}>
<Form.Item name="keyword" noStyle>
<SearchInput
placeholder={t('application.logSearchPlaceholder')}
variant="outlined"
/>
</Form.Item>
</Form>
</Flex>
<Table<LogItem>
apiUrl={getAppLogsUrl(id || '')}
apiParams={{
is_draft: false,
...(values ?? {})
}}
columns={columns}
rowKey="id"
isScroll={true}
scrollY="calc(100vh - 242px)"
scrollY="calc(100vh - 214px)"
/>
<LogDetailModal ref={logDetailRef} />
</div>

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-03-13 17:27:52
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-24 18:14:25
* @Last Modified time: 2026-04-07 21:48:30
*/
import { type FC, useState, useRef, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
@@ -59,7 +59,6 @@ interface NodeData {
node_type?: string;
input?: any;
output?: any;
process?: any;
elapsed_time?: string;
error?: any;
state: Record<string, any>;
@@ -486,7 +485,7 @@ const TestChat: FC<TestChatProps> = ({
}
const updateWorkflowNodeEndMessage = (data: NodeData) => {
const { node_id, input, output, process, error, elapsed_time, status } = data;
const { node_id, input, output, error, elapsed_time, status } = data;
setChatList(prev => {
const newList = [...prev]
const lastIndex = newList.length - 1
@@ -499,7 +498,6 @@ const TestChat: FC<TestChatProps> = ({
content: {
input,
output,
process,
error,
},
status: status || 'completed',
@@ -516,7 +514,7 @@ const TestChat: FC<TestChatProps> = ({
}
const updateWorkflowCycleMessage = (data: NodeData) => {
const { node_id, cycle_id, cycle_idx, input, output, process, error, elapsed_time, status } = data;
const { node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = data;
const { nodes } = config as WorkflowConfig
const node = nodes.find(n => n.id === node_id);
const { name, type } = node || {}
@@ -540,7 +538,6 @@ const TestChat: FC<TestChatProps> = ({
cycle_idx,
input,
output,
process,
error,
},
status: status || 'completed',

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-03-24 16:31:24
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-24 17:49:58
* @Last Modified time: 2026-03-24 16:31:24
*/
import { forwardRef, useImperativeHandle, useState, useEffect } from 'react';
import { Flex, Button, Empty, Skeleton } from 'antd';
@@ -14,12 +14,6 @@ import { getAppLogDetail } from '@/api/application'
import ChatContent from '@/components/Chat/ChatContent'
import { formatDateTime } from '@/utils/format'
import type { ChatItem } from '@/components/Chat/types'
import Runtime from '@/views/Workflow/components/Chat/Runtime'
import { nodeLibrary } from '@/views/Workflow/constant'
const nodeIconMap = Object.fromEntries(
nodeLibrary.flatMap(c => c.nodes.map(n => [n.type, n.icon]))
)
/** Log detail data with conversation messages */
type Data = LogItem & {
@@ -60,30 +54,7 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
if (!vo) return
setLoading(true)
getAppLogDetail(vo.app_id, vo.id).then(res => {
const { node_executions_map, messages, ...rest } = res as Data;
let hasSubContentMessages = messages
if (messages && messages.length > 0 && node_executions_map && Object.keys(node_executions_map).length > 0) {
hasSubContentMessages = messages.map(item => {
if (item.id && node_executions_map[item.id]) {
item.subContent = node_executions_map[item.id]?.map(({ input, output, cycle_items = [], error, process, ...node }: any) => {
const converted: any = { ...node, icon: nodeIconMap[node.node_type], content: { input, output, process, error } }
if (node.node_type === 'loop' && Array.isArray(cycle_items) && cycle_items.length > 0) {
converted.subContent = cycle_items.map(({ input: cInput, output: cOutput, error: cError, process: cProcess, ...cNode }: any) => ({
...cNode,
icon: nodeIconMap[cNode.node_type],
content: { input: cInput, output: cOutput, process: cProcess, error: cError }
}))
}
return converted
})
}
return { ...item }
})
}
setData({
...rest,
messages: hasSubContentMessages
})
setData(res as Data)
})
.finally(() => {
setLoading(false)
@@ -95,8 +66,6 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
handleClose
}));
console.log('data', data)
return (
<RbModal
title={<>
@@ -123,7 +92,6 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
data={data.messages || []}
streamLoading={false}
labelFormat={(item) => formatDateTime(item.created_at)}
renderRuntime={(item, index) => <Runtime item={item} index={index} />}
/>
)
}

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-06 21:10:56
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-24 18:13:22
* @Last Modified time: 2026-04-21 14:59:13
*/
/**
* Workflow Chat Component
@@ -66,7 +66,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
const [fileList, setFileList] = useState<any[]>([])
const [message, setMessage] = useState<string | undefined>(undefined)
console.log('abortRef', abortRef, chatList)
console.log('abortRef', abortRef)
/**
* Opens the chat drawer and loads workflow variables from the start node
@@ -185,7 +185,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
*/
const handleStreamMessage = (data: SSEMessage[]) => {
data.forEach(item => {
const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, process, error, elapsed_time, status, citations } = item.data as {
const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status, citations } = item.data as {
content: string;
conversation_id: string | null;
cycle_id: string;
@@ -193,7 +193,6 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
node_id: string;
node_name?: string;
node_type?: string;
process?: any;
input?: any;
output?: any;
elapsed_time?: string;
@@ -278,7 +277,6 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
content: {
input,
output,
process,
error,
},
status: status || 'completed',
@@ -307,14 +305,13 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
cycle_id,
cycle_idx,
node_id,
node_name: type === 'cycle-start' ? t('workflow.cycle-start') : name,
node_name: name,
node_type: type,
icon,
content: {
cycle_idx,
input,
output,
process,
error,
},
status: status || 'completed',

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-24 17:57:08
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-24 18:04:31
* @Last Modified time: 2026-04-20 15:33:48
*/
/*
* Runtime Component
@@ -184,9 +184,7 @@ const Runtime: FC<{ item: ChatItem; index: number;}> = ({
</Flex>
)}
{/* Display input and output data as JSON code blocks */}
{['input', 'process', 'output'].map(key => {
if (vo.node_type !== 'http-request' && key === 'process') return null
return (
{['input', 'output'].map(key => (
<div key={key} className="rb:bg-[#EBEBEB] rb:rounded-lg">
<div className="rb:py-2 rb:px-3 rb:flex rb:justify-between rb:items-center rb:text-[12px]">
{isLoop ? t(`workflow.runtime.${key}_cycle_vars`) : t(`workflow.${key}_result`)}
@@ -206,8 +204,7 @@ const Runtime: FC<{ item: ChatItem; index: number;}> = ({
/>
</div>
</div>
)
})}
))}
</Flex>
)
}]}

View File

@@ -2,183 +2,16 @@
* @Author: ZhaoYing
* @Date: 2026-02-09 18:31:30
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-30 11:55:10
* @Last Modified time: 2026-04-28 10:24:58
*/
import { useState } from 'react';
import { Popover, Flex } from 'antd';
import { Flex } from 'antd';
import clsx from 'clsx';
import type { ReactShapeConfig } from '@antv/x6-react-shape';
import { nodeLibrary, graphNodeLibrary, edgeAttrs, nodeWidth } from '../../constant';
import { useTranslation } from 'react-i18next';
const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
const AddNode: ReactShapeConfig['component'] = ({ node }) => {
const data = node?.getData() || {};
const { t } = useTranslation();
const [open, setOpen] = useState(false);
// Handle node selection from popover and create new node replacing the add-node placeholder
const handleNodeSelect = (selectedNodeType: any) => {
const parentBBox = node.getBBox();
const cycleId = data.cycle;
const horizontalSpacing = 0;
const id = `${selectedNodeType.type.replace(/-/g, '_') }_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const newNode = graph.addNode({
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
x: parentBBox.x + horizontalSpacing,
y: parentBBox.y - 12,
id,
data: {
id,
type: selectedNodeType.type,
icon: selectedNodeType.icon,
name: t(`workflow.${selectedNodeType.type}`),
cycle: cycleId,
parentId: data.parentId,
config: selectedNodeType.config || {}
},
});
// Add new node as child of parent node
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) {
parentNode.addChild(newNode);
}
}
const incomingEdges = graph.getIncomingEdges(node);
const outgoingEdges = graph.getOutgoingEdges(node);
const addedEdges: any[] = [];
incomingEdges?.forEach((edge: any) => {
addedEdges.push(graph.addEdge({
source: { cell: edge.getSourceCellId(), port: edge.getSourcePortId() },
target: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'left')?.id || 'left' },
...edgeAttrs
}));
});
outgoingEdges?.forEach((edge: any) => {
const targetCell = graph.getCellById(edge.getTargetCellId()) as any;
const targetPortId = targetCell?.getPorts?.()?.find((port: any) => port.group === 'left')?.id || edge.getTargetPortId();
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'right')?.id || 'right' },
target: { cell: edge.getTargetCellId(), port: targetPortId },
...edgeAttrs
}));
});
// Remove all add-node type nodes
graph.getNodes().forEach((n: any) => {
if (n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId) {
n.remove();
}
});
setTimeout(() => {
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
}, 50);
// Automatically adjust loop node size
const loopNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (loopNode) {
const adjustLoopSize = () => {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc, child) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
loopNode.prop('size', { width: newWidth, height: newHeight });
// Update right port x position
const ports = loopNode.getPorts();
ports.forEach(port => {
if (port.group === 'right' && port.args) {
loopNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
};
adjustLoopSize();
// Listen to child node movement events
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
childNodes.forEach((childNode: any) => {
childNode.on('change:position', adjustLoopSize);
});
}
setOpen(false);
};
const content = (
<div style={{ maxHeight: '300px', overflowY: 'auto', minWidth: `${nodeWidth}px'` }}>
{nodeLibrary.map((category, categoryIndex) => {
const filteredNodes = category.nodes.filter(nodeType =>
nodeType.type !== 'start' && nodeType.type !== 'end' && nodeType.type !== 'iteration' && nodeType.type !== 'loop' && nodeType.type !== 'cycle-start'
);
if (filteredNodes.length === 0) return null;
return (
<div key={category.category}>
{categoryIndex > 0 && <div style={{ height: '1px', background: '#f0f0f0', margin: '4px 0' }} />}
<div style={{ padding: '4px 12px', fontSize: '12px', color: '#999', fontWeight: 'bold' }}>
{t(`workflow.${category.category}`)}
</div>
{filteredNodes.map((nodeType) => (
<div
key={nodeType.type}
style={{
padding: '8px 12px',
cursor: 'pointer',
display: 'flex',
alignItems: 'center',
gap: '8px',
}}
onClick={() => handleNodeSelect(nodeType)}
onMouseEnter={(e) => {
e.currentTarget.style.background = '#f0f8ff';
}}
onMouseLeave={(e) => {
e.currentTarget.style.background = 'white';
}}
>
<div className={`rb:size-4 rb:bg-cover ${nodeType.icon}`} />
<span style={{ fontSize: '14px' }}>{t(`workflow.${nodeType.type}`)}</span>
</div>
))}
</div>
);
})}
</div>
);
return (
<Popover
content={content}
trigger="click"
open={open}
onOpenChange={setOpen}
placement="bottomLeft"
>
<Flex
align="center"
justify="center"
@@ -191,7 +24,6 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
<div className="rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/workflow/node_plus.png')]"></div>
{data.label}
</Flex>
</Popover>
);
};

View File

@@ -65,8 +65,8 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
return (
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus,
'rb:border-[#171719]!': data.isSelected,
'rb:border-[#FCFCFD]': !data.isSelected,
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
})}>

View File

@@ -131,8 +131,8 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
return (
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus,
'rb:border-[#171719]': data.isSelected,
'rb:border-[#FCFCFD]': !data.isSelected,
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
})}>

View File

@@ -12,8 +12,8 @@ const NormalNode: ReactShapeConfig['component'] = ({ node }) => {
return (
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus,
'rb:border-[#171719]!': data.isSelected,
'rb:border-[#FCFCFD]': !data.isSelected,
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
})}>

View File

@@ -2,13 +2,44 @@
* @Author: ZhaoYing
* @Date: 2026-02-09 18:30:28
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-30 15:14:02
* @Last Modified time: 2026-04-28 11:41:17
*/
import { useEffect, useState } from 'react';
import { createPortal } from 'react-dom';
import { Flex, Popover } from 'antd';
import { useTranslation } from 'react-i18next';
import { nodeLibrary, graphNodeLibrary, edgeAttrs, nodeWidth } from '../constant';
// Shared helper: adjust loop/iteration container size to fit child nodes
export const adjustCycleContainerSize = (graph: any, cycleId: string) => {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (!parentNode) return;
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length === 0) return;
const bounds = childNodes.reduce((acc: any, child: any) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height),
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
parentNode.prop('size', { width: newWidth, height: newHeight });
parentNode.getPorts().forEach((port: any) => {
if (port.group === 'right' && port.args) {
parentNode.portProp(port.id!, 'args/x', newWidth);
}
});
childNodes.forEach((childNode: any) => {
childNode.off('change:position');
childNode.on('change:position', () => adjustCycleContainerSize(graph, cycleId));
});
};
interface PortClickHandlerProps {
graph: any;
}
@@ -16,7 +47,6 @@ interface PortClickHandlerProps {
const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
const { t } = useTranslation();
const [popoverVisible, setPopoverVisible] = useState(false);
const [popoverPosition, setPopoverPosition] = useState({ x: 0, y: 0 });
const [sourceNode, setSourceNode] = useState<any>(null);
const [sourcePort, setSourcePort] = useState<string>('');
const [tempElement, setTempElement] = useState<HTMLElement | null>(null);
@@ -24,12 +54,11 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
useEffect(() => {
const handlePortClick = (event: CustomEvent) => {
const { node, port, element, rect, edgeInsertion } = event.detail;
const { node, port, element, edgeInsertion } = event.detail;
setSourceNode(node);
setSourcePort(port);
setTempElement(element);
setEdgeInsertion(edgeInsertion || null);
setPopoverPosition({ x: rect.left, y: rect.top });
setPopoverVisible(true);
};
@@ -50,6 +79,66 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
const sourceNodeData = sourceNode.getData();
const sourceNodeType = sourceNodeData?.type;
// AddNode placeholder mode: replace the add-node placeholder with the selected node
if (sourceNodeType === 'add-node') {
const placeholderBBox = sourceNode.getBBox();
const cycleId = sourceNodeData.cycle;
const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const newNode = graph.addNode({
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
x: placeholderBBox.x,
y: placeholderBBox.y - 12,
id,
data: {
id,
type: selectedNodeType.type,
icon: selectedNodeType.icon,
name: t(`workflow.${selectedNodeType.type}`),
cycle: cycleId,
parentId: sourceNodeData.parentId,
config: selectedNodeType.config || {},
},
});
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) parentNode.addChild(newNode);
}
const incomingEdges = graph.getIncomingEdges(sourceNode);
const outgoingEdges = graph.getOutgoingEdges(sourceNode);
const addedEdges: any[] = [];
incomingEdges?.forEach((edge: any) => {
addedEdges.push(graph.addEdge({
source: { cell: edge.getSourceCellId(), port: edge.getSourcePortId() },
target: { cell: newNode.id, port: newNode.getPorts().find((p: any) => p.group === 'left')?.id || 'left' },
...edgeAttrs,
}));
});
outgoingEdges?.forEach((edge: any) => {
const targetCell = graph.getCellById(edge.getTargetCellId()) as any;
const targetPortId = targetCell?.getPorts?.()?.find((p: any) => p.group === 'left')?.id || edge.getTargetPortId();
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: newNode.getPorts().find((p: any) => p.group === 'right')?.id || 'right' },
target: { cell: edge.getTargetCellId(), port: targetPortId },
...edgeAttrs,
}));
});
graph.getNodes().forEach((n: any) => {
if (n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId) n.remove();
});
setTimeout(() => {
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
}, 50);
if (cycleId) adjustCycleContainerSize(graph, cycleId);
if (tempElement) { document.body.removeChild(tempElement); setTempElement(null); }
setPopoverVisible(false);
return;
}
// If it's a cycle-start node, handle the add-node placeholder
let addNodePosition = null;
const isCycleSubNode = sourceNodeData.cycle
@@ -228,48 +317,7 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
// Adjust loop node size when child node is added via port within loop node
const cycleId = sourceNodeData.cycle;
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) {
const adjustLoopSize = () => {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc: any, child: any) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
parentNode.prop('size', { width: newWidth, height: newHeight });
// Update right port x position
const ports = parentNode.getPorts();
ports.forEach((port: any) => {
if (port.group === 'right' && port.args) {
parentNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
};
adjustLoopSize();
// Listen to child node movement events
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
childNodes.forEach((childNode: any) => {
childNode.on('change:position', adjustLoopSize);
});
}
}
if (cycleId) adjustCycleContainerSize(graph, cycleId);
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
const newNodeType = selectedNodeType.type;
@@ -372,22 +420,18 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
if (!tempElement) return null;
return (
return createPortal(
<Popover
content={content}
open={popoverVisible}
onOpenChange={(visible) => {
if (!visible) handlePopoverClose();
}}
onOpenChange={(visible) => { if (!visible) handlePopoverClose(); }}
placement="right"
overlayStyle={{
position: 'fixed',
left: popoverPosition.x + 10,
top: popoverPosition.y - 10,
}}
autoAdjustOverflow
getPopupContainer={() => document.body}
>
<div />
</Popover>
<div style={{ width: '1px', height: '1px' }} />
</Popover>,
tempElement
);
};

View File

@@ -2,14 +2,16 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 15:17:48
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-24 17:21:09
* @Last Modified time: 2026-04-28 12:07:33
*/
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
import type { HistoryCommand as Command } from '@antv/x6/lib/plugin/history/type';
import { register } from '@antv/x6-react-shape';
import { register as registerReactShape } from '@antv/x6-react-shape';
import type { PortMetadata } from '@antv/x6/lib/model/port';
import { App } from 'antd';
import { useEffect, useRef, useState } from 'react';
import { useEffect, useRef, useState, createElement } from 'react';
import type { RefObject, Dispatch, SetStateAction, MutableRefObject, DragEvent } from 'react';
import { createRoot } from 'react-dom/client';
import { useTranslation } from 'react-i18next';
import { useParams } from 'react-router-dom';
@@ -21,14 +23,16 @@ import type { ChatVariable, NodeProperties, WorkflowConfig } from '../types';
import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils';
import { useWorkflowStore } from '@/store/workflow';
const isSafari = /^((?!chrome|android).)*safari/i.test(navigator.userAgent);
/**
* Props for useWorkflowGraph hook
*/
export interface UseWorkflowGraphProps {
/** Reference to the main graph container element */
containerRef: React.RefObject<HTMLDivElement>;
containerRef: RefObject<HTMLDivElement>;
/** Reference to the minimap container element */
miniMapRef: React.RefObject<HTMLDivElement>;
miniMapRef: RefObject<HTMLDivElement>;
/** Callback when features config is loaded */
onFeaturesLoad?: (features: FeaturesConfigForm | undefined) => void;
}
@@ -40,23 +44,23 @@ export interface UseWorkflowGraphReturn {
/** Current workflow configuration */
config: WorkflowConfig | null;
/** Function to update workflow configuration */
setConfig: React.Dispatch<React.SetStateAction<WorkflowConfig | null>>;
setConfig: Dispatch<SetStateAction<WorkflowConfig | null>>;
/** Reference to the X6 graph instance */
graphRef: React.MutableRefObject<Graph | undefined>;
graphRef: MutableRefObject<Graph | undefined>;
/** Currently selected node */
selectedNode: Node | null;
/** Function to update selected node */
setSelectedNode: React.Dispatch<React.SetStateAction<Node | null>>;
setSelectedNode: Dispatch<SetStateAction<Node | null>>;
/** Current zoom level of the graph */
zoomLevel: number;
/** Function to update zoom level */
setZoomLevel: React.Dispatch<React.SetStateAction<number>>;
setZoomLevel: Dispatch<SetStateAction<number>>;
/** Whether hand/pan mode is enabled */
isHandMode: boolean;
/** Function to toggle hand mode */
setIsHandMode: React.Dispatch<React.SetStateAction<boolean>>;
setIsHandMode: Dispatch<SetStateAction<boolean>>;
/** Handler for dropping nodes onto canvas */
onDrop: (event: React.DragEvent) => void;
onDrop: (event: DragEvent) => void;
/** Handler for clicking blank canvas area */
blankClick: () => void;
/** Handler for delete keyboard event */
@@ -78,7 +82,7 @@ export interface UseWorkflowGraphReturn {
/** Chat variables for workflow */
chatVariables: ChatVariable[];
/** Function to update chat variables */
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>;
setChatVariables: Dispatch<SetStateAction<ChatVariable[]>>;
handleAddNotes: () => void;
handleSaveFeaturesConfig: (value: FeaturesConfigForm) => void;
@@ -160,6 +164,21 @@ export const useWorkflowGraph = ({
initWorkflow()
}, [config, graphRef.current])
/**
* Assign explicit zIndex values to enforce layer order:
* parent nodes (loop/iteration) → child edges → child nodes
* Ports live inside each node's SVG container and are always above
* edges once the node zIndex is higher than the edge zIndex.
*/
const reorderCells = (graph: Graph) => {
// Safari uses x6-html-shape (dual HTML layer architecture).
// zIndex controls order within each HTML layer and SVG layer.
graph.getEdges().forEach(edge => edge.setZIndex(0));
graph.getNodes().forEach(node => {
node.setZIndex(node.getData()?.cycle ? 2 : 1);
});
};
/**
* Initialize workflow graph with nodes and edges from configuration
*/
@@ -470,6 +489,9 @@ export const useWorkflowGraph = ({
if (nodes.length > 0 || edges.length > 0) {
setTimeout(() => {
if (graphRef.current) {
if (isSafari) {
reorderCells(graphRef.current)
} else {
graphRef.current.getNodes().forEach(node => {
if (!node.getData()?.cycle) node.toFront();
});
@@ -484,10 +506,11 @@ export const useWorkflowGraph = ({
graphRef.current.getNodes().forEach(node => {
if (node.getData()?.cycle) node.toFront();
});
}
graphRef.current.enableHistory()
graphRef.current.cleanHistory()
}
}, 200)
}, isSafari ? 0 : 200)
}
}
/**
@@ -551,12 +574,33 @@ export const useWorkflowGraph = ({
* @param node - Clicked node
*/
const nodeClick = ({ node }: { node: Node }) => {
// add-node type: dispatch port:click to open node selection popover
// Must handle before blankClick() to avoid blank:click closing the popover immediately
const nodeData = node.getData()
if (nodeData?.type === 'add-node') {
const bbox = node.getBBox();
const screenPos = graphRef.current!.localToClient(bbox.x + bbox.width, bbox.y + bbox.height / 2);
const tempDiv = document.createElement('div');
tempDiv.style.cssText = `position:fixed;left:${screenPos.x}px;top:${screenPos.y}px;width:1px;height:1px;z-index:9999;`;
document.body.appendChild(tempDiv);
window.dispatchEvent(new CustomEvent('port:click', {
detail: {
node,
port: 'right',
element: tempDiv,
rect: { left: screenPos.x, top: screenPos.y },
edgeInsertion: null,
},
}));
return;
}
blankClick()
setTimeout(() => {
// Ignore add-node type node clicks
const nodeData = node.getData()
if (nodeData?.type === 'add-node' || nodeData.type === 'break' || nodeData.type === 'cycle-start') {
if (nodeData.type === 'break' || nodeData.type === 'cycle-start') {
setSelectedNode(null)
return;
}
@@ -644,7 +688,8 @@ export const useWorkflowGraph = ({
const cycle = node.getData()?.cycle;
if (cycle) {
const parentNode = graphRef.current!.getNodes().find(n => n.id === cycle);
if (parentNode?.getData()?.isGroup) {
const parentType = parentNode?.getData()?.type;
if (parentNode && (parentType === 'loop' || parentType === 'iteration')) {
// Get parent node and child node bounding boxes
const parentBBox = parentNode.getBBox();
const childBBox = node.getBBox();
@@ -857,13 +902,35 @@ export const useWorkflowGraph = ({
/**
* Initialize X6 graph with configuration and event listeners
*/
const init = () => {
const init = async () => {
if (!containerRef.current || !miniMapRef.current) return;
// Register React shapes
nodeRegisterLibrary.forEach((item) => {
register(item);
// Safari: use x6-html-shape to avoid foreignObject rendering issues
if (isSafari) {
const { register: registerHtmlShape } = await import('x6-html-shape');
nodeRegisterLibrary.forEach(({ shape, width, height, component }) => {
registerHtmlShape({
shape,
width,
height,
render(node: Node, _graph: unknown, container: HTMLElement) {
const root = createRoot(container);
const doRender = () => {
root.render(createElement(component as any, { node, graph: node.model?.graph, data: node.getData() }));
};
doRender();
node.on('change:data', doRender);
return () => {
node.off('change:data', doRender);
root.unmount();
};
},
});
});
} else {
nodeRegisterLibrary.forEach((item) => registerReactShape(item));
}
const container = containerRef.current;
graphRef.current = new Graph({
@@ -1053,10 +1120,71 @@ export const useWorkflowGraph = ({
// Listen to node move event
graphRef.current.on('node:moved', nodeMoved);
if (isSafari) {
// When a parent (loop/iteration) node moves, keep child nodes in sync.
// Store each child's offset relative to the parent at drag start, then
// reapply it every frame to avoid cumulative delta errors.
const dragOffsets = new Map<string, { dx: number; dy: number }>();
graphRef.current.on('node:moving', ({ node }: { node: Node }) => {
const data = node.getData();
if (data?.type !== 'loop' && data?.type !== 'iteration') return;
const pos = node.getPosition();
const PORT_RADIUS = 6;
// Update parent componentContainer directly
const parentView = graphRef.current?.findViewByCell(node) as any;
if (parentView?.componentContainer) {
parentView.componentContainer.style.transform =
`translate(${pos.x + PORT_RADIUS}px, ${pos.y}px)`;
}
const children = graphRef.current?.getNodes().filter(child => {
const cycle = child.getData()?.cycle;
return cycle === data.id || cycle === node.id;
}) ?? [];
// First event for this drag: record offsets
if (!dragOffsets.has(node.id)) {
children.forEach(child => {
const cp = child.getPosition();
dragOffsets.set(child.id, { dx: cp.x - pos.x, dy: cp.y - pos.y });
});
}
// Apply stored offsets to keep children in place relative to parent
children.forEach(child => {
const off = dragOffsets.get(child.id);
if (!off) return;
const nx = pos.x + off.dx;
const ny = pos.y + off.dy;
child.setPosition(nx, ny);
const childView = graphRef.current?.findViewByCell(child) as any;
if (childView?.componentContainer) {
childView.componentContainer.style.transform =
`translate(${nx + PORT_RADIUS}px, ${ny}px)`;
}
});
});
graphRef.current.on('node:moved', ({ node }: { node: Node }) => {
// Clear offsets for this parent and all its children
const data = node.getData();
graphRef.current?.getNodes().forEach(child => {
const cycle = child.getData()?.cycle;
if (cycle === data?.id || cycle === node.id) dragOffsets.delete(child.id);
});
dragOffsets.delete(node.id);
nodeMoved({ node });
});
}
graphRef.current.on('node:removed', blankClick)
// When edge connected, bring connected nodes' ports to front
// When edge connected, reorder all cells to maintain correct layer order
graphRef.current.on('edge:connected', ({ isNew, edge }) => {
if (isNew) {
if (isSafari && isNew && graphRef.current) {
reorderCells(graphRef.current);
} else if (!isSafari && isNew) {
const sourceCellId = edge.getSourceCellId()
const targetCellId = edge.getTargetCellId()
const sourceCell = graphRef.current?.getCellById(sourceCellId);
@@ -1170,7 +1298,7 @@ export const useWorkflowGraph = ({
* Creates new node at drop position
* @param event - React drag event
*/
const onDrop = (event: React.DragEvent) => {
const onDrop = (event: DragEvent) => {
if (!graphRef.current) return;
event.preventDefault();
const dragData = JSON.parse(event.dataTransfer.getData('application/json'));
@@ -1492,7 +1620,7 @@ export const useWorkflowGraph = ({
// Reset all node execution status first
nodes.forEach(node => {
const data = node.getData();
if (typeof data.executionStatus === 'string') {
if (typeof data.status === 'string') {
node.setData({ ...data, executionStatus: undefined });
}
});

View File

@@ -44,6 +44,9 @@ export default defineConfig({
resolve: {
alias: {
'@': resolve(__dirname, 'src'),
'x6-html-shape': resolve(__dirname, 'src/vendor/x6-html-shape/index.js'),
'x6-html-shape/dist/react': resolve(__dirname, 'src/vendor/x6-html-shape/react.js'),
'x6-html-shape/dist/utils.js': resolve(__dirname, 'src/vendor/x6-html-shape/utils.js'),
},
},
base: './', // 使用相对路径,确保资源能正确加载