Compare commits
1 Commits
main
...
feature/sa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8476f3b7a8 |
7
.github/workflows/sync-to-gitee.yml
vendored
7
.github/workflows/sync-to-gitee.yml
vendored
@@ -3,9 +3,12 @@ name: Sync to Gitee
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- '**' # All branchs
|
- main # Production
|
||||||
|
- develop # Integration
|
||||||
|
- 'release/*' # Release preparation
|
||||||
|
- 'hotfix/*' # Urgent fixes
|
||||||
tags:
|
tags:
|
||||||
- '**' # All version tags (v1.0.0, etc.)
|
- '*' # All version tags (v1.0.0, etc.)
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
sync:
|
sync:
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ def _mask_url(url: str) -> str:
|
|||||||
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
|
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
|
||||||
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
|
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
|
||||||
|
|
||||||
|
|
||||||
# macOS fork() safety - must be set before any Celery initialization
|
# macOS fork() safety - must be set before any Celery initialization
|
||||||
if platform.system() == 'Darwin':
|
if platform.system() == 'Darwin':
|
||||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
@@ -30,7 +29,7 @@ if platform.system() == 'Darwin':
|
|||||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||||
|
|
||||||
_broker_url = os.getenv("CELERY_BROKER_URL") or \
|
_broker_url = os.getenv("CELERY_BROKER_URL") or \
|
||||||
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||||
@@ -67,11 +66,11 @@ celery_app.conf.update(
|
|||||||
task_serializer='json',
|
task_serializer='json',
|
||||||
accept_content=['json'],
|
accept_content=['json'],
|
||||||
result_serializer='json',
|
result_serializer='json',
|
||||||
|
|
||||||
# # 时区
|
# # 时区
|
||||||
# timezone='Asia/Shanghai',
|
# timezone='Asia/Shanghai',
|
||||||
# enable_utc=False,
|
# enable_utc=False,
|
||||||
|
|
||||||
# 任务追踪
|
# 任务追踪
|
||||||
task_track_started=True,
|
task_track_started=True,
|
||||||
task_ignore_result=False,
|
task_ignore_result=False,
|
||||||
|
|||||||
@@ -1,500 +0,0 @@
|
|||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
import redis
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.logging_config import get_named_logger
|
|
||||||
from app.celery_app import celery_app
|
|
||||||
|
|
||||||
logger = get_named_logger("task_scheduler")
|
|
||||||
|
|
||||||
# per-user queue scheduler:uq:{user_id}
|
|
||||||
USER_QUEUE_PREFIX = "scheduler:uq:"
|
|
||||||
# User Collection of Pending Messages
|
|
||||||
ACTIVE_USERS = "scheduler:active_users"
|
|
||||||
# Set of users that can dispatch (ready signal)
|
|
||||||
READY_SET = "scheduler:ready_users"
|
|
||||||
# Metadata of tasks that have been dispatched and are pending completion
|
|
||||||
PENDING_HASH = "scheduler:pending_tasks"
|
|
||||||
# Dynamic Sharding: Instance Registry
|
|
||||||
REGISTRY_KEY = "scheduler:instances"
|
|
||||||
|
|
||||||
TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded
|
|
||||||
HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds)
|
|
||||||
INSTANCE_TTL = 30 # Instance timeout (seconds)
|
|
||||||
|
|
||||||
LUA_ATOMIC_LOCK = """
|
|
||||||
local dispatch_lock = KEYS[1]
|
|
||||||
local lock_key = KEYS[2]
|
|
||||||
local instance_id = ARGV[1]
|
|
||||||
local dispatch_ttl = tonumber(ARGV[2])
|
|
||||||
local lock_ttl = tonumber(ARGV[3])
|
|
||||||
|
|
||||||
if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then
|
|
||||||
return 0
|
|
||||||
end
|
|
||||||
|
|
||||||
if redis.call('EXISTS', lock_key) == 1 then
|
|
||||||
redis.call('DEL', dispatch_lock)
|
|
||||||
return -1
|
|
||||||
end
|
|
||||||
|
|
||||||
redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl)
|
|
||||||
return 1
|
|
||||||
"""
|
|
||||||
|
|
||||||
LUA_SAFE_DELETE = """
|
|
||||||
if redis.call('GET', KEYS[1]) == ARGV[1] then
|
|
||||||
return redis.call('DEL', KEYS[1])
|
|
||||||
end
|
|
||||||
return 0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def stable_hash(value: str) -> int:
|
|
||||||
return int.from_bytes(
|
|
||||||
hashlib.md5(value.encode("utf-8")).digest(),
|
|
||||||
"big"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def health_check_server(scheduler_ref):
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
health_app = FastAPI()
|
|
||||||
|
|
||||||
@health_app.get("/")
|
|
||||||
def health():
|
|
||||||
return scheduler_ref.health()
|
|
||||||
|
|
||||||
port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001"))
|
|
||||||
threading.Thread(
|
|
||||||
target=uvicorn.run,
|
|
||||||
kwargs={
|
|
||||||
"app": health_app,
|
|
||||||
"host": "0.0.0.0",
|
|
||||||
"port": port,
|
|
||||||
"log_config": None,
|
|
||||||
},
|
|
||||||
daemon=True,
|
|
||||||
).start()
|
|
||||||
logger.info("[Health] Server started at http://0.0.0.0:%s", port)
|
|
||||||
|
|
||||||
|
|
||||||
class RedisTaskScheduler:
|
|
||||||
def __init__(self):
|
|
||||||
self.redis = redis.Redis(
|
|
||||||
host=settings.REDIS_HOST,
|
|
||||||
port=settings.REDIS_PORT,
|
|
||||||
db=settings.REDIS_DB_CELERY_BACKEND,
|
|
||||||
password=settings.REDIS_PASSWORD,
|
|
||||||
decode_responses=True,
|
|
||||||
)
|
|
||||||
self.running = False
|
|
||||||
self.dispatched = 0
|
|
||||||
self.errors = 0
|
|
||||||
|
|
||||||
self.instance_id = f"{socket.gethostname()}-{os.getpid()}"
|
|
||||||
self._shard_index = 0
|
|
||||||
self._shard_count = 1
|
|
||||||
self._last_heartbeat = 0.0
|
|
||||||
|
|
||||||
def push_task(self, task_name, user_id, params):
|
|
||||||
try:
|
|
||||||
msg_id = str(uuid.uuid4())
|
|
||||||
msg = json.dumps({
|
|
||||||
"msg_id": msg_id,
|
|
||||||
"task_name": task_name,
|
|
||||||
"user_id": user_id,
|
|
||||||
"params": json.dumps(params),
|
|
||||||
})
|
|
||||||
|
|
||||||
lock_key = f"{task_name}:{user_id}"
|
|
||||||
queue_key = f"{USER_QUEUE_PREFIX}{user_id}"
|
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
pipe.rpush(queue_key, msg)
|
|
||||||
pipe.sadd(ACTIVE_USERS, user_id)
|
|
||||||
pipe.set(
|
|
||||||
f"task_tracker:{msg_id}",
|
|
||||||
json.dumps({"status": "QUEUED", "task_id": None}),
|
|
||||||
ex=86400,
|
|
||||||
)
|
|
||||||
pipe.execute()
|
|
||||||
|
|
||||||
if not self.redis.exists(lock_key):
|
|
||||||
self.redis.sadd(READY_SET, user_id)
|
|
||||||
|
|
||||||
logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id)
|
|
||||||
return msg_id
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Push task exception %s", e, exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_task_status(self, msg_id: str) -> dict:
|
|
||||||
raw = self.redis.get(f"task_tracker:{msg_id}")
|
|
||||||
if raw is None:
|
|
||||||
return {"status": "NOT_FOUND"}
|
|
||||||
|
|
||||||
tracker = json.loads(raw)
|
|
||||||
status = tracker["status"]
|
|
||||||
task_id = tracker.get("task_id")
|
|
||||||
result_content = tracker.get("result") or {}
|
|
||||||
|
|
||||||
if status == "DISPATCHED" and task_id:
|
|
||||||
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
|
|
||||||
if result_raw:
|
|
||||||
result_data = json.loads(result_raw)
|
|
||||||
status = result_data.get("status", status)
|
|
||||||
result_content = result_data.get("result")
|
|
||||||
|
|
||||||
return {"status": status, "task_id": task_id, "result": result_content}
|
|
||||||
|
|
||||||
def _cleanup_finished(self):
|
|
||||||
pending = self.redis.hgetall(PENDING_HASH)
|
|
||||||
if not pending:
|
|
||||||
return
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
task_ids = list(pending.keys())
|
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for task_id in task_ids:
|
|
||||||
pipe.get(f"celery-task-meta-{task_id}")
|
|
||||||
results = pipe.execute()
|
|
||||||
|
|
||||||
cleanup_pipe = self.redis.pipeline()
|
|
||||||
has_cleanup = False
|
|
||||||
ready_user_ids = set()
|
|
||||||
|
|
||||||
for task_id, raw_result in zip(task_ids, results):
|
|
||||||
try:
|
|
||||||
meta = json.loads(pending[task_id])
|
|
||||||
lock_key = meta["lock_key"]
|
|
||||||
dispatched_at = meta.get("dispatched_at", 0)
|
|
||||||
age = now - dispatched_at
|
|
||||||
|
|
||||||
should_cleanup = False
|
|
||||||
result_data = {}
|
|
||||||
|
|
||||||
if raw_result is not None:
|
|
||||||
result_data = json.loads(raw_result)
|
|
||||||
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
|
|
||||||
should_cleanup = True
|
|
||||||
logger.info(
|
|
||||||
"Task finished: %s state=%s", task_id,
|
|
||||||
result_data.get("status"),
|
|
||||||
)
|
|
||||||
elif age > TASK_TIMEOUT:
|
|
||||||
should_cleanup = True
|
|
||||||
logger.warning(
|
|
||||||
"Task expired or lost: %s age=%.0fs, force cleanup",
|
|
||||||
task_id, age,
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_cleanup:
|
|
||||||
final_status = (
|
|
||||||
result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id)
|
|
||||||
|
|
||||||
cleanup_pipe.hdel(PENDING_HASH, task_id)
|
|
||||||
|
|
||||||
tracker_msg_id = meta.get("msg_id")
|
|
||||||
if tracker_msg_id:
|
|
||||||
cleanup_pipe.set(
|
|
||||||
f"task_tracker:{tracker_msg_id}",
|
|
||||||
json.dumps({
|
|
||||||
"status": final_status,
|
|
||||||
"task_id": task_id,
|
|
||||||
"result": result_data.get("result") or {},
|
|
||||||
}),
|
|
||||||
ex=86400,
|
|
||||||
)
|
|
||||||
has_cleanup = True
|
|
||||||
|
|
||||||
parts = lock_key.split(":", 1)
|
|
||||||
if len(parts) == 2:
|
|
||||||
ready_user_ids.add(parts[1])
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
|
|
||||||
self.errors += 1
|
|
||||||
|
|
||||||
if has_cleanup:
|
|
||||||
cleanup_pipe.execute()
|
|
||||||
|
|
||||||
if ready_user_ids:
|
|
||||||
self.redis.sadd(READY_SET, *ready_user_ids)
|
|
||||||
|
|
||||||
def _heartbeat(self):
|
|
||||||
now = time.time()
|
|
||||||
if now - self._last_heartbeat < HEARTBEAT_INTERVAL:
|
|
||||||
return
|
|
||||||
self._last_heartbeat = now
|
|
||||||
|
|
||||||
self.redis.hset(REGISTRY_KEY, self.instance_id, str(now))
|
|
||||||
|
|
||||||
all_instances = self.redis.hgetall(REGISTRY_KEY)
|
|
||||||
|
|
||||||
alive = []
|
|
||||||
dead = []
|
|
||||||
for iid, ts in all_instances.items():
|
|
||||||
if now - float(ts) < INSTANCE_TTL:
|
|
||||||
alive.append(iid)
|
|
||||||
else:
|
|
||||||
dead.append(iid)
|
|
||||||
|
|
||||||
if dead:
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for iid in dead:
|
|
||||||
pipe.hdel(REGISTRY_KEY, iid)
|
|
||||||
pipe.execute()
|
|
||||||
logger.info("Cleaned dead instances: %s", dead)
|
|
||||||
|
|
||||||
alive.sort()
|
|
||||||
self._shard_count = max(len(alive), 1)
|
|
||||||
self._shard_index = (
|
|
||||||
alive.index(self.instance_id) if self.instance_id in alive else 0
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Shard: %s/%s (instance=%s, alive=%d)",
|
|
||||||
self._shard_index, self._shard_count,
|
|
||||||
self.instance_id, len(alive),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _is_mine(self, user_id: str) -> bool:
|
|
||||||
if self._shard_count <= 1:
|
|
||||||
return True
|
|
||||||
return stable_hash(user_id) % self._shard_count == self._shard_index
|
|
||||||
|
|
||||||
def _dispatch(self, msg_id, msg_data) -> bool:
|
|
||||||
user_id = msg_data["user_id"]
|
|
||||||
task_name = msg_data["task_name"]
|
|
||||||
params = json.loads(msg_data.get("params", "{}"))
|
|
||||||
|
|
||||||
lock_key = f"{task_name}:{user_id}"
|
|
||||||
dispatch_lock = f"dispatch:{msg_id}"
|
|
||||||
|
|
||||||
result = self.redis.eval(
|
|
||||||
LUA_ATOMIC_LOCK, 2,
|
|
||||||
dispatch_lock, lock_key,
|
|
||||||
self.instance_id, str(300), str(3600),
|
|
||||||
)
|
|
||||||
|
|
||||||
if result == 0:
|
|
||||||
return False
|
|
||||||
if result == -1:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
task = celery_app.send_task(task_name, kwargs=params)
|
|
||||||
except Exception as e:
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
pipe.delete(dispatch_lock)
|
|
||||||
pipe.delete(lock_key)
|
|
||||||
pipe.execute()
|
|
||||||
self.errors += 1
|
|
||||||
logger.error(
|
|
||||||
"send_task failed for %s:%s msg=%s: %s",
|
|
||||||
task_name, user_id, msg_id, e, exc_info=True,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
pipe.set(lock_key, task.id, ex=3600)
|
|
||||||
pipe.hset(PENDING_HASH, task.id, json.dumps({
|
|
||||||
"lock_key": lock_key,
|
|
||||||
"dispatched_at": time.time(),
|
|
||||||
"msg_id": msg_id,
|
|
||||||
}))
|
|
||||||
pipe.delete(dispatch_lock)
|
|
||||||
pipe.set(
|
|
||||||
f"task_tracker:{msg_id}",
|
|
||||||
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
|
|
||||||
ex=86400,
|
|
||||||
)
|
|
||||||
pipe.execute()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"Post-dispatch state update failed for %s: %s",
|
|
||||||
task.id, e, exc_info=True,
|
|
||||||
)
|
|
||||||
self.errors += 1
|
|
||||||
|
|
||||||
self.dispatched += 1
|
|
||||||
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _process_batch(self, user_ids):
|
|
||||||
if not user_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for uid in user_ids:
|
|
||||||
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
|
|
||||||
heads = pipe.execute()
|
|
||||||
|
|
||||||
candidates = [] # (user_id, msg_dict)
|
|
||||||
empty_users = []
|
|
||||||
|
|
||||||
for uid, head in zip(user_ids, heads):
|
|
||||||
if head is None:
|
|
||||||
empty_users.append(uid)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
candidates.append((uid, json.loads(head)))
|
|
||||||
except (json.JSONDecodeError, TypeError) as e:
|
|
||||||
logger.error("Bad message in queue for user %s: %s", uid, e)
|
|
||||||
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
|
|
||||||
|
|
||||||
if empty_users:
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for uid in empty_users:
|
|
||||||
pipe.srem(ACTIVE_USERS, uid)
|
|
||||||
pipe.execute()
|
|
||||||
|
|
||||||
if not candidates:
|
|
||||||
return
|
|
||||||
|
|
||||||
for uid, msg in candidates:
|
|
||||||
if self._dispatch(msg["msg_id"], msg):
|
|
||||||
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
|
|
||||||
|
|
||||||
def schedule_loop(self):
|
|
||||||
self._heartbeat()
|
|
||||||
self._cleanup_finished()
|
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
pipe.smembers(READY_SET)
|
|
||||||
pipe.delete(READY_SET)
|
|
||||||
results = pipe.execute()
|
|
||||||
ready_users = results[0] or set()
|
|
||||||
|
|
||||||
my_users = [uid for uid in ready_users if self._is_mine(uid)]
|
|
||||||
|
|
||||||
if not my_users:
|
|
||||||
time.sleep(0.5)
|
|
||||||
return
|
|
||||||
|
|
||||||
self._process_batch(my_users)
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
def _full_scan(self):
|
|
||||||
cursor = 0
|
|
||||||
ready_batch = []
|
|
||||||
while True:
|
|
||||||
cursor, user_ids = self.redis.sscan(
|
|
||||||
ACTIVE_USERS, cursor=cursor, count=1000,
|
|
||||||
)
|
|
||||||
if user_ids:
|
|
||||||
my_users = [uid for uid in user_ids if self._is_mine(uid)]
|
|
||||||
if my_users:
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for uid in my_users:
|
|
||||||
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
|
|
||||||
heads = pipe.execute()
|
|
||||||
|
|
||||||
for uid, head in zip(my_users, heads):
|
|
||||||
if head is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
msg = json.loads(head)
|
|
||||||
lock_key = f"{msg['task_name']}:{uid}"
|
|
||||||
ready_batch.append((uid, lock_key))
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if cursor == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not ready_batch:
|
|
||||||
return
|
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for _, lock_key in ready_batch:
|
|
||||||
pipe.exists(lock_key)
|
|
||||||
lock_exists = pipe.execute()
|
|
||||||
|
|
||||||
ready_uids = [
|
|
||||||
uid for (uid, _), locked in zip(ready_batch, lock_exists)
|
|
||||||
if not locked
|
|
||||||
]
|
|
||||||
|
|
||||||
if ready_uids:
|
|
||||||
self.redis.sadd(READY_SET, *ready_uids)
|
|
||||||
logger.info("Full scan found %d ready users", len(ready_uids))
|
|
||||||
|
|
||||||
def run_server(self):
|
|
||||||
health_check_server(self)
|
|
||||||
self.running = True
|
|
||||||
|
|
||||||
last_full_scan = 0.0
|
|
||||||
full_scan_interval = 30.0
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Scheduler started: instance=%s", self.instance_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
self.schedule_loop()
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
if now - last_full_scan > full_scan_interval:
|
|
||||||
self._full_scan()
|
|
||||||
last_full_scan = now
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Scheduler exception %s", e, exc_info=True)
|
|
||||||
self.errors += 1
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
def health(self) -> dict:
|
|
||||||
return {
|
|
||||||
"running": self.running,
|
|
||||||
"active_users": self.redis.scard(ACTIVE_USERS),
|
|
||||||
"ready_users": self.redis.scard(READY_SET),
|
|
||||||
"pending_tasks": self.redis.hlen(PENDING_HASH),
|
|
||||||
"dispatched": self.dispatched,
|
|
||||||
"errors": self.errors,
|
|
||||||
"shard": f"{self._shard_index}/{self._shard_count}",
|
|
||||||
"instance": self.instance_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
logger.info("Scheduler shutting down: instance=%s", self.instance_id)
|
|
||||||
self.running = False
|
|
||||||
try:
|
|
||||||
self.redis.hdel(REGISTRY_KEY, self.instance_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Shutdown cleanup error: %s", e)
|
|
||||||
|
|
||||||
|
|
||||||
scheduler: RedisTaskScheduler | None = None
|
|
||||||
if scheduler is None:
|
|
||||||
scheduler = RedisTaskScheduler()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
def _signal_handler(signum, frame):
|
|
||||||
scheduler.shutdown()
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, _signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, _signal_handler)
|
|
||||||
|
|
||||||
scheduler.run_server()
|
|
||||||
@@ -9,7 +9,7 @@ from app.core.logging_config import get_business_logger
|
|||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
|
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.services.app_service import AppService
|
from app.services.app_service import AppService
|
||||||
from app.services.app_log_service import AppLogService
|
from app.services.app_log_service import AppLogService
|
||||||
@@ -41,7 +41,7 @@ def list_app_logs(
|
|||||||
|
|
||||||
# 验证应用访问权限
|
# 验证应用访问权限
|
||||||
app_service = AppService(db)
|
app_service = AppService(db)
|
||||||
app = app_service.get_app(app_id, workspace_id)
|
app_service.get_app(app_id, workspace_id)
|
||||||
|
|
||||||
# 使用 Service 层查询
|
# 使用 Service 层查询
|
||||||
log_service = AppLogService(db)
|
log_service = AppLogService(db)
|
||||||
@@ -51,8 +51,7 @@ def list_app_logs(
|
|||||||
page=page,
|
page=page,
|
||||||
pagesize=pagesize,
|
pagesize=pagesize,
|
||||||
is_draft=is_draft,
|
is_draft=is_draft,
|
||||||
keyword=keyword,
|
keyword=keyword
|
||||||
app_type=app.type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
items = [AppLogConversation.model_validate(c) for c in conversations]
|
items = [AppLogConversation.model_validate(c) for c in conversations]
|
||||||
@@ -79,32 +78,17 @@ def get_app_log_detail(
|
|||||||
|
|
||||||
# 验证应用访问权限
|
# 验证应用访问权限
|
||||||
app_service = AppService(db)
|
app_service = AppService(db)
|
||||||
app = app_service.get_app(app_id, workspace_id)
|
app_service.get_app(app_id, workspace_id)
|
||||||
|
|
||||||
# 使用 Service 层查询
|
# 使用 Service 层查询
|
||||||
log_service = AppLogService(db)
|
log_service = AppLogService(db)
|
||||||
conversation, messages, node_executions_map = log_service.get_conversation_detail(
|
conversation, node_executions_map = log_service.get_conversation_detail(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id
|
||||||
app_type=app.type
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建基础会话信息(不经过 ORM relationship)
|
detail = AppLogConversationDetail.model_validate(conversation)
|
||||||
base = AppLogConversation.model_validate(conversation)
|
detail.node_executions_map = node_executions_map
|
||||||
|
|
||||||
# 单独处理 messages,避免触发 SQLAlchemy relationship 校验
|
|
||||||
if messages and isinstance(messages[0], AppLogMessage):
|
|
||||||
# 工作流:已经是 AppLogMessage 实例
|
|
||||||
msg_list = messages
|
|
||||||
else:
|
|
||||||
# Agent:ORM Message 对象逐个转换
|
|
||||||
msg_list = [AppLogMessage.model_validate(m) for m in messages]
|
|
||||||
|
|
||||||
detail = AppLogConversationDetail(
|
|
||||||
**base.model_dump(),
|
|
||||||
messages=msg_list,
|
|
||||||
node_executions_map=node_executions_map,
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data=detail)
|
return success(data=detail)
|
||||||
|
|||||||
@@ -4,9 +4,7 @@
|
|||||||
处理显性记忆相关的API接口,包括情景记忆和语义记忆的查询。
|
处理显性记忆相关的API接口,包括情景记忆和语义记忆的查询。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success, fail
|
from app.core.response_utils import success, fail
|
||||||
@@ -71,140 +69,6 @@ async def get_explicit_memory_overview_api(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/episodics", response_model=ApiResponse)
|
|
||||||
async def get_episodic_memory_list_api(
|
|
||||||
end_user_id: str = Query(..., description="end user ID"),
|
|
||||||
page: int = Query(1, gt=0, description="page number, starting from 1"),
|
|
||||||
pagesize: int = Query(10, gt=0, le=100, description="number of items per page, max 100"),
|
|
||||||
start_date: Optional[int] = Query(None, description="start timestamp (ms)"),
|
|
||||||
end_date: Optional[int] = Query(None, description="end timestamp (ms)"),
|
|
||||||
episodic_type: str = Query("all", description="episodic type :all/conversation/project_work/learning/decision/important_event"),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
获取情景记忆分页列表
|
|
||||||
|
|
||||||
返回指定用户的情景记忆列表,支持分页、时间范围筛选和情景类型筛选。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID(必填)
|
|
||||||
page: 页码(从1开始,默认1)
|
|
||||||
pagesize: 每页数量(默认10,最大100)
|
|
||||||
start_date: 开始时间戳(可选,毫秒),自动扩展到当天 00:00:00
|
|
||||||
end_date: 结束时间戳(可选,毫秒),自动扩展到当天 23:59:59
|
|
||||||
episodic_type: 情景类型筛选(可选,默认all)
|
|
||||||
current_user: 当前用户
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse: 包含情景记忆分页列表
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
- 基础分页查询:GET /episodics?end_user_id=xxx&page=1&pagesize=5
|
|
||||||
返回第1页,每页5条数据
|
|
||||||
- 按时间范围筛选:GET /episodics?end_user_id=xxx&page=1&pagesize=5&start_date=1738684800000&end_date=1738771199000
|
|
||||||
返回指定时间范围内的数据
|
|
||||||
- 按情景类型筛选:GET /episodics?end_user_id=xxx&page=1&pagesize=5&episodic_type=important_event
|
|
||||||
返回类型为"重要事件"的数据
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- start_date 和 end_date 必须同时提供或同时不提供
|
|
||||||
- start_date 不能大于 end_date
|
|
||||||
- episodic_type 可选值:all, conversation, project_work, learning, decision, important_event
|
|
||||||
- total 为该用户情景记忆总数(不受筛选条件影响)
|
|
||||||
- page.total 为筛选后的总条数
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
|
||||||
if workspace_id is None:
|
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆列表但未选择工作空间")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"情景记忆分页查询: end_user_id={end_user_id}, "
|
|
||||||
f"start_date={start_date}, end_date={end_date}, episodic_type={episodic_type}, "
|
|
||||||
f"page={page}, pagesize={pagesize}, username={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. 参数校验
|
|
||||||
if page < 1 or pagesize < 1:
|
|
||||||
api_logger.warning(f"分页参数错误: page={page}, pagesize={pagesize}")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "分页参数必须大于0")
|
|
||||||
|
|
||||||
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
|
|
||||||
if episodic_type not in valid_episodic_types:
|
|
||||||
api_logger.warning(f"无效的情景类型参数: {episodic_type}")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
|
|
||||||
|
|
||||||
# 时间戳参数校验
|
|
||||||
if (start_date is not None and end_date is None) or (end_date is not None and start_date is None):
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "start_date和end_date必须同时提供")
|
|
||||||
|
|
||||||
if start_date is not None and end_date is not None and start_date > end_date:
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "start_date不能大于end_date")
|
|
||||||
|
|
||||||
# 2. 执行查询
|
|
||||||
try:
|
|
||||||
result = await memory_explicit_service.get_episodic_memory_list(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
page=page,
|
|
||||||
pagesize=pagesize,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
episodic_type=episodic_type,
|
|
||||||
)
|
|
||||||
api_logger.info(
|
|
||||||
f"情景记忆分页查询成功: end_user_id={end_user_id}, "
|
|
||||||
f"total={result['total']}, 返回={len(result['items'])}条"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"情景记忆分页查询失败: end_user_id={end_user_id}, error={str(e)}")
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "情景记忆分页查询失败", str(e))
|
|
||||||
|
|
||||||
# 3. 返回结构化响应
|
|
||||||
return success(data=result, msg="查询成功")
|
|
||||||
|
|
||||||
@router.get("/semantics", response_model=ApiResponse)
|
|
||||||
async def get_semantic_memory_list_api(
|
|
||||||
end_user_id: str = Query(..., description="终端用户ID"),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
获取语义记忆列表
|
|
||||||
|
|
||||||
返回指定用户的全量语义记忆列表。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID(必填)
|
|
||||||
current_user: 当前用户
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse: 包含语义记忆全量列表
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
if workspace_id is None:
|
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询语义记忆列表但未选择工作空间")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"语义记忆列表查询: end_user_id={end_user_id}, username={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await memory_explicit_service.get_semantic_memory_list(
|
|
||||||
end_user_id=end_user_id
|
|
||||||
)
|
|
||||||
api_logger.info(
|
|
||||||
f"语义记忆列表查询成功: end_user_id={end_user_id}, total={len(result)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"语义记忆列表查询失败: end_user_id={end_user_id}, error={str(e)}")
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "语义记忆列表查询失败", str(e))
|
|
||||||
|
|
||||||
return success(data=result, msg="查询成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/details", response_model=ApiResponse)
|
@router.post("/details", response_model=ApiResponse)
|
||||||
async def get_explicit_memory_details_api(
|
async def get_explicit_memory_details_api(
|
||||||
request: ExplicitMemoryDetailsRequest,
|
request: ExplicitMemoryDetailsRequest,
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from . import (
|
|||||||
rag_api_document_controller,
|
rag_api_document_controller,
|
||||||
rag_api_file_controller,
|
rag_api_file_controller,
|
||||||
rag_api_knowledge_controller,
|
rag_api_knowledge_controller,
|
||||||
user_memory_api_controller,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 V1 API 路由器
|
# 创建 V1 API 路由器
|
||||||
@@ -29,6 +28,5 @@ service_router.include_router(rag_api_chunk_controller.router)
|
|||||||
service_router.include_router(memory_api_controller.router)
|
service_router.include_router(memory_api_controller.router)
|
||||||
service_router.include_router(end_user_api_controller.router)
|
service_router.include_router(end_user_api_controller.router)
|
||||||
service_router.include_router(memory_config_api_controller.router)
|
service_router.include_router(memory_config_api_controller.router)
|
||||||
service_router.include_router(user_memory_api_controller.router)
|
|
||||||
|
|
||||||
__all__ = ["service_router"]
|
__all__ = ["service_router"]
|
||||||
|
|||||||
@@ -296,7 +296,7 @@ async def chat(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# workflow 非流式返回
|
# 多 Agent 非流式返回
|
||||||
result = await app_chat_service.workflow_chat(
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
from fastapi import APIRouter, Body, Depends, Query, Request
|
from fastapi import APIRouter, Body, Depends, Query, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.celery_task_scheduler import scheduler
|
|
||||||
from app.core.api_key_auth import require_api_key
|
from app.core.api_key_auth import require_api_key
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.quota_stub import check_end_user_quota
|
from app.core.quota_stub import check_end_user_quota
|
||||||
@@ -87,7 +86,7 @@ async def write_memory(
|
|||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}")
|
logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
|
||||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
||||||
|
|
||||||
|
|
||||||
@@ -106,7 +105,8 @@ async def get_write_task_status(
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Write task status check - task_id: {task_id}")
|
logger.info(f"Write task status check - task_id: {task_id}")
|
||||||
|
|
||||||
result = scheduler.get_task_status(task_id)
|
from app.services.task_service import get_task_memory_write_result
|
||||||
|
result = get_task_memory_write_result(task_id)
|
||||||
|
|
||||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||||
|
|
||||||
|
|||||||
@@ -1,230 +0,0 @@
|
|||||||
"""User Memory 服务接口 — 基于 API Key 认证
|
|
||||||
|
|
||||||
包装 user_memory_controllers.py 和 memory_agent_controller.py 中的内部接口,
|
|
||||||
提供基于 API Key 认证的对外服务:
|
|
||||||
1./analytics/graph_data - 知识图谱数据接口
|
|
||||||
2./analytics/community_graph - 社区图谱接口
|
|
||||||
3./analytics/node_statistics - 记忆节点统计接口
|
|
||||||
4./analytics/user_summary - 用户摘要接口
|
|
||||||
5./analytics/memory_insight - 记忆洞察接口
|
|
||||||
6./analytics/interest_distribution - 兴趣分布接口
|
|
||||||
7./analytics/end_user_info - 终端用户信息接口
|
|
||||||
8./analytics/generate_cache - 缓存生成接口
|
|
||||||
|
|
||||||
|
|
||||||
路由前缀: /memory
|
|
||||||
子路径: /analytics/...
|
|
||||||
最终路径: /v1/memory/analytics/...
|
|
||||||
认证方式: API Key (@require_api_key)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, Query, Request, Body
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.api_key_auth import require_api_key
|
|
||||||
from app.core.api_key_utils import get_current_user_from_api_key, validate_end_user_in_workspace
|
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
from app.db import get_db
|
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
|
||||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
|
||||||
|
|
||||||
# 包装内部服务 controller
|
|
||||||
from app.controllers import user_memory_controllers, memory_agent_controller
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/memory", tags=["V1 - User Memory API"])
|
|
||||||
logger = get_business_logger()
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 知识图谱 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/graph_data")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_graph_data(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
node_types: Optional[str] = Query(None, description="Comma-separated node types filter"),
|
|
||||||
limit: int = Query(100, description="Max nodes to return (auto-capped at 1000 in service layer)"),
|
|
||||||
depth: int = Query(1, description="Graph traversal depth (auto-capped at 3 in service layer)"),
|
|
||||||
center_node_id: Optional[str] = Query(None, description="Center node for subgraph"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get knowledge graph data (nodes + edges) for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_graph_data_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
node_types=node_types,
|
|
||||||
limit=limit,
|
|
||||||
depth=depth,
|
|
||||||
center_node_id=center_node_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/community_graph")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_community_graph(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get community clustering graph for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_community_graph_data_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 节点统计 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/node_statistics")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_node_statistics(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get memory node type statistics for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_node_statistics_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 用户摘要 & 洞察 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/user_summary")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_user_summary(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get cached user summary for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_user_summary_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
language_type=language_type,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/memory_insight")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_memory_insight(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get cached memory insight report for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_memory_insight_report_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 兴趣分布 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/interest_distribution")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_interest_distribution(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
limit: int = Query(5, le=5, description="Max interest tags to return"),
|
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get interest distribution tags for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await memory_agent_controller.get_interest_distribution_by_user_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
language_type=language_type,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 终端用户信息 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/end_user_info")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_end_user_info(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get end user basic information (name, aliases, metadata)."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_end_user_info(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 缓存生成 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/analytics/generate_cache")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def generate_cache(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(None, description="Request body"),
|
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
|
||||||
):
|
|
||||||
"""Trigger cache generation (user summary + memory insight) for an end user or all workspace users."""
|
|
||||||
body = await request.json()
|
|
||||||
cache_request = GenerateCacheRequest(**body)
|
|
||||||
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
|
|
||||||
if cache_request.end_user_id:
|
|
||||||
validate_end_user_in_workspace(db, cache_request.end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.generate_cache_api(
|
|
||||||
request=cache_request,
|
|
||||||
language_type=language_type,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@@ -173,8 +173,6 @@ async def delete_tool(
|
|||||||
return success(msg="工具删除成功")
|
return success(msg="工具删除成功")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -251,8 +249,6 @@ async def parse_openapi_schema(
|
|||||||
if result["success"] is False:
|
if result["success"] is False:
|
||||||
raise HTTPException(status_code=400, detail=result["message"])
|
raise HTTPException(status_code=400, detail=result["message"])
|
||||||
return success(data=result, msg="Schema解析完成")
|
return success(data=result, msg="Schema解析完成")
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ def update_workspace_members(
|
|||||||
|
|
||||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def delete_workspace_member(
|
def delete_workspace_member(
|
||||||
member_id: uuid.UUID,
|
member_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -230,7 +230,7 @@ async def delete_workspace_member(
|
|||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
|
|
||||||
await workspace_service.delete_workspace_member(
|
workspace_service.delete_workspace_member(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
member_id=member_id,
|
member_id=member_id,
|
||||||
|
|||||||
@@ -1,15 +1,8 @@
|
|||||||
"""API Key 工具函数"""
|
"""API Key 工具函数"""
|
||||||
import secrets
|
import secrets
|
||||||
import uuid as _uuid
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy.orm import Session as _Session
|
|
||||||
from app.core.error_codes import BizCode as _BizCode
|
|
||||||
from app.core.exceptions import BusinessException as _BusinessException
|
|
||||||
from app.models.end_user_model import EndUser as _EndUser
|
|
||||||
from app.repositories.end_user_repository import EndUserRepository as _EndUserRepository
|
|
||||||
|
|
||||||
from app.models.api_key_model import ApiKeyType
|
from app.models.api_key_model import ApiKeyType
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
@@ -72,72 +65,3 @@ def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return int(dt.timestamp() * 1000)
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
def get_current_user_from_api_key(db: _Session, api_key_auth):
|
|
||||||
"""通过 API Key 构造 current_user 对象。
|
|
||||||
|
|
||||||
从 API Key 反查创建者(管理员用户),并设置其 workspace 上下文。
|
|
||||||
与内部接口的 Depends(get_current_user) (JWT) 等价。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
api_key_auth: API Key 认证信息(ApiKeyAuth)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User ORM 对象,已设置 current_workspace_id
|
|
||||||
"""
|
|
||||||
from app.services import api_key_service
|
|
||||||
|
|
||||||
api_key = api_key_service.ApiKeyService.get_api_key(
|
|
||||||
db, api_key_auth.api_key_id, api_key_auth.workspace_id
|
|
||||||
)
|
|
||||||
current_user = api_key.creator
|
|
||||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
def validate_end_user_in_workspace(
|
|
||||||
db: _Session,
|
|
||||||
end_user_id: str,
|
|
||||||
workspace_id,
|
|
||||||
) -> _EndUser:
|
|
||||||
"""校验 end_user 是否存在且属于指定 workspace。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
end_user_id: 终端用户 ID
|
|
||||||
workspace_id: 工作空间 ID(UUID 或字符串均可)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
EndUser ORM 对象(校验通过时)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
BusinessException(INVALID_PARAMETER): end_user_id 格式无效
|
|
||||||
BusinessException(USER_NOT_FOUND): end_user 不存在
|
|
||||||
BusinessException(PERMISSION_DENIED): end_user 不属于该 workspace
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
_uuid.UUID(end_user_id)
|
|
||||||
except (ValueError, AttributeError):
|
|
||||||
raise _BusinessException(
|
|
||||||
f"Invalid end_user_id format: {end_user_id}",
|
|
||||||
_BizCode.INVALID_PARAMETER,
|
|
||||||
)
|
|
||||||
|
|
||||||
end_user_repo = _EndUserRepository(db)
|
|
||||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
|
||||||
|
|
||||||
if end_user is None:
|
|
||||||
raise _BusinessException(
|
|
||||||
"End user not found",
|
|
||||||
_BizCode.USER_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
if str(end_user.workspace_id) != str(workspace_id):
|
|
||||||
raise _BusinessException(
|
|
||||||
"End user does not belong to this workspace",
|
|
||||||
_BizCode.PERMISSION_DENIED,
|
|
||||||
)
|
|
||||||
|
|
||||||
return end_user
|
|
||||||
@@ -241,8 +241,6 @@ class Settings:
|
|||||||
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
||||||
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
||||||
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
||||||
|
|
||||||
SANDBOX_URL: str = os.getenv("SANDBOX_URL", "")
|
|
||||||
|
|
||||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from app.celery_task_scheduler import scheduler
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||||
@@ -13,6 +12,8 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
|
from app.services.task_service import get_task_memory_write_result
|
||||||
|
from app.tasks import write_message_task
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
@@ -85,28 +86,16 @@ async def write(
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
# write_id = write_message_task.delay(
|
write_id = write_message_task.delay(
|
||||||
# actual_end_user_id, # end_user_id: User ID
|
actual_end_user_id, # end_user_id: User ID
|
||||||
# structured_messages, # message: JSON string format message list
|
structured_messages, # message: JSON string format message list
|
||||||
# str(actual_config_id), # config_id: Configuration ID string
|
str(actual_config_id), # config_id: Configuration ID string
|
||||||
# storage_type, # storage_type: "neo4j"
|
storage_type, # storage_type: "neo4j"
|
||||||
# user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||||
# )
|
|
||||||
scheduler.push_task(
|
|
||||||
"app.core.memory.agent.write_message",
|
|
||||||
str(actual_end_user_id),
|
|
||||||
{
|
|
||||||
"end_user_id": str(actual_end_user_id),
|
|
||||||
"message": structured_messages,
|
|
||||||
"config_id": str(actual_config_id),
|
|
||||||
"storage_type": storage_type,
|
|
||||||
"user_rag_memory_id": user_rag_memory_id or ""
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
# logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
# write_status = get_task_memory_write_result(str(write_id))
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
# logger.info(f'[WRITE] Task result - user={actual_end_user_id}')
|
|
||||||
|
|
||||||
|
|
||||||
async def term_memory_save(end_user_id, strategy_type, scope):
|
async def term_memory_save(end_user_id, strategy_type, scope):
|
||||||
@@ -175,24 +164,13 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
|||||||
else:
|
else:
|
||||||
config_id = memory_config
|
config_id = memory_config
|
||||||
|
|
||||||
scheduler.push_task(
|
write_message_task.delay(
|
||||||
"app.core.memory.agent.write_message",
|
end_user_id, # end_user_id: User ID
|
||||||
str(end_user_id),
|
redis_messages, # message: JSON string format message list
|
||||||
{
|
config_id, # config_id: Configuration ID string
|
||||||
"end_user_id": str(end_user_id),
|
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
||||||
"message": redis_messages,
|
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||||
"config_id": str(config_id),
|
|
||||||
"storage_type": AgentMemory_Long_Term.STORAGE_NEO4J,
|
|
||||||
"user_rag_memory_id": ""
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
# write_message_task.delay(
|
|
||||||
# end_user_id, # end_user_id: User ID
|
|
||||||
# redis_messages, # message: JSON string format message list
|
|
||||||
# config_id, # config_id: Configuration ID string
|
|
||||||
# AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
|
||||||
# "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
|
||||||
# )
|
|
||||||
count_store.update_sessions_count(end_user_id, 0, [])
|
count_store.update_sessions_count(end_user_id, 0, [])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from app.core.memory.enums import SearchStrategy, StorageType
|
from app.core.memory.enums import SearchStrategy, StorageType
|
||||||
from app.core.memory.models.service_models import MemorySearchResult
|
from app.core.memory.models.service_models import MemorySearchResult
|
||||||
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
|
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
|
||||||
from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
|
from app.core.memory.read_services.content_search import Neo4jSearchService, RAGSearchService
|
||||||
from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor
|
from app.core.memory.read_services.query_preprocessor import QueryPreprocessor
|
||||||
|
|
||||||
|
|
||||||
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from neo4j import Session
|
|||||||
from app.core.memory.enums import Neo4jNodeType
|
from app.core.memory.enums import Neo4jNodeType
|
||||||
from app.core.memory.memory_service import MemoryContext
|
from app.core.memory.memory_service import MemoryContext
|
||||||
from app.core.memory.models.service_models import Memory, MemorySearchResult
|
from app.core.memory.models.service_models import Memory, MemorySearchResult
|
||||||
from app.core.memory.read_services.search_engine.result_builder import data_builder_factory
|
from app.core.memory.read_services.result_builder import data_builder_factory
|
||||||
from app.core.models import RedBearEmbeddings
|
from app.core.models import RedBearEmbeddings
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
@@ -8,4 +8,4 @@ class RetrievalSummaryProcessor:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify(content: str, llm_client: RedBearLLM):
|
def verify(content: str, llm_client: RedBearLLM):
|
||||||
return
|
return
|
||||||
@@ -216,7 +216,7 @@ class RedBearModelFactory:
|
|||||||
# 深度思考模式:Claude 3.7 Sonnet 等支持思考的模型
|
# 深度思考模式:Claude 3.7 Sonnet 等支持思考的模型
|
||||||
# 通过 additional_model_request_fields 传递 thinking 块,关闭时不传(Bedrock 无 disabled 选项)
|
# 通过 additional_model_request_fields 传递 thinking 块,关闭时不传(Bedrock 无 disabled 选项)
|
||||||
if config.deep_thinking:
|
if config.deep_thinking:
|
||||||
budget = config.thinking_budget_tokens or 1024
|
budget = config.thinking_budget_tokens or 10000
|
||||||
params["additional_model_request_fields"] = {
|
params["additional_model_request_fields"] = {
|
||||||
"thinking": {"type": "enabled", "budget_tokens": budget}
|
"thinking": {"type": "enabled", "budget_tokens": budget}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,7 +73,6 @@ class CustomTool(BaseTool):
|
|||||||
# 添加通用参数(基于第一个操作的参数)
|
# 添加通用参数(基于第一个操作的参数)
|
||||||
if self._parsed_operations:
|
if self._parsed_operations:
|
||||||
first_operation = next(iter(self._parsed_operations.values()))
|
first_operation = next(iter(self._parsed_operations.values()))
|
||||||
# path/query 参数
|
|
||||||
for param_name, param_info in first_operation.get("parameters", {}).items():
|
for param_name, param_info in first_operation.get("parameters", {}).items():
|
||||||
params.append(ToolParameter(
|
params.append(ToolParameter(
|
||||||
name=param_name,
|
name=param_name,
|
||||||
@@ -86,23 +85,6 @@ class CustomTool(BaseTool):
|
|||||||
maximum=param_info.get("maximum"),
|
maximum=param_info.get("maximum"),
|
||||||
pattern=param_info.get("pattern")
|
pattern=param_info.get("pattern")
|
||||||
))
|
))
|
||||||
# requestBody 参数 — 将 body 字段平铺为独立参数暴露给模型
|
|
||||||
request_body = first_operation.get("request_body")
|
|
||||||
if request_body:
|
|
||||||
body_schema = request_body.get("properties", {})
|
|
||||||
required_fields = request_body.get("required", [])
|
|
||||||
for prop_name, prop_schema in body_schema.items():
|
|
||||||
params.append(ToolParameter(
|
|
||||||
name=prop_name,
|
|
||||||
type=self._convert_openapi_type(prop_schema.get("type", "string")),
|
|
||||||
description=prop_schema.get("description", ""),
|
|
||||||
required=prop_name in required_fields,
|
|
||||||
default=prop_schema.get("default"),
|
|
||||||
enum=prop_schema.get("enum"),
|
|
||||||
minimum=prop_schema.get("minimum"),
|
|
||||||
maximum=prop_schema.get("maximum"),
|
|
||||||
pattern=prop_schema.get("pattern")
|
|
||||||
))
|
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from app.core.workflow.engine.runtime_schema import ExecutionContext
|
|||||||
from app.core.workflow.engine.state_manager import WorkflowStateManager
|
from app.core.workflow.engine.state_manager import WorkflowStateManager
|
||||||
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
|
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
|
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
|
||||||
from app.core.workflow.nodes.base_node import NodeExecutionError
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -327,43 +326,10 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||||
exc_info=True)
|
exc_info=True)
|
||||||
|
|
||||||
# 1) 尝试从 checkpoint 回补已成功节点的 node_outputs
|
|
||||||
recovered: dict[str, Any] = {}
|
|
||||||
try:
|
|
||||||
if self.graph is not None:
|
|
||||||
recovered = self.graph.get_state(
|
|
||||||
self.execution_context.checkpoint_config
|
|
||||||
).values or {}
|
|
||||||
except Exception as recover_err:
|
|
||||||
logger.warning(
|
|
||||||
f"Recover state on failure failed: {recover_err}, "
|
|
||||||
f"execution_id={self.execution_context.execution_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
result = dict(recovered) if recovered else {}
|
result = {"error": str(e)}
|
||||||
else:
|
else:
|
||||||
# 已有 result 与 recovered 合并,node_outputs 深度合并
|
result["error"] = str(e)
|
||||||
for k, v in recovered.items():
|
|
||||||
if k == "node_outputs" and isinstance(v, dict):
|
|
||||||
existing = result.get("node_outputs") or {}
|
|
||||||
result["node_outputs"] = {**v, **existing}
|
|
||||||
else:
|
|
||||||
result.setdefault(k, v)
|
|
||||||
|
|
||||||
# 2) 如果是节点抛出的 NodeExecutionError,把失败节点的 node_output 注入 node_outputs
|
|
||||||
failed_node_id: str | None = None
|
|
||||||
if isinstance(e, NodeExecutionError):
|
|
||||||
failed_node_id = e.node_id
|
|
||||||
node_outputs = result.setdefault("node_outputs", {})
|
|
||||||
# 不覆盖已有(理论上不会有),保底写入失败节点记录
|
|
||||||
node_outputs.setdefault(e.node_id, e.node_output)
|
|
||||||
|
|
||||||
result["error"] = str(e)
|
|
||||||
if failed_node_id:
|
|
||||||
result["error_node"] = failed_node_id
|
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": self.result_builder.build_final_output(
|
"data": self.result_builder.build_final_output(
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -23,20 +22,6 @@ from app.services.multimodal_service import MultimodalService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class NodeExecutionError(Exception):
|
|
||||||
"""节点执行失败异常。
|
|
||||||
|
|
||||||
携带失败节点的完整 node_output,供 executor 兜底注入 node_outputs,
|
|
||||||
保证 workflow_executions.output_data 里能看到失败节点的日志记录。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, node_id: str, node_output: dict[str, Any], error_message: str):
|
|
||||||
super().__init__(f"Node {node_id} execution failed: {error_message}")
|
|
||||||
self.node_id = node_id
|
|
||||||
self.node_output = node_output
|
|
||||||
self.error_message = error_message
|
|
||||||
|
|
||||||
|
|
||||||
class BaseNode(ABC):
|
class BaseNode(ABC):
|
||||||
"""Base class for workflow nodes.
|
"""Base class for workflow nodes.
|
||||||
|
|
||||||
@@ -411,8 +396,6 @@ class BaseNode(ABC):
|
|||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"token_usage": token_usage,
|
"token_usage": token_usage,
|
||||||
"error": None,
|
"error": None,
|
||||||
# 单调递增序号,用于日志按执行顺序排序(JSONB 不保证 key 顺序)
|
|
||||||
"execution_order": time.monotonic_ns(),
|
|
||||||
**self._extract_extra_fields(business_result),
|
**self._extract_extra_fields(business_result),
|
||||||
}
|
}
|
||||||
final_output = {
|
final_output = {
|
||||||
@@ -461,9 +444,7 @@ class BaseNode(ABC):
|
|||||||
"output": None,
|
"output": None,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"token_usage": None,
|
"token_usage": None,
|
||||||
"error": error_message,
|
"error": error_message
|
||||||
# 单调递增序号,用于日志按执行顺序排序
|
|
||||||
"execution_order": time.monotonic_ns(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# if error_edge:
|
# if error_edge:
|
||||||
@@ -485,12 +466,7 @@ class BaseNode(ABC):
|
|||||||
**node_output
|
**node_output
|
||||||
})
|
})
|
||||||
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
||||||
# 抛出自定义异常,把 node_output 带给 executor,供其写入 node_outputs
|
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
||||||
raise NodeExecutionError(
|
|
||||||
node_id=self.node_id,
|
|
||||||
node_output=node_output,
|
|
||||||
error_message=error_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""Extracts the input data for this node (used for logging or audit).
|
"""Extracts the input data for this node (used for logging or audit).
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes import BaseNode
|
from app.core.workflow.nodes import BaseNode
|
||||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -132,7 +131,7 @@ class CodeNode(BaseNode):
|
|||||||
|
|
||||||
async with httpx.AsyncClient(timeout=60) as client:
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{settings.SANDBOX_URL}:8194/v1/sandbox/run",
|
"http://sandbox:8194/v1/sandbox/run",
|
||||||
headers={
|
headers={
|
||||||
"x-api-key": 'redbear-sandbox'
|
"x-api-key": 'redbear-sandbox'
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -174,18 +174,12 @@ class IterationRuntime:
|
|||||||
continue
|
continue
|
||||||
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
||||||
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
|
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
|
||||||
node_cfg = next(
|
|
||||||
(n for n in self.cycle_nodes if n.get("id") == node_name), None
|
|
||||||
)
|
|
||||||
self.event_write({
|
self.event_write({
|
||||||
"type": "cycle_item",
|
"type": "cycle_item",
|
||||||
"data": {
|
"data": {
|
||||||
"cycle_id": self.node_id,
|
"cycle_id": self.node_id,
|
||||||
"cycle_idx": idx,
|
"cycle_idx": idx,
|
||||||
"node_id": node_name,
|
"node_id": node_name,
|
||||||
"node_type": node_type,
|
|
||||||
"node_name": node_cfg.get("data", {}).get("label") if node_cfg else node_name,
|
|
||||||
"status": result.get("node_outputs", {}).get(node_name, {}).get("status", "completed"),
|
|
||||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||||
if not cycle_variable else cycle_variable,
|
if not cycle_variable else cycle_variable,
|
||||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||||
|
|||||||
@@ -210,9 +210,6 @@ class LoopRuntime:
|
|||||||
"cycle_id": self.node_id,
|
"cycle_id": self.node_id,
|
||||||
"cycle_idx": idx,
|
"cycle_idx": idx,
|
||||||
"node_id": node_name,
|
"node_id": node_name,
|
||||||
"node_type": node_type,
|
|
||||||
"node_name": node_name,
|
|
||||||
"status": result.get("node_outputs", {}).get(node_name, {}).get("status", "completed"),
|
|
||||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||||
if not cycle_variable else cycle_variable,
|
if not cycle_variable else cycle_variable,
|
||||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class DocExtractorNode(BaseNode):
|
|||||||
mime_type=f"image/{ext}",
|
mime_type=f"image/{ext}",
|
||||||
is_file=True,
|
is_file=True,
|
||||||
).model_dump())
|
).model_dump())
|
||||||
text = text + f"\n{placeholder}: <img src=\"{url}\" data-url=\"{url}\">"
|
text = text + f"\n{placeholder}: {url}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}")
|
logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -272,11 +272,6 @@ class HttpRequestNodeOutput(BaseModel):
|
|||||||
description="HTTP response body",
|
description="HTTP response body",
|
||||||
)
|
)
|
||||||
|
|
||||||
process_data: dict = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Raw HTTP request details for debugging",
|
|
||||||
)
|
|
||||||
|
|
||||||
# files: list[File] = Field(
|
# files: list[File] = Field(
|
||||||
# ...
|
# ...
|
||||||
# )
|
# )
|
||||||
|
|||||||
@@ -160,6 +160,7 @@ class HttpRequestNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: HttpRequestNodeConfig | None = None
|
self.typed_config: HttpRequestNodeConfig | None = None
|
||||||
|
self.last_request: str = ""
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {
|
return {
|
||||||
@@ -170,6 +171,47 @@ class HttpRequestNode(BaseNode):
|
|||||||
"output": VariableType.STRING
|
"output": VariableType.STRING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _extract_output(self, business_result: Any) -> Any:
|
||||||
|
if isinstance(business_result, dict):
|
||||||
|
result = {k: v for k, v in business_result.items() if k != "request"}
|
||||||
|
return result
|
||||||
|
return business_result
|
||||||
|
|
||||||
|
def _extract_extra_fields(self, business_result: Any) -> dict[str, Any]:
|
||||||
|
if isinstance(business_result, dict) and "request" in business_result:
|
||||||
|
return {
|
||||||
|
"process": {
|
||||||
|
"request": business_result.get("request", "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _wrap_error(
|
||||||
|
self,
|
||||||
|
error_message: str,
|
||||||
|
elapsed_time: float,
|
||||||
|
state: WorkflowState,
|
||||||
|
variable_pool: VariablePool
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
input_data = self._extract_input(state, variable_pool)
|
||||||
|
node_output = {
|
||||||
|
"node_id": self.node_id,
|
||||||
|
"node_type": self.node_type,
|
||||||
|
"node_name": self.node_name,
|
||||||
|
"status": "failed",
|
||||||
|
"input": input_data,
|
||||||
|
"output": None,
|
||||||
|
"process": {"request": self.last_request} if self.last_request else None,
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"token_usage": None,
|
||||||
|
"error": error_message
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"node_outputs": {self.node_id: node_output},
|
||||||
|
"error": error_message,
|
||||||
|
"error_node": self.node_id
|
||||||
|
}
|
||||||
|
|
||||||
def _build_timeout(self) -> Timeout:
|
def _build_timeout(self) -> Timeout:
|
||||||
"""
|
"""
|
||||||
Build httpx Timeout configuration.
|
Build httpx Timeout configuration.
|
||||||
@@ -255,18 +297,13 @@ class HttpRequestNode(BaseNode):
|
|||||||
case HttpContentType.NONE:
|
case HttpContentType.NONE:
|
||||||
return {}
|
return {}
|
||||||
case HttpContentType.JSON:
|
case HttpContentType.JSON:
|
||||||
rendered = self._render_template(
|
rendered_body = self._render_template(
|
||||||
self.typed_config.body.data, variable_pool
|
self.typed_config.body.data, variable_pool
|
||||||
)
|
).strip()
|
||||||
if not rendered or not rendered.strip():
|
if not rendered_body:
|
||||||
# 第三方导入的工作流可能出现 content_type=json 但 data 为空的情况,视为无 body
|
content["json"] = {}
|
||||||
return {}
|
else:
|
||||||
try:
|
content["json"] = json.loads(rendered_body)
|
||||||
content["json"] = json.loads(rendered)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Invalid JSON body for HTTP request node: {e.msg} (data={rendered!r})"
|
|
||||||
)
|
|
||||||
case HttpContentType.FROM_DATA:
|
case HttpContentType.FROM_DATA:
|
||||||
data = {}
|
data = {}
|
||||||
files = []
|
files = []
|
||||||
@@ -334,15 +371,61 @@ class HttpRequestNode(BaseNode):
|
|||||||
case _:
|
case _:
|
||||||
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
|
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> Any:
|
def _generate_raw_request(
|
||||||
if isinstance(business_result, dict):
|
self,
|
||||||
return {k: v for k, v in business_result.items() if k != "process_data"}
|
variable_pool: VariablePool,
|
||||||
return business_result
|
url: str,
|
||||||
|
headers: dict[str, str],
|
||||||
|
params: dict[str, str],
|
||||||
|
content: dict[str, Any]
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate raw HTTP request format for debugging.
|
||||||
|
|
||||||
def _extract_extra_fields(self, business_result: Any) -> dict:
|
Args:
|
||||||
if isinstance(business_result, dict) and "process_data" in business_result:
|
variable_pool: Variable Pool
|
||||||
return {"process": business_result["process_data"]}
|
url: Rendered URL
|
||||||
return {}
|
headers: Request headers
|
||||||
|
params: Query parameters
|
||||||
|
content: Request body content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw HTTP request string
|
||||||
|
"""
|
||||||
|
method = self.typed_config.method.value
|
||||||
|
|
||||||
|
if params:
|
||||||
|
param_str = "&".join([f"{k}={v}" for k, v in params.items()])
|
||||||
|
full_url = f"{url}?{param_str}" if "?" not in url else f"{url}&{param_str}"
|
||||||
|
else:
|
||||||
|
full_url = url
|
||||||
|
|
||||||
|
lines = [f"{method} {full_url} HTTP/1.1"]
|
||||||
|
|
||||||
|
for key, value in headers.items():
|
||||||
|
lines.append(f"{key}: {value}")
|
||||||
|
|
||||||
|
if "json" in content and content["json"]:
|
||||||
|
json_body = json.dumps(content["json"], ensure_ascii=False)
|
||||||
|
lines.append(f"Content-Length: {len(json_body)}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(json_body)
|
||||||
|
elif "data" in content and "files" not in content:
|
||||||
|
if isinstance(content["data"], dict):
|
||||||
|
body_str = "&".join([f"{k}={v}" for k, v in content["data"].items()])
|
||||||
|
lines.append(f"Content-Length: {len(body_str)}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(body_str)
|
||||||
|
elif "content" in content:
|
||||||
|
lines.append(f"Content-Length: {len(content['content'])}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(content["content"])
|
||||||
|
elif "files" in content:
|
||||||
|
lines.append("Content-Length: 0")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("# Note: This request includes file uploads")
|
||||||
|
|
||||||
|
return "\r\n".join(lines)
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
|
||||||
"""
|
"""
|
||||||
@@ -362,42 +445,47 @@ class HttpRequestNode(BaseNode):
|
|||||||
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
||||||
"""
|
"""
|
||||||
self.typed_config = HttpRequestNodeConfig(**self.config)
|
self.typed_config = HttpRequestNodeConfig(**self.config)
|
||||||
rendered_url = self._render_template(self.typed_config.url, variable_pool)
|
|
||||||
built_headers = self._build_header(variable_pool) | self._build_auth(variable_pool)
|
# Build request components
|
||||||
built_params = self._build_params(variable_pool)
|
headers = self._build_header(variable_pool) | self._build_auth(variable_pool)
|
||||||
|
params = self._build_params(variable_pool)
|
||||||
|
content = await self._build_content(variable_pool)
|
||||||
|
url = self._render_template(self.typed_config.url, variable_pool)
|
||||||
|
|
||||||
|
logger.info(f"Node {self.node_id}: headers={headers}, params={params}, content keys={list(content.keys())}")
|
||||||
|
|
||||||
|
# Generate raw HTTP request for debugging
|
||||||
|
raw_request = self._generate_raw_request(variable_pool, url, headers, params, content)
|
||||||
|
self.last_request = raw_request
|
||||||
|
logger.info(f"Node {self.node_id}: Generated HTTP request:\n{raw_request}")
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
verify=self.typed_config.verify_ssl,
|
verify=self.typed_config.verify_ssl,
|
||||||
timeout=self._build_timeout(),
|
timeout=self._build_timeout(),
|
||||||
headers=built_headers,
|
headers=headers,
|
||||||
params=built_params,
|
params=params,
|
||||||
follow_redirects=True
|
follow_redirects=True
|
||||||
) as client:
|
) as client:
|
||||||
retries = self.typed_config.retry.max_attempts
|
retries = self.typed_config.retry.max_attempts
|
||||||
while retries > 0:
|
while retries > 0:
|
||||||
try:
|
try:
|
||||||
request_func = self._get_client_method(client)
|
request_func = self._get_client_method(client)
|
||||||
built_content = await self._build_content(variable_pool)
|
|
||||||
resp = await request_func(
|
resp = await request_func(
|
||||||
url=rendered_url,
|
url=url,
|
||||||
**built_content
|
**content
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||||
response = HttpResponse(resp)
|
response = HttpResponse(resp)
|
||||||
# Build raw request summary for process_data
|
return {
|
||||||
raw_request = (
|
**HttpRequestNodeOutput(
|
||||||
f"{self.typed_config.method.upper()} {resp.request.url} HTTP/1.1\r\n"
|
body=response.body,
|
||||||
+ "".join(f"{k}: {v}\r\n" for k, v in resp.request.headers.items())
|
status_code=resp.status_code,
|
||||||
+ "\r\n"
|
headers=resp.headers,
|
||||||
+ (resp.request.content.decode(errors="replace") if resp.request.content else "")
|
files=response.files
|
||||||
)
|
).model_dump(),
|
||||||
return HttpRequestNodeOutput(
|
"request": raw_request
|
||||||
body=response.body,
|
}
|
||||||
status_code=resp.status_code,
|
|
||||||
headers=resp.headers,
|
|
||||||
files=response.files,
|
|
||||||
process_data={"request": raw_request},
|
|
||||||
).model_dump()
|
|
||||||
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
||||||
logger.error(f"HTTP request node exception: {e}")
|
logger.error(f"HTTP request node exception: {e}")
|
||||||
retries -= 1
|
retries -= 1
|
||||||
@@ -413,10 +501,19 @@ class HttpRequestNode(BaseNode):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||||
)
|
)
|
||||||
return self.typed_config.error_handle.default.model_dump()
|
error_result = self.typed_config.error_handle.default.model_dump()
|
||||||
|
error_result["request"] = raw_request
|
||||||
|
return error_result
|
||||||
case HttpErrorHandle.BRANCH:
|
case HttpErrorHandle.BRANCH:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||||
)
|
)
|
||||||
return {"output": "ERROR"}
|
return {
|
||||||
|
"output": "ERROR",
|
||||||
|
"body": "",
|
||||||
|
"status_code": 500,
|
||||||
|
"headers": {},
|
||||||
|
"files": [],
|
||||||
|
"request": raw_request
|
||||||
|
}
|
||||||
raise RuntimeError("http request failed")
|
raise RuntimeError("http request failed")
|
||||||
|
|||||||
@@ -334,8 +334,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
for kb_config in knowledge_bases:
|
for kb_config in knowledge_bases:
|
||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||||
logger.warning("The knowledge base does not exist or access is denied.")
|
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||||
continue
|
|
||||||
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
|
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
|
||||||
if tasks:
|
if tasks:
|
||||||
result = await asyncio.gather(*tasks)
|
result = await asyncio.gather(*tasks)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.celery_task_scheduler import scheduler
|
|
||||||
from app.core.memory.enums import SearchStrategy
|
from app.core.memory.enums import SearchStrategy
|
||||||
from app.core.memory.memory_service import MemoryService
|
from app.core.memory.memory_service import MemoryService
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
@@ -12,6 +11,7 @@ from app.core.workflow.variable.base_variable import VariableType
|
|||||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||||
from app.db import get_db_read
|
from app.db import get_db_read
|
||||||
from app.schemas import FileInput
|
from app.schemas import FileInput
|
||||||
|
from app.tasks import write_message_task
|
||||||
|
|
||||||
|
|
||||||
class MemoryReadNode(BaseNode):
|
class MemoryReadNode(BaseNode):
|
||||||
@@ -126,23 +126,12 @@ class MemoryWriteNode(BaseNode):
|
|||||||
"files": file_info
|
"files": file_info
|
||||||
})
|
})
|
||||||
|
|
||||||
scheduler.push_task(
|
write_message_task.delay(
|
||||||
"app.core.memory.agent.write_message",
|
end_user_id=end_user_id,
|
||||||
end_user_id,
|
message=messages,
|
||||||
{
|
config_id=str(self.typed_config.config_id),
|
||||||
"end_user_id": end_user_id,
|
storage_type=state["memory_storage_type"],
|
||||||
"message": messages,
|
user_rag_memory_id=state["user_rag_memory_id"]
|
||||||
"config_id": str(self.typed_config.config_id),
|
|
||||||
"storage_type": state["memory_storage_type"],
|
|
||||||
"user_rag_memory_id": state["user_rag_memory_id"]
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
# write_message_task.delay(
|
|
||||||
# end_user_id=end_user_id,
|
|
||||||
# message=messages,
|
|
||||||
# config_id=str(self.typed_config.config_id),
|
|
||||||
# storage_type=state["memory_storage_type"],
|
|
||||||
# user_rag_memory_id=state["user_rag_memory_id"]
|
|
||||||
# )
|
|
||||||
|
|
||||||
return "success"
|
return "success"
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import select, desc, func, or_, cast, Text
|
from sqlalchemy import select, desc, func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.exceptions import ResourceNotFoundException
|
from app.core.exceptions import ResourceNotFoundException
|
||||||
from app.core.logging_config import get_db_logger
|
from app.core.logging_config import get_db_logger
|
||||||
from app.models import Conversation, Message
|
from app.models import Conversation, Message
|
||||||
from app.models.app_model import AppType
|
|
||||||
from app.models.conversation_model import ConversationDetail
|
from app.models.conversation_model import ConversationDetail
|
||||||
from app.models.workflow_model import WorkflowExecution
|
|
||||||
|
|
||||||
logger = get_db_logger()
|
logger = get_db_logger()
|
||||||
|
|
||||||
@@ -208,8 +206,7 @@ class ConversationRepository:
|
|||||||
is_draft: Optional[bool] = None,
|
is_draft: Optional[bool] = None,
|
||||||
keyword: Optional[str] = None,
|
keyword: Optional[str] = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
pagesize: int = 20,
|
pagesize: int = 20
|
||||||
app_type: Optional[str] = None,
|
|
||||||
) -> tuple[list[Conversation], int]:
|
) -> tuple[list[Conversation], int]:
|
||||||
"""
|
"""
|
||||||
查询应用日志会话列表(带分页和过滤)
|
查询应用日志会话列表(带分页和过滤)
|
||||||
@@ -221,9 +218,6 @@ class ConversationRepository:
|
|||||||
keyword: 搜索关键词(匹配消息内容)
|
keyword: 搜索关键词(匹配消息内容)
|
||||||
page: 页码(从 1 开始)
|
page: 页码(从 1 开始)
|
||||||
pagesize: 每页数量
|
pagesize: 每页数量
|
||||||
app_type: 应用类型。WORKFLOW 类型改用 workflow_executions 的
|
|
||||||
input_data/output_data 做关键词过滤(因为失败的工作流不会写入 messages 表);
|
|
||||||
其他类型仍走 messages 表。
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[List[Conversation], int]: (会话列表,总数)
|
Tuple[List[Conversation], int]: (会话列表,总数)
|
||||||
@@ -240,28 +234,12 @@ class ConversationRepository:
|
|||||||
|
|
||||||
# 如果有关键词搜索,通过子查询过滤包含该关键词的 conversation
|
# 如果有关键词搜索,通过子查询过滤包含该关键词的 conversation
|
||||||
if keyword:
|
if keyword:
|
||||||
kw_pattern = f"%{keyword}%"
|
# 查找包含关键词的 conversation_id 列表
|
||||||
if app_type == AppType.WORKFLOW:
|
keyword_stmt = (
|
||||||
# 工作流:从 workflow_executions 的 input_data / output_data 匹配
|
select(Message.conversation_id)
|
||||||
# (messages 表只存开场白 assistant 消息,失败的工作流也不会写入)
|
.where(Message.content.ilike(f"%{keyword}%"))
|
||||||
keyword_stmt = (
|
.distinct()
|
||||||
select(WorkflowExecution.conversation_id)
|
)
|
||||||
.where(
|
|
||||||
WorkflowExecution.conversation_id.is_not(None),
|
|
||||||
or_(
|
|
||||||
cast(WorkflowExecution.input_data, Text).ilike(kw_pattern),
|
|
||||||
cast(WorkflowExecution.output_data, Text).ilike(kw_pattern),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.distinct()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Agent 等其他类型:仍走 messages 表(user + assistant 内容)
|
|
||||||
keyword_stmt = (
|
|
||||||
select(Message.conversation_id)
|
|
||||||
.where(Message.content.ilike(kw_pattern))
|
|
||||||
.distinct()
|
|
||||||
)
|
|
||||||
base_stmt = base_stmt.where(Conversation.id.in_(keyword_stmt))
|
base_stmt = base_stmt.where(Conversation.id.in_(keyword_stmt))
|
||||||
|
|
||||||
# Calculate total number of records
|
# Calculate total number of records
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ class AppLogMessage(BaseModel):
|
|||||||
conversation_id: uuid.UUID
|
conversation_id: uuid.UUID
|
||||||
role: str = Field(description="角色: user / assistant / system")
|
role: str = Field(description="角色: user / assistant / system")
|
||||||
content: str
|
content: str
|
||||||
status: Optional[str] = Field(default=None, description="执行状态(工作流专用): completed / failed")
|
|
||||||
meta_data: Optional[Dict[str, Any]] = None
|
meta_data: Optional[Dict[str, Any]] = None
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
|
|
||||||
@@ -59,7 +58,6 @@ class AppLogNodeExecution(BaseModel):
|
|||||||
input: Optional[Any] = None
|
input: Optional[Any] = None
|
||||||
process: Optional[Any] = None
|
process: Optional[Any] = None
|
||||||
output: Optional[Any] = None
|
output: Optional[Any] = None
|
||||||
cycle_items: Optional[List[Any]] = None
|
|
||||||
elapsed_time: Optional[float] = None
|
elapsed_time: Optional[float] = None
|
||||||
token_usage: Optional[Dict[str, Any]] = None
|
token_usage: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import uuid
|
|||||||
from typing import Optional, Any, List, Dict, Union
|
from typing import Optional, Any, List, Dict, Union
|
||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator, model_serializer
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||||
|
|
||||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||||
|
|
||||||
@@ -250,7 +250,7 @@ class ModelParameters(BaseModel):
|
|||||||
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
|
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
|
||||||
stop: Optional[List[str]] = Field(default=None, description="停止序列")
|
stop: Optional[List[str]] = Field(default=None, description="停止序列")
|
||||||
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
|
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
|
||||||
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1, le=131072, description="深度思考 token 预算(仅部分模型支持)")
|
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)")
|
||||||
json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)")
|
json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)")
|
||||||
|
|
||||||
|
|
||||||
@@ -661,11 +661,9 @@ class DraftRunResponse(BaseModel):
|
|||||||
suggested_questions: List[str] = Field(default_factory=list, description="下一步建议问题")
|
suggested_questions: List[str] = Field(default_factory=list, description="下一步建议问题")
|
||||||
citations: List[Dict[str, Any]] = Field(default_factory=list, description="引用来源")
|
citations: List[Dict[str, Any]] = Field(default_factory=list, description="引用来源")
|
||||||
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
|
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
|
||||||
audio_status: Optional[str] = Field(default=None, description="TTS 语音状态")
|
|
||||||
|
|
||||||
@model_serializer(mode="wrap")
|
def model_dump(self, **kwargs):
|
||||||
def _serialize(self, handler):
|
data = super().model_dump(**kwargs)
|
||||||
data = handler(self)
|
|
||||||
if not data.get("reasoning_content"):
|
if not data.get("reasoning_content"):
|
||||||
data.pop("reasoning_content", None)
|
data.pop("reasoning_content", None)
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, model_serializer
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||||
|
|
||||||
# 导入 FileInput(用于体验运行)
|
# 导入 FileInput(用于体验运行)
|
||||||
from app.schemas.app_schema import FileInput
|
from app.schemas.app_schema import FileInput
|
||||||
@@ -94,18 +94,6 @@ class ChatResponse(BaseModel):
|
|||||||
message_id: str
|
message_id: str
|
||||||
usage: Optional[Dict[str, Any]] = None
|
usage: Optional[Dict[str, Any]] = None
|
||||||
elapsed_time: Optional[float] = None
|
elapsed_time: Optional[float] = None
|
||||||
reasoning_content: Optional[str] = None
|
|
||||||
suggested_questions: Optional[List[str]] = None
|
|
||||||
citations: Optional[List[Dict[str, Any]]] = None
|
|
||||||
audio_url: Optional[str] = None
|
|
||||||
audio_status: Optional[str] = None
|
|
||||||
|
|
||||||
@model_serializer(mode="wrap")
|
|
||||||
def _serialize(self, handler):
|
|
||||||
data = handler(self)
|
|
||||||
if not data.get("reasoning_content"):
|
|
||||||
data.pop("reasoning_content", None)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
# ---------- Conversation Summary Schemas ----------
|
# ---------- Conversation Summary Schemas ----------
|
||||||
|
|||||||
@@ -112,12 +112,12 @@ class MemoryWriteResponse(BaseModel):
|
|||||||
"""Response schema for memory write operation.
|
"""Response schema for memory write operation.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
task_id: task ID for status polling
|
task_id: Celery task ID for status polling
|
||||||
status: Initial task status (QUEUED)
|
status: Initial task status (PENDING)
|
||||||
end_user_id: End user ID the write was submitted for
|
end_user_id: End user ID the write was submitted for
|
||||||
"""
|
"""
|
||||||
task_id: str = Field(..., description="task ID for polling")
|
task_id: str = Field(..., description="Celery task ID for polling")
|
||||||
status: str = Field(..., description="Task status: QUEUED")
|
status: str = Field(..., description="Task status: PENDING")
|
||||||
end_user_id: str = Field(..., description="End user ID")
|
end_user_id: str = Field(..., description="End user ID")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.aioRedis import aio_redis
|
from app.aioRedis import aio_redis
|
||||||
from app.models.api_key_model import ApiKey, ApiKeyType
|
from app.models.api_key_model import ApiKey
|
||||||
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
|
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
|
||||||
from app.schemas import api_key_schema
|
from app.schemas import api_key_schema
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
@@ -65,12 +65,6 @@ class ApiKeyService:
|
|||||||
BizCode.BAD_REQUEST
|
BizCode.BAD_REQUEST
|
||||||
)
|
)
|
||||||
|
|
||||||
# SERVICE 类型的 resource_id 指向 workspace,非应用,跳过应用发布校验
|
|
||||||
if data.resource_id and data.type != ApiKeyType.SERVICE.value:
|
|
||||||
app = db.get(App, data.resource_id)
|
|
||||||
if not app or not app.current_release_id:
|
|
||||||
raise BusinessException("该应用未发布", BizCode.APP_NOT_PUBLISHED)
|
|
||||||
|
|
||||||
# 生成 API Key
|
# 生成 API Key
|
||||||
api_key = generate_api_key(data.type)
|
api_key = generate_api_key(data.type)
|
||||||
|
|
||||||
@@ -453,12 +447,9 @@ class ApiKeyAuthService:
|
|||||||
def check_app_published(db: Session, api_key_obj: ApiKey) -> None:
|
def check_app_published(db: Session, api_key_obj: ApiKey) -> None:
|
||||||
"""
|
"""
|
||||||
检查应用是否已发布,未发布则抛出异常
|
检查应用是否已发布,未发布则抛出异常
|
||||||
SERVICE 类型的 api_key 不绑定应用(resource_id 指向 workspace),跳过校验
|
|
||||||
"""
|
"""
|
||||||
if not api_key_obj.resource_id:
|
if not api_key_obj.resource_id:
|
||||||
return
|
return
|
||||||
if api_key_obj.type == ApiKeyType.SERVICE.value:
|
|
||||||
return
|
|
||||||
app = db.get(App, api_key_obj.resource_id)
|
app = db.get(App, api_key_obj.resource_id)
|
||||||
if not app or not app.current_release_id:
|
if not app or not app.current_release_id:
|
||||||
raise BusinessException("应用未发布,不可用", BizCode.APP_NOT_PUBLISHED)
|
raise BusinessException("应用未发布,不可用", BizCode.APP_NOT_PUBLISHED)
|
||||||
|
|||||||
@@ -107,6 +107,23 @@ class AppChatService:
|
|||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
|
|
||||||
|
# 创建 LangChain Agent
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_obj.model_name,
|
||||||
|
api_key=api_key_obj.api_key,
|
||||||
|
provider=api_key_obj.provider,
|
||||||
|
api_base=api_key_obj.api_base,
|
||||||
|
is_omni=api_key_obj.is_omni,
|
||||||
|
temperature=model_parameters.get("temperature", 0.7),
|
||||||
|
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||||
|
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||||
|
json_output=model_parameters.get("json_output", False),
|
||||||
|
capability=api_key_obj.capability or [],
|
||||||
|
)
|
||||||
|
|
||||||
model_info = ModelInfo(
|
model_info = ModelInfo(
|
||||||
model_name=api_key_obj.model_name,
|
model_name=api_key_obj.model_name,
|
||||||
provider=api_key_obj.provider,
|
provider=api_key_obj.provider,
|
||||||
@@ -160,30 +177,16 @@ class AppChatService:
|
|||||||
if doc_img_recognition and "vision" in (api_key_obj.capability or []) and any(
|
if doc_img_recognition and "vision" in (api_key_obj.capability or []) and any(
|
||||||
f.type == FileType.DOCUMENT for f in files
|
f.type == FileType.DOCUMENT for f in files
|
||||||
):
|
):
|
||||||
system_prompt += (
|
from langchain.agents import create_agent
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
agent.system_prompt += (
|
||||||
"请在回答中用 Markdown 格式  展示对应图片。"
|
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
||||||
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
||||||
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
)
|
||||||
|
agent.agent = create_agent(
|
||||||
|
model=agent.llm,
|
||||||
|
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
||||||
|
system_prompt=agent.system_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
|
||||||
agent = LangChainAgent(
|
|
||||||
model_name=api_key_obj.model_name,
|
|
||||||
api_key=api_key_obj.api_key,
|
|
||||||
provider=api_key_obj.provider,
|
|
||||||
api_base=api_key_obj.api_base,
|
|
||||||
is_omni=api_key_obj.is_omni,
|
|
||||||
temperature=model_parameters.get("temperature", 0.7),
|
|
||||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
|
||||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
|
||||||
json_output=model_parameters.get("json_output", False),
|
|
||||||
capability=api_key_obj.capability or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 为需要运行时上下文的工具注入上下文
|
# 为需要运行时上下文的工具注入上下文
|
||||||
for t in tools:
|
for t in tools:
|
||||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||||
@@ -320,7 +323,7 @@ class AppChatService:
|
|||||||
"suggested_questions": suggested_questions,
|
"suggested_questions": suggested_questions,
|
||||||
"citations": filtered_citations,
|
"citations": filtered_citations,
|
||||||
"audio_url": audio_url,
|
"audio_url": audio_url,
|
||||||
"audio_status": "pending" if audio_url else None
|
"audio_status": "pending"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def agnet_chat_stream(
|
async def agnet_chat_stream(
|
||||||
@@ -396,6 +399,24 @@ class AppChatService:
|
|||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
|
|
||||||
|
# 创建 LangChain Agent
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_obj.model_name,
|
||||||
|
api_key=api_key_obj.api_key,
|
||||||
|
provider=api_key_obj.provider,
|
||||||
|
api_base=api_key_obj.api_base,
|
||||||
|
is_omni=api_key_obj.is_omni,
|
||||||
|
temperature=model_parameters.get("temperature", 0.7),
|
||||||
|
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
streaming=True,
|
||||||
|
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||||
|
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||||
|
json_output=model_parameters.get("json_output", False),
|
||||||
|
capability=api_key_obj.capability or [],
|
||||||
|
)
|
||||||
|
|
||||||
model_info = ModelInfo(
|
model_info = ModelInfo(
|
||||||
model_name=api_key_obj.model_name,
|
model_name=api_key_obj.model_name,
|
||||||
provider=api_key_obj.provider,
|
provider=api_key_obj.provider,
|
||||||
@@ -450,30 +471,15 @@ class AppChatService:
|
|||||||
f.type == FileType.DOCUMENT for f in files
|
f.type == FileType.DOCUMENT for f in files
|
||||||
):
|
):
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
system_prompt += (
|
agent.system_prompt += (
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
||||||
"请在回答中用 Markdown 格式  展示对应图片。"
|
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
||||||
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
)
|
||||||
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
agent.agent = create_agent(
|
||||||
|
model=agent.llm,
|
||||||
|
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
||||||
|
system_prompt=agent.system_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
|
||||||
agent = LangChainAgent(
|
|
||||||
model_name=api_key_obj.model_name,
|
|
||||||
api_key=api_key_obj.api_key,
|
|
||||||
provider=api_key_obj.provider,
|
|
||||||
api_base=api_key_obj.api_base,
|
|
||||||
is_omni=api_key_obj.is_omni,
|
|
||||||
temperature=model_parameters.get("temperature", 0.7),
|
|
||||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
streaming=True,
|
|
||||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
|
||||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
|
||||||
json_output=model_parameters.get("json_output", False),
|
|
||||||
capability=api_key_obj.capability or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 为需要运行时上下文的工具注入上下文
|
# 为需要运行时上下文的工具注入上下文
|
||||||
for t in tools:
|
for t in tools:
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
"""应用日志服务层"""
|
"""应用日志服务层"""
|
||||||
import uuid
|
import uuid
|
||||||
import datetime as dt
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.models.app_model import AppType
|
|
||||||
from app.models.conversation_model import Conversation, Message
|
from app.models.conversation_model import Conversation, Message
|
||||||
from app.models.workflow_model import WorkflowExecution
|
from app.models.workflow_model import WorkflowExecution
|
||||||
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
||||||
from app.schemas.app_log_schema import AppLogMessage, AppLogNodeExecution
|
from app.schemas.app_log_schema import AppLogNodeExecution
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -32,7 +31,6 @@ class AppLogService:
|
|||||||
pagesize: int = 20,
|
pagesize: int = 20,
|
||||||
is_draft: Optional[bool] = None,
|
is_draft: Optional[bool] = None,
|
||||||
keyword: Optional[str] = None,
|
keyword: Optional[str] = None,
|
||||||
app_type: Optional[str] = None,
|
|
||||||
) -> Tuple[list[Conversation], int]:
|
) -> Tuple[list[Conversation], int]:
|
||||||
"""
|
"""
|
||||||
查询应用日志会话列表
|
查询应用日志会话列表
|
||||||
@@ -44,7 +42,6 @@ class AppLogService:
|
|||||||
pagesize: 每页数量
|
pagesize: 每页数量
|
||||||
is_draft: 是否草稿会话(None表示返回全部)
|
is_draft: 是否草稿会话(None表示返回全部)
|
||||||
keyword: 搜索关键词(匹配消息内容)
|
keyword: 搜索关键词(匹配消息内容)
|
||||||
app_type: 应用类型(WORKFLOW 时关键词将从 workflow_executions 搜索)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[list[Conversation], int]: (会话列表,总数)
|
Tuple[list[Conversation], int]: (会话列表,总数)
|
||||||
@@ -57,8 +54,7 @@ class AppLogService:
|
|||||||
"page": page,
|
"page": page,
|
||||||
"pagesize": pagesize,
|
"pagesize": pagesize,
|
||||||
"is_draft": is_draft,
|
"is_draft": is_draft,
|
||||||
"keyword": keyword,
|
"keyword": keyword
|
||||||
"app_type": app_type,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -69,8 +65,7 @@ class AppLogService:
|
|||||||
is_draft=is_draft,
|
is_draft=is_draft,
|
||||||
keyword=keyword,
|
keyword=keyword,
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=pagesize,
|
pagesize=pagesize
|
||||||
app_type=app_type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -88,40 +83,51 @@ class AppLogService:
|
|||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID
|
||||||
app_type: str = AppType.AGENT
|
) -> Tuple[Conversation, dict[str, list[AppLogNodeExecution]]]:
|
||||||
) -> Tuple[Conversation, list, dict[str, list[AppLogNodeExecution]]]:
|
|
||||||
"""
|
"""
|
||||||
查询会话详情
|
查询会话详情(包含消息和工作流节点执行记录)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
conversation_id: 会话 ID
|
||||||
|
workspace_id: 工作空间 ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Conversation, list[AppLogMessage|Message], dict[str, list[AppLogNodeExecution]]]
|
Tuple[Conversation, dict[str, list[AppLogNodeExecution]]]:
|
||||||
|
(包含消息的会话对象, 按消息ID分组的节点执行记录)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResourceNotFoundException: 当会话不存在时
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
"查询应用日志会话详情",
|
"查询应用日志会话详情",
|
||||||
extra={
|
extra={
|
||||||
"app_id": str(app_id),
|
"app_id": str(app_id),
|
||||||
"conversation_id": str(conversation_id),
|
"conversation_id": str(conversation_id),
|
||||||
"workspace_id": str(workspace_id),
|
"workspace_id": str(workspace_id)
|
||||||
"app_type": app_type
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 查询会话
|
||||||
conversation = self.conversation_repository.get_conversation_for_app_log(
|
conversation = self.conversation_repository.get_conversation_for_app_log(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if app_type == AppType.WORKFLOW:
|
# 查询消息(按时间正序)
|
||||||
messages, node_executions_map = self._get_workflow_messages_and_nodes(conversation_id)
|
messages = self.message_repository.get_messages_by_conversation(
|
||||||
else:
|
conversation_id=conversation_id
|
||||||
messages = self.message_repository.get_messages_by_conversation(
|
)
|
||||||
conversation_id=conversation_id
|
|
||||||
)
|
# 将消息附加到会话对象
|
||||||
node_executions_map = self._get_workflow_node_executions_with_map(
|
conversation.messages = messages
|
||||||
conversation_id, messages
|
|
||||||
)
|
# 查询工作流节点执行记录(按消息分组)
|
||||||
|
_, node_executions_map = self._get_workflow_node_executions_with_map(
|
||||||
|
conversation_id, messages
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"查询应用日志会话详情成功",
|
"查询应用日志会话详情成功",
|
||||||
@@ -133,129 +139,13 @@ class AppLogService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return conversation, messages, node_executions_map
|
return conversation, node_executions_map
|
||||||
|
|
||||||
def _get_workflow_messages_and_nodes(
|
|
||||||
self,
|
|
||||||
conversation_id: uuid.UUID,
|
|
||||||
) -> Tuple[list[AppLogMessage], dict[str, list[AppLogNodeExecution]]]:
|
|
||||||
"""
|
|
||||||
工作流应用专用:从 workflow_executions 构建 messages 和节点日志。
|
|
||||||
|
|
||||||
每条 WorkflowExecution 对应一轮对话:
|
|
||||||
- user message:来自 execution.input_data(content 取 message 字段,files 放 meta_data)
|
|
||||||
- assistant message:来自 execution.output_data(失败时内容为错误信息)
|
|
||||||
开场白的 suggested_questions 合并到第一条 assistant message 的 meta_data 里。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(messages 列表, node_executions_map)
|
|
||||||
"""
|
|
||||||
stmt = (
|
|
||||||
select(WorkflowExecution)
|
|
||||||
.where(
|
|
||||||
WorkflowExecution.conversation_id == conversation_id,
|
|
||||||
WorkflowExecution.status.in_(["completed", "failed"])
|
|
||||||
)
|
|
||||||
.order_by(WorkflowExecution.started_at.asc())
|
|
||||||
)
|
|
||||||
executions = list(self.db.scalars(stmt).all())
|
|
||||||
|
|
||||||
# 查开场白:Message 表里 meta_data 含 suggested_questions 的第一条 assistant 消息
|
|
||||||
opening_stmt = (
|
|
||||||
select(Message)
|
|
||||||
.where(
|
|
||||||
Message.conversation_id == conversation_id,
|
|
||||||
Message.role == "assistant",
|
|
||||||
)
|
|
||||||
.order_by(Message.created_at.asc())
|
|
||||||
.limit(10)
|
|
||||||
)
|
|
||||||
early_messages = list(self.db.scalars(opening_stmt).all())
|
|
||||||
suggested_questions: list = []
|
|
||||||
for m in early_messages:
|
|
||||||
if isinstance(m.meta_data, dict) and "suggested_questions" in m.meta_data:
|
|
||||||
suggested_questions = m.meta_data.get("suggested_questions") or []
|
|
||||||
break
|
|
||||||
|
|
||||||
messages: list[AppLogMessage] = []
|
|
||||||
node_executions_map: dict[str, list[AppLogNodeExecution]] = {}
|
|
||||||
|
|
||||||
# 如果有开场白,作为第一条 assistant 消息插入
|
|
||||||
if suggested_questions or early_messages:
|
|
||||||
opening_msg = next(
|
|
||||||
(m for m in early_messages
|
|
||||||
if isinstance(m.meta_data, dict) and "suggested_questions" in m.meta_data),
|
|
||||||
None
|
|
||||||
)
|
|
||||||
if opening_msg:
|
|
||||||
messages.append(AppLogMessage(
|
|
||||||
id=opening_msg.id,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
role="assistant",
|
|
||||||
content=opening_msg.content,
|
|
||||||
status=None,
|
|
||||||
meta_data={"suggested_questions": suggested_questions},
|
|
||||||
created_at=opening_msg.created_at,
|
|
||||||
))
|
|
||||||
|
|
||||||
for execution in executions:
|
|
||||||
started_at = execution.started_at or dt.datetime.now()
|
|
||||||
completed_at = execution.completed_at or started_at
|
|
||||||
|
|
||||||
# assistant message 的 id,同时作为 node_executions_map 的 key
|
|
||||||
assistant_msg_id = uuid.uuid5(execution.id, "assistant")
|
|
||||||
|
|
||||||
# --- user message(输入)---
|
|
||||||
input_data = execution.input_data or {}
|
|
||||||
input_content = input_data.get("message") or _extract_text(input_data)
|
|
||||||
|
|
||||||
# 跳过没有用户输入的 execution(如开场白触发的记录)
|
|
||||||
if not input_content or not input_content.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
files = input_data.get("files") or []
|
|
||||||
user_msg = AppLogMessage(
|
|
||||||
id=uuid.uuid5(execution.id, "user"),
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
role="user",
|
|
||||||
content=input_content,
|
|
||||||
meta_data={"files": files} if files else None,
|
|
||||||
created_at=started_at,
|
|
||||||
)
|
|
||||||
messages.append(user_msg)
|
|
||||||
|
|
||||||
# --- assistant message(输出)---
|
|
||||||
if execution.status == "completed":
|
|
||||||
output_content = _extract_text(execution.output_data)
|
|
||||||
meta = {"usage": execution.token_usage or {}, "elapsed_time": execution.elapsed_time}
|
|
||||||
else:
|
|
||||||
output_content = _extract_text(execution.output_data) or ""
|
|
||||||
meta = {"error": execution.error_message, "error_node_id": execution.error_node_id}
|
|
||||||
|
|
||||||
assistant_msg = AppLogMessage(
|
|
||||||
id=assistant_msg_id,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
role="assistant",
|
|
||||||
content=output_content,
|
|
||||||
status=execution.status,
|
|
||||||
meta_data=meta,
|
|
||||||
created_at=completed_at,
|
|
||||||
)
|
|
||||||
messages.append(assistant_msg)
|
|
||||||
|
|
||||||
# --- 节点执行记录,从 workflow_executions.output_data["node_outputs"] 读取 ---
|
|
||||||
execution_nodes = _build_nodes_from_output_data(execution.output_data)
|
|
||||||
|
|
||||||
if execution_nodes:
|
|
||||||
node_executions_map[str(assistant_msg_id)] = execution_nodes
|
|
||||||
|
|
||||||
return messages, node_executions_map
|
|
||||||
|
|
||||||
def _get_workflow_node_executions_with_map(
|
def _get_workflow_node_executions_with_map(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
messages: list[Message]
|
messages: list[Message]
|
||||||
) -> dict[str, list[AppLogNodeExecution]]:
|
) -> Tuple[list[AppLogNodeExecution], dict[str, list[AppLogNodeExecution]]]:
|
||||||
"""
|
"""
|
||||||
从 workflow_executions 表中提取节点执行记录,并按 assistant message 分组
|
从 workflow_executions 表中提取节点执行记录,并按 assistant message 分组
|
||||||
|
|
||||||
@@ -267,12 +157,13 @@ class AppLogService:
|
|||||||
Tuple[list[AppLogNodeExecution], dict[str, list[AppLogNodeExecution]]]:
|
Tuple[list[AppLogNodeExecution], dict[str, list[AppLogNodeExecution]]]:
|
||||||
(所有节点执行记录列表, 按 message_id 分组的节点执行记录字典)
|
(所有节点执行记录列表, 按 message_id 分组的节点执行记录字典)
|
||||||
"""
|
"""
|
||||||
|
node_executions = []
|
||||||
node_executions_map: dict[str, list[AppLogNodeExecution]] = {}
|
node_executions_map: dict[str, list[AppLogNodeExecution]] = {}
|
||||||
|
|
||||||
# 查询该会话关联的所有工作流执行记录(按时间正序)
|
# 查询该会话关联的所有工作流执行记录(按时间正序)
|
||||||
stmt = select(WorkflowExecution).where(
|
stmt = select(WorkflowExecution).where(
|
||||||
WorkflowExecution.conversation_id == conversation_id,
|
WorkflowExecution.conversation_id == conversation_id,
|
||||||
WorkflowExecution.status.in_(["completed", "failed"])
|
WorkflowExecution.status == "completed"
|
||||||
).order_by(WorkflowExecution.started_at.asc())
|
).order_by(WorkflowExecution.started_at.asc())
|
||||||
|
|
||||||
executions = self.db.scalars(stmt).all()
|
executions = self.db.scalars(stmt).all()
|
||||||
@@ -297,18 +188,10 @@ class AppLogService:
|
|||||||
used_message_ids: set[str] = set()
|
used_message_ids: set[str] = set()
|
||||||
|
|
||||||
for execution in executions:
|
for execution in executions:
|
||||||
# 构建节点执行记录列表,从 workflow_executions.output_data["node_outputs"] 读取
|
if not execution.output_data:
|
||||||
execution_nodes = _build_nodes_from_output_data(execution.output_data)
|
|
||||||
|
|
||||||
if not execution_nodes:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 失败的执行没有 assistant message,直接用 execution id 作为 key
|
# 找到该 execution 对应的 assistant message
|
||||||
if execution.status == "failed":
|
|
||||||
node_executions_map[f"execution_{str(execution.id)}"] = execution_nodes
|
|
||||||
continue
|
|
||||||
|
|
||||||
# completed:通过时序匹配关联到对应的 assistant message
|
|
||||||
# 逻辑:找 execution.started_at 之后最近的、未使用的 assistant message
|
# 逻辑:找 execution.started_at 之后最近的、未使用的 assistant message
|
||||||
best_msg = None
|
best_msg = None
|
||||||
best_dt = None
|
best_dt = None
|
||||||
@@ -317,9 +200,9 @@ class AppLogService:
|
|||||||
if msg_id_str in used_message_ids:
|
if msg_id_str in used_message_ids:
|
||||||
continue
|
continue
|
||||||
if msg.created_at and msg.created_at >= execution.started_at:
|
if msg.created_at and msg.created_at >= execution.started_at:
|
||||||
delta = (msg.created_at - execution.started_at).total_seconds()
|
dt = (msg.created_at - execution.started_at).total_seconds()
|
||||||
if best_dt is None or delta < best_dt:
|
if best_dt is None or dt < best_dt:
|
||||||
best_dt = delta
|
best_dt = dt
|
||||||
best_msg = msg
|
best_msg = msg
|
||||||
|
|
||||||
if not best_msg:
|
if not best_msg:
|
||||||
@@ -327,86 +210,31 @@ class AppLogService:
|
|||||||
|
|
||||||
msg_id_str = str(best_msg.id)
|
msg_id_str = str(best_msg.id)
|
||||||
used_message_ids.add(msg_id_str)
|
used_message_ids.add(msg_id_str)
|
||||||
node_executions_map[msg_id_str] = execution_nodes
|
|
||||||
|
|
||||||
return node_executions_map
|
# 提取节点输出
|
||||||
|
output_data = execution.output_data
|
||||||
|
if isinstance(output_data, dict):
|
||||||
|
node_outputs = output_data.get("node_outputs", {})
|
||||||
|
execution_nodes = []
|
||||||
|
for node_id, node_data in node_outputs.items():
|
||||||
|
if not isinstance(node_data, dict):
|
||||||
|
continue
|
||||||
|
node_execution = AppLogNodeExecution(
|
||||||
|
node_id=node_data.get("node_id", node_id),
|
||||||
|
node_type=node_data.get("node_type", "unknown"),
|
||||||
|
node_name=node_data.get("node_name"),
|
||||||
|
status=node_data.get("status", "unknown"),
|
||||||
|
error=node_data.get("error"),
|
||||||
|
input=node_data.get("input"),
|
||||||
|
process=node_data.get("process"),
|
||||||
|
output=node_data.get("output"),
|
||||||
|
elapsed_time=node_data.get("elapsed_time"),
|
||||||
|
token_usage=node_data.get("token_usage"),
|
||||||
|
)
|
||||||
|
node_executions.append(node_execution)
|
||||||
|
execution_nodes.append(node_execution)
|
||||||
|
|
||||||
|
# 将节点记录关联到 message_id
|
||||||
|
node_executions_map[msg_id_str] = execution_nodes
|
||||||
|
|
||||||
def _extract_text(data: Optional[dict]) -> str:
|
return node_executions, node_executions_map
|
||||||
"""从 workflow execution 的 input_data / output_data 中提取可读文本。
|
|
||||||
|
|
||||||
优先取 'text'、'content'、'output' 字段;若都没有则 JSON 序列化整个 dict。
|
|
||||||
"""
|
|
||||||
if not data:
|
|
||||||
return ""
|
|
||||||
for key in ("message", "text", "content", "output", "result", "answer"):
|
|
||||||
if key in data and isinstance(data[key], str):
|
|
||||||
return data[key]
|
|
||||||
import json
|
|
||||||
return json.dumps(data, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_nodes_from_output_data(output_data: Optional[dict]) -> list[AppLogNodeExecution]:
|
|
||||||
"""从 workflow_executions.output_data["node_outputs"] 构建节点执行记录列表。
|
|
||||||
|
|
||||||
output_data 结构:
|
|
||||||
{
|
|
||||||
"node_outputs": {
|
|
||||||
"<node_id>": {
|
|
||||||
"node_type": ...,
|
|
||||||
"node_name": ...,
|
|
||||||
"status": ...,
|
|
||||||
"input": ...,
|
|
||||||
"output": ...,
|
|
||||||
"elapsed_time": ...,
|
|
||||||
"token_usage": ...,
|
|
||||||
"error": ...,
|
|
||||||
"cycle_items": [...],
|
|
||||||
...
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"error": ...,
|
|
||||||
...
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
if not output_data:
|
|
||||||
return []
|
|
||||||
node_outputs: dict = output_data.get("node_outputs") or {}
|
|
||||||
# 按 execution_order(节点执行时写入的单调递增序号)排序。
|
|
||||||
# PostgreSQL JSONB 不保证 key 顺序,不能依赖 dict 插入顺序;
|
|
||||||
# 缺失 execution_order 的历史数据退化到 0,保持在最前。
|
|
||||||
ordered_items = sorted(
|
|
||||||
node_outputs.items(),
|
|
||||||
key=lambda kv: (kv[1] or {}).get("execution_order", 0)
|
|
||||||
if isinstance(kv[1], dict) else 0
|
|
||||||
)
|
|
||||||
result = []
|
|
||||||
for node_id, node_data in ordered_items:
|
|
||||||
if not isinstance(node_data, dict):
|
|
||||||
continue
|
|
||||||
output = dict(node_data)
|
|
||||||
cycle_items = output.pop("cycle_items", None)
|
|
||||||
# 把已知的顶层字段剥离,剩余的作为 output
|
|
||||||
node_type = output.pop("node_type", "unknown")
|
|
||||||
node_name = output.pop("node_name", None)
|
|
||||||
status = output.pop("status", "completed")
|
|
||||||
error = output.pop("error", None)
|
|
||||||
inp = output.pop("input", None)
|
|
||||||
elapsed_time = output.pop("elapsed_time", None)
|
|
||||||
token_usage = output.pop("token_usage", None)
|
|
||||||
# execution_order 仅用于排序,不返回给前端
|
|
||||||
output.pop("execution_order", None)
|
|
||||||
result.append(AppLogNodeExecution(
|
|
||||||
node_id=node_id,
|
|
||||||
node_type=node_type,
|
|
||||||
node_name=node_name,
|
|
||||||
status=status,
|
|
||||||
error=error,
|
|
||||||
input=inp,
|
|
||||||
process=None,
|
|
||||||
output=output if output else None,
|
|
||||||
cycle_items=cycle_items,
|
|
||||||
elapsed_time=elapsed_time,
|
|
||||||
token_usage=token_usage,
|
|
||||||
))
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -595,6 +595,23 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
tools.extend(memory_tools)
|
tools.extend(memory_tools)
|
||||||
|
|
||||||
|
# 4. 创建 LangChain Agent
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_config["model_name"],
|
||||||
|
api_key=api_key_config["api_key"],
|
||||||
|
provider=api_key_config.get("provider", "openai"),
|
||||||
|
api_base=api_key_config.get("api_base"),
|
||||||
|
is_omni=api_key_config.get("is_omni", False),
|
||||||
|
temperature=effective_params.get("temperature", 0.7),
|
||||||
|
max_tokens=effective_params.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
deep_thinking=effective_params.get("deep_thinking", False),
|
||||||
|
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||||||
|
json_output=effective_params.get("json_output", False),
|
||||||
|
capability=api_key_config.get("capability", []),
|
||||||
|
)
|
||||||
|
|
||||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||||
is_new_conversation = not conversation_id
|
is_new_conversation = not conversation_id
|
||||||
opening, suggested_questions = None, None
|
opening, suggested_questions = None, None
|
||||||
@@ -649,29 +666,16 @@ class AgentRunService:
|
|||||||
and any(f.type == FileType.DOCUMENT for f in files)
|
and any(f.type == FileType.DOCUMENT for f in files)
|
||||||
)
|
)
|
||||||
if has_doc_with_images:
|
if has_doc_with_images:
|
||||||
system_prompt += (
|
agent.system_prompt += (
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
||||||
"请在回答中用 Markdown 格式  展示对应图片。"
|
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
||||||
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
)
|
||||||
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
# 重建 agent graph 以使新 system_prompt 生效
|
||||||
|
agent.agent = create_agent(
|
||||||
|
model=agent.llm,
|
||||||
|
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
||||||
|
system_prompt=agent.system_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = LangChainAgent(
|
|
||||||
model_name=api_key_config["model_name"],
|
|
||||||
api_key=api_key_config["api_key"],
|
|
||||||
provider=api_key_config.get("provider", "openai"),
|
|
||||||
api_base=api_key_config.get("api_base"),
|
|
||||||
is_omni=api_key_config.get("is_omni", False),
|
|
||||||
temperature=effective_params.get("temperature", 0.7),
|
|
||||||
max_tokens=effective_params.get("max_tokens", 2000),
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
deep_thinking=effective_params.get("deep_thinking", False),
|
|
||||||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
|
||||||
json_output=effective_params.get("json_output", False),
|
|
||||||
capability=api_key_config.get("capability", []),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 为需要运行时上下文的工具注入上下文
|
# 为需要运行时上下文的工具注入上下文
|
||||||
for t in tools:
|
for t in tools:
|
||||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||||
@@ -757,7 +761,7 @@ class AgentRunService:
|
|||||||
) if not sub_agent else [],
|
) if not sub_agent else [],
|
||||||
"citations": filtered_citations,
|
"citations": filtered_citations,
|
||||||
"audio_url": audio_url,
|
"audio_url": audio_url,
|
||||||
"audio_status": "pending" if audio_url else None
|
"audio_status": "pending"
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -871,6 +875,24 @@ class AgentRunService:
|
|||||||
user_rag_memory_id)
|
user_rag_memory_id)
|
||||||
tools.extend(memory_tools)
|
tools.extend(memory_tools)
|
||||||
|
|
||||||
|
# 4. 创建 LangChain Agent
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_config["model_name"],
|
||||||
|
api_key=api_key_config["api_key"],
|
||||||
|
provider=api_key_config.get("provider", "openai"),
|
||||||
|
api_base=api_key_config.get("api_base"),
|
||||||
|
is_omni=api_key_config.get("is_omni", False),
|
||||||
|
temperature=effective_params.get("temperature", 0.7),
|
||||||
|
max_tokens=effective_params.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
streaming=True,
|
||||||
|
deep_thinking=effective_params.get("deep_thinking", False),
|
||||||
|
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||||||
|
json_output=effective_params.get("json_output", False),
|
||||||
|
capability=api_key_config.get("capability", []),
|
||||||
|
)
|
||||||
|
|
||||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||||
is_new_conversation = not conversation_id
|
is_new_conversation = not conversation_id
|
||||||
opening, suggested_questions = None, None
|
opening, suggested_questions = None, None
|
||||||
@@ -926,31 +948,18 @@ class AgentRunService:
|
|||||||
and any(f.type == FileType.DOCUMENT for f in files)
|
and any(f.type == FileType.DOCUMENT for f in files)
|
||||||
)
|
)
|
||||||
if has_doc_with_images:
|
if has_doc_with_images:
|
||||||
system_prompt += (
|
agent.system_prompt += (
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
"\n\n文档中包含图片,图片位置已在文本中以 [图片 第N页 第M张图片]: URL 标记。"
|
||||||
"请在回答中用 Markdown 格式  展示对应图片。"
|
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
||||||
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
"**规则1:图片URL必须原封不动、一字不差地复制,禁止修改、禁止省略任何字符**"
|
||||||
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
"**规则2:禁止修改URL中UUID里的任何数字和字母**"
|
||||||
|
"**规则3:直接使用  格式输出**"
|
||||||
|
)
|
||||||
|
agent.agent = create_agent(
|
||||||
|
model=agent.llm,
|
||||||
|
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
||||||
|
system_prompt=agent.system_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
|
||||||
agent = LangChainAgent(
|
|
||||||
model_name=api_key_config["model_name"],
|
|
||||||
api_key=api_key_config["api_key"],
|
|
||||||
provider=api_key_config.get("provider", "openai"),
|
|
||||||
api_base=api_key_config.get("api_base"),
|
|
||||||
is_omni=api_key_config.get("is_omni", False),
|
|
||||||
temperature=effective_params.get("temperature", 0.7),
|
|
||||||
max_tokens=effective_params.get("max_tokens", 2000),
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
streaming=True,
|
|
||||||
deep_thinking=effective_params.get("deep_thinking", False),
|
|
||||||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
|
||||||
json_output=effective_params.get("json_output", False),
|
|
||||||
capability=api_key_config.get("capability", []),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 为需要运行时上下文的工具注入上下文
|
# 为需要运行时上下文的工具注入上下文
|
||||||
for t in tools:
|
for t in tools:
|
||||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.celery_task_scheduler import scheduler
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
@@ -167,31 +166,20 @@ class MemoryAPIService:
|
|||||||
# Convert to message list format expected by write_message_task
|
# Convert to message list format expected by write_message_task
|
||||||
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
||||||
|
|
||||||
# from app.tasks import write_message_task
|
from app.tasks import write_message_task
|
||||||
# task = write_message_task.delay(
|
task = write_message_task.delay(
|
||||||
# end_user_id,
|
|
||||||
# messages,
|
|
||||||
# config_id,
|
|
||||||
# storage_type,
|
|
||||||
# user_rag_memory_id or "",
|
|
||||||
# )
|
|
||||||
task_id = scheduler.push_task(
|
|
||||||
"app.core.memory.agent.write_message",
|
|
||||||
end_user_id,
|
end_user_id,
|
||||||
{
|
messages,
|
||||||
"end_user_id": end_user_id,
|
config_id,
|
||||||
"message": messages,
|
storage_type,
|
||||||
"config_id": config_id,
|
user_rag_memory_id or "",
|
||||||
"storage_type": storage_type,
|
|
||||||
"user_rag_memory_id": user_rag_memory_id or ""
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write task submitted, task_id={task_id} end_user_id={end_user_id}")
|
logger.info(f"Memory write task submitted: task_id={task.id}, end_user_id={end_user_id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"task_id": task_id,
|
"task_id": task.id,
|
||||||
"status": "QUEUED",
|
"status": "PENDING",
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
处理显性记忆相关的业务逻辑,包括情景记忆和语义记忆的查询。
|
处理显性记忆相关的业务逻辑,包括情景记忆和语义记忆的查询。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
from app.services.memory_base_service import MemoryBaseService
|
from app.services.memory_base_service import MemoryBaseService
|
||||||
@@ -104,7 +104,7 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
e.description AS core_definition
|
e.description AS core_definition
|
||||||
ORDER BY e.name ASC
|
ORDER BY e.name ASC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
semantic_result = await self.neo4j_connector.execute_query(
|
semantic_result = await self.neo4j_connector.execute_query(
|
||||||
semantic_query,
|
semantic_query,
|
||||||
end_user_id=end_user_id
|
end_user_id=end_user_id
|
||||||
@@ -146,209 +146,6 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
logger.error(f"获取显性记忆总览时出错: {str(e)}", exc_info=True)
|
logger.error(f"获取显性记忆总览时出错: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def get_episodic_memory_list(
|
|
||||||
self,
|
|
||||||
end_user_id: str,
|
|
||||||
page: int,
|
|
||||||
pagesize: int,
|
|
||||||
start_date: Optional[int] = None,
|
|
||||||
end_date: Optional[int] = None,
|
|
||||||
episodic_type: str = "all",
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
获取情景记忆分页列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
page: 页码
|
|
||||||
pagesize: 每页数量
|
|
||||||
start_date: 开始时间戳(毫秒),可选
|
|
||||||
end_date: 结束时间戳(毫秒),可选
|
|
||||||
episodic_type: 情景类型筛选
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
{
|
|
||||||
"total": int, # 该用户情景记忆总数(不受筛选影响)
|
|
||||||
"items": [...], # 当前页数据
|
|
||||||
"page": {
|
|
||||||
"page": int,
|
|
||||||
"pagesize": int,
|
|
||||||
"total": int, # 筛选后总数
|
|
||||||
"hasnext": bool
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f"情景记忆分页查询: end_user_id={end_user_id}, "
|
|
||||||
f"start_date={start_date}, end_date={end_date}, "
|
|
||||||
f"episodic_type={episodic_type}, page={page}, pagesize={pagesize}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. 查询情景记忆总数(不受筛选条件限制)
|
|
||||||
total_all_query = """
|
|
||||||
MATCH (s:MemorySummary)
|
|
||||||
WHERE s.end_user_id = $end_user_id
|
|
||||||
RETURN count(s) AS total
|
|
||||||
"""
|
|
||||||
total_all_result = await self.neo4j_connector.execute_query(
|
|
||||||
total_all_query, end_user_id=end_user_id
|
|
||||||
)
|
|
||||||
total_all = total_all_result[0]["total"] if total_all_result else 0
|
|
||||||
|
|
||||||
# 2. 构建筛选条件
|
|
||||||
where_clauses = ["s.end_user_id = $end_user_id"]
|
|
||||||
params = {"end_user_id": end_user_id}
|
|
||||||
|
|
||||||
# 时间戳筛选(毫秒时间戳转为 UTC ISO 字符串,使用 Neo4j datetime() 精确比较)
|
|
||||||
if start_date is not None and end_date is not None:
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
start_dt = datetime.fromtimestamp(start_date / 1000, tz=timezone.utc)
|
|
||||||
end_dt = datetime.fromtimestamp(end_date / 1000, tz=timezone.utc)
|
|
||||||
# 开始时间取当天 UTC 00:00:00,结束时间取当天 UTC 23:59:59.999999
|
|
||||||
start_iso = start_dt.strftime("%Y-%m-%dT") + "00:00:00.000000"
|
|
||||||
end_iso = end_dt.strftime("%Y-%m-%dT") + "23:59:59.999999"
|
|
||||||
|
|
||||||
where_clauses.append("datetime(s.created_at) >= datetime($start_iso) AND datetime(s.created_at) <= datetime($end_iso)")
|
|
||||||
params["start_iso"] = start_iso
|
|
||||||
params["end_iso"] = end_iso
|
|
||||||
|
|
||||||
# 类型筛选下推到 Cypher(兼容中英文)
|
|
||||||
if episodic_type != "all":
|
|
||||||
type_mapping = {
|
|
||||||
"conversation": "对话",
|
|
||||||
"project_work": "项目/工作",
|
|
||||||
"learning": "学习",
|
|
||||||
"decision": "决策",
|
|
||||||
"important_event": "重要事件"
|
|
||||||
}
|
|
||||||
chinese_type = type_mapping.get(episodic_type)
|
|
||||||
if chinese_type:
|
|
||||||
where_clauses.append(
|
|
||||||
"(s.memory_type = $episodic_type OR s.memory_type = $chinese_type)"
|
|
||||||
)
|
|
||||||
params["episodic_type"] = episodic_type
|
|
||||||
params["chinese_type"] = chinese_type
|
|
||||||
else:
|
|
||||||
where_clauses.append("s.memory_type = $episodic_type")
|
|
||||||
params["episodic_type"] = episodic_type
|
|
||||||
|
|
||||||
where_str = " AND ".join(where_clauses)
|
|
||||||
|
|
||||||
# 3. 查询筛选后的总数
|
|
||||||
count_query = f"""
|
|
||||||
MATCH (s:MemorySummary)
|
|
||||||
WHERE {where_str}
|
|
||||||
RETURN count(s) AS total
|
|
||||||
"""
|
|
||||||
count_result = await self.neo4j_connector.execute_query(count_query, **params)
|
|
||||||
filtered_total = count_result[0]["total"] if count_result else 0
|
|
||||||
|
|
||||||
# 4. 查询分页数据
|
|
||||||
skip = (page - 1) * pagesize
|
|
||||||
data_query = f"""
|
|
||||||
MATCH (s:MemorySummary)
|
|
||||||
WHERE {where_str}
|
|
||||||
RETURN elementId(s) AS id,
|
|
||||||
s.name AS title,
|
|
||||||
s.memory_type AS memory_type,
|
|
||||||
s.content AS content,
|
|
||||||
s.created_at AS created_at
|
|
||||||
ORDER BY s.created_at DESC
|
|
||||||
SKIP $skip LIMIT $limit
|
|
||||||
"""
|
|
||||||
params["skip"] = skip
|
|
||||||
params["limit"] = pagesize
|
|
||||||
|
|
||||||
result = await self.neo4j_connector.execute_query(data_query, **params)
|
|
||||||
|
|
||||||
# 5. 处理结果
|
|
||||||
items = []
|
|
||||||
if result:
|
|
||||||
for record in result:
|
|
||||||
raw_created_at = record.get("created_at")
|
|
||||||
created_at_timestamp = self.parse_timestamp(raw_created_at)
|
|
||||||
items.append({
|
|
||||||
"id": record["id"],
|
|
||||||
"title": record.get("title") or "未命名",
|
|
||||||
"memory_type": record.get("memory_type") or "其他",
|
|
||||||
"content": record.get("content") or "",
|
|
||||||
"created_at": created_at_timestamp
|
|
||||||
})
|
|
||||||
|
|
||||||
# 6. 构建返回结果
|
|
||||||
return {
|
|
||||||
"total": total_all,
|
|
||||||
"items": items,
|
|
||||||
"page": {
|
|
||||||
"page": page,
|
|
||||||
"pagesize": pagesize,
|
|
||||||
"total": filtered_total,
|
|
||||||
"hasnext": (page * pagesize) < filtered_total
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"情景记忆分页查询出错: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_semantic_memory_list(
|
|
||||||
self,
|
|
||||||
end_user_id: str
|
|
||||||
) -> list:
|
|
||||||
"""
|
|
||||||
获取语义记忆全量列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": str,
|
|
||||||
"name": str,
|
|
||||||
"entity_type": str,
|
|
||||||
"core_definition": str
|
|
||||||
}
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"语义记忆列表查询: end_user_id={end_user_id}")
|
|
||||||
|
|
||||||
semantic_query = """
|
|
||||||
MATCH (e:ExtractedEntity)
|
|
||||||
WHERE e.end_user_id = $end_user_id
|
|
||||||
AND e.is_explicit_memory = true
|
|
||||||
RETURN elementId(e) AS id,
|
|
||||||
e.name AS name,
|
|
||||||
e.entity_type AS entity_type,
|
|
||||||
e.description AS core_definition
|
|
||||||
ORDER BY e.name ASC
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = await self.neo4j_connector.execute_query(
|
|
||||||
semantic_query, end_user_id=end_user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
items = []
|
|
||||||
if result:
|
|
||||||
for record in result:
|
|
||||||
items.append({
|
|
||||||
"id": record["id"],
|
|
||||||
"name": record.get("name") or "未命名",
|
|
||||||
"entity_type": record.get("entity_type") or "未分类",
|
|
||||||
"core_definition": record.get("core_definition") or ""
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(f"语义记忆列表查询成功: end_user_id={end_user_id}, total={len(items)}")
|
|
||||||
|
|
||||||
return items
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"语义记忆列表查询出错: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_explicit_memory_details(
|
async def get_explicit_memory_details(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
|||||||
"""通义千问文档格式"""
|
"""通义千问文档格式"""
|
||||||
return True, {
|
return True, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<document name=\"{file_name}\">\n文档内容:\n{text}\n</document>"
|
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def format_audio(
|
async def format_audio(
|
||||||
@@ -167,7 +167,6 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
|||||||
async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]:
|
async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
||||||
# Bedrock 文档需要 base64 编码
|
# Bedrock 文档需要 base64 编码
|
||||||
text = f"文档内容:\n{text}\n"
|
|
||||||
text_bytes = text.encode('utf-8')
|
text_bytes = text.encode('utf-8')
|
||||||
base64_text = base64.b64encode(text_bytes).decode('utf-8')
|
base64_text = base64.b64encode(text_bytes).decode('utf-8')
|
||||||
|
|
||||||
@@ -224,7 +223,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
|||||||
"""OpenAI 文档格式"""
|
"""OpenAI 文档格式"""
|
||||||
return True, {
|
return True, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<document name=\"{file_name}\">\n文档内容:\n{text}\n</document>"
|
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def format_audio(
|
async def format_audio(
|
||||||
@@ -389,18 +388,17 @@ class MultimodalService:
|
|||||||
from app.models.workspace_model import Workspace as WorkspaceModel
|
from app.models.workspace_model import Workspace as WorkspaceModel
|
||||||
ws = self.db.query(WorkspaceModel).filter(WorkspaceModel.id == workspace_id).first()
|
ws = self.db.query(WorkspaceModel).filter(WorkspaceModel.id == workspace_id).first()
|
||||||
tenant_id = ws.tenant_id if ws else None
|
tenant_id = ws.tenant_id if ws else None
|
||||||
img_result = []
|
|
||||||
for img_info in img_infos:
|
for img_info in img_infos:
|
||||||
page = img_info["page"]
|
page = img_info["page"]
|
||||||
index = img_info["index"]
|
index = img_info["index"]
|
||||||
ext = img_info.get("ext", "png")
|
ext = img_info.get("ext", "png")
|
||||||
try:
|
try:
|
||||||
_, img_url = await self._save_doc_image_to_storage(img_info["bytes"], ext, tenant_id, workspace_id)
|
_, img_url = await self._save_doc_image_to_storage(img_info["bytes"], ext, tenant_id, workspace_id)
|
||||||
placeholder = f"第{page}页 第{index + 1}张" if page > 0 else f"第{index + 1}张"
|
placeholder = f"第{page}页 第{index + 1}张图片" if page > 0 else f"第{index + 1}张图片"
|
||||||
# 在文本内容中追加图片位置标记
|
# 在文本内容中追加图片位置标记
|
||||||
if result and result[-1].get("type") in ("text", "document"):
|
if result and result[-1].get("type") in ("text", "document"):
|
||||||
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
|
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
|
||||||
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: <img src=\"{img_url}\" data-url=\"{img_url}\">"
|
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: {img_url}"
|
||||||
# 将图片以视觉格式追加到消息内容中
|
# 将图片以视觉格式追加到消息内容中
|
||||||
img_file = FileInput(
|
img_file = FileInput(
|
||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
@@ -409,10 +407,9 @@ class MultimodalService:
|
|||||||
file_type="image/png",
|
file_type="image/png",
|
||||||
)
|
)
|
||||||
_, img_content = await self._process_image(img_file, strategy_class(img_file))
|
_, img_content = await self._process_image(img_file, strategy_class(img_file))
|
||||||
img_result.append(img_content)
|
result.append(img_content)
|
||||||
except Exception as img_err:
|
except Exception as img_err:
|
||||||
logger.warning(f"文档图片处理失败: {img_err}")
|
logger.warning(f"文档图片处理失败: {img_err}")
|
||||||
result.extend(img_result)
|
|
||||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||||
is_support, content = await self._process_audio(file, strategy)
|
is_support, content = await self._process_audio(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
|
|||||||
@@ -815,12 +815,11 @@ class ToolService:
|
|||||||
"default": param_info.get("default")
|
"default": param_info.get("default")
|
||||||
})
|
})
|
||||||
|
|
||||||
# 请求体参数 — _extract_request_body 返回 {"schema": {...}, "required": bool, ...}
|
# 请求体参数
|
||||||
request_body = operation.get("request_body")
|
request_body = operation.get("request_body")
|
||||||
if request_body:
|
if request_body:
|
||||||
body_schema = request_body.get("schema", {})
|
schema_props = request_body.get("schema", {}).get("properties", {})
|
||||||
schema_props = body_schema.get("properties", {})
|
required_props = request_body.get("schema", {}).get("required", [])
|
||||||
required_props = body_schema.get("required", [])
|
|
||||||
|
|
||||||
for prop_name, prop_schema in schema_props.items():
|
for prop_name, prop_schema in schema_props.items():
|
||||||
parameters.append({
|
parameters.append({
|
||||||
|
|||||||
@@ -17,9 +17,8 @@ from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
|||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.core.workflow.validator import validate_workflow_config
|
from app.core.workflow.validator import validate_workflow_config
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from sqlalchemy import select
|
|
||||||
from app.models import App
|
from app.models import App
|
||||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
|
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.repositories.workflow_repository import (
|
from app.repositories.workflow_repository import (
|
||||||
WorkflowConfigRepository,
|
WorkflowConfigRepository,
|
||||||
@@ -554,16 +553,13 @@ class WorkflowService:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "workflow_end":
|
case "workflow_end":
|
||||||
data = {
|
|
||||||
"elapsed_time": payload.get("elapsed_time"),
|
|
||||||
"message_length": len(payload.get("output", "")),
|
|
||||||
"error": payload.get("error", "")
|
|
||||||
}
|
|
||||||
if "citations" in payload and payload["citations"]:
|
|
||||||
data["citations"] = payload["citations"]
|
|
||||||
return {
|
return {
|
||||||
"event": "end",
|
"event": "end",
|
||||||
"data": data
|
"data": {
|
||||||
|
"elapsed_time": payload.get("elapsed_time"),
|
||||||
|
"message_length": len(payload.get("output", "")),
|
||||||
|
"error": payload.get("error", "")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||||
return None
|
return None
|
||||||
@@ -922,7 +918,6 @@ class WorkflowService:
|
|||||||
input_data["conv_messages"] = conv_messages
|
input_data["conv_messages"] = conv_messages
|
||||||
init_message_length = len(input_data.get("conv_messages", []))
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
message_id = uuid.uuid4()
|
message_id = uuid.uuid4()
|
||||||
_cycle_items: dict[str, list] = {}
|
|
||||||
|
|
||||||
# 新会话时写入开场白
|
# 新会话时写入开场白
|
||||||
is_new_conversation = init_message_length == 0
|
is_new_conversation = init_message_length == 0
|
||||||
@@ -953,15 +948,6 @@ class WorkflowService:
|
|||||||
memory_storage_type=storage_type,
|
memory_storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
):
|
):
|
||||||
event_type = event.get("event")
|
|
||||||
event_data = event.get("data", {})
|
|
||||||
|
|
||||||
if event_type == "cycle_item":
|
|
||||||
cycle_id = event_data.get("cycle_id")
|
|
||||||
if cycle_id not in _cycle_items:
|
|
||||||
_cycle_items[cycle_id] = []
|
|
||||||
_cycle_items[cycle_id].append(event_data)
|
|
||||||
|
|
||||||
if event.get("event") == "workflow_end":
|
if event.get("event") == "workflow_end":
|
||||||
status = event.get("data", {}).get("status")
|
status = event.get("data", {}).get("status")
|
||||||
token_usage = event.get("data", {}).get("token_usage", {}) or {}
|
token_usage = event.get("data", {}).get("token_usage", {}) or {}
|
||||||
@@ -1033,18 +1019,6 @@ class WorkflowService:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f"unexpect workflow run status, status: {status}")
|
logger.error(f"unexpect workflow run status, status: {status}")
|
||||||
# 把积累的 cycle_item 写入 workflow_executions.output_data["node_outputs"]
|
|
||||||
if _cycle_items and execution.output_data:
|
|
||||||
import copy
|
|
||||||
new_output_data = copy.deepcopy(execution.output_data)
|
|
||||||
node_outputs = new_output_data.setdefault("node_outputs", {})
|
|
||||||
for cycle_node_id, items in _cycle_items.items():
|
|
||||||
if cycle_node_id in node_outputs:
|
|
||||||
node_outputs[cycle_node_id]["cycle_items"] = items
|
|
||||||
else:
|
|
||||||
node_outputs[cycle_node_id] = {"cycle_items": items}
|
|
||||||
execution.output_data = new_output_data
|
|
||||||
self.db.commit()
|
|
||||||
elif event.get("event") == "workflow_start":
|
elif event.get("event") == "workflow_start":
|
||||||
event["data"]["message_id"] = str(message_id)
|
event["data"]["message_id"] = str(message_id)
|
||||||
event = self._emit(public, event)
|
event = self._emit(public, event)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from app.models.workspace_model import (
|
|||||||
)
|
)
|
||||||
from app.repositories import workspace_repository
|
from app.repositories import workspace_repository
|
||||||
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
||||||
from app.services.session_service import SessionService
|
|
||||||
from app.schemas.workspace_schema import (
|
from app.schemas.workspace_schema import (
|
||||||
InviteAcceptRequest,
|
InviteAcceptRequest,
|
||||||
InviteValidateResponse,
|
InviteValidateResponse,
|
||||||
@@ -59,7 +58,7 @@ def switch_workspace(
|
|||||||
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||||
|
|
||||||
|
|
||||||
async def delete_workspace_member(
|
def delete_workspace_member(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
member_id: uuid.UUID,
|
member_id: uuid.UUID,
|
||||||
@@ -77,29 +76,10 @@ async def delete_workspace_member(
|
|||||||
BizCode.WORKSPACE_NOT_FOUND)
|
BizCode.WORKSPACE_NOT_FOUND)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
deleted_user = workspace_member.user
|
|
||||||
workspace_member.is_active = False
|
workspace_member.is_active = False
|
||||||
deleted_user.current_workspace_id = None
|
workspace_member.user.current_workspace_id = None
|
||||||
|
|
||||||
# 若被删除成员不是超级管理员且没有其他可用工作空间,则禁用该用户
|
|
||||||
if not deleted_user.is_superuser:
|
|
||||||
remaining = (
|
|
||||||
db.query(WorkspaceMember)
|
|
||||||
.filter(
|
|
||||||
WorkspaceMember.user_id == deleted_user.id,
|
|
||||||
WorkspaceMember.workspace_id != workspace_id,
|
|
||||||
WorkspaceMember.is_active.is_(True),
|
|
||||||
)
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
if remaining == 0:
|
|
||||||
deleted_user.is_active = False
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
|
|
||||||
# 使被删除成员的所有 token 立即失效
|
|
||||||
await SessionService.invalidate_all_user_tokens(str(workspace_member.user_id))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from app.core.rag.prompts.generator import question_proposal
|
|||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||||
ElasticSearchVectorFactory,
|
ElasticSearchVectorFactory,
|
||||||
)
|
)
|
||||||
from app.db import get_db_context
|
from app.db import get_db, get_db_context
|
||||||
from app.models import Document, File, Knowledge
|
from app.models import Document, File, Knowledge
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.schemas import document_schema, file_schema
|
from app.schemas import document_schema, file_schema
|
||||||
@@ -2025,7 +2025,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
end_users = db.query(EndUser).all()
|
end_users = db.query(EndUser).all()
|
||||||
if not end_users:
|
if not end_users:
|
||||||
logger.info("没有终端用户,跳过遗忘周期")
|
logger.info("没有终端用户,跳过遗忘周期")
|
||||||
return {"status": "SUCCESS", "message": "没有终端用户",
|
return {"status": "SUCCESS", "message": "没有终端用户",
|
||||||
"report": {"merged_count": 0, "failed_count": 0, "processed_users": 0},
|
"report": {"merged_count": 0, "failed_count": 0, "processed_users": 0},
|
||||||
"duration_seconds": time.time() - start_time}
|
"duration_seconds": time.time() - start_time}
|
||||||
|
|
||||||
@@ -2039,7 +2039,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
# 获取用户配置(自动回退到工作空间默认配置)
|
# 获取用户配置(自动回退到工作空间默认配置)
|
||||||
connected_config = get_end_user_connected_config(str(end_user.id), db)
|
connected_config = get_end_user_connected_config(str(end_user.id), db)
|
||||||
user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db)
|
user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db)
|
||||||
|
|
||||||
if not user_config_id:
|
if not user_config_id:
|
||||||
failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"})
|
failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"})
|
||||||
continue
|
continue
|
||||||
@@ -2048,13 +2048,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
report = await forget_service.trigger_forgetting_cycle(
|
report = await forget_service.trigger_forgetting_cycle(
|
||||||
db=db, end_user_id=str(end_user.id), config_id=user_config_id
|
db=db, end_user_id=str(end_user.id), config_id=user_config_id
|
||||||
)
|
)
|
||||||
|
|
||||||
total_merged += report.get('merged_count', 0)
|
total_merged += report.get('merged_count', 0)
|
||||||
total_failed += report.get('failed_count', 0)
|
total_failed += report.get('failed_count', 0)
|
||||||
processed_users += 1
|
processed_users += 1
|
||||||
|
|
||||||
logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点")
|
logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True)
|
logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True)
|
||||||
failed_users.append({"end_user_id": str(end_user.id), "error": str(e)})
|
failed_users.append({"end_user_id": str(end_user.id), "error": str(e)})
|
||||||
@@ -2801,18 +2801,18 @@ def run_incremental_clustering(
|
|||||||
包含任务执行结果的字典
|
包含任务执行结果的字典
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
async def _run() -> Dict[str, Any]:
|
async def _run() -> Dict[str, Any]:
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, "
|
f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, "
|
||||||
f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}"
|
f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
engine = LabelPropagationEngine(
|
engine = LabelPropagationEngine(
|
||||||
@@ -2820,12 +2820,12 @@ def run_incremental_clustering(
|
|||||||
llm_model_id=llm_model_id,
|
llm_model_id=llm_model_id,
|
||||||
embedding_model_id=embedding_model_id,
|
embedding_model_id=embedding_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 执行增量聚类
|
# 执行增量聚类
|
||||||
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
||||||
|
|
||||||
logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}")
|
logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
@@ -2836,18 +2836,18 @@ def run_incremental_clustering(
|
|||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop = set_asyncio_event_loop()
|
loop = set_asyncio_event_loop()
|
||||||
result = loop.run_until_complete(_run())
|
result = loop.run_until_complete(_run())
|
||||||
result["elapsed_time"] = time.time() - start_time
|
result["elapsed_time"] = time.time() - start_time
|
||||||
result["task_id"] = self.request.id
|
result["task_id"] = self.request.id
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, "
|
f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, "
|
||||||
f"elapsed_time={result['elapsed_time']:.2f}s"
|
f"elapsed_time={result['elapsed_time']:.2f}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|||||||
@@ -63,23 +63,6 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- celery
|
- celery
|
||||||
|
|
||||||
celery-task-scheduler:
|
|
||||||
image: redbear-mem-open:latest
|
|
||||||
container_name: celery-task-scheduler
|
|
||||||
env_file:
|
|
||||||
- .env
|
|
||||||
volumes:
|
|
||||||
- /etc/localtime:/etc/localtime:ro
|
|
||||||
command: python -m app.celery_task_scheduler
|
|
||||||
restart: unless-stopped
|
|
||||||
healthcheck:
|
|
||||||
test: CMD curl -f 127.0.0.1:8001 || exit 1
|
|
||||||
interval: 30s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 3
|
|
||||||
networks:
|
|
||||||
- celery
|
|
||||||
|
|
||||||
# Celery Beat - scheduler
|
# Celery Beat - scheduler
|
||||||
beat:
|
beat:
|
||||||
image: redbear-mem-open:latest
|
image: redbear-mem-open:latest
|
||||||
|
|||||||
@@ -62,6 +62,7 @@
|
|||||||
"remark-gfm": "^4.0.1",
|
"remark-gfm": "^4.0.1",
|
||||||
"remark-math": "^6.0.0",
|
"remark-math": "^6.0.0",
|
||||||
"tailwindcss": "^4.1.14",
|
"tailwindcss": "^4.1.14",
|
||||||
|
"x6-html-shape": "0.4.9",
|
||||||
"xlsx": "^0.18.5",
|
"xlsx": "^0.18.5",
|
||||||
"zustand": "^5.0.8"
|
"zustand": "^5.0.8"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ import { type FC, useRef, useEffect, useState } from 'react'
|
|||||||
import clsx from 'clsx'
|
import clsx from 'clsx'
|
||||||
import Markdown from '@/components/Markdown'
|
import Markdown from '@/components/Markdown'
|
||||||
import type { ChatContentProps } from './types'
|
import type { ChatContentProps } from './types'
|
||||||
import { Spin, Flex, Button } from 'antd'
|
import { Spin, Image, Flex, Button } from 'antd'
|
||||||
import { SoundOutlined } from '@ant-design/icons'
|
import { SoundOutlined } from '@ant-design/icons'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
import MessageFiles from './MessageFiles'
|
import AudioPlayer from './AudioPlayer'
|
||||||
|
import VideoPlayer from './VideoPlayer'
|
||||||
|
|
||||||
const getFileUrl = (file: any) => {
|
const getFileUrl = (file: any) => {
|
||||||
return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
||||||
@@ -148,7 +149,72 @@ const ChatContent: FC<ChatContentProps> = ({
|
|||||||
{labelFormat(item)}
|
{labelFormat(item)}
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
<MessageFiles files={item.meta_data?.files ?? []} contentClassNames={contentClassNames} onDownload={handleDownload} />
|
{item?.meta_data?.files && item.meta_data?.files.length > 0 && <Flex gap={8} vertical align="end" className="rb:mb-2!">
|
||||||
|
{item.meta_data?.files?.map((file) => {
|
||||||
|
if (file.type.includes('image')) {
|
||||||
|
return (
|
||||||
|
<div key={file.url || file.uid} className={`rb:inline-block rb:group rb:relative rb:rounded-lg ${contentClassNames}`}>
|
||||||
|
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (file.type.includes('video')) {
|
||||||
|
return (
|
||||||
|
<div key={file.url || file.uid} className="rb:w-50">
|
||||||
|
{/* <video src={getFileUrl(file)} controls className="rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" /> */}
|
||||||
|
<VideoPlayer key={file.url || file.uid} src={getFileUrl(file)} />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (file.type.includes('audio')) {
|
||||||
|
return (
|
||||||
|
<div key={file.url || file.uid} className="rb:w-50">
|
||||||
|
<AudioPlayer key={file.url || file.uid} src={getFileUrl(file)} />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const documentType = (file.file_type || file.type)?.split('/')
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
key={file.url || file.uid}
|
||||||
|
align="center"
|
||||||
|
gap={10}
|
||||||
|
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
|
||||||
|
onClick={() => handleDownload(file)}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className={clsx(
|
||||||
|
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
|
||||||
|
file.type?.includes('pdf')
|
||||||
|
? "rb:bg-[url('@/assets/images/file/pdf.svg')]"
|
||||||
|
: (file.type?.includes('excel') || file.type?.includes('spreadsheetml.sheet')) || file.type?.includes('xls') || file.type?.includes('xlsx')
|
||||||
|
? "rb:bg-[url('@/assets/images/file/excel.svg')]"
|
||||||
|
: file.type?.includes('csv')
|
||||||
|
? "rb:bg-[url('@/assets/images/file/csv.svg')]"
|
||||||
|
: file.type?.includes('html')
|
||||||
|
? "rb:bg-[url('@/assets/images/file/html.svg')]"
|
||||||
|
: file.type?.includes('json')
|
||||||
|
? "rb:bg-[url('@/assets/images/file/json.svg')]"
|
||||||
|
: file.type?.includes('ppt')
|
||||||
|
? "rb:bg-[url('@/assets/images/file/ppt.svg')]"
|
||||||
|
: file.type?.includes('markdown')
|
||||||
|
? "rb:bg-[url('@/assets/images/file/md.svg')]"
|
||||||
|
: file.type?.includes('text')
|
||||||
|
? "rb:bg-[url('@/assets/images/file/txt.svg')]"
|
||||||
|
: (file.type?.includes('doc') || file.type?.includes('docx') || file.type?.includes('word') || file.type?.includes('wordprocessingml.document'))
|
||||||
|
? "rb:bg-[url('@/assets/images/file/word.svg')]"
|
||||||
|
: "rb:bg-[url('@/assets/images/file/txt.svg')]"
|
||||||
|
)}
|
||||||
|
></div>
|
||||||
|
<div className="rb:flex-1 rb:w-32.5">
|
||||||
|
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
|
||||||
|
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{documentType?.[documentType.length - 1]} · {file.size}</div>
|
||||||
|
</div>
|
||||||
|
</Flex>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</Flex>}
|
||||||
{/* Message bubble */}
|
{/* Message bubble */}
|
||||||
<div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', {
|
<div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', {
|
||||||
// Error message style (content is null and not assistant message)
|
// Error message style (content is null and not assistant message)
|
||||||
|
|||||||
@@ -1,87 +0,0 @@
|
|||||||
import { Image, Flex } from 'antd'
|
|
||||||
import clsx from 'clsx'
|
|
||||||
import AudioPlayer from './AudioPlayer'
|
|
||||||
import VideoPlayer from './VideoPlayer'
|
|
||||||
|
|
||||||
const getFileUrl = (file: any) =>
|
|
||||||
file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
|
||||||
|
|
||||||
const DOC_ICONS: [string[], string][] = [
|
|
||||||
[['pdf'], "rb:bg-[url('@/assets/images/file/pdf.svg')]"],
|
|
||||||
[['excel', 'spreadsheetml.sheet', 'xls', 'xlsx'], "rb:bg-[url('@/assets/images/file/excel.svg')]"],
|
|
||||||
[['csv'], "rb:bg-[url('@/assets/images/file/csv.svg')]"],
|
|
||||||
[['html'], "rb:bg-[url('@/assets/images/file/html.svg')]"],
|
|
||||||
[['json'], "rb:bg-[url('@/assets/images/file/json.svg')]"],
|
|
||||||
[['ppt'], "rb:bg-[url('@/assets/images/file/ppt.svg')]"],
|
|
||||||
[['markdown'], "rb:bg-[url('@/assets/images/file/md.svg')]"],
|
|
||||||
[['text'], "rb:bg-[url('@/assets/images/file/txt.svg')]"],
|
|
||||||
[['doc', 'docx', 'word', 'wordprocessingml.document'], "rb:bg-[url('@/assets/images/file/word.svg')]"],
|
|
||||||
]
|
|
||||||
|
|
||||||
const getDocIcon = (parts: string[]) => {
|
|
||||||
const match = DOC_ICONS.find(([keys]) => keys.some(k => parts.includes(k)))
|
|
||||||
return match ? match[1] : "rb:bg-[url('@/assets/images/file/txt.svg')]"
|
|
||||||
}
|
|
||||||
|
|
||||||
interface MessageFilesProps {
|
|
||||||
files: any[]
|
|
||||||
contentClassNames?: string | Record<string, boolean>
|
|
||||||
onDownload: (file: any) => void
|
|
||||||
}
|
|
||||||
|
|
||||||
const MessageFiles = ({ files, contentClassNames, onDownload }: MessageFilesProps) => {
|
|
||||||
if (!files?.length) return null
|
|
||||||
return (
|
|
||||||
<Flex gap={8} vertical align="end" className="rb:mb-2!">
|
|
||||||
{files.map((file) => {
|
|
||||||
const key = file.url || file.uid
|
|
||||||
if (file.type.includes('image')) {
|
|
||||||
return (
|
|
||||||
<div key={key} className={clsx('rb:inline-block rb:group rb:relative rb:rounded-lg', contentClassNames)}>
|
|
||||||
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if (file.type.includes('video')) {
|
|
||||||
return (
|
|
||||||
<div key={key} className="rb:w-50">
|
|
||||||
<VideoPlayer src={getFileUrl(file)} />
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if (file.type.includes('audio')) {
|
|
||||||
return (
|
|
||||||
<div key={key} className="rb:w-50">
|
|
||||||
<AudioPlayer src={getFileUrl(file)} />
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
const documentType = (file.file_type || file.type)?.split('/') ?? []
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
key={key}
|
|
||||||
align="center"
|
|
||||||
gap={10}
|
|
||||||
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
|
|
||||||
onClick={() => onDownload(file)}
|
|
||||||
>
|
|
||||||
<div
|
|
||||||
className={clsx(
|
|
||||||
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
|
|
||||||
getDocIcon(documentType)
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
<div className="rb:flex-1 rb:w-32.5">
|
|
||||||
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
|
|
||||||
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">
|
|
||||||
{documentType?.[documentType.length - 1]} · {file.size}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</Flex>
|
|
||||||
)
|
|
||||||
})}
|
|
||||||
</Flex>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
export default MessageFiles
|
|
||||||
@@ -3,14 +3,14 @@ import { Popover, type PopoverProps } from 'antd'
|
|||||||
import Tag, { type TagProps } from '@/components/Tag'
|
import Tag, { type TagProps } from '@/components/Tag'
|
||||||
|
|
||||||
interface OverflowTagsProps {
|
interface OverflowTagsProps {
|
||||||
items?: ReactNode[];
|
items: ReactNode[];
|
||||||
gap?: number;
|
gap?: number;
|
||||||
numTagColor?: TagProps['color'];
|
numTagColor?: TagProps['color'];
|
||||||
numTag?: (num?: number) => ReactNode;
|
numTag?: (num?: number) => ReactNode;
|
||||||
popoverProps?: PopoverProps | false;
|
popoverProps?: PopoverProps | false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const OverflowTags = ({ items = [], gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => {
|
const OverflowTags = ({ items, gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => {
|
||||||
const containerRef = useRef<HTMLDivElement>(null)
|
const containerRef = useRef<HTMLDivElement>(null)
|
||||||
const measureRef = useRef<HTMLDivElement>(null)
|
const measureRef = useRef<HTMLDivElement>(null)
|
||||||
const [visibleCount, setVisibleCount] = useState(items.length)
|
const [visibleCount, setVisibleCount] = useState(items.length)
|
||||||
@@ -20,7 +20,7 @@ const OverflowTags = ({ items = [], gap = 8, numTagColor = 'default', numTag, po
|
|||||||
if (!measure || containerWidth === 0) return
|
if (!measure || containerWidth === 0) return
|
||||||
|
|
||||||
const children = Array.from(measure.children) as HTMLElement[]
|
const children = Array.from(measure.children) as HTMLElement[]
|
||||||
if (!children.length) { setVisibleCount(0); return }
|
if (!children.length) return
|
||||||
|
|
||||||
// last child is the sample +N tag
|
// last child is the sample +N tag
|
||||||
const extraTagWidth = (children[children.length - 1] as HTMLElement).offsetWidth
|
const extraTagWidth = (children[children.length - 1] as HTMLElement).offsetWidth
|
||||||
|
|||||||
@@ -399,7 +399,7 @@ const Menu: FC<{
|
|||||||
className="rb:overflow-y-auto rb:flex-1!"
|
className="rb:overflow-y-auto rb:flex-1!"
|
||||||
/>
|
/>
|
||||||
{/* Return to space button for superusers */}
|
{/* Return to space button for superusers */}
|
||||||
{source === 'space' &&
|
{user?.is_superuser && source === 'space' &&
|
||||||
<Flex gap={4} vertical className="rb:my-3! rb:mx-3!">
|
<Flex gap={4} vertical className="rb:my-3! rb:mx-3!">
|
||||||
<Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" />
|
<Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" />
|
||||||
<Flex
|
<Flex
|
||||||
@@ -412,18 +412,16 @@ const Menu: FC<{
|
|||||||
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div>
|
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div>
|
||||||
{collapsed ? null : t('common.switchSpace')}
|
{collapsed ? null : t('common.switchSpace')}
|
||||||
</Flex>
|
</Flex>
|
||||||
{user?.is_superuser &&
|
<Flex
|
||||||
<Flex
|
gap={8}
|
||||||
gap={8}
|
align="center"
|
||||||
align="center"
|
justify="start"
|
||||||
justify="start"
|
onClick={goToSpace}
|
||||||
onClick={goToSpace}
|
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
|
||||||
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
|
>
|
||||||
>
|
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
|
||||||
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
|
{collapsed ? null : t('common.returnToSpace')}
|
||||||
{collapsed ? null : t('common.returnToSpace')}
|
</Flex>
|
||||||
</Flex>
|
|
||||||
}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
}
|
}
|
||||||
{source === 'manage' && subscription && !collapsed &&
|
{source === 'manage' && subscription && !collapsed &&
|
||||||
|
|||||||
@@ -1538,7 +1538,6 @@ export const en = {
|
|||||||
json_output: 'Support JSON formatted output',
|
json_output: 'Support JSON formatted output',
|
||||||
thinking_budget_tokens: 'thinking budget tokens',
|
thinking_budget_tokens: 'thinking budget tokens',
|
||||||
thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})",
|
thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})",
|
||||||
thinking_budget_tokens_min_error: "Cannot be less than {{min}}",
|
|
||||||
logSearchPlaceholder: 'Search log content',
|
logSearchPlaceholder: 'Search log content',
|
||||||
},
|
},
|
||||||
userMemory: {
|
userMemory: {
|
||||||
|
|||||||
@@ -868,7 +868,6 @@ export const zh = {
|
|||||||
json_output: '支持JSON格式化输出',
|
json_output: '支持JSON格式化输出',
|
||||||
thinking_budget_tokens: '深度思考预算Token数',
|
thinking_budget_tokens: '深度思考预算Token数',
|
||||||
thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})",
|
thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})",
|
||||||
thinking_budget_tokens_min_error: "不能小于 {{min}}",
|
|
||||||
logSearchPlaceholder: '搜索日志内容',
|
logSearchPlaceholder: '搜索日志内容',
|
||||||
},
|
},
|
||||||
table: {
|
table: {
|
||||||
|
|||||||
176
web/src/vendor/x6-html-shape/index.js
vendored
Normal file
176
web/src/vendor/x6-html-shape/index.js
vendored
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
// Patched x6-html-shape: replaces View.createElement (removed in X6 3.x) with document.createElement
|
||||||
|
import { Node as p, NodeView as l, Graph as C, Dom as s } from "@antv/x6";
|
||||||
|
import { getConfig as w, clickable as x, isInputElement as y, forwardEvent as S } from "./utils.js";
|
||||||
|
|
||||||
|
const u = "html-shape", h = "html-shape-view", T = p.define(w(h)), m = {};
|
||||||
|
|
||||||
|
export function register(i) {
|
||||||
|
const { shape: e, render: n, inherit: t = u, ...o } = i;
|
||||||
|
if (!e) throw new Error("should specify shape in config");
|
||||||
|
m[e] = n;
|
||||||
|
C.registerNode(e, { inherit: t, ...o }, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
const a = "html";
|
||||||
|
|
||||||
|
// Determine which HTML layer a node belongs to.
|
||||||
|
// Parent (loop/iteration) nodes go behind the SVG layer so edges render above them.
|
||||||
|
// All other nodes go in front of the SVG layer so they render above edges.
|
||||||
|
function isBackNode(cell) {
|
||||||
|
const type = cell.getData?.()?.type;
|
||||||
|
return type === 'loop' || type === 'iteration';
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the two HTML container layers exist and are correctly positioned.
|
||||||
|
function ensureHtmlLayers(graph) {
|
||||||
|
if (!graph._htmlBack) {
|
||||||
|
const back = graph._htmlBack = document.createElement('div');
|
||||||
|
s.css(back, {
|
||||||
|
position: 'absolute', width: '100%', height: '100%',
|
||||||
|
'touch-action': 'none', 'user-select': 'none', 'pointer-events': 'none',
|
||||||
|
'z-index': 0, 'transform-origin': 'left top',
|
||||||
|
});
|
||||||
|
back.classList.add('x6-html-shape-container', 'x6-html-shape-back');
|
||||||
|
const svg = graph.container.querySelector('svg');
|
||||||
|
// back layer: before SVG → visually behind edges
|
||||||
|
graph.container.insertBefore(back, svg || null);
|
||||||
|
}
|
||||||
|
if (!graph._htmlFront) {
|
||||||
|
const front = graph._htmlFront = document.createElement('div');
|
||||||
|
s.css(front, {
|
||||||
|
position: 'absolute', width: '100%', height: '100%',
|
||||||
|
'touch-action': 'none', 'user-select': 'none', 'pointer-events': 'none',
|
||||||
|
'z-index': 0, 'transform-origin': 'left top',
|
||||||
|
});
|
||||||
|
front.classList.add('x6-html-shape-container', 'x6-html-shape-front');
|
||||||
|
// front layer: after SVG → visually above edges
|
||||||
|
graph.container.append(front);
|
||||||
|
}
|
||||||
|
// Keep legacy alias so updateHtmlContainerSize can iterate both
|
||||||
|
graph.htmlContainers = [graph._htmlBack, graph._htmlFront];
|
||||||
|
}
|
||||||
|
|
||||||
|
class BaseHTMLShapeView extends l {
|
||||||
|
confirmUpdate(e) {
|
||||||
|
const n = super.confirmUpdate(e);
|
||||||
|
return this.handleAction(n, a, () => {
|
||||||
|
if (!this.mounted) {
|
||||||
|
const t = m[this.cell.shape], o = this.ensureComponentContainer();
|
||||||
|
t && o && (this.mounted = t(this.cell, this.graph, o) || true,
|
||||||
|
this.onMounted(),
|
||||||
|
o.addEventListener("mousedown", this.prevEvent, true),
|
||||||
|
o.addEventListener("mouseup", this.prevEvent, true));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
prevEvent(e) {
|
||||||
|
(x(e.target) || y(e.target)) && (e.preventDefault(), e.stopPropagation());
|
||||||
|
}
|
||||||
|
ensureComponentContainer() {}
|
||||||
|
onMounted() {}
|
||||||
|
onUnMount() {
|
||||||
|
if (this.onZIndexChange) {
|
||||||
|
this.cell.off("change:zIndex", this.onZIndexChange);
|
||||||
|
}
|
||||||
|
if (this.onNodeMoving) {
|
||||||
|
this.graph.off("node:moving", this.onNodeMoving);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unmount() {
|
||||||
|
typeof this.mounted == "function" && this.mounted();
|
||||||
|
this.componentContainer && this.componentContainer.remove();
|
||||||
|
this.onUnMount();
|
||||||
|
return super.unmount(), this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BaseHTMLShapeView.config({ bootstrap: [a], actions: { component: a } });
|
||||||
|
|
||||||
|
class HTMLShapeView extends BaseHTMLShapeView {
|
||||||
|
constructor(...e) {
|
||||||
|
super(...e);
|
||||||
|
this.cell.on("change:visible", ({ cell: n }) => {
|
||||||
|
if (n.view === h) {
|
||||||
|
const t = this.graph.findViewByCell(n.id);
|
||||||
|
t && Promise.resolve().then(() => {
|
||||||
|
t.componentContainer.style.display = t.container.style.display;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
onMounted() {
|
||||||
|
const listeners = this.graph.listeners;
|
||||||
|
// Always register per-cell zIndex listener regardless of shared transform events
|
||||||
|
this.onZIndexChange = () => this.updateContainerStyle();
|
||||||
|
this.cell.on("change:zIndex", this.onZIndexChange);
|
||||||
|
if (listeners?.hasTransformEvent?.length) return;
|
||||||
|
this.onTranslate = this.updateHtmlContainerSize.bind(this);
|
||||||
|
this.graph.on("translate", this.onTranslate);
|
||||||
|
this.graph.on("scale", this.onTranslate);
|
||||||
|
this.graph.on("node:change:position", this.onTranslate);
|
||||||
|
this.graph.on("hasTransformEvent", this.onTranslate);
|
||||||
|
// While dragging, lift this node's componentContainer to the top of its
|
||||||
|
// layer so its ports are never obscured by a sibling node underneath.
|
||||||
|
this.onNodeMoving = ({ node }) => {
|
||||||
|
if (node === this.cell && this.componentContainer) {
|
||||||
|
const layer = isBackNode(this.cell) ? this.graph._htmlBack : this.graph._htmlFront;
|
||||||
|
layer.append(this.componentContainer);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
this.graph.on("node:moving", this.onNodeMoving);
|
||||||
|
this.updateHtmlContainerSize();
|
||||||
|
}
|
||||||
|
ensureComponentContainer() {
|
||||||
|
ensureHtmlLayers(this.graph);
|
||||||
|
const layer = isBackNode(this.cell) ? this.graph._htmlBack : this.graph._htmlFront;
|
||||||
|
if (!this.componentContainer) {
|
||||||
|
const e = this.componentContainer = document.createElement("div");
|
||||||
|
s.css(e, {
|
||||||
|
"pointer-events": "auto", "touch-action": "none", "user-select": "none",
|
||||||
|
"transform-origin": "center", position: "absolute"
|
||||||
|
});
|
||||||
|
e.classList.add("x6-html-shape-node");
|
||||||
|
"click,dblclick,contextmenu,mousedown,mousemove,mouseup,mouseover,mouseout,mouseenter,mouseleave"
|
||||||
|
.split(",").forEach(t => S(t, e, this.container));
|
||||||
|
layer.append(e);
|
||||||
|
}
|
||||||
|
return this.componentContainer;
|
||||||
|
}
|
||||||
|
resize() { super.resize(); this.updateContainerStyle(); }
|
||||||
|
updateTransform() { super.updateTransform(); this.updateContainerStyle(); }
|
||||||
|
updateContainerStyle() {
|
||||||
|
const e = this.ensureComponentContainer();
|
||||||
|
const { x: n, y: t } = this.cell.getBBox();
|
||||||
|
const { width: o, height: r } = this.cell.getSize();
|
||||||
|
const g = getComputedStyle(this.container).cursor;
|
||||||
|
const f = this.cell.getZIndex() ?? 0;
|
||||||
|
// Shrink the interactive width by the port hover radius (6px) so the right
|
||||||
|
// port circle is fully outside the componentContainer and never blocked by it.
|
||||||
|
// overflow:visible keeps the visual rendering intact.
|
||||||
|
const PORT_RADIUS = 6;
|
||||||
|
s.css(e, {
|
||||||
|
cursor: g, height: r + "px", width: (o - PORT_RADIUS) + "px",
|
||||||
|
overflow: "visible",
|
||||||
|
"z-index": f,
|
||||||
|
transform: `translate(${n}px, ${t}px) rotate(${this.cell.getAngle()}deg)`
|
||||||
|
});
|
||||||
|
}
|
||||||
|
updateHtmlContainerSize() {
|
||||||
|
const { graph: e } = this;
|
||||||
|
const t = e.transform.getMatrix();
|
||||||
|
const { offsetHeight: o, offsetWidth: r } = e.container;
|
||||||
|
const n = e.transform.getZoom();
|
||||||
|
const style = {
|
||||||
|
transform: `matrix(${t.a}, ${t.b}, ${t.c}, ${t.d}, ${t.e}, ${t.f})`,
|
||||||
|
width: r / n + "px",
|
||||||
|
height: o / n + "px",
|
||||||
|
};
|
||||||
|
// Update both layers
|
||||||
|
(e.htmlContainers || [e._htmlBack, e._htmlFront].filter(Boolean)).forEach(c => s.css(c, style));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
l.registry.register(h, HTMLShapeView, true);
|
||||||
|
p.registry.register(u, T, true);
|
||||||
|
|
||||||
|
export { BaseHTMLShapeView, T as HTMLShape, u as HTMLShapeName, HTMLShapeView, h as HTMLView, a as action };
|
||||||
1
web/src/vendor/x6-html-shape/react.js
vendored
Normal file
1
web/src/vendor/x6-html-shape/react.js
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
export { default } from "x6-html-shape/dist/react.js";
|
||||||
98
web/src/vendor/x6-html-shape/utils.js
vendored
Normal file
98
web/src/vendor/x6-html-shape/utils.js
vendored
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import { Dom as u, ObjectExt as l, Markup as c } from "@antv/x6";
|
||||||
|
const o = "fo-shape-view";
|
||||||
|
function p(t, e, r) {
|
||||||
|
e.addEventListener(t, function(n) {
|
||||||
|
r.dispatchEvent(new n.constructor(n.type, n)), n.preventDefault(), n.stopPropagation();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
function s(t, e = 3) {
|
||||||
|
return !t || !u.isHTMLElement(t) || e <= 0 ? !1 : ["a", "button"].includes(u.tagName(t)) || t.getAttribute("role") === "button" || t.getAttribute("type") === "button" ? !0 : s(t.parentNode, e - 1);
|
||||||
|
}
|
||||||
|
function g(t) {
|
||||||
|
if (u.tagName(t) === "input") {
|
||||||
|
const r = t.getAttribute("type");
|
||||||
|
if (r == null || ["text", "password", "number", "email", "search", "tel", "url"].includes(
|
||||||
|
r
|
||||||
|
))
|
||||||
|
return !0;
|
||||||
|
}
|
||||||
|
return !1;
|
||||||
|
}
|
||||||
|
function f(t = "rect", e = !0) {
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
tagName: t,
|
||||||
|
selector: "body"
|
||||||
|
},
|
||||||
|
e ? c.getForeignObjectMarkup() : null,
|
||||||
|
{
|
||||||
|
tagName: "text",
|
||||||
|
selector: "label"
|
||||||
|
}
|
||||||
|
].filter((r) => r);
|
||||||
|
}
|
||||||
|
function b(t) {
|
||||||
|
return {
|
||||||
|
view: t,
|
||||||
|
markup: f("rect", t === o),
|
||||||
|
attrs: {
|
||||||
|
body: {
|
||||||
|
// fill: "none",
|
||||||
|
// 这里很奇怪,none的时候不能触发节点移动,改成transparent可以触发
|
||||||
|
fill: "transparent",
|
||||||
|
stroke: "none",
|
||||||
|
refWidth: "100%",
|
||||||
|
refHeight: "100%"
|
||||||
|
},
|
||||||
|
label: {
|
||||||
|
fontSize: 14,
|
||||||
|
fill: "#333",
|
||||||
|
refX: "50%",
|
||||||
|
refY: "50%",
|
||||||
|
textAnchor: "middle",
|
||||||
|
textVerticalAnchor: "middle"
|
||||||
|
},
|
||||||
|
fo: {
|
||||||
|
refWidth: "100%",
|
||||||
|
refHeight: "100%"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
propHooks(e) {
|
||||||
|
if (e.markup == null) {
|
||||||
|
const { primer: r, view: n } = e;
|
||||||
|
if (r && r !== "rect") {
|
||||||
|
e.markup = f(r, n === o);
|
||||||
|
let i = {};
|
||||||
|
r === "circle" ? i = {
|
||||||
|
refCx: "50%",
|
||||||
|
refCy: "50%",
|
||||||
|
refR: "50%"
|
||||||
|
} : r === "ellipse" && (i = {
|
||||||
|
refCx: "50%",
|
||||||
|
refCy: "50%",
|
||||||
|
refRx: "50%",
|
||||||
|
refRy: "50%"
|
||||||
|
}), e.attrs = l.merge(
|
||||||
|
{},
|
||||||
|
{
|
||||||
|
body: {
|
||||||
|
refWidth: null,
|
||||||
|
refHeight: null,
|
||||||
|
...i
|
||||||
|
}
|
||||||
|
},
|
||||||
|
e.attrs || {}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
export {
|
||||||
|
o as FOView,
|
||||||
|
s as clickable,
|
||||||
|
p as forwardEvent,
|
||||||
|
b as getConfig,
|
||||||
|
g as isInputElement
|
||||||
|
};
|
||||||
@@ -7,7 +7,7 @@
|
|||||||
import { type FC, useRef } from 'react';
|
import { type FC, useRef } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useParams } from 'react-router-dom';
|
import { useParams } from 'react-router-dom';
|
||||||
import { Flex, Button, Form } from 'antd';
|
import { Flex, Button } from 'antd';
|
||||||
import type { ColumnsType } from 'antd/es/table';
|
import type { ColumnsType } from 'antd/es/table';
|
||||||
|
|
||||||
import { getAppLogsUrl } from '@/api/application';
|
import { getAppLogsUrl } from '@/api/application';
|
||||||
@@ -15,14 +15,11 @@ import Table from '@/components/Table'
|
|||||||
import { formatDateTime } from '@/utils/format';
|
import { formatDateTime } from '@/utils/format';
|
||||||
import type { LogItem, LogDetailModalRef } from './types'
|
import type { LogItem, LogDetailModalRef } from './types'
|
||||||
import LogDetailModal from './components/LogDetailModal'
|
import LogDetailModal from './components/LogDetailModal'
|
||||||
import SearchInput from '@/components/SearchInput'
|
|
||||||
|
|
||||||
const Statistics: FC = () => {
|
const Statistics: FC = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { id } = useParams();
|
const { id } = useParams();
|
||||||
const logDetailRef = useRef<LogDetailModalRef>(null);
|
const logDetailRef = useRef<LogDetailModalRef>(null);
|
||||||
const [form] = Form.useForm();
|
|
||||||
const values = Form.useWatch([], form);
|
|
||||||
|
|
||||||
const handleViewDetail = (item: LogItem) => {
|
const handleViewDetail = (item: LogItem) => {
|
||||||
logDetailRef.current?.handleOpen(item);
|
logDetailRef.current?.handleOpen(item);
|
||||||
@@ -65,26 +62,15 @@ const Statistics: FC = () => {
|
|||||||
];
|
];
|
||||||
return (
|
return (
|
||||||
<div className="rb:bg-white rb:rounded-lg rb:pt-3 rb:px-3">
|
<div className="rb:bg-white rb:rounded-lg rb:pt-3 rb:px-3">
|
||||||
<Flex justify="flex-end" className="rb:mb-3!">
|
|
||||||
<Form form={form}>
|
|
||||||
<Form.Item name="keyword" noStyle>
|
|
||||||
<SearchInput
|
|
||||||
placeholder={t('application.logSearchPlaceholder')}
|
|
||||||
variant="outlined"
|
|
||||||
/>
|
|
||||||
</Form.Item>
|
|
||||||
</Form>
|
|
||||||
</Flex>
|
|
||||||
<Table<LogItem>
|
<Table<LogItem>
|
||||||
apiUrl={getAppLogsUrl(id || '')}
|
apiUrl={getAppLogsUrl(id || '')}
|
||||||
apiParams={{
|
apiParams={{
|
||||||
is_draft: false,
|
is_draft: false,
|
||||||
...(values ?? {})
|
|
||||||
}}
|
}}
|
||||||
columns={columns}
|
columns={columns}
|
||||||
rowKey="id"
|
rowKey="id"
|
||||||
isScroll={true}
|
isScroll={true}
|
||||||
scrollY="calc(100vh - 242px)"
|
scrollY="calc(100vh - 214px)"
|
||||||
/>
|
/>
|
||||||
<LogDetailModal ref={logDetailRef} />
|
<LogDetailModal ref={logDetailRef} />
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-03-13 17:27:52
|
* @Date: 2026-03-13 17:27:52
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-24 18:14:25
|
* @Last Modified time: 2026-04-07 21:48:30
|
||||||
*/
|
*/
|
||||||
import { type FC, useState, useRef, useEffect } from 'react'
|
import { type FC, useState, useRef, useEffect } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
@@ -59,7 +59,6 @@ interface NodeData {
|
|||||||
node_type?: string;
|
node_type?: string;
|
||||||
input?: any;
|
input?: any;
|
||||||
output?: any;
|
output?: any;
|
||||||
process?: any;
|
|
||||||
elapsed_time?: string;
|
elapsed_time?: string;
|
||||||
error?: any;
|
error?: any;
|
||||||
state: Record<string, any>;
|
state: Record<string, any>;
|
||||||
@@ -486,7 +485,7 @@ const TestChat: FC<TestChatProps> = ({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const updateWorkflowNodeEndMessage = (data: NodeData) => {
|
const updateWorkflowNodeEndMessage = (data: NodeData) => {
|
||||||
const { node_id, input, output, process, error, elapsed_time, status } = data;
|
const { node_id, input, output, error, elapsed_time, status } = data;
|
||||||
setChatList(prev => {
|
setChatList(prev => {
|
||||||
const newList = [...prev]
|
const newList = [...prev]
|
||||||
const lastIndex = newList.length - 1
|
const lastIndex = newList.length - 1
|
||||||
@@ -499,7 +498,6 @@ const TestChat: FC<TestChatProps> = ({
|
|||||||
content: {
|
content: {
|
||||||
input,
|
input,
|
||||||
output,
|
output,
|
||||||
process,
|
|
||||||
error,
|
error,
|
||||||
},
|
},
|
||||||
status: status || 'completed',
|
status: status || 'completed',
|
||||||
@@ -516,7 +514,7 @@ const TestChat: FC<TestChatProps> = ({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const updateWorkflowCycleMessage = (data: NodeData) => {
|
const updateWorkflowCycleMessage = (data: NodeData) => {
|
||||||
const { node_id, cycle_id, cycle_idx, input, output, process, error, elapsed_time, status } = data;
|
const { node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = data;
|
||||||
const { nodes } = config as WorkflowConfig
|
const { nodes } = config as WorkflowConfig
|
||||||
const node = nodes.find(n => n.id === node_id);
|
const node = nodes.find(n => n.id === node_id);
|
||||||
const { name, type } = node || {}
|
const { name, type } = node || {}
|
||||||
@@ -540,7 +538,6 @@ const TestChat: FC<TestChatProps> = ({
|
|||||||
cycle_idx,
|
cycle_idx,
|
||||||
input,
|
input,
|
||||||
output,
|
output,
|
||||||
process,
|
|
||||||
error,
|
error,
|
||||||
},
|
},
|
||||||
status: status || 'completed',
|
status: status || 'completed',
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-03-24 16:31:24
|
* @Date: 2026-03-24 16:31:24
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-24 17:49:58
|
* @Last Modified time: 2026-03-24 16:31:24
|
||||||
*/
|
*/
|
||||||
import { forwardRef, useImperativeHandle, useState, useEffect } from 'react';
|
import { forwardRef, useImperativeHandle, useState, useEffect } from 'react';
|
||||||
import { Flex, Button, Empty, Skeleton } from 'antd';
|
import { Flex, Button, Empty, Skeleton } from 'antd';
|
||||||
@@ -14,12 +14,6 @@ import { getAppLogDetail } from '@/api/application'
|
|||||||
import ChatContent from '@/components/Chat/ChatContent'
|
import ChatContent from '@/components/Chat/ChatContent'
|
||||||
import { formatDateTime } from '@/utils/format'
|
import { formatDateTime } from '@/utils/format'
|
||||||
import type { ChatItem } from '@/components/Chat/types'
|
import type { ChatItem } from '@/components/Chat/types'
|
||||||
import Runtime from '@/views/Workflow/components/Chat/Runtime'
|
|
||||||
import { nodeLibrary } from '@/views/Workflow/constant'
|
|
||||||
|
|
||||||
const nodeIconMap = Object.fromEntries(
|
|
||||||
nodeLibrary.flatMap(c => c.nodes.map(n => [n.type, n.icon]))
|
|
||||||
)
|
|
||||||
|
|
||||||
/** Log detail data with conversation messages */
|
/** Log detail data with conversation messages */
|
||||||
type Data = LogItem & {
|
type Data = LogItem & {
|
||||||
@@ -60,30 +54,7 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
|
|||||||
if (!vo) return
|
if (!vo) return
|
||||||
setLoading(true)
|
setLoading(true)
|
||||||
getAppLogDetail(vo.app_id, vo.id).then(res => {
|
getAppLogDetail(vo.app_id, vo.id).then(res => {
|
||||||
const { node_executions_map, messages, ...rest } = res as Data;
|
setData(res as Data)
|
||||||
let hasSubContentMessages = messages
|
|
||||||
if (messages && messages.length > 0 && node_executions_map && Object.keys(node_executions_map).length > 0) {
|
|
||||||
hasSubContentMessages = messages.map(item => {
|
|
||||||
if (item.id && node_executions_map[item.id]) {
|
|
||||||
item.subContent = node_executions_map[item.id]?.map(({ input, output, cycle_items = [], error, process, ...node }: any) => {
|
|
||||||
const converted: any = { ...node, icon: nodeIconMap[node.node_type], content: { input, output, process, error } }
|
|
||||||
if (node.node_type === 'loop' && Array.isArray(cycle_items) && cycle_items.length > 0) {
|
|
||||||
converted.subContent = cycle_items.map(({ input: cInput, output: cOutput, error: cError, process: cProcess, ...cNode }: any) => ({
|
|
||||||
...cNode,
|
|
||||||
icon: nodeIconMap[cNode.node_type],
|
|
||||||
content: { input: cInput, output: cOutput, process: cProcess, error: cError }
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
return converted
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return { ...item }
|
|
||||||
})
|
|
||||||
}
|
|
||||||
setData({
|
|
||||||
...rest,
|
|
||||||
messages: hasSubContentMessages
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.finally(() => {
|
.finally(() => {
|
||||||
setLoading(false)
|
setLoading(false)
|
||||||
@@ -95,8 +66,6 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
|
|||||||
handleClose
|
handleClose
|
||||||
}));
|
}));
|
||||||
|
|
||||||
console.log('data', data)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<RbModal
|
<RbModal
|
||||||
title={<>
|
title={<>
|
||||||
@@ -123,7 +92,6 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
|
|||||||
data={data.messages || []}
|
data={data.messages || []}
|
||||||
streamLoading={false}
|
streamLoading={false}
|
||||||
labelFormat={(item) => formatDateTime(item.created_at)}
|
labelFormat={(item) => formatDateTime(item.created_at)}
|
||||||
renderRuntime={(item, index) => <Runtime item={item} index={index} />}
|
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,8 +49,6 @@ const configFields = [
|
|||||||
{ key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 },
|
{ key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 },
|
||||||
]
|
]
|
||||||
|
|
||||||
const minThinkingBudgetTokens = 128;
|
|
||||||
const defaultThinkingBudgetTokens = 1000;
|
|
||||||
const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({
|
const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({
|
||||||
refresh,
|
refresh,
|
||||||
data,
|
data,
|
||||||
@@ -110,7 +108,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
const newValues: ModelConfig = {
|
const newValues: ModelConfig = {
|
||||||
capability: (option as Model).capability,
|
capability: (option as Model).capability,
|
||||||
deep_thinking: false,
|
deep_thinking: false,
|
||||||
thinking_budget_tokens: defaultThinkingBudgetTokens,
|
thinking_budget_tokens: undefined,
|
||||||
json_output: false,
|
json_output: false,
|
||||||
}
|
}
|
||||||
if (source === 'chat') {
|
if (source === 'chat') {
|
||||||
@@ -130,12 +128,6 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
form.setFieldsValue({ ...rest })
|
form.setFieldsValue({ ...rest })
|
||||||
}, [data?.default_model_config_id])
|
}, [data?.default_model_config_id])
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (values?.deep_thinking && !values?.thinking_budget_tokens) {
|
|
||||||
form.setFieldValue('thinking_budget_tokens', defaultThinkingBudgetTokens)
|
|
||||||
}
|
|
||||||
}, [values?.deep_thinking])
|
|
||||||
|
|
||||||
const handleReset = () => {
|
const handleReset = () => {
|
||||||
if (!id) return
|
if (!id) return
|
||||||
resetAppModelConfig(id).then((res) => {
|
resetAppModelConfig(id).then((res) => {
|
||||||
@@ -186,20 +178,15 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
name="thinking_budget_tokens"
|
name="thinking_budget_tokens"
|
||||||
label={t('application.thinking_budget_tokens')}
|
label={t('application.thinking_budget_tokens')}
|
||||||
hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))}
|
hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))}
|
||||||
extra={<>{t('application.range')}: [{minThinkingBudgetTokens}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
|
extra={<>{t('application.range')}: [{0}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
|
||||||
rules={[
|
rules={[
|
||||||
{ required: values?.deep_thinking, message: t('common.pleaseEnter') },
|
{ required: values?.deep_thinking, message: t('common.pleaseEnter') },
|
||||||
{
|
{
|
||||||
validator: (_, value) => {
|
validator: (_, value) => {
|
||||||
const maxTokens = values?.max_tokens
|
const maxTokens = values?.max_tokens
|
||||||
const deep_thinking = values?.deep_thinking;
|
const deep_thinking = values?.deep_thinking;
|
||||||
if (deep_thinking && value !== undefined) {
|
if (deep_thinking && value !== undefined && maxTokens !== undefined && value > maxTokens) {
|
||||||
if (value < minThinkingBudgetTokens) {
|
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
|
||||||
return Promise.reject(t('application.thinking_budget_tokens_min_error', { min: minThinkingBudgetTokens }))
|
|
||||||
}
|
|
||||||
if (maxTokens !== undefined && value > maxTokens) {
|
|
||||||
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
@@ -208,7 +195,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
>
|
>
|
||||||
<RbSlider
|
<RbSlider
|
||||||
step={1}
|
step={1}
|
||||||
min={minThinkingBudgetTokens}
|
min={0}
|
||||||
max={32000}
|
max={32000}
|
||||||
isInput={true}
|
isInput={true}
|
||||||
disabled={!values?.deep_thinking}
|
disabled={!values?.deep_thinking}
|
||||||
|
|||||||
@@ -166,10 +166,10 @@ const Ontology: FC = () => {
|
|||||||
<div className="rb:h-10 rb:wrap-break-word rb:line-clamp-2 rb:leading-5">{item.scene_description}</div>
|
<div className="rb:h-10 rb:wrap-break-word rb:line-clamp-2 rb:leading-5">{item.scene_description}</div>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
|
||||||
<div className="rb:mt-2 rb:h-5.5">
|
<div className="rb:mt-2">
|
||||||
<OverflowTags
|
<OverflowTags
|
||||||
popoverProps={false}
|
popoverProps={false}
|
||||||
items={item.entity_type ? [...item.entity_type.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>] : []}
|
items={[...item.entity_type?.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>]}
|
||||||
numTag={(num?: number) => <Tag variant="borderless" color="dark">{`+${item.type_num - 3 + (num ? num - 1 : 0)}`}</Tag>}
|
numTag={(num?: number) => <Tag variant="borderless" color="dark">{`+${item.type_num - 3 + (num ? num - 1 : 0)}`}</Tag>}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -101,7 +101,6 @@ const CustomToolModal = forwardRef<CustomToolModalRef, CustomToolModalProps>(({
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
const formatSchema = (value: string) => {
|
const formatSchema = (value: string) => {
|
||||||
if (!value || value.trim() === '') return
|
|
||||||
setParseSchemaData({} as ParseSchemaData)
|
setParseSchemaData({} as ParseSchemaData)
|
||||||
parseSchema({ schema_content: value })
|
parseSchema({ schema_content: value })
|
||||||
.then(res => {
|
.then(res => {
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
|
|||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
labelRender={(props) => {
|
labelRender={(props) => {
|
||||||
|
console.log('props', props)
|
||||||
return `${props.value}%`
|
return `${props.value}%`
|
||||||
}}
|
}}
|
||||||
className="rb:w-20 rb:h-4!"
|
className="rb:w-20 rb:h-4!"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-06 21:10:56
|
* @Date: 2026-02-06 21:10:56
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-24 18:13:22
|
* @Last Modified time: 2026-04-21 14:59:13
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Workflow Chat Component
|
* Workflow Chat Component
|
||||||
@@ -66,6 +66,8 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
const [fileList, setFileList] = useState<any[]>([])
|
const [fileList, setFileList] = useState<any[]>([])
|
||||||
const [message, setMessage] = useState<string | undefined>(undefined)
|
const [message, setMessage] = useState<string | undefined>(undefined)
|
||||||
|
|
||||||
|
console.log('abortRef', abortRef)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Opens the chat drawer and loads workflow variables from the start node
|
* Opens the chat drawer and loads workflow variables from the start node
|
||||||
*/
|
*/
|
||||||
@@ -183,7 +185,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
*/
|
*/
|
||||||
const handleStreamMessage = (data: SSEMessage[]) => {
|
const handleStreamMessage = (data: SSEMessage[]) => {
|
||||||
data.forEach(item => {
|
data.forEach(item => {
|
||||||
const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, process, error, elapsed_time, status, citations } = item.data as {
|
const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status, citations } = item.data as {
|
||||||
content: string;
|
content: string;
|
||||||
conversation_id: string | null;
|
conversation_id: string | null;
|
||||||
cycle_id: string;
|
cycle_id: string;
|
||||||
@@ -191,7 +193,6 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
node_id: string;
|
node_id: string;
|
||||||
node_name?: string;
|
node_name?: string;
|
||||||
node_type?: string;
|
node_type?: string;
|
||||||
process?: any;
|
|
||||||
input?: any;
|
input?: any;
|
||||||
output?: any;
|
output?: any;
|
||||||
elapsed_time?: string;
|
elapsed_time?: string;
|
||||||
@@ -276,7 +277,6 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
content: {
|
content: {
|
||||||
input,
|
input,
|
||||||
output,
|
output,
|
||||||
process,
|
|
||||||
error,
|
error,
|
||||||
},
|
},
|
||||||
status: status || 'completed',
|
status: status || 'completed',
|
||||||
@@ -305,14 +305,13 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
cycle_id,
|
cycle_id,
|
||||||
cycle_idx,
|
cycle_idx,
|
||||||
node_id,
|
node_id,
|
||||||
node_name: type === 'cycle-start' ? t('workflow.cycle-start') : name,
|
node_name: name,
|
||||||
node_type: type,
|
node_type: type,
|
||||||
icon,
|
icon,
|
||||||
content: {
|
content: {
|
||||||
cycle_idx,
|
cycle_idx,
|
||||||
input,
|
input,
|
||||||
output,
|
output,
|
||||||
process,
|
|
||||||
error,
|
error,
|
||||||
},
|
},
|
||||||
status: status || 'completed',
|
status: status || 'completed',
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-24 17:57:08
|
* @Date: 2026-02-24 17:57:08
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-24 18:04:31
|
* @Last Modified time: 2026-04-20 15:33:48
|
||||||
*/
|
*/
|
||||||
/*
|
/*
|
||||||
* Runtime Component
|
* Runtime Component
|
||||||
@@ -184,30 +184,27 @@ const Runtime: FC<{ item: ChatItem; index: number;}> = ({
|
|||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
{/* Display input and output data as JSON code blocks */}
|
{/* Display input and output data as JSON code blocks */}
|
||||||
{['input', 'process', 'output'].map(key => {
|
{['input', 'output'].map(key => (
|
||||||
if (vo.node_type !== 'http-request' && key === 'process') return null
|
<div key={key} className="rb:bg-[#EBEBEB] rb:rounded-lg">
|
||||||
return (
|
<div className="rb:py-2 rb:px-3 rb:flex rb:justify-between rb:items-center rb:text-[12px]">
|
||||||
<div key={key} className="rb:bg-[#EBEBEB] rb:rounded-lg">
|
{isLoop ? t(`workflow.runtime.${key}_cycle_vars`) : t(`workflow.${key}_result`)}
|
||||||
<div className="rb:py-2 rb:px-3 rb:flex rb:justify-between rb:items-center rb:text-[12px]">
|
<Button
|
||||||
{isLoop ? t(`workflow.runtime.${key}_cycle_vars`) : t(`workflow.${key}_result`)}
|
className="rb:py-0! rb:px-1! rb:text-[12px]!"
|
||||||
<Button
|
size="small"
|
||||||
className="rb:py-0! rb:px-1! rb:text-[12px]!"
|
onClick={() => handleCopy(typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}')}
|
||||||
size="small"
|
>{t('common.copy')}</Button>
|
||||||
onClick={() => handleCopy(typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}')}
|
|
||||||
>{t('common.copy')}</Button>
|
|
||||||
</div>
|
|
||||||
<div className="rb:max-h-40 rb:overflow-auto">
|
|
||||||
<CodeBlock
|
|
||||||
size="small"
|
|
||||||
value={typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}'}
|
|
||||||
needCopy={false}
|
|
||||||
showLineNumbers={true}
|
|
||||||
background="#EBEBEB"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
)
|
<div className="rb:max-h-40 rb:overflow-auto">
|
||||||
})}
|
<CodeBlock
|
||||||
|
size="small"
|
||||||
|
value={typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}'}
|
||||||
|
needCopy={false}
|
||||||
|
showLineNumbers={true}
|
||||||
|
background="#EBEBEB"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
</Flex>
|
</Flex>
|
||||||
)
|
)
|
||||||
}]}
|
}]}
|
||||||
|
|||||||
@@ -2,183 +2,29 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-09 18:31:30
|
* @Date: 2026-02-09 18:31:30
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-30 11:55:10
|
* @Last Modified time: 2026-04-28 10:24:58
|
||||||
*/
|
*/
|
||||||
import { useState } from 'react';
|
import { Flex } from 'antd';
|
||||||
import { Popover, Flex } from 'antd';
|
|
||||||
import clsx from 'clsx';
|
import clsx from 'clsx';
|
||||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||||
import { nodeLibrary, graphNodeLibrary, edgeAttrs, nodeWidth } from '../../constant';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
const AddNode: ReactShapeConfig['component'] = ({ node }) => {
|
||||||
const data = node?.getData() || {};
|
const data = node?.getData() || {};
|
||||||
const { t } = useTranslation();
|
|
||||||
const [open, setOpen] = useState(false);
|
|
||||||
|
|
||||||
// Handle node selection from popover and create new node replacing the add-node placeholder
|
|
||||||
const handleNodeSelect = (selectedNodeType: any) => {
|
|
||||||
graph.startBatch('add-node');
|
|
||||||
const parentBBox = node.getBBox();
|
|
||||||
const cycleId = data.cycle;
|
|
||||||
const horizontalSpacing = 0;
|
|
||||||
|
|
||||||
const id = `${selectedNodeType.type.replace(/-/g, '_') }_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
|
||||||
const newNode = graph.addNode({
|
|
||||||
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
|
|
||||||
x: parentBBox.x + horizontalSpacing,
|
|
||||||
y: parentBBox.y - 12,
|
|
||||||
id,
|
|
||||||
data: {
|
|
||||||
id,
|
|
||||||
type: selectedNodeType.type,
|
|
||||||
icon: selectedNodeType.icon,
|
|
||||||
name: t(`workflow.${selectedNodeType.type}`),
|
|
||||||
cycle: cycleId,
|
|
||||||
parentId: data.parentId,
|
|
||||||
config: selectedNodeType.config || {}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add new node as child of parent node
|
|
||||||
if (cycleId) {
|
|
||||||
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
|
||||||
if (parentNode) {
|
|
||||||
parentNode.addChild(newNode, { silent: true });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const incomingEdges = graph.getIncomingEdges(node);
|
|
||||||
const outgoingEdges = graph.getOutgoingEdges(node);
|
|
||||||
const addedEdges: any[] = [];
|
|
||||||
|
|
||||||
incomingEdges?.forEach((edge: any) => {
|
|
||||||
addedEdges.push(graph.addEdge({
|
|
||||||
source: { cell: edge.getSourceCellId(), port: edge.getSourcePortId() },
|
|
||||||
target: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'left')?.id || 'left' },
|
|
||||||
...edgeAttrs
|
|
||||||
}));
|
|
||||||
});
|
|
||||||
|
|
||||||
outgoingEdges?.forEach((edge: any) => {
|
|
||||||
const targetCell = graph.getCellById(edge.getTargetCellId()) as any;
|
|
||||||
const targetPortId = targetCell?.getPorts?.()?.find((port: any) => port.group === 'left')?.id || edge.getTargetPortId();
|
|
||||||
addedEdges.push(graph.addEdge({
|
|
||||||
source: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'right')?.id || 'right' },
|
|
||||||
target: { cell: edge.getTargetCellId(), port: targetPortId },
|
|
||||||
...edgeAttrs
|
|
||||||
}));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Remove all add-node type nodes
|
|
||||||
graph.getNodes().forEach((n: any) => {
|
|
||||||
if (n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId) {
|
|
||||||
n.remove();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Automatically adjust loop node size
|
|
||||||
const loopNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
|
||||||
if (loopNode) {
|
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
|
||||||
if (childNodes.length > 0) {
|
|
||||||
const bounds = childNodes.reduce((acc, child) => {
|
|
||||||
const bbox = child.getBBox();
|
|
||||||
return {
|
|
||||||
minX: Math.min(acc.minX, bbox.x),
|
|
||||||
minY: Math.min(acc.minY, bbox.y),
|
|
||||||
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
|
|
||||||
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
|
|
||||||
};
|
|
||||||
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
|
|
||||||
const padding = 50;
|
|
||||||
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
|
|
||||||
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
|
|
||||||
loopNode.prop('size', { width: newWidth, height: newHeight });
|
|
||||||
loopNode.getPorts().forEach(port => {
|
|
||||||
if (port.group === 'right' && port.args) {
|
|
||||||
loopNode.portProp(port.id!, 'args/x', newWidth);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
addedEdges.forEach(e => {
|
|
||||||
const src = graph.getCellById(e.getSourceCellId());
|
|
||||||
const tgt = graph.getCellById(e.getTargetCellId());
|
|
||||||
if (src?.isNode()) src.toFront();
|
|
||||||
if (tgt?.isNode()) tgt.toFront();
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.stopBatch('add-node');
|
|
||||||
setOpen(false);
|
|
||||||
};
|
|
||||||
|
|
||||||
const content = (
|
|
||||||
<div style={{ maxHeight: '300px', overflowY: 'auto', minWidth: `${nodeWidth}px'` }}>
|
|
||||||
{nodeLibrary.map((category, categoryIndex) => {
|
|
||||||
const filteredNodes = category.nodes.filter(nodeType =>
|
|
||||||
nodeType.type !== 'start' && nodeType.type !== 'end' && nodeType.type !== 'iteration' && nodeType.type !== 'loop' && nodeType.type !== 'cycle-start'
|
|
||||||
);
|
|
||||||
|
|
||||||
if (filteredNodes.length === 0) return null;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div key={category.category}>
|
|
||||||
{categoryIndex > 0 && <div style={{ height: '1px', background: '#f0f0f0', margin: '4px 0' }} />}
|
|
||||||
<div style={{ padding: '4px 12px', fontSize: '12px', color: '#999', fontWeight: 'bold' }}>
|
|
||||||
{t(`workflow.${category.category}`)}
|
|
||||||
</div>
|
|
||||||
{filteredNodes.map((nodeType) => (
|
|
||||||
<div
|
|
||||||
key={nodeType.type}
|
|
||||||
style={{
|
|
||||||
padding: '8px 12px',
|
|
||||||
cursor: 'pointer',
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
gap: '8px',
|
|
||||||
}}
|
|
||||||
onClick={() => handleNodeSelect(nodeType)}
|
|
||||||
onMouseEnter={(e) => {
|
|
||||||
e.currentTarget.style.background = '#f0f8ff';
|
|
||||||
}}
|
|
||||||
onMouseLeave={(e) => {
|
|
||||||
e.currentTarget.style.background = 'white';
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<div className={`rb:size-4 rb:bg-cover ${nodeType.icon}`} />
|
|
||||||
<span style={{ fontSize: '14px' }}>{t(`workflow.${nodeType.type}`)}</span>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Popover
|
<Flex
|
||||||
content={content}
|
align="center"
|
||||||
trigger="click"
|
justify="center"
|
||||||
open={open}
|
gap={4}
|
||||||
onOpenChange={setOpen}
|
className={clsx('rb:text-[#212332] rb:font-medium rb:text-[12px] rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:border rb:rounded-lg rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:border-[#FCFCFD] rb:flex rb:items-center rb:justify-center', {
|
||||||
placement="bottomLeft"
|
'rb:border-orange-500 rb:border-[3px] rb:bg-[#FCFCFD] rb:text-[#475467]': data.isSelected,
|
||||||
|
'rb:border-[#d1d5db] rb:bg-[#FCFCFD] rb:text-[#374151]': !data.isSelected
|
||||||
|
})}
|
||||||
>
|
>
|
||||||
<Flex
|
<div className="rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/workflow/node_plus.png')]"></div>
|
||||||
align="center"
|
{data.label}
|
||||||
justify="center"
|
</Flex>
|
||||||
gap={4}
|
|
||||||
className={clsx('rb:text-[#212332] rb:font-medium rb:text-[12px] rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:border rb:rounded-lg rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:border-[#FCFCFD] rb:flex rb:items-center rb:justify-center', {
|
|
||||||
'rb:border-orange-500 rb:border-[3px] rb:bg-[#FCFCFD] rb:text-[#475467]': data.isSelected,
|
|
||||||
'rb:border-[#d1d5db] rb:bg-[#FCFCFD] rb:text-[#374151]': !data.isSelected
|
|
||||||
})}
|
|
||||||
>
|
|
||||||
<div className="rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/workflow/node_plus.png')]"></div>
|
|
||||||
{data.label}
|
|
||||||
</Flex>
|
|
||||||
</Popover>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default AddNode;
|
export default AddNode;
|
||||||
|
|||||||
@@ -65,8 +65,8 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
||||||
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
|
'rb:border-[#171719]!': data.isSelected,
|
||||||
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus,
|
'rb:border-[#FCFCFD]': !data.isSelected,
|
||||||
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
||||||
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
||||||
})}>
|
})}>
|
||||||
@@ -99,7 +99,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
|
|||||||
{data.type === 'if-else' &&
|
{data.type === 'if-else' &&
|
||||||
<Flex vertical gap={4} className="rb:mt-3!">
|
<Flex vertical gap={4} className="rb:mt-3!">
|
||||||
{data.config?.cases?.defaultValue.map((item: any, index: number) => (
|
{data.config?.cases?.defaultValue.map((item: any, index: number) => (
|
||||||
<div key={index}>
|
<div key={index} className={item.expressions.length > 0 ? '' : 'rb:mb-1'}>
|
||||||
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
|
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
|
||||||
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
|
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
|
||||||
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>
|
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>
|
||||||
|
|||||||
@@ -1,19 +1,138 @@
|
|||||||
|
import { useEffect } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
import clsx from 'clsx';
|
import clsx from 'clsx';
|
||||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||||
import { Flex } from 'antd';
|
import { Flex } from 'antd';
|
||||||
import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons';
|
import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons';
|
||||||
import { useTranslation } from 'react-i18next'
|
|
||||||
|
|
||||||
|
import { graphNodeLibrary, edgeAttrs } from '../../constant';
|
||||||
import NodeTools from './NodeTools'
|
import NodeTools from './NodeTools'
|
||||||
|
|
||||||
const LoopNode: ReactShapeConfig['component'] = ({ node }) => {
|
const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||||
const data = node.getData() || {};
|
const data = node.getData() || {};
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// 使用setTimeout确保在所有节点都添加完成后再创建连线
|
||||||
|
const timer = setTimeout(() => {
|
||||||
|
initNodes()
|
||||||
|
checkAndAddAddNode()
|
||||||
|
}, 50)
|
||||||
|
|
||||||
|
return () => clearTimeout(timer)
|
||||||
|
}, [graph])
|
||||||
|
|
||||||
|
const checkAndAddAddNode = () => {
|
||||||
|
if (!graph) return;
|
||||||
|
|
||||||
|
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === data.id);
|
||||||
|
const cycleStartNodes = childNodes.filter((n: any) => n.getData()?.type === 'cycle-start');
|
||||||
|
|
||||||
|
// 如果只有一个cycle-start节点且没有其他类型的子节点,则添加add-node
|
||||||
|
if (cycleStartNodes.length === 1 && childNodes.length === 1) {
|
||||||
|
const cycleStartNode = cycleStartNodes[0];
|
||||||
|
const cycleStartBBox = cycleStartNode.getBBox();
|
||||||
|
|
||||||
|
const addNode = graph.addNode({
|
||||||
|
...graphNodeLibrary.addStart,
|
||||||
|
x: cycleStartBBox.x + 84,
|
||||||
|
y: cycleStartBBox.y + 4,
|
||||||
|
data: {
|
||||||
|
type: 'add-node',
|
||||||
|
label: t('workflow.addNode'),
|
||||||
|
icon: '+',
|
||||||
|
parentId: node.id,
|
||||||
|
cycle: data.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
node.addChild(addNode);
|
||||||
|
|
||||||
|
// 连接cycle-start和add-node
|
||||||
|
const sourcePorts = cycleStartNode.getPorts();
|
||||||
|
const targetPorts = addNode.getPorts();
|
||||||
|
const sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
|
||||||
|
const targetPort = targetPorts.find((port: any) => port.group === 'left')?.id || 'left';
|
||||||
|
|
||||||
|
// 然后创建连线
|
||||||
|
graph.addEdge({
|
||||||
|
source: { cell: cycleStartNode.id, port: sourcePort },
|
||||||
|
target: { cell: addNode.id, port: targetPort },
|
||||||
|
...edgeAttrs,
|
||||||
|
});
|
||||||
|
|
||||||
|
cycleStartNode.toFront()
|
||||||
|
addNode.toFront()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const initNodes = () => {
|
||||||
|
// 检查是否存在cycle为当前节点ID的子节点,若存在则不调用initNodes,避免重复创建
|
||||||
|
const existingCycleNodes = graph.getNodes().filter((n: any) =>
|
||||||
|
n.getData()?.cycle === data.id
|
||||||
|
);
|
||||||
|
if (existingCycleNodes.length > 0) return;
|
||||||
|
// 添加默认子节点
|
||||||
|
const parentBBox = node.getBBox();
|
||||||
|
const centerX = parentBBox.x + 24;
|
||||||
|
const centerY = parentBBox.y + 70;
|
||||||
|
|
||||||
|
const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||||
|
const cycleStartNode = graph.addNode({
|
||||||
|
...graphNodeLibrary.cycleStart,
|
||||||
|
x: centerX,
|
||||||
|
y: centerY,
|
||||||
|
id: cycleStartNodeId,
|
||||||
|
data: {
|
||||||
|
id: cycleStartNodeId,
|
||||||
|
type: 'cycle-start',
|
||||||
|
parentId: node.id,
|
||||||
|
isDefault: true, // 标记为默认节点,不可删除
|
||||||
|
cycle: data.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const addNode = graph.addNode({
|
||||||
|
...graphNodeLibrary.addStart,
|
||||||
|
x: centerX + 84,
|
||||||
|
y: centerY + 4,
|
||||||
|
data: {
|
||||||
|
type: 'add-node',
|
||||||
|
label: t('workflow.addNode'),
|
||||||
|
icon: '+',
|
||||||
|
parentId: node.id,
|
||||||
|
cycle: data.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
node.addChild(cycleStartNode)
|
||||||
|
node.addChild(addNode)
|
||||||
|
const sourcePorts = cycleStartNode.getPorts()
|
||||||
|
const targetPorts = addNode.getPorts()
|
||||||
|
let sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
|
||||||
|
|
||||||
|
const edgeConfig = {
|
||||||
|
source: {
|
||||||
|
cell: cycleStartNode.id,
|
||||||
|
port: sourcePort
|
||||||
|
},
|
||||||
|
target: {
|
||||||
|
cell: addNode.id,
|
||||||
|
port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left'
|
||||||
|
},
|
||||||
|
...edgeAttrs
|
||||||
|
}
|
||||||
|
graph.addEdge(edgeConfig)
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
|
||||||
|
cycleStartNode.toFront()
|
||||||
|
addNode.toFront()
|
||||||
|
}, 0)
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
||||||
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
|
'rb:border-[#171719]': data.isSelected,
|
||||||
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus,
|
'rb:border-[#FCFCFD]': !data.isSelected,
|
||||||
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
||||||
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
||||||
})}>
|
})}>
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ const NormalNode: ReactShapeConfig['component'] = ({ node }) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
||||||
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
|
'rb:border-[#171719]!': data.isSelected,
|
||||||
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus,
|
'rb:border-[#FCFCFD]': !data.isSelected,
|
||||||
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
||||||
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
||||||
})}>
|
})}>
|
||||||
|
|||||||
@@ -2,13 +2,44 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-09 18:30:28
|
* @Date: 2026-02-09 18:30:28
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-30 15:14:02
|
* @Last Modified time: 2026-04-28 11:41:17
|
||||||
*/
|
*/
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
|
import { createPortal } from 'react-dom';
|
||||||
import { Flex, Popover } from 'antd';
|
import { Flex, Popover } from 'antd';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { nodeLibrary, graphNodeLibrary, edgeAttrs, nodeWidth } from '../constant';
|
import { nodeLibrary, graphNodeLibrary, edgeAttrs, nodeWidth } from '../constant';
|
||||||
|
|
||||||
|
// Shared helper: adjust loop/iteration container size to fit child nodes
|
||||||
|
export const adjustCycleContainerSize = (graph: any, cycleId: string) => {
|
||||||
|
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
||||||
|
if (!parentNode) return;
|
||||||
|
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
||||||
|
if (childNodes.length === 0) return;
|
||||||
|
const bounds = childNodes.reduce((acc: any, child: any) => {
|
||||||
|
const bbox = child.getBBox();
|
||||||
|
return {
|
||||||
|
minX: Math.min(acc.minX, bbox.x),
|
||||||
|
minY: Math.min(acc.minY, bbox.y),
|
||||||
|
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
|
||||||
|
maxY: Math.max(acc.maxY, bbox.y + bbox.height),
|
||||||
|
};
|
||||||
|
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
|
||||||
|
const padding = 50;
|
||||||
|
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
|
||||||
|
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
|
||||||
|
parentNode.prop('size', { width: newWidth, height: newHeight });
|
||||||
|
parentNode.getPorts().forEach((port: any) => {
|
||||||
|
if (port.group === 'right' && port.args) {
|
||||||
|
parentNode.portProp(port.id!, 'args/x', newWidth);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
childNodes.forEach((childNode: any) => {
|
||||||
|
childNode.off('change:position');
|
||||||
|
childNode.on('change:position', () => adjustCycleContainerSize(graph, cycleId));
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
interface PortClickHandlerProps {
|
interface PortClickHandlerProps {
|
||||||
graph: any;
|
graph: any;
|
||||||
}
|
}
|
||||||
@@ -16,7 +47,6 @@ interface PortClickHandlerProps {
|
|||||||
const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [popoverVisible, setPopoverVisible] = useState(false);
|
const [popoverVisible, setPopoverVisible] = useState(false);
|
||||||
const [popoverPosition, setPopoverPosition] = useState({ x: 0, y: 0 });
|
|
||||||
const [sourceNode, setSourceNode] = useState<any>(null);
|
const [sourceNode, setSourceNode] = useState<any>(null);
|
||||||
const [sourcePort, setSourcePort] = useState<string>('');
|
const [sourcePort, setSourcePort] = useState<string>('');
|
||||||
const [tempElement, setTempElement] = useState<HTMLElement | null>(null);
|
const [tempElement, setTempElement] = useState<HTMLElement | null>(null);
|
||||||
@@ -24,12 +54,11 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const handlePortClick = (event: CustomEvent) => {
|
const handlePortClick = (event: CustomEvent) => {
|
||||||
const { node, port, element, rect, edgeInsertion } = event.detail;
|
const { node, port, element, edgeInsertion } = event.detail;
|
||||||
setSourceNode(node);
|
setSourceNode(node);
|
||||||
setSourcePort(port);
|
setSourcePort(port);
|
||||||
setTempElement(element);
|
setTempElement(element);
|
||||||
setEdgeInsertion(edgeInsertion || null);
|
setEdgeInsertion(edgeInsertion || null);
|
||||||
setPopoverPosition({ x: rect.left, y: rect.top });
|
|
||||||
setPopoverVisible(true);
|
setPopoverVisible(true);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -43,52 +72,130 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
|||||||
};
|
};
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
// Handle node selection from popover menu and create new node with edge connection
|
||||||
const handleNodeSelect = (selectedNodeType: any) => {
|
const handleNodeSelect = (selectedNodeType: any) => {
|
||||||
if (!sourceNode || !graph) return;
|
if (!sourceNode || !graph) return;
|
||||||
|
|
||||||
const sourceNodeData = sourceNode.getData();
|
const sourceNodeData = sourceNode.getData();
|
||||||
const sourceNodeType = sourceNodeData?.type;
|
const sourceNodeType = sourceNodeData?.type;
|
||||||
const isCycleSubNode = !!sourceNodeData.cycle;
|
|
||||||
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
|
|
||||||
const newNodeType = selectedNodeType.type;
|
|
||||||
|
|
||||||
// Save add-node placeholder position before disabling history
|
// AddNode placeholder mode: replace the add-node placeholder with the selected node
|
||||||
|
if (sourceNodeType === 'add-node') {
|
||||||
|
const placeholderBBox = sourceNode.getBBox();
|
||||||
|
const cycleId = sourceNodeData.cycle;
|
||||||
|
const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
||||||
|
const newNode = graph.addNode({
|
||||||
|
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
|
||||||
|
x: placeholderBBox.x,
|
||||||
|
y: placeholderBBox.y - 12,
|
||||||
|
id,
|
||||||
|
data: {
|
||||||
|
id,
|
||||||
|
type: selectedNodeType.type,
|
||||||
|
icon: selectedNodeType.icon,
|
||||||
|
name: t(`workflow.${selectedNodeType.type}`),
|
||||||
|
cycle: cycleId,
|
||||||
|
parentId: sourceNodeData.parentId,
|
||||||
|
config: selectedNodeType.config || {},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
if (cycleId) {
|
||||||
|
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
||||||
|
if (parentNode) parentNode.addChild(newNode);
|
||||||
|
}
|
||||||
|
const incomingEdges = graph.getIncomingEdges(sourceNode);
|
||||||
|
const outgoingEdges = graph.getOutgoingEdges(sourceNode);
|
||||||
|
const addedEdges: any[] = [];
|
||||||
|
incomingEdges?.forEach((edge: any) => {
|
||||||
|
addedEdges.push(graph.addEdge({
|
||||||
|
source: { cell: edge.getSourceCellId(), port: edge.getSourcePortId() },
|
||||||
|
target: { cell: newNode.id, port: newNode.getPorts().find((p: any) => p.group === 'left')?.id || 'left' },
|
||||||
|
...edgeAttrs,
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
outgoingEdges?.forEach((edge: any) => {
|
||||||
|
const targetCell = graph.getCellById(edge.getTargetCellId()) as any;
|
||||||
|
const targetPortId = targetCell?.getPorts?.()?.find((p: any) => p.group === 'left')?.id || edge.getTargetPortId();
|
||||||
|
addedEdges.push(graph.addEdge({
|
||||||
|
source: { cell: newNode.id, port: newNode.getPorts().find((p: any) => p.group === 'right')?.id || 'right' },
|
||||||
|
target: { cell: edge.getTargetCellId(), port: targetPortId },
|
||||||
|
...edgeAttrs,
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
graph.getNodes().forEach((n: any) => {
|
||||||
|
if (n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId) n.remove();
|
||||||
|
});
|
||||||
|
setTimeout(() => {
|
||||||
|
addedEdges.forEach(e => {
|
||||||
|
const src = graph.getCellById(e.getSourceCellId());
|
||||||
|
const tgt = graph.getCellById(e.getTargetCellId());
|
||||||
|
if (src?.isNode()) src.toFront();
|
||||||
|
if (tgt?.isNode()) tgt.toFront();
|
||||||
|
});
|
||||||
|
}, 50);
|
||||||
|
if (cycleId) adjustCycleContainerSize(graph, cycleId);
|
||||||
|
if (tempElement) { document.body.removeChild(tempElement); setTempElement(null); }
|
||||||
|
setPopoverVisible(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's a cycle-start node, handle the add-node placeholder
|
||||||
let addNodePosition = null;
|
let addNodePosition = null;
|
||||||
|
const isCycleSubNode = sourceNodeData.cycle
|
||||||
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
|
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
|
||||||
const cycleId = sourceNodeData.cycle;
|
const cycleId = sourceNodeData.cycle;
|
||||||
const addNodes = graph.getNodes().filter((n: any) =>
|
const addNodes = graph.getNodes().filter((n: any) =>
|
||||||
n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId
|
n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId
|
||||||
);
|
);
|
||||||
if (addNodes.length > 0) addNodePosition = addNodes[0].getBBox();
|
|
||||||
|
if (addNodes.length > 0) {
|
||||||
|
const addNode = addNodes[0];
|
||||||
|
addNodePosition = addNode.getBBox();
|
||||||
|
addNode.remove();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate position
|
// Calculate new node position to avoid overlapping
|
||||||
const sourceBBox = sourceNode.getBBox();
|
const sourceBBox = sourceNode.getBBox();
|
||||||
const nw = graphNodeLibrary[newNodeType]?.width || 120;
|
const nodeWidth = graphNodeLibrary[selectedNodeType.type]?.width || 120;
|
||||||
const nh = graphNodeLibrary[newNodeType]?.height || 88;
|
const nodeHeight = graphNodeLibrary[selectedNodeType.type]?.height || 88;
|
||||||
const hSpacing = isCycleSubNode ? 48 : 80;
|
const horizontalSpacing = isCycleSubNode ? 48 : 80;
|
||||||
const vSpacing = 10;
|
const verticalSpacing = 10;
|
||||||
|
|
||||||
|
// Get source port group information
|
||||||
const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort);
|
const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort);
|
||||||
const sourcePortGroup = sourcePortInfo?.group || sourcePort;
|
const sourcePortGroup = sourcePortInfo?.group || sourcePort;
|
||||||
|
|
||||||
let newX: number, newY: number;
|
// Calculate new node position
|
||||||
|
let newX, newY;
|
||||||
if (edgeInsertion) {
|
if (edgeInsertion) {
|
||||||
|
// Edge insertion: place new node on the same row as target, between source and target
|
||||||
const targetBBox = edgeInsertion.targetCell.getBBox();
|
const targetBBox = edgeInsertion.targetCell.getBBox();
|
||||||
const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width);
|
const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width);
|
||||||
const requiredSpace = nw + hSpacing * 4;
|
const requiredSpace = nodeWidth + horizontalSpacing * 4;
|
||||||
newX = sourceBBox.x + sourceBBox.width + hSpacing;
|
|
||||||
newY = targetBBox.y + (targetBBox.height - nh) / 2;
|
// New node x: right after source + spacing
|
||||||
|
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
|
||||||
|
// Same row as target node
|
||||||
|
newY = targetBBox.y + (targetBBox.height - nodeHeight) / 2;
|
||||||
|
|
||||||
|
// If not enough space, shift target and all downstream nodes to the right
|
||||||
if (gap < requiredSpace) {
|
if (gap < requiredSpace) {
|
||||||
const shiftX = requiredSpace - gap;
|
const shiftX = requiredSpace - gap;
|
||||||
const visited = new Set<string>();
|
const visited = new Set<string>();
|
||||||
const shiftDownstream = (cell: any) => {
|
const shiftDownstream = (cell: any) => {
|
||||||
if (visited.has(cell.id)) return;
|
const cellId = cell.id;
|
||||||
visited.add(cell.id);
|
if (visited.has(cellId)) return;
|
||||||
|
visited.add(cellId);
|
||||||
const pos = cell.getPosition();
|
const pos = cell.getPosition();
|
||||||
cell.setPosition(pos.x + shiftX, pos.y);
|
cell.setPosition(pos.x + shiftX, pos.y);
|
||||||
|
// Recursively shift nodes connected from right ports
|
||||||
graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => {
|
graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => {
|
||||||
const tCell = graph.getCellById(e.getTargetCellId());
|
const tId = e.getTargetCellId();
|
||||||
if (tCell?.isNode()) shiftDownstream(tCell);
|
if (tId && !visited.has(tId)) {
|
||||||
|
const tCell = graph.getCellById(tId);
|
||||||
|
if (tCell?.isNode()) shiftDownstream(tCell);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
shiftDownstream(edgeInsertion.targetCell);
|
shiftDownstream(edgeInsertion.targetCell);
|
||||||
@@ -96,170 +203,167 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
|||||||
} else if (addNodePosition) {
|
} else if (addNodePosition) {
|
||||||
newX = addNodePosition.x;
|
newX = addNodePosition.x;
|
||||||
newY = addNodePosition.y;
|
newY = addNodePosition.y;
|
||||||
} else if (sourcePortGroup === 'left') {
|
|
||||||
newX = sourceBBox.x - nw * 2 - hSpacing;
|
|
||||||
newY = sourceBBox.y;
|
|
||||||
} else {
|
} else {
|
||||||
newX = sourceBBox.x + sourceBBox.width + hSpacing;
|
// Determine node placement direction based on port position
|
||||||
newY = sourceBBox.y;
|
if (sourcePortGroup === 'left') {
|
||||||
const connectedNodes = new Set<string>();
|
// Left port: add node to the left
|
||||||
graph.getConnectedEdges(sourceNode).forEach((e: any) => {
|
newX = sourceBBox.x - nodeWidth*2 - horizontalSpacing;
|
||||||
[e.getSourceCellId(), e.getTargetCellId()].forEach((cid: string) => {
|
newY = sourceBBox.y;
|
||||||
if (cid !== sourceNode.id) connectedNodes.add(cid);
|
} else {
|
||||||
|
// Right port: add node to the right
|
||||||
|
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
|
||||||
|
newY = sourceBBox.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if position overlaps with existing nodes (only consider connected nodes)
|
||||||
|
const checkOverlap = (x: number, y: number) => {
|
||||||
|
// Get nodes connected to the source node
|
||||||
|
const connectedNodes = new Set();
|
||||||
|
graph.getConnectedEdges(sourceNode).forEach((edge: any) => {
|
||||||
|
const sourceId = edge.getSourceCellId();
|
||||||
|
const targetId = edge.getTargetCellId();
|
||||||
|
if (sourceId !== sourceNode.id) connectedNodes.add(sourceId);
|
||||||
|
if (targetId !== sourceNode.id) connectedNodes.add(targetId);
|
||||||
});
|
});
|
||||||
});
|
|
||||||
const checkOverlap = (x: number, y: number) =>
|
return graph.getNodes().some((node: any) => {
|
||||||
graph.getNodes().some((n: any) => {
|
if (node.id === sourceNode.id) return false;
|
||||||
if (n.id === sourceNode.id || !connectedNodes.has(n.id)) return false;
|
if (!connectedNodes.has(node.id)) return false; // Only consider connected nodes
|
||||||
const b = n.getBBox();
|
const bbox = node.getBBox();
|
||||||
return !(x + nw < b.x || x > b.x + b.width || y + nh < b.y || y > b.y + b.height);
|
return !(x + nodeWidth < bbox.x || x > bbox.x + bbox.width ||
|
||||||
|
y + nodeHeight < bbox.y || y > bbox.y + bbox.height);
|
||||||
});
|
});
|
||||||
while (checkOverlap(newX, newY)) newY += nh + vSpacing;
|
};
|
||||||
|
|
||||||
|
// If position is occupied, search downward for empty space
|
||||||
|
while (checkOverlap(newX, newY)) {
|
||||||
|
newY += nodeHeight + verticalSpacing;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable history for all graph mutations
|
// Create new node
|
||||||
graph.disableHistory();
|
const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||||
|
|
||||||
// Remove add-node placeholder
|
|
||||||
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
|
|
||||||
const cycleId = sourceNodeData.cycle;
|
|
||||||
graph.getNodes()
|
|
||||||
.filter((n: any) => n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId)
|
|
||||||
.forEach((n: any) => n.remove());
|
|
||||||
}
|
|
||||||
|
|
||||||
const id = `${newNodeType.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
|
||||||
const newNode = graph.addNode({
|
const newNode = graph.addNode({
|
||||||
...(graphNodeLibrary[newNodeType] || graphNodeLibrary.default),
|
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
|
||||||
x: newX,
|
x: newX,
|
||||||
y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0),
|
y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0),
|
||||||
id,
|
id,
|
||||||
data: {
|
data: {
|
||||||
id,
|
id,
|
||||||
type: newNodeType,
|
type: selectedNodeType.type,
|
||||||
icon: selectedNodeType.icon,
|
icon: selectedNodeType.icon,
|
||||||
name: t(`workflow.${newNodeType}`),
|
name: t(`workflow.${selectedNodeType.type}`),
|
||||||
cycle: sourceNodeData.cycle,
|
cycle: sourceNodeData.cycle, // Inherit cycle from source node
|
||||||
config: selectedNodeType.config || {}
|
config: selectedNodeType.config || {}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Add new node as child of parent node
|
||||||
if (sourceNodeData.cycle) {
|
if (sourceNodeData.cycle) {
|
||||||
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle);
|
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle);
|
||||||
if (parentNode) parentNode.addChild(newNode, { silent: true });
|
|
||||||
}
|
|
||||||
|
|
||||||
if (edgeInsertion) {
|
|
||||||
const { edge: oldEdge } = edgeInsertion;
|
|
||||||
if (oldEdge.id && graph.getCellById(oldEdge.id)) graph.removeCell(oldEdge.id);
|
|
||||||
else graph.removeEdge(oldEdge);
|
|
||||||
}
|
|
||||||
|
|
||||||
const newPorts = newNode.getPorts();
|
|
||||||
const addedCells: any[] = [newNode];
|
|
||||||
|
|
||||||
if (edgeInsertion) {
|
|
||||||
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
|
|
||||||
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
|
|
||||||
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
|
|
||||||
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: newLeftPort }, ...edgeAttrs }));
|
|
||||||
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: newRightPort }, target: { cell: targetCell.id, port: origTargetPort }, ...edgeAttrs }));
|
|
||||||
setEdgeInsertion(null);
|
|
||||||
} else if (sourcePortGroup === 'left') {
|
|
||||||
const tp = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
|
|
||||||
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: tp }, target: { cell: sourceNode.id, port: sourcePort }, ...edgeAttrs }));
|
|
||||||
} else {
|
|
||||||
const tp = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
|
|
||||||
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: tp }, ...edgeAttrs }));
|
|
||||||
}
|
|
||||||
|
|
||||||
// If adding a loop/iteration node, create cycle-start, add-node and inner edge regardless of source type
|
|
||||||
if (isCycleContainer(newNodeType)) {
|
|
||||||
const parentBBox = newNode.getBBox();
|
|
||||||
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
|
||||||
const cycleStartNode = graph.addNode({
|
|
||||||
...graphNodeLibrary.cycleStart,
|
|
||||||
x: parentBBox.x + 24,
|
|
||||||
y: parentBBox.y + 70,
|
|
||||||
id: cycleStartId,
|
|
||||||
data: { id: cycleStartId, type: 'cycle-start', parentId: id, isDefault: true, cycle: id },
|
|
||||||
});
|
|
||||||
const addNodePlaceholder = graph.addNode({
|
|
||||||
...graphNodeLibrary.addStart,
|
|
||||||
x: parentBBox.x + 24 + 84,
|
|
||||||
y: parentBBox.y + 70 + 4,
|
|
||||||
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: id, cycle: id },
|
|
||||||
});
|
|
||||||
newNode.addChild(cycleStartNode, { silent: true });
|
|
||||||
newNode.addChild(addNodePlaceholder, { silent: true });
|
|
||||||
const innerEdge = graph.addEdge({
|
|
||||||
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find((p: any) => p.group === 'right')?.id || 'right' },
|
|
||||||
target: { cell: addNodePlaceholder.id, port: addNodePlaceholder.getPorts().find((p: any) => p.group === 'left')?.id || 'left' },
|
|
||||||
...edgeAttrs,
|
|
||||||
});
|
|
||||||
addedCells.push(cycleStartNode, addNodePlaceholder, innerEdge);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adjust parent size if adding inside a cycle container
|
|
||||||
const cycleId = sourceNodeData.cycle;
|
|
||||||
if (cycleId) {
|
|
||||||
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
|
||||||
if (parentNode) {
|
if (parentNode) {
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
parentNode.addChild(newNode);
|
||||||
if (childNodes.length > 0) {
|
|
||||||
const bounds = childNodes.reduce((acc: any, child: any) => {
|
|
||||||
const b = child.getBBox();
|
|
||||||
return { minX: Math.min(acc.minX, b.x), minY: Math.min(acc.minY, b.y), maxX: Math.max(acc.maxX, b.x + b.width), maxY: Math.max(acc.maxY, b.y + b.height) };
|
|
||||||
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
|
|
||||||
const padding = 50;
|
|
||||||
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
|
|
||||||
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
|
|
||||||
parentNode.prop('size', { width: newWidth, height: newHeight });
|
|
||||||
parentNode.getPorts().forEach((port: any) => {
|
|
||||||
if (port.group === 'right' && port.args) parentNode.portProp(port.id!, 'args/x', newWidth);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// toFront
|
// Edge insertion: remove old edge immediately before creating new edges
|
||||||
const bringCycleChildrenToFront = (cycleContainerId: string) => {
|
if (edgeInsertion) {
|
||||||
graph.getEdges().forEach((e: any) => {
|
const { edge: oldEdge } = edgeInsertion;
|
||||||
const src = graph.getCellById(e.getSourceCellId());
|
if (oldEdge.id && graph.getCellById(oldEdge.id)) {
|
||||||
const tgt = graph.getCellById(e.getTargetCellId());
|
graph.removeCell(oldEdge.id);
|
||||||
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
|
} else {
|
||||||
});
|
graph.removeEdge(oldEdge);
|
||||||
graph.getNodes().forEach((n: any) => { if (n.getData()?.cycle === cycleContainerId) n.toFront(); });
|
}
|
||||||
};
|
|
||||||
|
|
||||||
if (isCycleContainer(sourceNodeType)) {
|
|
||||||
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(sourceNodeData.id);
|
|
||||||
if (isCycleContainer(newNodeType)) bringCycleChildrenToFront(id);
|
|
||||||
} else if (isCycleContainer(newNodeType)) {
|
|
||||||
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(id);
|
|
||||||
} else {
|
|
||||||
addedCells.forEach(c => { if (c.isNode?.()) c.toFront(); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Re-enable history and manually push one batch frame for all added cells
|
// Create edge connection
|
||||||
graph.enableHistory();
|
setTimeout(() => {
|
||||||
const history = graph.getPlugin('history') as any;
|
const newPorts = newNode.getPorts();
|
||||||
if (history) {
|
|
||||||
const batchFrame = addedCells.map((cell: any) => ({
|
|
||||||
batch: true,
|
|
||||||
event: 'cell:added',
|
|
||||||
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
|
|
||||||
options: {},
|
|
||||||
}));
|
|
||||||
history.undoStack.push(batchFrame);
|
|
||||||
history.redoStack = [];
|
|
||||||
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-node' } });
|
|
||||||
}
|
|
||||||
|
|
||||||
|
const addedEdges: any[] = [];
|
||||||
|
if (edgeInsertion) {
|
||||||
|
// Edge insertion: create source→new and new→target edges
|
||||||
|
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
|
||||||
|
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
|
||||||
|
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
|
||||||
|
addedEdges.push(graph.addEdge({
|
||||||
|
source: { cell: sourceNode.id, port: sourcePort },
|
||||||
|
target: { cell: newNode.id, port: newLeftPort },
|
||||||
|
...edgeAttrs
|
||||||
|
}));
|
||||||
|
addedEdges.push(graph.addEdge({
|
||||||
|
source: { cell: newNode.id, port: newRightPort },
|
||||||
|
target: { cell: targetCell.id, port: origTargetPort },
|
||||||
|
...edgeAttrs
|
||||||
|
}));
|
||||||
|
setEdgeInsertion(null);
|
||||||
|
} else if (sourcePortGroup === 'left') {
|
||||||
|
// Connect from left port to new node's right side
|
||||||
|
const targetPort = newPorts.find((port: any) => port.group === 'right')?.id || 'right';
|
||||||
|
addedEdges.push(graph.addEdge({
|
||||||
|
source: { cell: newNode.id, port: targetPort },
|
||||||
|
target: { cell: sourceNode.id, port: sourcePort },
|
||||||
|
...edgeAttrs
|
||||||
|
}));
|
||||||
|
} else {
|
||||||
|
// Connect from right port to new node's left side
|
||||||
|
const targetPort = newPorts.find((port: any) => port.group === 'left')?.id || 'left';
|
||||||
|
addedEdges.push(graph.addEdge({
|
||||||
|
source: { cell: sourceNode.id, port: sourcePort },
|
||||||
|
target: { cell: newNode.id, port: targetPort },
|
||||||
|
...edgeAttrs
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust loop node size when child node is added via port within loop node
|
||||||
|
const cycleId = sourceNodeData.cycle;
|
||||||
|
if (cycleId) adjustCycleContainerSize(graph, cycleId);
|
||||||
|
|
||||||
|
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
|
||||||
|
const newNodeType = selectedNodeType.type;
|
||||||
|
|
||||||
|
// Helper: bring all child nodes and their edges of a cycle container to front
|
||||||
|
const bringCycleChildrenToFront = (cycleContainerId: string) => {
|
||||||
|
|
||||||
|
graph.getEdges().forEach((e: any) => {
|
||||||
|
const src = graph.getCellById(e.getSourceCellId());
|
||||||
|
const tgt = graph.getCellById(e.getTargetCellId());
|
||||||
|
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
|
||||||
|
});
|
||||||
|
graph.getNodes().forEach((n: any) => {
|
||||||
|
if (n.getData()?.cycle === cycleContainerId) n.toFront();
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isCycleContainer(sourceNodeType)) {
|
||||||
|
console.log('isCycleContainer(sourceNodeType)')
|
||||||
|
// Case 4: source is a loop/iteration node — bring new node to front, then its children
|
||||||
|
newNode.toFront();
|
||||||
|
sourceNode.toFront();
|
||||||
|
bringCycleChildrenToFront(sourceNodeData.id);
|
||||||
|
} else if (isCycleContainer(newNodeType)) {
|
||||||
|
console.log('isCycleContainer(newNodeType)')
|
||||||
|
// Case 3: adding a loop/iteration node from a normal node — bring new node to front, then its children
|
||||||
|
newNode.toFront();
|
||||||
|
sourceNode.toFront()
|
||||||
|
bringCycleChildrenToFront(id);
|
||||||
|
} else {
|
||||||
|
// Case 2: normal node → normal node
|
||||||
|
addedEdges.forEach(e => {
|
||||||
|
const src = graph.getCellById(e.getSourceCellId());
|
||||||
|
const tgt = graph.getCellById(e.getTargetCellId());
|
||||||
|
if (src?.isNode()) src.toFront();
|
||||||
|
if (tgt?.isNode()) tgt.toFront();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, 50);
|
||||||
|
|
||||||
|
// Clean up temporary element
|
||||||
if (tempElement) {
|
if (tempElement) {
|
||||||
document.body.removeChild(tempElement);
|
document.body.removeChild(tempElement);
|
||||||
setTempElement(null);
|
setTempElement(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
setPopoverVisible(false);
|
setPopoverVisible(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -316,23 +420,19 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
|||||||
|
|
||||||
if (!tempElement) return null;
|
if (!tempElement) return null;
|
||||||
|
|
||||||
return (
|
return createPortal(
|
||||||
<Popover
|
<Popover
|
||||||
content={content}
|
content={content}
|
||||||
open={popoverVisible}
|
open={popoverVisible}
|
||||||
onOpenChange={(visible) => {
|
onOpenChange={(visible) => { if (!visible) handlePopoverClose(); }}
|
||||||
if (!visible) handlePopoverClose();
|
|
||||||
}}
|
|
||||||
placement="right"
|
placement="right"
|
||||||
overlayStyle={{
|
autoAdjustOverflow
|
||||||
position: 'fixed',
|
getPopupContainer={() => document.body}
|
||||||
left: popoverPosition.x + 10,
|
|
||||||
top: popoverPosition.y - 10,
|
|
||||||
}}
|
|
||||||
>
|
>
|
||||||
<div />
|
<div style={{ width: '1px', height: '1px' }} />
|
||||||
</Popover>
|
</Popover>,
|
||||||
|
tempElement
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default PortClickHandler;
|
export default PortClickHandler;
|
||||||
@@ -242,11 +242,10 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
|
|||||||
className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''}
|
className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''}
|
||||||
>
|
>
|
||||||
{parameter.type === 'string' && parameter.enum && parameter.enum.length > 0
|
{parameter.type === 'string' && parameter.enum && parameter.enum.length > 0
|
||||||
? <Select key={values.tool_id} size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
|
? <Select size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
|
||||||
: parameter.type === 'boolean'
|
: parameter.type === 'boolean'
|
||||||
? <Switch key={values.tool_id} size="small" />
|
? <Switch size="small" />
|
||||||
: <Editor
|
: <Editor
|
||||||
key={values.tool_id}
|
|
||||||
variant="outlined"
|
variant="outlined"
|
||||||
type="input"
|
type="input"
|
||||||
size="small"
|
size="small"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 15:06:18
|
* @Date: 2026-02-03 15:06:18
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-27 14:07:14
|
* @Last Modified time: 2026-04-21 18:23:31
|
||||||
*/
|
*/
|
||||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||||
import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port';
|
import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port';
|
||||||
@@ -948,15 +948,6 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
|
|||||||
width: nodeWidth,
|
width: nodeWidth,
|
||||||
height: 120,
|
height: 120,
|
||||||
shape: 'notes-node',
|
shape: 'notes-node',
|
||||||
},
|
|
||||||
output: {
|
|
||||||
width: nodeWidth,
|
|
||||||
height: 76,
|
|
||||||
shape: 'normal-node',
|
|
||||||
ports: {
|
|
||||||
groups: { left: defaultPortGroup },
|
|
||||||
items: [defaultPortItems[0]],
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,16 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 15:17:48
|
* @Date: 2026-02-03 15:17:48
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-28 13:49:11
|
* @Last Modified time: 2026-04-28 12:07:33
|
||||||
*/
|
*/
|
||||||
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
|
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
|
||||||
import { register } from '@antv/x6-react-shape';
|
import type { HistoryCommand as Command } from '@antv/x6/lib/plugin/history/type';
|
||||||
|
import { register as registerReactShape } from '@antv/x6-react-shape';
|
||||||
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
||||||
import { App } from 'antd';
|
import { App } from 'antd';
|
||||||
import { useEffect, useRef, useState } from 'react';
|
import { useEffect, useRef, useState, createElement } from 'react';
|
||||||
|
import type { RefObject, Dispatch, SetStateAction, MutableRefObject, DragEvent } from 'react';
|
||||||
|
import { createRoot } from 'react-dom/client';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useParams } from 'react-router-dom';
|
import { useParams } from 'react-router-dom';
|
||||||
|
|
||||||
@@ -16,18 +19,20 @@ import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application';
|
|||||||
import { useUser } from '@/store/user';
|
import { useUser } from '@/store/user';
|
||||||
import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types';
|
import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types';
|
||||||
import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant';
|
import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant';
|
||||||
import type { ChatVariable, HistoryRecord, NodeProperties, WorkflowConfig } from '../types';
|
import type { ChatVariable, NodeProperties, WorkflowConfig } from '../types';
|
||||||
import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils';
|
import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils';
|
||||||
import { useWorkflowStore } from '@/store/workflow';
|
import { useWorkflowStore } from '@/store/workflow';
|
||||||
|
|
||||||
|
const isSafari = /^((?!chrome|android).)*safari/i.test(navigator.userAgent);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Props for useWorkflowGraph hook
|
* Props for useWorkflowGraph hook
|
||||||
*/
|
*/
|
||||||
export interface UseWorkflowGraphProps {
|
export interface UseWorkflowGraphProps {
|
||||||
/** Reference to the main graph container element */
|
/** Reference to the main graph container element */
|
||||||
containerRef: React.RefObject<HTMLDivElement>;
|
containerRef: RefObject<HTMLDivElement>;
|
||||||
/** Reference to the minimap container element */
|
/** Reference to the minimap container element */
|
||||||
miniMapRef: React.RefObject<HTMLDivElement>;
|
miniMapRef: RefObject<HTMLDivElement>;
|
||||||
/** Callback when features config is loaded */
|
/** Callback when features config is loaded */
|
||||||
onFeaturesLoad?: (features: FeaturesConfigForm | undefined) => void;
|
onFeaturesLoad?: (features: FeaturesConfigForm | undefined) => void;
|
||||||
}
|
}
|
||||||
@@ -39,23 +44,23 @@ export interface UseWorkflowGraphReturn {
|
|||||||
/** Current workflow configuration */
|
/** Current workflow configuration */
|
||||||
config: WorkflowConfig | null;
|
config: WorkflowConfig | null;
|
||||||
/** Function to update workflow configuration */
|
/** Function to update workflow configuration */
|
||||||
setConfig: React.Dispatch<React.SetStateAction<WorkflowConfig | null>>;
|
setConfig: Dispatch<SetStateAction<WorkflowConfig | null>>;
|
||||||
/** Reference to the X6 graph instance */
|
/** Reference to the X6 graph instance */
|
||||||
graphRef: React.MutableRefObject<Graph | undefined>;
|
graphRef: MutableRefObject<Graph | undefined>;
|
||||||
/** Currently selected node */
|
/** Currently selected node */
|
||||||
selectedNode: Node | null;
|
selectedNode: Node | null;
|
||||||
/** Function to update selected node */
|
/** Function to update selected node */
|
||||||
setSelectedNode: React.Dispatch<React.SetStateAction<Node | null>>;
|
setSelectedNode: Dispatch<SetStateAction<Node | null>>;
|
||||||
/** Current zoom level of the graph */
|
/** Current zoom level of the graph */
|
||||||
zoomLevel: number;
|
zoomLevel: number;
|
||||||
/** Function to update zoom level */
|
/** Function to update zoom level */
|
||||||
setZoomLevel: React.Dispatch<React.SetStateAction<number>>;
|
setZoomLevel: Dispatch<SetStateAction<number>>;
|
||||||
/** Whether hand/pan mode is enabled */
|
/** Whether hand/pan mode is enabled */
|
||||||
isHandMode: boolean;
|
isHandMode: boolean;
|
||||||
/** Function to toggle hand mode */
|
/** Function to toggle hand mode */
|
||||||
setIsHandMode: React.Dispatch<React.SetStateAction<boolean>>;
|
setIsHandMode: Dispatch<SetStateAction<boolean>>;
|
||||||
/** Handler for dropping nodes onto canvas */
|
/** Handler for dropping nodes onto canvas */
|
||||||
onDrop: (event: React.DragEvent) => void;
|
onDrop: (event: DragEvent) => void;
|
||||||
/** Handler for clicking blank canvas area */
|
/** Handler for clicking blank canvas area */
|
||||||
blankClick: () => void;
|
blankClick: () => void;
|
||||||
/** Handler for delete keyboard event */
|
/** Handler for delete keyboard event */
|
||||||
@@ -77,7 +82,7 @@ export interface UseWorkflowGraphReturn {
|
|||||||
/** Chat variables for workflow */
|
/** Chat variables for workflow */
|
||||||
chatVariables: ChatVariable[];
|
chatVariables: ChatVariable[];
|
||||||
/** Function to update chat variables */
|
/** Function to update chat variables */
|
||||||
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>;
|
setChatVariables: Dispatch<SetStateAction<ChatVariable[]>>;
|
||||||
|
|
||||||
handleAddNotes: () => void;
|
handleAddNotes: () => void;
|
||||||
handleSaveFeaturesConfig: (value: FeaturesConfigForm) => void;
|
handleSaveFeaturesConfig: (value: FeaturesConfigForm) => void;
|
||||||
@@ -85,10 +90,6 @@ export interface UseWorkflowGraphReturn {
|
|||||||
/** Get start node output variable list (user-defined + system variables) */
|
/** Get start node output variable list (user-defined + system variables) */
|
||||||
getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>;
|
getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>;
|
||||||
nodeClick: ({ node }: { node: Node }) => void;
|
nodeClick: ({ node }: { node: Node }) => void;
|
||||||
/** All recorded history operations */
|
|
||||||
historyRecords: HistoryRecord[];
|
|
||||||
/** Clear history records */
|
|
||||||
clearHistoryRecords: () => void;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -122,19 +123,14 @@ export const useWorkflowGraph = ({
|
|||||||
const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined)
|
const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined)
|
||||||
const [canUndo, setCanUndo] = useState(false)
|
const [canUndo, setCanUndo] = useState(false)
|
||||||
const [canRedo, setCanRedo] = useState(false)
|
const [canRedo, setCanRedo] = useState(false)
|
||||||
const [historyRecords, setHistoryRecords] = useState<HistoryRecord[]>([])
|
|
||||||
const lastHistoryRef = useRef<{ cellIds: string[]; timestamp: number; type: string } | null>(null)
|
|
||||||
const undoRef = useRef<() => void>(() => {})
|
|
||||||
const redoRef = useRef<() => void>(() => {})
|
|
||||||
const syncChildRelationshipsRef = useRef<() => void>(() => {})
|
|
||||||
const isSyncingRef = useRef(false)
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!graphRef.current) return
|
if (!graphRef.current) return
|
||||||
graphRef.current.getNodes().forEach(node => {
|
graphRef.current.getNodes().forEach(node => {
|
||||||
const data = node.getData()
|
const data = node.getData()
|
||||||
if (data?.type === 'if-else' || data?.type === 'question-classifier') {
|
if (data?.type === 'if-else' || data?.type === 'question-classifier') {
|
||||||
console.log('chatVariables', chatVariables)
|
console.log('chatVariables', chatVariables)
|
||||||
node.setData({ ...data, chatVariables })
|
node.setData({ ...data, chatVariables }, { silent: true })
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}, [chatVariables])
|
}, [chatVariables])
|
||||||
@@ -168,6 +164,21 @@ export const useWorkflowGraph = ({
|
|||||||
initWorkflow()
|
initWorkflow()
|
||||||
}, [config, graphRef.current])
|
}, [config, graphRef.current])
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Assign explicit zIndex values to enforce layer order:
|
||||||
|
* parent nodes (loop/iteration) → child edges → child nodes
|
||||||
|
* Ports live inside each node's SVG container and are always above
|
||||||
|
* edges once the node zIndex is higher than the edge zIndex.
|
||||||
|
*/
|
||||||
|
const reorderCells = (graph: Graph) => {
|
||||||
|
// Safari uses x6-html-shape (dual HTML layer architecture).
|
||||||
|
// zIndex controls order within each HTML layer and SVG layer.
|
||||||
|
graph.getEdges().forEach(edge => edge.setZIndex(0));
|
||||||
|
graph.getNodes().forEach(node => {
|
||||||
|
node.setZIndex(node.getData()?.cycle ? 2 : 1);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize workflow graph with nodes and edges from configuration
|
* Initialize workflow graph with nodes and edges from configuration
|
||||||
*/
|
*/
|
||||||
@@ -351,7 +362,7 @@ export const useWorkflowGraph = ({
|
|||||||
if (parentNode) {
|
if (parentNode) {
|
||||||
const addedChild = graphRef.current?.addNode(childNode)
|
const addedChild = graphRef.current?.addNode(childNode)
|
||||||
if (addedChild) {
|
if (addedChild) {
|
||||||
parentNode.addChild(addedChild, { silent: true })
|
parentNode.addChild(addedChild)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -382,6 +393,8 @@ export const useWorkflowGraph = ({
|
|||||||
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
|
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
|
||||||
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
|
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
|
||||||
|
|
||||||
|
console.log('newWidth', newHeight, newWidth)
|
||||||
|
|
||||||
parentNode.prop('size', { width: newWidth, height: newHeight })
|
parentNode.prop('size', { width: newWidth, height: newHeight })
|
||||||
|
|
||||||
// Update x position of right group ports
|
// Update x position of right group ports
|
||||||
@@ -476,95 +489,30 @@ export const useWorkflowGraph = ({
|
|||||||
if (nodes.length > 0 || edges.length > 0) {
|
if (nodes.length > 0 || edges.length > 0) {
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
if (graphRef.current) {
|
if (graphRef.current) {
|
||||||
graphRef.current.getNodes().forEach(node => {
|
if (isSafari) {
|
||||||
if (!node.getData()?.cycle) node.toFront();
|
reorderCells(graphRef.current)
|
||||||
});
|
} else {
|
||||||
// Bring edges to front first, then child nodes above edges; parent nodes stay behind
|
graphRef.current.getNodes().forEach(node => {
|
||||||
graphRef.current.getEdges().forEach(edge => {
|
if (!node.getData()?.cycle) node.toFront();
|
||||||
const sourceCell = graphRef.current?.getCellById(edge.getSourceCellId());
|
});
|
||||||
const targetCell = graphRef.current?.getCellById(edge.getTargetCellId());
|
// Bring edges to front first, then child nodes above edges; parent nodes stay behind
|
||||||
if (sourceCell?.getData()?.cycle || targetCell?.getData()?.cycle) {
|
graphRef.current.getEdges().forEach(edge => {
|
||||||
edge.toFront();
|
const sourceCell = graphRef.current?.getCellById(edge.getSourceCellId());
|
||||||
}
|
const targetCell = graphRef.current?.getCellById(edge.getTargetCellId());
|
||||||
});
|
if (sourceCell?.getData()?.cycle || targetCell?.getData()?.cycle) {
|
||||||
graphRef.current.getNodes().forEach(node => {
|
edge.toFront();
|
||||||
if (node.getData()?.cycle) node.toFront();
|
}
|
||||||
});
|
});
|
||||||
|
graphRef.current.getNodes().forEach(node => {
|
||||||
|
if (node.getData()?.cycle) node.toFront();
|
||||||
|
});
|
||||||
|
}
|
||||||
graphRef.current.enableHistory()
|
graphRef.current.enableHistory()
|
||||||
graphRef.current.cleanHistory()
|
graphRef.current.cleanHistory()
|
||||||
}
|
}
|
||||||
}, 200)
|
}, isSafari ? 0 : 200)
|
||||||
} else {
|
|
||||||
graphRef.current.enableHistory()
|
|
||||||
graphRef.current.cleanHistory()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const resizeGroupNodes = (graph: Graph) => {
|
|
||||||
graph.getNodes().forEach(parentNode => {
|
|
||||||
const parentType = parentNode.getData()?.type
|
|
||||||
if (parentType !== 'loop' && parentType !== 'iteration') return
|
|
||||||
const children = graph.getNodes().filter(
|
|
||||||
n => n.getData()?.cycle === parentNode.getData()?.id && n.getData()?.type !== 'add-node'
|
|
||||||
)
|
|
||||||
if (!children.length) return
|
|
||||||
const padding = 24
|
|
||||||
const headerHeight = 50
|
|
||||||
const childBounds = children.map(c => c.getBBox())
|
|
||||||
const minX = Math.min(...childBounds.map(b => b.x))
|
|
||||||
const minY = Math.min(...childBounds.map(b => b.y))
|
|
||||||
const maxX = Math.max(...childBounds.map(b => b.x + b.width))
|
|
||||||
const maxY = Math.max(...childBounds.map(b => b.y + b.height))
|
|
||||||
const parentBBox = parentNode.getBBox()
|
|
||||||
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
|
|
||||||
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
|
|
||||||
parentNode.prop('size', { width: newWidth, height: newHeight })
|
|
||||||
parentNode.getPorts().forEach(port => {
|
|
||||||
if (port.group === 'right' && port.args) {
|
|
||||||
parentNode.portProp(port.id!, 'args/x', newWidth)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const syncChildRelationships = () => {
|
|
||||||
if (!graphRef.current) return
|
|
||||||
const graph = graphRef.current
|
|
||||||
graph.disableHistory()
|
|
||||||
graph.getNodes().forEach(node => {
|
|
||||||
const cycleId = node.getData()?.cycle
|
|
||||||
if (!cycleId) return
|
|
||||||
const parentNode = graph.getCellById(cycleId) as Node | null
|
|
||||||
if (!parentNode) return
|
|
||||||
if (!parentNode.getChildren()?.some(c => c.id === node.id)) {
|
|
||||||
parentNode.addChild(node, { silent: true })
|
|
||||||
}
|
|
||||||
})
|
|
||||||
graph.getNodes().forEach(node => {
|
|
||||||
const children = node.getChildren()
|
|
||||||
if (!children?.length) return
|
|
||||||
children.forEach(child => {
|
|
||||||
if (!child.isNode()) return
|
|
||||||
const childCycleId = (child as Node).getData?.()?.cycle
|
|
||||||
if (childCycleId !== node.id && childCycleId !== node.getData?.()?.id) {
|
|
||||||
node.removeChild(child, { silent: true })
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
resizeGroupNodes(graph)
|
|
||||||
graph.getEdges().forEach(edge => {
|
|
||||||
const src = graph.getCellById(edge.getSourceCellId())
|
|
||||||
const tgt = graph.getCellById(edge.getTargetCellId())
|
|
||||||
if (src?.getData()?.cycle || tgt?.getData()?.cycle) {
|
|
||||||
edge.toFront()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
graph.getNodes().forEach(node => {
|
|
||||||
if (node.getData()?.cycle) node.toFront()
|
|
||||||
})
|
|
||||||
graph.enableHistory()
|
|
||||||
}
|
|
||||||
syncChildRelationshipsRef.current = syncChildRelationships
|
|
||||||
/**
|
/**
|
||||||
* Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard)
|
* Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard)
|
||||||
*/
|
*/
|
||||||
@@ -600,44 +548,18 @@ export const useWorkflowGraph = ({
|
|||||||
new History({
|
new History({
|
||||||
enabled: false,
|
enabled: false,
|
||||||
beforeAddCommand(_event, args: any) {
|
beforeAddCommand(_event, args: any) {
|
||||||
const key = args?.key
|
const event = args?.key ? `cell:change:${args.key}` : _event;
|
||||||
if (key === 'attrs' || key === 'tools') return false
|
if (event.startsWith('cell:change:') &&
|
||||||
|
event !== 'cell:change:position' &&
|
||||||
|
event !== 'cell:change:source' &&
|
||||||
|
event !== 'cell:change:target') return false;
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
const MERGE_INTERVAL = 1000
|
graphRef.current.on('history:change', ({ cmds }: { cmds: Command[] }) => {
|
||||||
graphRef.current.on('history:change', ({ cmds, options }: { cmds: any[]; options: any }) => {
|
|
||||||
setCanUndo(graphRef.current?.canUndo() ?? false)
|
setCanUndo(graphRef.current?.canUndo() ?? false)
|
||||||
setCanRedo(graphRef.current?.canRedo() ?? false)
|
setCanRedo(graphRef.current?.canRedo() ?? false)
|
||||||
console.log('history:change', cmds, options)
|
|
||||||
const batchName: string | undefined = options?.name
|
|
||||||
const actionType = batchName === 'undo' ? 'undo' : batchName === 'redo' ? 'redo' : batchName ? 'batch' : 'change'
|
|
||||||
const cellIds = [...new Set(cmds?.map((cmd: any) => cmd.data?.id).filter(Boolean))]
|
|
||||||
const now = Date.now()
|
|
||||||
const last = lastHistoryRef.current
|
|
||||||
const canMerge =
|
|
||||||
actionType === 'change' &&
|
|
||||||
last?.type === 'change' &&
|
|
||||||
now - last.timestamp < MERGE_INTERVAL &&
|
|
||||||
cellIds.length > 0 &&
|
|
||||||
cellIds.length === last.cellIds.length &&
|
|
||||||
cellIds.every((id, i) => id === last.cellIds[i])
|
|
||||||
if (canMerge) {
|
|
||||||
lastHistoryRef.current!.timestamp = now
|
|
||||||
setHistoryRecords(prev => {
|
|
||||||
const next = [...prev]
|
|
||||||
next[next.length - 1] = { ...next[next.length - 1], timestamp: now }
|
|
||||||
return next
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
const record: HistoryRecord = { type: actionType, timestamp: now, batchName, cellIds }
|
|
||||||
lastHistoryRef.current = { cellIds, timestamp: now, type: actionType }
|
|
||||||
setHistoryRecords(prev => [...prev, record])
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
graphRef.current.on('history:undo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
|
|
||||||
graphRef.current.on('history:redo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
|
|
||||||
};
|
};
|
||||||
// 显示/隐藏连接桩
|
// 显示/隐藏连接桩
|
||||||
// const showPorts = (show: boolean) => {
|
// const showPorts = (show: boolean) => {
|
||||||
@@ -652,12 +574,33 @@ export const useWorkflowGraph = ({
|
|||||||
* @param node - Clicked node
|
* @param node - Clicked node
|
||||||
*/
|
*/
|
||||||
const nodeClick = ({ node }: { node: Node }) => {
|
const nodeClick = ({ node }: { node: Node }) => {
|
||||||
|
// add-node type: dispatch port:click to open node selection popover
|
||||||
|
// Must handle before blankClick() to avoid blank:click closing the popover immediately
|
||||||
|
const nodeData = node.getData()
|
||||||
|
if (nodeData?.type === 'add-node') {
|
||||||
|
const bbox = node.getBBox();
|
||||||
|
const screenPos = graphRef.current!.localToClient(bbox.x + bbox.width, bbox.y + bbox.height / 2);
|
||||||
|
const tempDiv = document.createElement('div');
|
||||||
|
tempDiv.style.cssText = `position:fixed;left:${screenPos.x}px;top:${screenPos.y}px;width:1px;height:1px;z-index:9999;`;
|
||||||
|
document.body.appendChild(tempDiv);
|
||||||
|
window.dispatchEvent(new CustomEvent('port:click', {
|
||||||
|
detail: {
|
||||||
|
node,
|
||||||
|
port: 'right',
|
||||||
|
element: tempDiv,
|
||||||
|
rect: { left: screenPos.x, top: screenPos.y },
|
||||||
|
edgeInsertion: null,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
blankClick()
|
blankClick()
|
||||||
|
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
// Ignore add-node type node clicks
|
// Ignore add-node type node clicks
|
||||||
const nodeData = node.getData()
|
const nodeData = node.getData()
|
||||||
if (nodeData?.type === 'add-node' || nodeData.type === 'break' || nodeData.type === 'cycle-start') {
|
if (nodeData.type === 'break' || nodeData.type === 'cycle-start') {
|
||||||
setSelectedNode(null)
|
setSelectedNode(null)
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -670,13 +613,13 @@ export const useWorkflowGraph = ({
|
|||||||
vo.setData({
|
vo.setData({
|
||||||
...data,
|
...data,
|
||||||
isSelected: false,
|
isSelected: false,
|
||||||
}, { silent: true });
|
});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
node.setData({
|
node.setData({
|
||||||
...nodeData,
|
...nodeData,
|
||||||
isSelected: true,
|
isSelected: true,
|
||||||
}, { silent: true });
|
});
|
||||||
clearEdgeSelect()
|
clearEdgeSelect()
|
||||||
if (nodeData.type !== 'notes') {
|
if (nodeData.type !== 'notes') {
|
||||||
setSelectedNode(node);
|
setSelectedNode(node);
|
||||||
@@ -690,7 +633,7 @@ export const useWorkflowGraph = ({
|
|||||||
const edgeClick = ({ edge }: { edge: Edge }) => {
|
const edgeClick = ({ edge }: { edge: Edge }) => {
|
||||||
clearEdgeSelect();
|
clearEdgeSelect();
|
||||||
edge.setAttrByPath('line/stroke', edge_selected_color);
|
edge.setAttrByPath('line/stroke', edge_selected_color);
|
||||||
edge.setData({ ...edge.getData(), isSelected: true }, { silent: true });
|
edge.setData({ ...edge.getData(), isSelected: true });
|
||||||
clearNodeSelect();
|
clearNodeSelect();
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
@@ -705,7 +648,7 @@ export const useWorkflowGraph = ({
|
|||||||
node.setData({
|
node.setData({
|
||||||
...data,
|
...data,
|
||||||
isSelected: false,
|
isSelected: false,
|
||||||
}, { silent: true });
|
});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
setSelectedNode(null);
|
setSelectedNode(null);
|
||||||
@@ -715,7 +658,7 @@ export const useWorkflowGraph = ({
|
|||||||
*/
|
*/
|
||||||
const clearEdgeSelect = () => {
|
const clearEdgeSelect = () => {
|
||||||
graphRef.current?.getEdges().forEach(e => {
|
graphRef.current?.getEdges().forEach(e => {
|
||||||
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false }, { silent: true });
|
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false });
|
||||||
e.setAttrByPath('line/stroke', edge_color);
|
e.setAttrByPath('line/stroke', edge_color);
|
||||||
e.setAttrByPath('line/strokeWidth', edge_width);
|
e.setAttrByPath('line/strokeWidth', edge_width);
|
||||||
});
|
});
|
||||||
@@ -745,7 +688,8 @@ export const useWorkflowGraph = ({
|
|||||||
const cycle = node.getData()?.cycle;
|
const cycle = node.getData()?.cycle;
|
||||||
if (cycle) {
|
if (cycle) {
|
||||||
const parentNode = graphRef.current!.getNodes().find(n => n.id === cycle);
|
const parentNode = graphRef.current!.getNodes().find(n => n.id === cycle);
|
||||||
if (parentNode?.getData()?.isGroup) {
|
const parentType = parentNode?.getData()?.type;
|
||||||
|
if (parentNode && (parentType === 'loop' || parentType === 'iteration')) {
|
||||||
// Get parent node and child node bounding boxes
|
// Get parent node and child node bounding boxes
|
||||||
const parentBBox = parentNode.getBBox();
|
const parentBBox = parentNode.getBBox();
|
||||||
const childBBox = node.getBBox();
|
const childBBox = node.getBBox();
|
||||||
@@ -854,6 +798,8 @@ export const useWorkflowGraph = ({
|
|||||||
// Find corresponding parent node
|
// Find corresponding parent node
|
||||||
const parentNode = nodes?.find(n => n.id === nodeData.cycle);
|
const parentNode = nodes?.find(n => n.id === nodeData.cycle);
|
||||||
if (parentNode) {
|
if (parentNode) {
|
||||||
|
// Use removeChild method to delete child node
|
||||||
|
parentNode.removeChild(nodeToDelete);
|
||||||
parentNodesToUpdate.push(parentNode);
|
parentNodesToUpdate.push(parentNode);
|
||||||
}
|
}
|
||||||
// Add child node to deletion list
|
// Add child node to deletion list
|
||||||
@@ -881,51 +827,42 @@ export const useWorkflowGraph = ({
|
|||||||
|
|
||||||
// Delete all collected nodes and edges
|
// Delete all collected nodes and edges
|
||||||
if (cells.length > 0) {
|
if (cells.length > 0) {
|
||||||
// Pre-calculate which parents need an add-node restored (before removal changes the graph)
|
|
||||||
const parentsNeedingAddNode = parentNodesToUpdate
|
|
||||||
.filter(parentNode => {
|
|
||||||
const parentShape = parentNode.shape;
|
|
||||||
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return false;
|
|
||||||
const parentData = parentNode.getData();
|
|
||||||
const allChildren = graphRef.current!.getNodes().filter(n => n.getData()?.cycle === parentData.id);
|
|
||||||
const cycleStartNodes = allChildren.filter(n => n.getData()?.type === 'cycle-start');
|
|
||||||
// After deletion, only cycle-start will remain
|
|
||||||
const nonCycleStartToDelete = cells.filter(c =>
|
|
||||||
c.isNode() &&
|
|
||||||
(c as Node).getData()?.cycle === parentData.id &&
|
|
||||||
(c as Node).getData()?.type !== 'cycle-start'
|
|
||||||
);
|
|
||||||
return cycleStartNodes.length === 1 && (allChildren.length - nonCycleStartToDelete.length) === 1;
|
|
||||||
})
|
|
||||||
.map(parentNode => ({
|
|
||||||
parentNode,
|
|
||||||
cycleStartNode: graphRef.current!.getNodes().find(
|
|
||||||
n => n.getData()?.cycle === parentNode.getData().id && n.getData()?.type === 'cycle-start'
|
|
||||||
)!
|
|
||||||
}))
|
|
||||||
.filter(({ cycleStartNode }) => !!cycleStartNode);
|
|
||||||
|
|
||||||
graphRef.current?.startBatch('delete');
|
|
||||||
graphRef.current?.removeCells(cells);
|
graphRef.current?.removeCells(cells);
|
||||||
|
|
||||||
parentsNeedingAddNode.forEach(({ parentNode, cycleStartNode }) => {
|
// If parent is iteration/loop and only cycle-start remains, add add-node connected to it
|
||||||
|
parentNodesToUpdate.forEach(parentNode => {
|
||||||
|
const parentShape = parentNode.shape;
|
||||||
|
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return;
|
||||||
const parentData = parentNode.getData();
|
const parentData = parentNode.getData();
|
||||||
const bbox = cycleStartNode.getBBox();
|
const remainingChildren = graphRef.current!.getNodes().filter(
|
||||||
const addNode = graphRef.current!.addNode({
|
n => n.getData()?.cycle === parentData.id
|
||||||
...graphNodeLibrary.addStart,
|
);
|
||||||
x: bbox.x + 84,
|
const cycleStartNodes = remainingChildren.filter(n => n.getData()?.type === 'cycle-start');
|
||||||
y: bbox.y + 4,
|
if (cycleStartNodes.length === 1 && remainingChildren.length === 1) {
|
||||||
data: { type: 'add-node', parentId: parentNode.id, cycle: parentData.id, label: t('workflow.addNode'), icon: '+' },
|
const cycleStartNode = cycleStartNodes[0];
|
||||||
});
|
const bbox = cycleStartNode.getBBox();
|
||||||
parentNode.addChild(addNode, { silent: true });
|
const addNode = graphRef.current!.addNode({
|
||||||
graphRef.current!.addEdge({
|
...graphNodeLibrary.addStart,
|
||||||
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
|
x: bbox.x + 84,
|
||||||
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
|
y: bbox.y + 4,
|
||||||
...edgeAttrs,
|
data: {
|
||||||
});
|
type: 'add-node',
|
||||||
|
parentId: parentNode.id,
|
||||||
|
cycle: parentData.id,
|
||||||
|
label: t('workflow.addNode'),
|
||||||
|
icon: '+',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
parentNode.addChild(addNode);
|
||||||
|
const sourcePort = cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right';
|
||||||
|
const targetPort = addNode.getPorts().find(p => p.group === 'left')?.id || 'left';
|
||||||
|
graphRef.current!.addEdge({
|
||||||
|
source: { cell: cycleStartNode.id, port: sourcePort },
|
||||||
|
target: { cell: addNode.id, port: targetPort },
|
||||||
|
...edgeAttrs,
|
||||||
|
});
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
graphRef.current?.stopBatch('delete');
|
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
@@ -965,13 +902,35 @@ export const useWorkflowGraph = ({
|
|||||||
/**
|
/**
|
||||||
* Initialize X6 graph with configuration and event listeners
|
* Initialize X6 graph with configuration and event listeners
|
||||||
*/
|
*/
|
||||||
const init = () => {
|
const init = async () => {
|
||||||
if (!containerRef.current || !miniMapRef.current) return;
|
if (!containerRef.current || !miniMapRef.current) return;
|
||||||
|
|
||||||
// Register React shapes
|
// Register React shapes
|
||||||
nodeRegisterLibrary.forEach((item) => {
|
// Safari: use x6-html-shape to avoid foreignObject rendering issues
|
||||||
register(item);
|
if (isSafari) {
|
||||||
});
|
const { register: registerHtmlShape } = await import('x6-html-shape');
|
||||||
|
nodeRegisterLibrary.forEach(({ shape, width, height, component }) => {
|
||||||
|
registerHtmlShape({
|
||||||
|
shape,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
render(node: Node, _graph: unknown, container: HTMLElement) {
|
||||||
|
const root = createRoot(container);
|
||||||
|
const doRender = () => {
|
||||||
|
root.render(createElement(component as any, { node, graph: node.model?.graph, data: node.getData() }));
|
||||||
|
};
|
||||||
|
doRender();
|
||||||
|
node.on('change:data', doRender);
|
||||||
|
return () => {
|
||||||
|
node.off('change:data', doRender);
|
||||||
|
root.unmount();
|
||||||
|
};
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
nodeRegisterLibrary.forEach((item) => registerReactShape(item));
|
||||||
|
}
|
||||||
|
|
||||||
const container = containerRef.current;
|
const container = containerRef.current;
|
||||||
graphRef.current = new Graph({
|
graphRef.current = new Graph({
|
||||||
@@ -1144,7 +1103,7 @@ export const useWorkflowGraph = ({
|
|||||||
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
||||||
if (!edge.getData()?.isSelected) {
|
if (!edge.getData()?.isSelected) {
|
||||||
edge.setAttrByPath('line/stroke', edge_selected_color);
|
edge.setAttrByPath('line/stroke', edge_selected_color);
|
||||||
edge.setData({ ...edge.getData(), isNodeHover: true }, { silent: true });
|
edge.setData({ ...edge.getData(), isNodeHover: true });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -1152,7 +1111,7 @@ export const useWorkflowGraph = ({
|
|||||||
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
||||||
if (!edge.getData()?.isSelected) {
|
if (!edge.getData()?.isSelected) {
|
||||||
edge.setAttrByPath('line/stroke', edge_color);
|
edge.setAttrByPath('line/stroke', edge_color);
|
||||||
edge.setData({ ...edge.getData(), isNodeHover: false }, { silent: true });
|
edge.setData({ ...edge.getData(), isNodeHover: false });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -1161,10 +1120,71 @@ export const useWorkflowGraph = ({
|
|||||||
// Listen to node move event
|
// Listen to node move event
|
||||||
graphRef.current.on('node:moved', nodeMoved);
|
graphRef.current.on('node:moved', nodeMoved);
|
||||||
|
|
||||||
|
if (isSafari) {
|
||||||
|
// When a parent (loop/iteration) node moves, keep child nodes in sync.
|
||||||
|
// Store each child's offset relative to the parent at drag start, then
|
||||||
|
// reapply it every frame to avoid cumulative delta errors.
|
||||||
|
const dragOffsets = new Map<string, { dx: number; dy: number }>();
|
||||||
|
|
||||||
|
graphRef.current.on('node:moving', ({ node }: { node: Node }) => {
|
||||||
|
const data = node.getData();
|
||||||
|
if (data?.type !== 'loop' && data?.type !== 'iteration') return;
|
||||||
|
const pos = node.getPosition();
|
||||||
|
const PORT_RADIUS = 6;
|
||||||
|
|
||||||
|
// Update parent componentContainer directly
|
||||||
|
const parentView = graphRef.current?.findViewByCell(node) as any;
|
||||||
|
if (parentView?.componentContainer) {
|
||||||
|
parentView.componentContainer.style.transform =
|
||||||
|
`translate(${pos.x + PORT_RADIUS}px, ${pos.y}px)`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const children = graphRef.current?.getNodes().filter(child => {
|
||||||
|
const cycle = child.getData()?.cycle;
|
||||||
|
return cycle === data.id || cycle === node.id;
|
||||||
|
}) ?? [];
|
||||||
|
|
||||||
|
// First event for this drag: record offsets
|
||||||
|
if (!dragOffsets.has(node.id)) {
|
||||||
|
children.forEach(child => {
|
||||||
|
const cp = child.getPosition();
|
||||||
|
dragOffsets.set(child.id, { dx: cp.x - pos.x, dy: cp.y - pos.y });
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply stored offsets to keep children in place relative to parent
|
||||||
|
children.forEach(child => {
|
||||||
|
const off = dragOffsets.get(child.id);
|
||||||
|
if (!off) return;
|
||||||
|
const nx = pos.x + off.dx;
|
||||||
|
const ny = pos.y + off.dy;
|
||||||
|
child.setPosition(nx, ny);
|
||||||
|
const childView = graphRef.current?.findViewByCell(child) as any;
|
||||||
|
if (childView?.componentContainer) {
|
||||||
|
childView.componentContainer.style.transform =
|
||||||
|
`translate(${nx + PORT_RADIUS}px, ${ny}px)`;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
graphRef.current.on('node:moved', ({ node }: { node: Node }) => {
|
||||||
|
// Clear offsets for this parent and all its children
|
||||||
|
const data = node.getData();
|
||||||
|
graphRef.current?.getNodes().forEach(child => {
|
||||||
|
const cycle = child.getData()?.cycle;
|
||||||
|
if (cycle === data?.id || cycle === node.id) dragOffsets.delete(child.id);
|
||||||
|
});
|
||||||
|
dragOffsets.delete(node.id);
|
||||||
|
nodeMoved({ node });
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
graphRef.current.on('node:removed', blankClick)
|
graphRef.current.on('node:removed', blankClick)
|
||||||
// When edge connected, bring connected nodes' ports to front
|
// When edge connected, reorder all cells to maintain correct layer order
|
||||||
graphRef.current.on('edge:connected', ({ isNew, edge }) => {
|
graphRef.current.on('edge:connected', ({ isNew, edge }) => {
|
||||||
if (isNew) {
|
if (isSafari && isNew && graphRef.current) {
|
||||||
|
reorderCells(graphRef.current);
|
||||||
|
} else if (!isSafari && isNew) {
|
||||||
const sourceCellId = edge.getSourceCellId()
|
const sourceCellId = edge.getSourceCellId()
|
||||||
const targetCellId = edge.getTargetCellId()
|
const targetCellId = edge.getTargetCellId()
|
||||||
const sourceCell = graphRef.current?.getCellById(sourceCellId);
|
const sourceCell = graphRef.current?.getCellById(sourceCellId);
|
||||||
@@ -1234,8 +1254,8 @@ export const useWorkflowGraph = ({
|
|||||||
// Delete selected nodes and edges
|
// Delete selected nodes and edges
|
||||||
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
|
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
|
||||||
// Undo / Redo
|
// Undo / Redo
|
||||||
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { undo(); return false; });
|
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { graphRef.current?.undo(); return false; });
|
||||||
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { redo(); return false; });
|
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { graphRef.current?.redo(); return false; });
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1278,7 +1298,7 @@ export const useWorkflowGraph = ({
|
|||||||
* Creates new node at drop position
|
* Creates new node at drop position
|
||||||
* @param event - React drag event
|
* @param event - React drag event
|
||||||
*/
|
*/
|
||||||
const onDrop = (event: React.DragEvent) => {
|
const onDrop = (event: DragEvent) => {
|
||||||
if (!graphRef.current) return;
|
if (!graphRef.current) return;
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
const dragData = JSON.parse(event.dataTransfer.getData('application/json'));
|
const dragData = JSON.parse(event.dataTransfer.getData('application/json'));
|
||||||
@@ -1301,51 +1321,13 @@ export const useWorkflowGraph = ({
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (dragData.type === 'loop' || dragData.type === 'iteration') {
|
if (dragData.type === 'loop' || dragData.type === 'iteration') {
|
||||||
graph.disableHistory()
|
graphRef.current.addNode({
|
||||||
const parentNode = graphRef.current.addNode({
|
|
||||||
...graphNodeLibrary[dragData.type],
|
...graphNodeLibrary[dragData.type],
|
||||||
x: point.x - 150,
|
x: point.x - 150,
|
||||||
y: point.y - 100,
|
y: point.y - 100,
|
||||||
id: cleanNodeData.id,
|
id: cleanNodeData.id,
|
||||||
data: { ...cleanNodeData, isGroup: true },
|
data: { ...cleanNodeData, isGroup: true },
|
||||||
})
|
});
|
||||||
const parentBBox = parentNode.getBBox()
|
|
||||||
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
|
||||||
const cycleStartNode = graphRef.current.addNode({
|
|
||||||
...graphNodeLibrary.cycleStart,
|
|
||||||
x: parentBBox.x + 24,
|
|
||||||
y: parentBBox.y + 70,
|
|
||||||
id: cycleStartId,
|
|
||||||
data: { id: cycleStartId, type: 'cycle-start', parentId: cleanNodeData.id, isDefault: true, cycle: cleanNodeData.id },
|
|
||||||
})
|
|
||||||
const addNode = graphRef.current.addNode({
|
|
||||||
...graphNodeLibrary.addStart,
|
|
||||||
x: parentBBox.x + 24 + 84,
|
|
||||||
y: parentBBox.y + 70 + 4,
|
|
||||||
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: cleanNodeData.id, cycle: cleanNodeData.id },
|
|
||||||
})
|
|
||||||
parentNode.addChild(cycleStartNode, { silent: true })
|
|
||||||
parentNode.addChild(addNode, { silent: true })
|
|
||||||
const newEdge = graphRef.current.addEdge({
|
|
||||||
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
|
|
||||||
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
|
|
||||||
...edgeAttrs,
|
|
||||||
})
|
|
||||||
cycleStartNode.toFront()
|
|
||||||
addNode.toFront()
|
|
||||||
graph.enableHistory()
|
|
||||||
// Manually push a single batch frame covering all 4 cells into undoStack
|
|
||||||
const history = graph.getPlugin('history') as History
|
|
||||||
const makeBatchCmd = (cell: any) => ({
|
|
||||||
batch: true,
|
|
||||||
event: 'cell:added',
|
|
||||||
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
|
|
||||||
options: {},
|
|
||||||
})
|
|
||||||
const batchFrame = [parentNode, cycleStartNode, addNode, newEdge].map(makeBatchCmd)
|
|
||||||
;(history as any).undoStack.push(batchFrame)
|
|
||||||
;(history as any).redoStack = []
|
|
||||||
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-group' } })
|
|
||||||
} else if (dragData.type === 'if-else') {
|
} else if (dragData.type === 'if-else') {
|
||||||
// Create condition node
|
// Create condition node
|
||||||
graphRef.current.addNode({
|
graphRef.current.addNode({
|
||||||
@@ -1592,80 +1574,8 @@ export const useWorkflowGraph = ({
|
|||||||
return userVars
|
return userVars
|
||||||
}
|
}
|
||||||
|
|
||||||
const clearHistoryRecords = () => {
|
const undo = () => graphRef.current?.undo()
|
||||||
setHistoryRecords([])
|
const redo = () => graphRef.current?.redo()
|
||||||
lastHistoryRef.current = null
|
|
||||||
}
|
|
||||||
|
|
||||||
const getStackCellIds = (cmds: any): string[] => {
|
|
||||||
const arr = Array.isArray(cmds) ? cmds : [cmds]
|
|
||||||
return [...new Set(arr.map((c: any) => c.data?.id).filter(Boolean))]
|
|
||||||
}
|
|
||||||
|
|
||||||
const isSkippableFrame = (frame: any): boolean => {
|
|
||||||
const arr = Array.isArray(frame) ? frame : [frame]
|
|
||||||
return arr.every((c: any) => ['zIndex', 'attrs', 'tools'].includes(c.data?.key))
|
|
||||||
}
|
|
||||||
|
|
||||||
const undo = () => {
|
|
||||||
const history = graphRef.current?.getPlugin('history') as History | undefined
|
|
||||||
if (!history || history.getUndoSize() === 0) return
|
|
||||||
const undoStack = (history as any).undoStack as any[]
|
|
||||||
isSyncingRef.current = true
|
|
||||||
while (undoStack.length > 0 && isSkippableFrame(undoStack[undoStack.length - 1])) {
|
|
||||||
graphRef.current!.undo()
|
|
||||||
}
|
|
||||||
if (undoStack.length === 0) {
|
|
||||||
isSyncingRef.current = false
|
|
||||||
return
|
|
||||||
}
|
|
||||||
const topIds = getStackCellIds(undoStack[undoStack.length - 1])
|
|
||||||
graphRef.current!.undo()
|
|
||||||
while (undoStack.length > 0) {
|
|
||||||
if (isSkippableFrame(undoStack[undoStack.length - 1])) {
|
|
||||||
graphRef.current!.undo()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
const nextIds = getStackCellIds(undoStack[undoStack.length - 1])
|
|
||||||
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
|
|
||||||
graphRef.current!.undo()
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
isSyncingRef.current = false
|
|
||||||
syncChildRelationships()
|
|
||||||
}
|
|
||||||
|
|
||||||
const redo = () => {
|
|
||||||
const history = graphRef.current?.getPlugin('history') as History | undefined
|
|
||||||
if (!history || history.getRedoSize() === 0) return
|
|
||||||
const redoStack = (history as any).redoStack as any[]
|
|
||||||
isSyncingRef.current = true
|
|
||||||
while (redoStack.length > 0 && isSkippableFrame(redoStack[redoStack.length - 1])) {
|
|
||||||
graphRef.current!.redo()
|
|
||||||
}
|
|
||||||
if (redoStack.length === 0) {
|
|
||||||
isSyncingRef.current = false
|
|
||||||
return
|
|
||||||
}
|
|
||||||
const topIds = getStackCellIds(redoStack[redoStack.length - 1])
|
|
||||||
graphRef.current!.redo()
|
|
||||||
while (redoStack.length > 0) {
|
|
||||||
if (isSkippableFrame(redoStack[redoStack.length - 1])) {
|
|
||||||
graphRef.current!.redo()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
const nextIds = getStackCellIds(redoStack[redoStack.length - 1])
|
|
||||||
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
|
|
||||||
graphRef.current!.redo()
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
isSyncingRef.current = false
|
|
||||||
syncChildRelationships()
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => {
|
const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => {
|
||||||
const { statement = '' } = value?.opening_statement || {}
|
const { statement = '' } = value?.opening_statement || {}
|
||||||
@@ -1706,16 +1616,20 @@ export const useWorkflowGraph = ({
|
|||||||
if (!graphRef.current) return;
|
if (!graphRef.current) return;
|
||||||
const nodes = graphRef.current.getNodes();
|
const nodes = graphRef.current.getNodes();
|
||||||
|
|
||||||
// Reset all node execution status on every chatHistory change
|
const lastWithSub = [...chatHistory].reverse().find(item => item.subContent?.length);
|
||||||
|
// Reset all node execution status first
|
||||||
nodes.forEach(node => {
|
nodes.forEach(node => {
|
||||||
const data = node.getData();
|
const data = node.getData();
|
||||||
node.setData({ ...data, executionStatus: '' });
|
if (typeof data.status === 'string') {
|
||||||
|
node.setData({ ...data, executionStatus: undefined });
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
if (!lastWithSub?.subContent) return;
|
||||||
const lastAssistant = [...chatHistory].reverse().find(item => item.role === 'assistant');
|
// Build a nodeId -> status map first
|
||||||
if (!lastAssistant?.subContent?.length) return;
|
const statusMap: Record<string, string> = {};
|
||||||
lastAssistant.subContent.forEach(sub => {
|
lastWithSub.subContent.forEach(sub => {
|
||||||
if (typeof sub.status === 'string') {
|
if (typeof sub.status === 'string') {
|
||||||
|
statusMap[sub.node_id] = sub.status;
|
||||||
const node = nodes.find(n => n.getData()?.id === sub.node_id);
|
const node = nodes.find(n => n.getData()?.id === sub.node_id);
|
||||||
if (node) {
|
if (node) {
|
||||||
node.setData({ ...node.getData(), executionStatus: sub.status });
|
node.setData({ ...node.getData(), executionStatus: sub.status });
|
||||||
@@ -1751,7 +1665,5 @@ export const useWorkflowGraph = ({
|
|||||||
canRedo,
|
canRedo,
|
||||||
undo,
|
undo,
|
||||||
redo,
|
redo,
|
||||||
historyRecords,
|
|
||||||
clearHistoryRecords,
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -113,13 +113,4 @@ export interface ChatVariable {
|
|||||||
}
|
}
|
||||||
export interface AddChatVariableRef {
|
export interface AddChatVariableRef {
|
||||||
handleOpen: (value?: ChatVariable) => void;
|
handleOpen: (value?: ChatVariable) => void;
|
||||||
}
|
|
||||||
|
|
||||||
export type HistoryActionType = 'add' | 'remove' | 'change' | 'undo' | 'redo' | 'batch'
|
|
||||||
|
|
||||||
export interface HistoryRecord {
|
|
||||||
type: HistoryActionType;
|
|
||||||
timestamp: number;
|
|
||||||
batchName?: string;
|
|
||||||
cellIds?: string[];
|
|
||||||
}
|
}
|
||||||
@@ -17,7 +17,6 @@ export const isSubExprSet = (sub: any) => {
|
|||||||
* Uses the same per-expression height logic as getConditionNodeCasePortY.
|
* Uses the same per-expression height logic as getConditionNodeCasePortY.
|
||||||
*/
|
*/
|
||||||
export const calcConditionNodeTotalHeight = (cases: any[]) => {
|
export const calcConditionNodeTotalHeight = (cases: any[]) => {
|
||||||
if (!cases?.length) return conditionNodeHeight;
|
|
||||||
const casesHeight = cases.reduce((acc: number, c: any) => {
|
const casesHeight = cases.reduce((acc: number, c: any) => {
|
||||||
const exprs = c?.expressions ?? [];
|
const exprs = c?.expressions ?? [];
|
||||||
const n = exprs.length;
|
const n = exprs.length;
|
||||||
|
|||||||
@@ -44,6 +44,9 @@ export default defineConfig({
|
|||||||
resolve: {
|
resolve: {
|
||||||
alias: {
|
alias: {
|
||||||
'@': resolve(__dirname, 'src'),
|
'@': resolve(__dirname, 'src'),
|
||||||
|
'x6-html-shape': resolve(__dirname, 'src/vendor/x6-html-shape/index.js'),
|
||||||
|
'x6-html-shape/dist/react': resolve(__dirname, 'src/vendor/x6-html-shape/react.js'),
|
||||||
|
'x6-html-shape/dist/utils.js': resolve(__dirname, 'src/vendor/x6-html-shape/utils.js'),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
base: './', // 使用相对路径,确保资源能正确加载
|
base: './', // 使用相对路径,确保资源能正确加载
|
||||||
|
|||||||
Reference in New Issue
Block a user