Compare commits
71 Commits
feature/sa
...
feature/ra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8d1ed51a7 | ||
|
|
9fa83ed01e | ||
|
|
e222490bce | ||
|
|
ad2e885f72 | ||
|
|
70c6d161c8 | ||
|
|
f85c0594c9 | ||
|
|
5fceba54b4 | ||
|
|
6e89302cb2 | ||
|
|
90aa4cef21 | ||
|
|
6c47bb77ab | ||
|
|
f667936664 | ||
|
|
64e640d882 | ||
|
|
140311048a | ||
|
|
4bef9b578b | ||
|
|
c53fcf3981 | ||
|
|
2997558bc8 | ||
|
|
30cdf229de | ||
|
|
ce4a3daec7 | ||
|
|
c12d06bb07 | ||
|
|
98d8d7b261 | ||
|
|
12a08a487d | ||
|
|
f7fa33c0c4 | ||
|
|
faf8d1a51a | ||
|
|
adb7f873b5 | ||
|
|
b64bcc2c50 | ||
|
|
d9de96cffa | ||
|
|
546bfb9627 | ||
|
|
9301eaf8df | ||
|
|
a268d0f7f1 | ||
|
|
6aef8227b1 | ||
|
|
675c7faf32 | ||
|
|
cd34d5f5ce | ||
|
|
1403b38648 | ||
|
|
b6e27da7b0 | ||
|
|
2c14344d3f | ||
|
|
141fd94513 | ||
|
|
a9413f57d1 | ||
|
|
0fc463036e | ||
|
|
ed5f98a746 | ||
|
|
422af69904 | ||
|
|
6cb48664b7 | ||
|
|
f48bb3cbee | ||
|
|
8dee2eae6a | ||
|
|
f63bcd6321 | ||
|
|
0228e6ad64 | ||
|
|
84ccb1e528 | ||
|
|
caef0fe44e | ||
|
|
21eb500680 | ||
|
|
c70f536acc | ||
|
|
5f96a6380e | ||
|
|
2c864f6337 | ||
|
|
32dfee803a | ||
|
|
4d9cfb70f7 | ||
|
|
4b0afe867a | ||
|
|
676c9a226c | ||
|
|
8f31236303 | ||
|
|
f2aedd29bc | ||
|
|
cf8db47389 | ||
|
|
cedf47b3bc | ||
|
|
be10bab763 | ||
|
|
b33f5951d8 | ||
|
|
2d120a64b1 | ||
|
|
0f7a7263eb | ||
|
|
5c89acced6 | ||
|
|
4619b40d03 | ||
|
|
7ac0eff0b8 | ||
|
|
aac89b172f | ||
|
|
bf9a3503de | ||
|
|
5c836c90c9 | ||
|
|
f93ec8d609 | ||
|
|
c5ae82c3c2 |
@@ -17,6 +17,7 @@ def _mask_url(url: str) -> str:
|
|||||||
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
|
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
|
||||||
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
|
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
|
||||||
|
|
||||||
|
|
||||||
# macOS fork() safety - must be set before any Celery initialization
|
# macOS fork() safety - must be set before any Celery initialization
|
||||||
if platform.system() == 'Darwin':
|
if platform.system() == 'Darwin':
|
||||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
@@ -29,7 +30,7 @@ if platform.system() == 'Darwin':
|
|||||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||||
|
|
||||||
_broker_url = os.getenv("CELERY_BROKER_URL") or \
|
_broker_url = os.getenv("CELERY_BROKER_URL") or \
|
||||||
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||||
@@ -66,11 +67,11 @@ celery_app.conf.update(
|
|||||||
task_serializer='json',
|
task_serializer='json',
|
||||||
accept_content=['json'],
|
accept_content=['json'],
|
||||||
result_serializer='json',
|
result_serializer='json',
|
||||||
|
|
||||||
# # 时区
|
# # 时区
|
||||||
# timezone='Asia/Shanghai',
|
# timezone='Asia/Shanghai',
|
||||||
# enable_utc=False,
|
# enable_utc=False,
|
||||||
|
|
||||||
# 任务追踪
|
# 任务追踪
|
||||||
task_track_started=True,
|
task_track_started=True,
|
||||||
task_ignore_result=False,
|
task_ignore_result=False,
|
||||||
|
|||||||
500
api/app/celery_task_scheduler.py
Normal file
500
api/app/celery_task_scheduler.py
Normal file
@@ -0,0 +1,500 @@
|
|||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import redis
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.logging_config import get_named_logger
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
logger = get_named_logger("task_scheduler")
|
||||||
|
|
||||||
|
# per-user queue scheduler:uq:{user_id}
|
||||||
|
USER_QUEUE_PREFIX = "scheduler:uq:"
|
||||||
|
# User Collection of Pending Messages
|
||||||
|
ACTIVE_USERS = "scheduler:active_users"
|
||||||
|
# Set of users that can dispatch (ready signal)
|
||||||
|
READY_SET = "scheduler:ready_users"
|
||||||
|
# Metadata of tasks that have been dispatched and are pending completion
|
||||||
|
PENDING_HASH = "scheduler:pending_tasks"
|
||||||
|
# Dynamic Sharding: Instance Registry
|
||||||
|
REGISTRY_KEY = "scheduler:instances"
|
||||||
|
|
||||||
|
TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded
|
||||||
|
HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds)
|
||||||
|
INSTANCE_TTL = 30 # Instance timeout (seconds)
|
||||||
|
|
||||||
|
LUA_ATOMIC_LOCK = """
|
||||||
|
local dispatch_lock = KEYS[1]
|
||||||
|
local lock_key = KEYS[2]
|
||||||
|
local instance_id = ARGV[1]
|
||||||
|
local dispatch_ttl = tonumber(ARGV[2])
|
||||||
|
local lock_ttl = tonumber(ARGV[3])
|
||||||
|
|
||||||
|
if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
|
||||||
|
if redis.call('EXISTS', lock_key) == 1 then
|
||||||
|
redis.call('DEL', dispatch_lock)
|
||||||
|
return -1
|
||||||
|
end
|
||||||
|
|
||||||
|
redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl)
|
||||||
|
return 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
LUA_SAFE_DELETE = """
|
||||||
|
if redis.call('GET', KEYS[1]) == ARGV[1] then
|
||||||
|
return redis.call('DEL', KEYS[1])
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def stable_hash(value: str) -> int:
|
||||||
|
return int.from_bytes(
|
||||||
|
hashlib.md5(value.encode("utf-8")).digest(),
|
||||||
|
"big"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def health_check_server(scheduler_ref):
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
health_app = FastAPI()
|
||||||
|
|
||||||
|
@health_app.get("/")
|
||||||
|
def health():
|
||||||
|
return scheduler_ref.health()
|
||||||
|
|
||||||
|
port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001"))
|
||||||
|
threading.Thread(
|
||||||
|
target=uvicorn.run,
|
||||||
|
kwargs={
|
||||||
|
"app": health_app,
|
||||||
|
"host": "0.0.0.0",
|
||||||
|
"port": port,
|
||||||
|
"log_config": None,
|
||||||
|
},
|
||||||
|
daemon=True,
|
||||||
|
).start()
|
||||||
|
logger.info("[Health] Server started at http://0.0.0.0:%s", port)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisTaskScheduler:
|
||||||
|
def __init__(self):
|
||||||
|
self.redis = redis.Redis(
|
||||||
|
host=settings.REDIS_HOST,
|
||||||
|
port=settings.REDIS_PORT,
|
||||||
|
db=settings.REDIS_DB_CELERY_BACKEND,
|
||||||
|
password=settings.REDIS_PASSWORD,
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
self.running = False
|
||||||
|
self.dispatched = 0
|
||||||
|
self.errors = 0
|
||||||
|
|
||||||
|
self.instance_id = f"{socket.gethostname()}-{os.getpid()}"
|
||||||
|
self._shard_index = 0
|
||||||
|
self._shard_count = 1
|
||||||
|
self._last_heartbeat = 0.0
|
||||||
|
|
||||||
|
def push_task(self, task_name, user_id, params):
|
||||||
|
try:
|
||||||
|
msg_id = str(uuid.uuid4())
|
||||||
|
msg = json.dumps({
|
||||||
|
"msg_id": msg_id,
|
||||||
|
"task_name": task_name,
|
||||||
|
"user_id": user_id,
|
||||||
|
"params": json.dumps(params),
|
||||||
|
})
|
||||||
|
|
||||||
|
lock_key = f"{task_name}:{user_id}"
|
||||||
|
queue_key = f"{USER_QUEUE_PREFIX}{user_id}"
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.rpush(queue_key, msg)
|
||||||
|
pipe.sadd(ACTIVE_USERS, user_id)
|
||||||
|
pipe.set(
|
||||||
|
f"task_tracker:{msg_id}",
|
||||||
|
json.dumps({"status": "QUEUED", "task_id": None}),
|
||||||
|
ex=86400,
|
||||||
|
)
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
if not self.redis.exists(lock_key):
|
||||||
|
self.redis.sadd(READY_SET, user_id)
|
||||||
|
|
||||||
|
logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id)
|
||||||
|
return msg_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Push task exception %s", e, exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_task_status(self, msg_id: str) -> dict:
|
||||||
|
raw = self.redis.get(f"task_tracker:{msg_id}")
|
||||||
|
if raw is None:
|
||||||
|
return {"status": "NOT_FOUND"}
|
||||||
|
|
||||||
|
tracker = json.loads(raw)
|
||||||
|
status = tracker["status"]
|
||||||
|
task_id = tracker.get("task_id")
|
||||||
|
result_content = tracker.get("result") or {}
|
||||||
|
|
||||||
|
if status == "DISPATCHED" and task_id:
|
||||||
|
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
|
||||||
|
if result_raw:
|
||||||
|
result_data = json.loads(result_raw)
|
||||||
|
status = result_data.get("status", status)
|
||||||
|
result_content = result_data.get("result")
|
||||||
|
|
||||||
|
return {"status": status, "task_id": task_id, "result": result_content}
|
||||||
|
|
||||||
|
def _cleanup_finished(self):
|
||||||
|
pending = self.redis.hgetall(PENDING_HASH)
|
||||||
|
if not pending:
|
||||||
|
return
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
task_ids = list(pending.keys())
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for task_id in task_ids:
|
||||||
|
pipe.get(f"celery-task-meta-{task_id}")
|
||||||
|
results = pipe.execute()
|
||||||
|
|
||||||
|
cleanup_pipe = self.redis.pipeline()
|
||||||
|
has_cleanup = False
|
||||||
|
ready_user_ids = set()
|
||||||
|
|
||||||
|
for task_id, raw_result in zip(task_ids, results):
|
||||||
|
try:
|
||||||
|
meta = json.loads(pending[task_id])
|
||||||
|
lock_key = meta["lock_key"]
|
||||||
|
dispatched_at = meta.get("dispatched_at", 0)
|
||||||
|
age = now - dispatched_at
|
||||||
|
|
||||||
|
should_cleanup = False
|
||||||
|
result_data = {}
|
||||||
|
|
||||||
|
if raw_result is not None:
|
||||||
|
result_data = json.loads(raw_result)
|
||||||
|
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
|
||||||
|
should_cleanup = True
|
||||||
|
logger.info(
|
||||||
|
"Task finished: %s state=%s", task_id,
|
||||||
|
result_data.get("status"),
|
||||||
|
)
|
||||||
|
elif age > TASK_TIMEOUT:
|
||||||
|
should_cleanup = True
|
||||||
|
logger.warning(
|
||||||
|
"Task expired or lost: %s age=%.0fs, force cleanup",
|
||||||
|
task_id, age,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_cleanup:
|
||||||
|
final_status = (
|
||||||
|
result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id)
|
||||||
|
|
||||||
|
cleanup_pipe.hdel(PENDING_HASH, task_id)
|
||||||
|
|
||||||
|
tracker_msg_id = meta.get("msg_id")
|
||||||
|
if tracker_msg_id:
|
||||||
|
cleanup_pipe.set(
|
||||||
|
f"task_tracker:{tracker_msg_id}",
|
||||||
|
json.dumps({
|
||||||
|
"status": final_status,
|
||||||
|
"task_id": task_id,
|
||||||
|
"result": result_data.get("result") or {},
|
||||||
|
}),
|
||||||
|
ex=86400,
|
||||||
|
)
|
||||||
|
has_cleanup = True
|
||||||
|
|
||||||
|
parts = lock_key.split(":", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
ready_user_ids.add(parts[1])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
if has_cleanup:
|
||||||
|
cleanup_pipe.execute()
|
||||||
|
|
||||||
|
if ready_user_ids:
|
||||||
|
self.redis.sadd(READY_SET, *ready_user_ids)
|
||||||
|
|
||||||
|
def _heartbeat(self):
|
||||||
|
now = time.time()
|
||||||
|
if now - self._last_heartbeat < HEARTBEAT_INTERVAL:
|
||||||
|
return
|
||||||
|
self._last_heartbeat = now
|
||||||
|
|
||||||
|
self.redis.hset(REGISTRY_KEY, self.instance_id, str(now))
|
||||||
|
|
||||||
|
all_instances = self.redis.hgetall(REGISTRY_KEY)
|
||||||
|
|
||||||
|
alive = []
|
||||||
|
dead = []
|
||||||
|
for iid, ts in all_instances.items():
|
||||||
|
if now - float(ts) < INSTANCE_TTL:
|
||||||
|
alive.append(iid)
|
||||||
|
else:
|
||||||
|
dead.append(iid)
|
||||||
|
|
||||||
|
if dead:
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for iid in dead:
|
||||||
|
pipe.hdel(REGISTRY_KEY, iid)
|
||||||
|
pipe.execute()
|
||||||
|
logger.info("Cleaned dead instances: %s", dead)
|
||||||
|
|
||||||
|
alive.sort()
|
||||||
|
self._shard_count = max(len(alive), 1)
|
||||||
|
self._shard_index = (
|
||||||
|
alive.index(self.instance_id) if self.instance_id in alive else 0
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Shard: %s/%s (instance=%s, alive=%d)",
|
||||||
|
self._shard_index, self._shard_count,
|
||||||
|
self.instance_id, len(alive),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_mine(self, user_id: str) -> bool:
|
||||||
|
if self._shard_count <= 1:
|
||||||
|
return True
|
||||||
|
return stable_hash(user_id) % self._shard_count == self._shard_index
|
||||||
|
|
||||||
|
def _dispatch(self, msg_id, msg_data) -> bool:
|
||||||
|
user_id = msg_data["user_id"]
|
||||||
|
task_name = msg_data["task_name"]
|
||||||
|
params = json.loads(msg_data.get("params", "{}"))
|
||||||
|
|
||||||
|
lock_key = f"{task_name}:{user_id}"
|
||||||
|
dispatch_lock = f"dispatch:{msg_id}"
|
||||||
|
|
||||||
|
result = self.redis.eval(
|
||||||
|
LUA_ATOMIC_LOCK, 2,
|
||||||
|
dispatch_lock, lock_key,
|
||||||
|
self.instance_id, str(300), str(3600),
|
||||||
|
)
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
return False
|
||||||
|
if result == -1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
task = celery_app.send_task(task_name, kwargs=params)
|
||||||
|
except Exception as e:
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.delete(dispatch_lock)
|
||||||
|
pipe.delete(lock_key)
|
||||||
|
pipe.execute()
|
||||||
|
self.errors += 1
|
||||||
|
logger.error(
|
||||||
|
"send_task failed for %s:%s msg=%s: %s",
|
||||||
|
task_name, user_id, msg_id, e, exc_info=True,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.set(lock_key, task.id, ex=3600)
|
||||||
|
pipe.hset(PENDING_HASH, task.id, json.dumps({
|
||||||
|
"lock_key": lock_key,
|
||||||
|
"dispatched_at": time.time(),
|
||||||
|
"msg_id": msg_id,
|
||||||
|
}))
|
||||||
|
pipe.delete(dispatch_lock)
|
||||||
|
pipe.set(
|
||||||
|
f"task_tracker:{msg_id}",
|
||||||
|
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
|
||||||
|
ex=86400,
|
||||||
|
)
|
||||||
|
pipe.execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Post-dispatch state update failed for %s: %s",
|
||||||
|
task.id, e, exc_info=True,
|
||||||
|
)
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
self.dispatched += 1
|
||||||
|
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _process_batch(self, user_ids):
|
||||||
|
if not user_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for uid in user_ids:
|
||||||
|
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
|
||||||
|
heads = pipe.execute()
|
||||||
|
|
||||||
|
candidates = [] # (user_id, msg_dict)
|
||||||
|
empty_users = []
|
||||||
|
|
||||||
|
for uid, head in zip(user_ids, heads):
|
||||||
|
if head is None:
|
||||||
|
empty_users.append(uid)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
candidates.append((uid, json.loads(head)))
|
||||||
|
except (json.JSONDecodeError, TypeError) as e:
|
||||||
|
logger.error("Bad message in queue for user %s: %s", uid, e)
|
||||||
|
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
|
||||||
|
|
||||||
|
if empty_users:
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for uid in empty_users:
|
||||||
|
pipe.srem(ACTIVE_USERS, uid)
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return
|
||||||
|
|
||||||
|
for uid, msg in candidates:
|
||||||
|
if self._dispatch(msg["msg_id"], msg):
|
||||||
|
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
|
||||||
|
|
||||||
|
def schedule_loop(self):
|
||||||
|
self._heartbeat()
|
||||||
|
self._cleanup_finished()
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.smembers(READY_SET)
|
||||||
|
pipe.delete(READY_SET)
|
||||||
|
results = pipe.execute()
|
||||||
|
ready_users = results[0] or set()
|
||||||
|
|
||||||
|
my_users = [uid for uid in ready_users if self._is_mine(uid)]
|
||||||
|
|
||||||
|
if not my_users:
|
||||||
|
time.sleep(0.5)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._process_batch(my_users)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
def _full_scan(self):
|
||||||
|
cursor = 0
|
||||||
|
ready_batch = []
|
||||||
|
while True:
|
||||||
|
cursor, user_ids = self.redis.sscan(
|
||||||
|
ACTIVE_USERS, cursor=cursor, count=1000,
|
||||||
|
)
|
||||||
|
if user_ids:
|
||||||
|
my_users = [uid for uid in user_ids if self._is_mine(uid)]
|
||||||
|
if my_users:
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for uid in my_users:
|
||||||
|
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
|
||||||
|
heads = pipe.execute()
|
||||||
|
|
||||||
|
for uid, head in zip(my_users, heads):
|
||||||
|
if head is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
msg = json.loads(head)
|
||||||
|
lock_key = f"{msg['task_name']}:{uid}"
|
||||||
|
ready_batch.append((uid, lock_key))
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not ready_batch:
|
||||||
|
return
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for _, lock_key in ready_batch:
|
||||||
|
pipe.exists(lock_key)
|
||||||
|
lock_exists = pipe.execute()
|
||||||
|
|
||||||
|
ready_uids = [
|
||||||
|
uid for (uid, _), locked in zip(ready_batch, lock_exists)
|
||||||
|
if not locked
|
||||||
|
]
|
||||||
|
|
||||||
|
if ready_uids:
|
||||||
|
self.redis.sadd(READY_SET, *ready_uids)
|
||||||
|
logger.info("Full scan found %d ready users", len(ready_uids))
|
||||||
|
|
||||||
|
def run_server(self):
|
||||||
|
health_check_server(self)
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
last_full_scan = 0.0
|
||||||
|
full_scan_interval = 30.0
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Scheduler started: instance=%s", self.instance_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self.schedule_loop()
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
if now - last_full_scan > full_scan_interval:
|
||||||
|
self._full_scan()
|
||||||
|
last_full_scan = now
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Scheduler exception %s", e, exc_info=True)
|
||||||
|
self.errors += 1
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
def health(self) -> dict:
|
||||||
|
return {
|
||||||
|
"running": self.running,
|
||||||
|
"active_users": self.redis.scard(ACTIVE_USERS),
|
||||||
|
"ready_users": self.redis.scard(READY_SET),
|
||||||
|
"pending_tasks": self.redis.hlen(PENDING_HASH),
|
||||||
|
"dispatched": self.dispatched,
|
||||||
|
"errors": self.errors,
|
||||||
|
"shard": f"{self._shard_index}/{self._shard_count}",
|
||||||
|
"instance": self.instance_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
logger.info("Scheduler shutting down: instance=%s", self.instance_id)
|
||||||
|
self.running = False
|
||||||
|
try:
|
||||||
|
self.redis.hdel(REGISTRY_KEY, self.instance_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Shutdown cleanup error: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
scheduler: RedisTaskScheduler | None = None
|
||||||
|
if scheduler is None:
|
||||||
|
scheduler = RedisTaskScheduler()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def _signal_handler(signum, frame):
|
||||||
|
scheduler.shutdown()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, _signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, _signal_handler)
|
||||||
|
|
||||||
|
scheduler.run_server()
|
||||||
@@ -9,7 +9,7 @@ from app.core.logging_config import get_business_logger
|
|||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
|
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.services.app_service import AppService
|
from app.services.app_service import AppService
|
||||||
from app.services.app_log_service import AppLogService
|
from app.services.app_log_service import AppLogService
|
||||||
@@ -41,7 +41,7 @@ def list_app_logs(
|
|||||||
|
|
||||||
# 验证应用访问权限
|
# 验证应用访问权限
|
||||||
app_service = AppService(db)
|
app_service = AppService(db)
|
||||||
app_service.get_app(app_id, workspace_id)
|
app = app_service.get_app(app_id, workspace_id)
|
||||||
|
|
||||||
# 使用 Service 层查询
|
# 使用 Service 层查询
|
||||||
log_service = AppLogService(db)
|
log_service = AppLogService(db)
|
||||||
@@ -51,7 +51,8 @@ def list_app_logs(
|
|||||||
page=page,
|
page=page,
|
||||||
pagesize=pagesize,
|
pagesize=pagesize,
|
||||||
is_draft=is_draft,
|
is_draft=is_draft,
|
||||||
keyword=keyword
|
keyword=keyword,
|
||||||
|
app_type=app.type,
|
||||||
)
|
)
|
||||||
|
|
||||||
items = [AppLogConversation.model_validate(c) for c in conversations]
|
items = [AppLogConversation.model_validate(c) for c in conversations]
|
||||||
@@ -78,17 +79,32 @@ def get_app_log_detail(
|
|||||||
|
|
||||||
# 验证应用访问权限
|
# 验证应用访问权限
|
||||||
app_service = AppService(db)
|
app_service = AppService(db)
|
||||||
app_service.get_app(app_id, workspace_id)
|
app = app_service.get_app(app_id, workspace_id)
|
||||||
|
|
||||||
# 使用 Service 层查询
|
# 使用 Service 层查询
|
||||||
log_service = AppLogService(db)
|
log_service = AppLogService(db)
|
||||||
conversation, node_executions_map = log_service.get_conversation_detail(
|
conversation, messages, node_executions_map = log_service.get_conversation_detail(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id,
|
||||||
|
app_type=app.type
|
||||||
)
|
)
|
||||||
|
|
||||||
detail = AppLogConversationDetail.model_validate(conversation)
|
# 构建基础会话信息(不经过 ORM relationship)
|
||||||
detail.node_executions_map = node_executions_map
|
base = AppLogConversation.model_validate(conversation)
|
||||||
|
|
||||||
|
# 单独处理 messages,避免触发 SQLAlchemy relationship 校验
|
||||||
|
if messages and isinstance(messages[0], AppLogMessage):
|
||||||
|
# 工作流:已经是 AppLogMessage 实例
|
||||||
|
msg_list = messages
|
||||||
|
else:
|
||||||
|
# Agent:ORM Message 对象逐个转换
|
||||||
|
msg_list = [AppLogMessage.model_validate(m) for m in messages]
|
||||||
|
|
||||||
|
detail = AppLogConversationDetail(
|
||||||
|
**base.model_dump(),
|
||||||
|
messages=msg_list,
|
||||||
|
node_executions_map=node_executions_map,
|
||||||
|
)
|
||||||
|
|
||||||
return success(data=detail)
|
return success(data=detail)
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
import csv
|
||||||
|
import io
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -23,6 +25,7 @@ from app.models.user_model import User
|
|||||||
from app.schemas import chunk_schema
|
from app.schemas import chunk_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||||
|
from app.services.file_storage_service import FileStorageService, get_file_storage_service, generate_kb_file_key
|
||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
@@ -82,19 +85,32 @@ async def get_preview_chunks(
|
|||||||
detail="The file does not exist or you do not have permission to access it"
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
# 5. Get file content from storage backend
|
||||||
file_path = os.path.join(
|
if not db_file.file_key:
|
||||||
settings.FILE_PATH,
|
|
||||||
str(db_file.kb_id),
|
|
||||||
str(db_file.parent_id),
|
|
||||||
f"{db_file.id}{db_file.file_ext}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 6. Check if the file exists
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="File not found (possibly deleted)"
|
detail="File has no storage key (legacy data not migrated)"
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.services.file_storage_service import FileStorageService
|
||||||
|
import asyncio
|
||||||
|
storage_service = FileStorageService()
|
||||||
|
|
||||||
|
async def _download():
|
||||||
|
return await storage_service.download_file(db_file.file_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_binary = asyncio.run(_download())
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
file_binary = loop.run_until_complete(_download())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"File not found in storage: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 7. Document parsing & segmentation
|
# 7. Document parsing & segmentation
|
||||||
@@ -104,11 +120,12 @@ async def get_preview_chunks(
|
|||||||
vision_model = QWenCV(
|
vision_model = QWenCV(
|
||||||
key=db_knowledge.image2text.api_keys[0].api_key,
|
key=db_knowledge.image2text.api_keys[0].api_key,
|
||||||
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
||||||
lang="Chinese", # Default to Chinese
|
lang="Chinese",
|
||||||
base_url=db_knowledge.image2text.api_keys[0].api_base
|
base_url=db_knowledge.image2text.api_keys[0].api_base
|
||||||
)
|
)
|
||||||
from app.core.rag.app.naive import chunk
|
from app.core.rag.app.naive import chunk
|
||||||
res = chunk(filename=file_path,
|
res = chunk(filename=db_file.file_name,
|
||||||
|
binary=file_binary,
|
||||||
from_page=0,
|
from_page=0,
|
||||||
to_page=5,
|
to_page=5,
|
||||||
callback=progress_callback,
|
callback=progress_callback,
|
||||||
@@ -257,6 +274,9 @@ async def create_chunk(
|
|||||||
"sort_id": sort_id,
|
"sort_id": sort_id,
|
||||||
"status": 1,
|
"status": 1,
|
||||||
}
|
}
|
||||||
|
# QA chunk: 注入 chunk_type/question/answer 到 metadata
|
||||||
|
if create_data.is_qa:
|
||||||
|
metadata.update(create_data.qa_metadata)
|
||||||
chunk = DocumentChunk(page_content=content, metadata=metadata)
|
chunk = DocumentChunk(page_content=content, metadata=metadata)
|
||||||
# 3. Segmented vector storage
|
# 3. Segmented vector storage
|
||||||
vector_service.add_chunks([chunk])
|
vector_service.add_chunks([chunk])
|
||||||
@@ -268,6 +288,187 @@ async def create_chunk(
|
|||||||
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
|
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
|
||||||
|
async def create_chunks_batch(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
batch_data: chunk_schema.ChunkBatchCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Batch create chunks (max 8)
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Batch create chunks: kb_id={kb_id}, document_id={document_id}, count={len(batch_data.items)}, username: {current_user.username}")
|
||||||
|
|
||||||
|
if len(batch_data.items) > settings.MAX_CHUNK_BATCH_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Batch size exceeds limit: max {settings.MAX_CHUNK_BATCH_SIZE}, got {len(batch_data.items)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The knowledge base does not exist or access is denied")
|
||||||
|
|
||||||
|
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||||
|
if not db_document:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The document does not exist or you do not have permission to access it")
|
||||||
|
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
|
||||||
|
# Get current max sort_id
|
||||||
|
sort_id = 0
|
||||||
|
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
|
||||||
|
if items:
|
||||||
|
sort_id = items[0].metadata["sort_id"]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for create_data in batch_data.items:
|
||||||
|
sort_id += 1
|
||||||
|
doc_id = uuid.uuid4().hex
|
||||||
|
metadata = {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"file_id": str(db_document.file_id),
|
||||||
|
"file_name": db_document.file_name,
|
||||||
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||||
|
"document_id": str(document_id),
|
||||||
|
"knowledge_id": str(kb_id),
|
||||||
|
"sort_id": sort_id,
|
||||||
|
"status": 1,
|
||||||
|
}
|
||||||
|
if create_data.is_qa:
|
||||||
|
metadata.update(create_data.qa_metadata)
|
||||||
|
chunks.append(DocumentChunk(page_content=create_data.chunk_content, metadata=metadata))
|
||||||
|
|
||||||
|
vector_service.add_chunks(chunks)
|
||||||
|
|
||||||
|
db_document.chunk_num += len(chunks)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{kb_id}/import_qa", response_model=ApiResponse)
|
||||||
|
async def import_qa_new_doc(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
导入 QA 问答对并新建文档(CSV/Excel),异步处理
|
||||||
|
"""
|
||||||
|
from app.schemas import file_schema, document_schema
|
||||||
|
|
||||||
|
api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# 1. 校验文件格式
|
||||||
|
filename = file.filename or ""
|
||||||
|
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
|
||||||
|
|
||||||
|
# 2. 校验知识库
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
|
||||||
|
|
||||||
|
# 3. 读取文件
|
||||||
|
contents = await file.read()
|
||||||
|
file_size = len(contents)
|
||||||
|
if file_size == 0:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件为空")
|
||||||
|
|
||||||
|
_, file_extension = os.path.splitext(filename)
|
||||||
|
file_ext = file_extension.lower()
|
||||||
|
|
||||||
|
# 4. 创建 File 记录
|
||||||
|
file_data = file_schema.FileCreate(
|
||||||
|
kb_id=kb_id, created_by=current_user.id,
|
||||||
|
parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"),
|
||||||
|
file_name=filename, file_ext=file_ext, file_size=file_size,
|
||||||
|
)
|
||||||
|
db_file = file_service.create_file(db=db, file=file_data, current_user=current_user)
|
||||||
|
|
||||||
|
# 5. 上传文件到存储后端
|
||||||
|
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
|
||||||
|
try:
|
||||||
|
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Storage upload failed: {e}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"文件存储失败: {str(e)}")
|
||||||
|
|
||||||
|
db_file.file_key = file_key
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_file)
|
||||||
|
|
||||||
|
# 6. 创建 Document 记录(标记为 QA 类型)
|
||||||
|
doc_data = document_schema.DocumentCreate(
|
||||||
|
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
||||||
|
file_name=filename, file_ext=file_ext, file_size=file_size,
|
||||||
|
file_meta={}, parser_id="qa",
|
||||||
|
parser_config={"doc_type": "qa", "auto_questions": 0}
|
||||||
|
)
|
||||||
|
db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user)
|
||||||
|
|
||||||
|
api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}, file_key={file_key}")
|
||||||
|
|
||||||
|
# 7. 派发异步任务
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
task = celery_app.send_task(
|
||||||
|
"app.core.rag.tasks.import_qa_chunks",
|
||||||
|
args=[str(kb_id), str(db_document.id), filename, contents],
|
||||||
|
queue="qa_import"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data={
|
||||||
|
"task_id": task.id,
|
||||||
|
"document_id": str(db_document.id),
|
||||||
|
"file_id": str(db_file.id),
|
||||||
|
}, msg="QA 导入任务已提交,后台处理中")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse)
|
||||||
|
async def import_qa_chunks(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
导入 QA 问答对(CSV/Excel),异步处理
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Import QA chunks: kb_id={kb_id}, document_id={document_id}, file={file.filename}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# 1. 校验文件格式
|
||||||
|
filename = file.filename or ""
|
||||||
|
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
|
||||||
|
|
||||||
|
# 2. 校验知识库和文档
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
|
||||||
|
|
||||||
|
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||||
|
if not db_document:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文档不存在或无权访问")
|
||||||
|
|
||||||
|
# 3. 读取文件内容,派发异步任务
|
||||||
|
contents = await file.read()
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
task = celery_app.send_task(
|
||||||
|
"app.core.rag.tasks.import_qa_chunks",
|
||||||
|
args=[str(kb_id), str(document_id), filename, contents],
|
||||||
|
queue="qa_import"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data={"task_id": task.id}, msg="QA 导入任务已提交,后台处理中")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||||
async def get_chunk(
|
async def get_chunk(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
@@ -328,6 +529,9 @@ async def update_chunk(
|
|||||||
if total:
|
if total:
|
||||||
chunk = items[0]
|
chunk = items[0]
|
||||||
chunk.page_content = content
|
chunk.page_content = content
|
||||||
|
# QA chunk: 更新 metadata 中的 question/answer
|
||||||
|
if update_data.is_qa:
|
||||||
|
chunk.metadata.update(update_data.qa_metadata)
|
||||||
vector_service.update_by_segment(chunk)
|
vector_service.update_by_segment(chunk)
|
||||||
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
|
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
|
||||||
else:
|
else:
|
||||||
@@ -342,6 +546,7 @@ async def delete_chunk(
|
|||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
document_id: uuid.UUID,
|
document_id: uuid.UUID,
|
||||||
doc_id: str,
|
doc_id: str,
|
||||||
|
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
@@ -359,7 +564,7 @@ async def delete_chunk(
|
|||||||
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
if vector_service.text_exists(doc_id):
|
if vector_service.text_exists(doc_id):
|
||||||
vector_service.delete_by_ids([doc_id])
|
vector_service.delete_by_ids([doc_id], refresh=force_refresh)
|
||||||
# 更新 chunk_num
|
# 更新 chunk_num
|
||||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||||
db_document.chunk_num -= 1
|
db_document.chunk_num -= 1
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from app.models.user_model import User
|
|||||||
from app.schemas import document_schema
|
from app.schemas import document_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import document_service, file_service, knowledge_service
|
from app.services import document_service, file_service, knowledge_service
|
||||||
|
from app.services.file_storage_service import FileStorageService, get_file_storage_service
|
||||||
|
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
@@ -231,7 +232,8 @@ async def update_document(
|
|||||||
async def delete_document(
|
async def delete_document(
|
||||||
document_id: uuid.UUID,
|
document_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete document
|
Delete document
|
||||||
@@ -257,7 +259,7 @@ async def delete_document(
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# 3. Delete file
|
# 3. Delete file
|
||||||
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user)
|
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
|
||||||
|
|
||||||
# 4. Delete vector index
|
# 4. Delete vector index
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||||
@@ -305,38 +307,25 @@ async def parse_documents(
|
|||||||
detail="The file does not exist or you do not have permission to access it"
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
# 3. Get file_key for storage backend
|
||||||
file_path = os.path.join(
|
if not db_file.file_key:
|
||||||
settings.FILE_PATH,
|
api_logger.error(f"File has no storage key (legacy data not migrated): file_id={db_file.id}")
|
||||||
str(db_file.kb_id),
|
|
||||||
str(db_file.parent_id),
|
|
||||||
f"{db_file.id}{db_file.file_ext}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Check if the file exists
|
|
||||||
api_logger.debug(f"Constructed file path: {file_path}")
|
|
||||||
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="File not found (possibly deleted)"
|
detail="File has no storage key (legacy data not migrated)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Obtain knowledge base information
|
# 4. Obtain knowledge base information
|
||||||
api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
|
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||||
if not db_knowledge:
|
if not db_knowledge:
|
||||||
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The knowledge base does not exist or access is denied"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 6. Task: Document parsing, vectorization, and storage
|
# 5. Dispatch parse task with file_key (not file_path)
|
||||||
# from app.tasks import parse_document
|
task = celery_app.send_task(
|
||||||
# parse_document(file_path, document_id)
|
"app.core.rag.tasks.parse_document",
|
||||||
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
|
args=[db_file.file_key, document_id, db_file.file_name]
|
||||||
|
)
|
||||||
result = {
|
result = {
|
||||||
"task_id": task.id
|
"task_id": task.id
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import Response
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
@@ -19,10 +17,14 @@ from app.models.user_model import User
|
|||||||
from app.schemas import file_schema, document_schema
|
from app.schemas import file_schema, document_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import file_service, document_service
|
from app.services import file_service, document_service
|
||||||
|
from app.services.knowledge_service import get_knowledge_by_id as get_kb_by_id
|
||||||
|
from app.services.file_storage_service import (
|
||||||
|
FileStorageService,
|
||||||
|
generate_kb_file_key,
|
||||||
|
get_file_storage_service,
|
||||||
|
)
|
||||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||||
|
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
@@ -35,67 +37,37 @@ router = APIRouter(
|
|||||||
async def get_files(
|
async def get_files(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
page: int = Query(1, gt=0),
|
||||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
pagesize: int = Query(20, gt=0, le=100),
|
||||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
|
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
|
||||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||||
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""Paged query file list"""
|
||||||
Paged query file list
|
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}")
|
||||||
- Support filtering by kb_id and parent_id
|
|
||||||
- Support keyword search for file names
|
|
||||||
- Support dynamic sorting
|
|
||||||
- Return paging metadata + file list
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
|
||||||
# 1. parameter validation
|
|
||||||
if page < 1 or pagesize < 1:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="The paging parameter must be greater than 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Construct query conditions
|
if page < 1 or pagesize < 1:
|
||||||
filters = [
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The paging parameter must be greater than 0")
|
||||||
file_model.File.kb_id == kb_id
|
|
||||||
]
|
filters = [file_model.File.kb_id == kb_id]
|
||||||
if parent_id:
|
if parent_id:
|
||||||
filters.append(file_model.File.parent_id == parent_id)
|
filters.append(file_model.File.parent_id == parent_id)
|
||||||
# Keyword search (fuzzy matching of file name)
|
|
||||||
if keywords:
|
if keywords:
|
||||||
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
|
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
|
||||||
|
|
||||||
# 3. Execute paged query
|
|
||||||
try:
|
try:
|
||||||
api_logger.debug("Start executing file paging query")
|
|
||||||
total, items = file_service.get_files_paginated(
|
total, items = file_service.get_files_paginated(
|
||||||
db=db,
|
db=db, filters=filters, page=page, pagesize=pagesize,
|
||||||
filters=filters,
|
orderby=orderby, desc=desc, current_user=current_user
|
||||||
page=page,
|
|
||||||
pagesize=pagesize,
|
|
||||||
orderby=orderby,
|
|
||||||
desc=desc,
|
|
||||||
current_user=current_user
|
|
||||||
)
|
)
|
||||||
api_logger.info(f"File query successful: total={total}, returned={len(items)} records")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Query failed: {str(e)}")
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Query failed: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Return structured response
|
|
||||||
result = {
|
result = {
|
||||||
"items": items,
|
"items": items,
|
||||||
"page": {
|
"page": {"page": page, "pagesize": pagesize, "total": total, "has_next": page * pagesize < total}
|
||||||
"page": page,
|
|
||||||
"pagesize": pagesize,
|
|
||||||
"total": total,
|
|
||||||
"has_next": True if page * pagesize < total else False
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
|
return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
|
||||||
|
|
||||||
@@ -108,23 +80,14 @@ async def create_folder(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""Create a new folder"""
|
||||||
Create a new folder
|
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}")
|
||||||
"""
|
|
||||||
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_logger.debug(f"Start creating a folder: {folder_name}")
|
create_folder_data = file_schema.FileCreate(
|
||||||
create_folder = file_schema.FileCreate(
|
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
||||||
kb_id=kb_id,
|
file_name=folder_name, file_ext='folder', file_size=0,
|
||||||
created_by=current_user.id,
|
|
||||||
parent_id=parent_id,
|
|
||||||
file_name=folder_name,
|
|
||||||
file_ext='folder',
|
|
||||||
file_size=0,
|
|
||||||
)
|
)
|
||||||
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
|
db_file = file_service.create_file(db=db, file=create_folder_data, current_user=current_user)
|
||||||
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
|
|
||||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
|
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
|
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
|
||||||
@@ -138,76 +101,58 @@ async def upload_file(
|
|||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
):
|
):
|
||||||
"""
|
"""Upload file to storage backend"""
|
||||||
upload file
|
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}")
|
||||||
"""
|
|
||||||
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}")
|
|
||||||
|
|
||||||
# Read the contents of the file
|
|
||||||
contents = await file.read()
|
contents = await file.read()
|
||||||
# Check file size
|
|
||||||
file_size = len(contents)
|
file_size = len(contents)
|
||||||
print(f"file size: {file_size} byte")
|
|
||||||
if file_size == 0:
|
if file_size == 0:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The file is empty.")
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="The file is empty."
|
|
||||||
)
|
|
||||||
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"File size exceeds {settings.MAX_FILE_SIZE} byte limit")
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract the extension using `os.path.splitext`
|
|
||||||
_, file_extension = os.path.splitext(file.filename)
|
_, file_extension = os.path.splitext(file.filename)
|
||||||
upload_file = file_schema.FileCreate(
|
file_ext = file_extension.lower()
|
||||||
kb_id=kb_id,
|
|
||||||
created_by=current_user.id,
|
# Create File record
|
||||||
parent_id=parent_id,
|
upload_file_data = file_schema.FileCreate(
|
||||||
file_name=file.filename,
|
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
||||||
file_ext=file_extension.lower(),
|
file_name=file.filename, file_ext=file_ext, file_size=file_size,
|
||||||
file_size=file_size,
|
|
||||||
)
|
)
|
||||||
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
|
||||||
|
|
||||||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
# Upload to storage backend
|
||||||
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
|
||||||
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
try:
|
||||||
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Storage upload failed: {e}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
|
||||||
|
|
||||||
# Save file
|
# Save file_key
|
||||||
with open(save_path, "wb") as f:
|
db_file.file_key = file_key
|
||||||
f.write(contents)
|
db.commit()
|
||||||
|
db.refresh(db_file)
|
||||||
|
|
||||||
# Verify whether the file has been saved successfully
|
# Create document (inherit parser_config from knowledge base)
|
||||||
if not os.path.exists(save_path):
|
default_parser_config = {
|
||||||
raise HTTPException(
|
"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
"auto_keywords": 0, "auto_questions": 0, "html4excel": "false"
|
||||||
detail="File save failed"
|
}
|
||||||
)
|
try:
|
||||||
|
db_knowledge = get_kb_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||||
|
if db_knowledge and db_knowledge.parser_config:
|
||||||
|
default_parser_config.update(dict(db_knowledge.parser_config))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Create a document
|
|
||||||
create_data = document_schema.DocumentCreate(
|
create_data = document_schema.DocumentCreate(
|
||||||
kb_id=kb_id,
|
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
||||||
created_by=current_user.id,
|
file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
|
||||||
file_id=db_file.id,
|
file_meta={}, parser_id="naive", parser_config=default_parser_config
|
||||||
file_name=db_file.file_name,
|
|
||||||
file_ext=db_file.file_ext,
|
|
||||||
file_size=db_file.file_size,
|
|
||||||
file_meta={},
|
|
||||||
parser_id="naive",
|
|
||||||
parser_config={
|
|
||||||
"layout_recognize": "DeepDOC",
|
|
||||||
"chunk_token_num": 128,
|
|
||||||
"delimiter": "\n",
|
|
||||||
"auto_keywords": 0,
|
|
||||||
"auto_questions": 0,
|
|
||||||
"html4excel": "false"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||||
|
|
||||||
@@ -221,123 +166,73 @@ async def custom_text(
|
|||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
create_data: file_schema.CustomTextFileCreate,
|
create_data: file_schema.CustomTextFileCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
):
|
):
|
||||||
"""
|
"""Custom text upload"""
|
||||||
custom text
|
|
||||||
"""
|
|
||||||
api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}")
|
|
||||||
|
|
||||||
# Check file content size
|
|
||||||
# 将内容编码为字节(UTF-8)
|
|
||||||
content_bytes = create_data.content.encode('utf-8')
|
content_bytes = create_data.content.encode('utf-8')
|
||||||
file_size = len(content_bytes)
|
file_size = len(content_bytes)
|
||||||
print(f"file size: {file_size} byte")
|
|
||||||
if file_size == 0:
|
if file_size == 0:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The content is empty.")
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="The content is empty."
|
|
||||||
)
|
|
||||||
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Content size exceeds {settings.MAX_FILE_SIZE} byte limit")
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
|
||||||
)
|
|
||||||
|
|
||||||
upload_file = file_schema.FileCreate(
|
upload_file_data = file_schema.FileCreate(
|
||||||
kb_id=kb_id,
|
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
||||||
created_by=current_user.id,
|
file_name=f"{create_data.title}.txt", file_ext=".txt", file_size=file_size,
|
||||||
parent_id=parent_id,
|
|
||||||
file_name=f"{create_data.title}.txt",
|
|
||||||
file_ext=".txt",
|
|
||||||
file_size=file_size,
|
|
||||||
)
|
)
|
||||||
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
|
||||||
|
|
||||||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
# Upload to storage backend
|
||||||
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=".txt")
|
||||||
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
try:
|
||||||
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
|
await storage_service.storage.upload(file_key=file_key, content=content_bytes, content_type="text/plain")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Storage upload failed: {e}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
|
||||||
|
|
||||||
# Save file
|
db_file.file_key = file_key
|
||||||
with open(save_path, "wb") as f:
|
db.commit()
|
||||||
f.write(content_bytes)
|
db.refresh(db_file)
|
||||||
|
|
||||||
# Verify whether the file has been saved successfully
|
|
||||||
if not os.path.exists(save_path):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="File save failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a document
|
|
||||||
create_document_data = document_schema.DocumentCreate(
|
create_document_data = document_schema.DocumentCreate(
|
||||||
kb_id=kb_id,
|
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
||||||
created_by=current_user.id,
|
file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
|
||||||
file_id=db_file.id,
|
file_meta={}, parser_id="naive",
|
||||||
file_name=db_file.file_name,
|
parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
|
||||||
file_ext=db_file.file_ext,
|
"auto_keywords": 0, "auto_questions": 0, "html4excel": "false"}
|
||||||
file_size=db_file.file_size,
|
|
||||||
file_meta={},
|
|
||||||
parser_id="naive",
|
|
||||||
parser_config={
|
|
||||||
"layout_recognize": "DeepDOC",
|
|
||||||
"chunk_token_num": 128,
|
|
||||||
"delimiter": "\n",
|
|
||||||
"auto_keywords": 0,
|
|
||||||
"auto_questions": 0,
|
|
||||||
"html4excel": "false"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||||
|
|
||||||
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
|
|
||||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
|
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{file_id}", response_model=Any)
|
@router.get("/{file_id}", response_model=Any)
|
||||||
async def get_file(
|
async def get_file(
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""Download file by file_id"""
|
||||||
Download the file based on the file_id
|
|
||||||
- Query file information from the database
|
|
||||||
- Construct the file path and check if it exists
|
|
||||||
- Return a FileResponse to download the file
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Download the file based on the file_id: file_id={file_id}")
|
|
||||||
|
|
||||||
# 1. Query file information from the database
|
|
||||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
if not db_file:
|
if not db_file:
|
||||||
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The file does not exist or you do not have permission to access it"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
if not db_file.file_key:
|
||||||
file_path = os.path.join(
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File has no storage key (legacy data not migrated)")
|
||||||
settings.FILE_PATH,
|
|
||||||
str(db_file.kb_id),
|
|
||||||
str(db_file.parent_id),
|
|
||||||
f"{db_file.id}{db_file.file_ext}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Check if the file exists
|
try:
|
||||||
if not os.path.exists(file_path):
|
content = await storage_service.download_file(db_file.file_key)
|
||||||
raise HTTPException(
|
except Exception as e:
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
api_logger.error(f"Storage download failed: {e}")
|
||||||
detail="File not found (possibly deleted)"
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage")
|
||||||
)
|
|
||||||
|
|
||||||
# 4.Return FileResponse (automatically handle download)
|
import mimetypes
|
||||||
return FileResponse(
|
media_type = mimetypes.guess_type(db_file.file_name)[0] or "application/octet-stream"
|
||||||
path=file_path,
|
return Response(
|
||||||
filename=db_file.file_name, # Use original file name
|
content=content,
|
||||||
media_type="application/octet-stream" # Universal binary stream type
|
media_type=media_type,
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="{db_file.file_name}"'}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -348,50 +243,22 @@ async def update_file(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""Update file information (such as file name)"""
|
||||||
Update file information (such as file name)
|
|
||||||
- Only specified fields such as file_name are allowed to be modified
|
|
||||||
"""
|
|
||||||
api_logger.debug(f"Query the file to be updated: {file_id}")
|
|
||||||
|
|
||||||
# 1. Check if the file exists
|
|
||||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
|
|
||||||
if not db_file:
|
if not db_file:
|
||||||
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The file does not exist or you do not have permission to access it"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Update fields (only update non-null fields)
|
|
||||||
api_logger.debug(f"Start updating the file fields: {file_id}")
|
|
||||||
updated_fields = []
|
|
||||||
for field, value in update_data.dict(exclude_unset=True).items():
|
for field, value in update_data.dict(exclude_unset=True).items():
|
||||||
if hasattr(db_file, field):
|
if hasattr(db_file, field):
|
||||||
old_value = getattr(db_file, field)
|
setattr(db_file, field, value)
|
||||||
if old_value != value:
|
|
||||||
# update value
|
|
||||||
setattr(db_file, field, value)
|
|
||||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
|
||||||
|
|
||||||
if updated_fields:
|
|
||||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
|
||||||
|
|
||||||
# 3. Save to database
|
|
||||||
try:
|
try:
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_file)
|
db.refresh(db_file)
|
||||||
api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
api_logger.error(f"File update failed: file_id={file_id} - {str(e)}")
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File update failed: {str(e)}")
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"File update failed: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Return the updated file
|
|
||||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
|
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
|
||||||
|
|
||||||
|
|
||||||
@@ -399,60 +266,43 @@ async def update_file(
|
|||||||
async def delete_file(
|
async def delete_file(
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
):
|
):
|
||||||
"""
|
"""Delete a file or folder"""
|
||||||
Delete a file or folder
|
api_logger.info(f"Request to delete file: file_id={file_id}")
|
||||||
"""
|
await _delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
|
||||||
api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}")
|
|
||||||
await _delete_file(db=db, file_id=file_id, current_user=current_user)
|
|
||||||
return success(msg="File deleted successfully")
|
return success(msg="File deleted successfully")
|
||||||
|
|
||||||
|
|
||||||
async def _delete_file(
|
async def _delete_file(
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User,
|
||||||
|
storage_service: FileStorageService,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Delete a file or folder from storage and database"""
|
||||||
Delete a file or folder
|
|
||||||
"""
|
|
||||||
# 1. Check if the file exists
|
|
||||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
|
|
||||||
if not db_file:
|
if not db_file:
|
||||||
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The file does not exist or you do not have permission to access it"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Construct physical path
|
# Delete from storage backend
|
||||||
file_path = Path(
|
|
||||||
settings.FILE_PATH,
|
|
||||||
str(db_file.kb_id),
|
|
||||||
str(db_file.id)
|
|
||||||
) if db_file.file_ext == 'folder' else Path(
|
|
||||||
settings.FILE_PATH,
|
|
||||||
str(db_file.kb_id),
|
|
||||||
str(db_file.parent_id),
|
|
||||||
f"{db_file.id}{db_file.file_ext}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Delete physical files/folders
|
|
||||||
try:
|
|
||||||
if file_path.exists():
|
|
||||||
if db_file.file_ext == 'folder':
|
|
||||||
shutil.rmtree(file_path) # Recursively delete folders
|
|
||||||
else:
|
|
||||||
file_path.unlink() # Delete a single file
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to delete physical file/folder: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4.Delete db_file
|
|
||||||
if db_file.file_ext == 'folder':
|
if db_file.file_ext == 'folder':
|
||||||
|
# For folders, delete all child files from storage first
|
||||||
|
child_files = db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).all()
|
||||||
|
for child in child_files:
|
||||||
|
if child.file_key:
|
||||||
|
try:
|
||||||
|
await storage_service.delete_file(child.file_key)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"Failed to delete child file from storage: {child.file_key} - {e}")
|
||||||
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
|
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
|
||||||
|
else:
|
||||||
|
if db_file.file_key:
|
||||||
|
try:
|
||||||
|
await storage_service.delete_file(db_file.file_key)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"Failed to delete file from storage: {db_file.file_key} - {e}")
|
||||||
|
|
||||||
db.delete(db_file)
|
db.delete(db_file)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -4,7 +4,9 @@
|
|||||||
处理显性记忆相关的API接口,包括情景记忆和语义记忆的查询。
|
处理显性记忆相关的API接口,包括情景记忆和语义记忆的查询。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success, fail
|
from app.core.response_utils import success, fail
|
||||||
@@ -69,6 +71,140 @@ async def get_explicit_memory_overview_api(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/episodics", response_model=ApiResponse)
|
||||||
|
async def get_episodic_memory_list_api(
|
||||||
|
end_user_id: str = Query(..., description="end user ID"),
|
||||||
|
page: int = Query(1, gt=0, description="page number, starting from 1"),
|
||||||
|
pagesize: int = Query(10, gt=0, le=100, description="number of items per page, max 100"),
|
||||||
|
start_date: Optional[int] = Query(None, description="start timestamp (ms)"),
|
||||||
|
end_date: Optional[int] = Query(None, description="end timestamp (ms)"),
|
||||||
|
episodic_type: str = Query("all", description="episodic type :all/conversation/project_work/learning/decision/important_event"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
获取情景记忆分页列表
|
||||||
|
|
||||||
|
返回指定用户的情景记忆列表,支持分页、时间范围筛选和情景类型筛选。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID(必填)
|
||||||
|
page: 页码(从1开始,默认1)
|
||||||
|
pagesize: 每页数量(默认10,最大100)
|
||||||
|
start_date: 开始时间戳(可选,毫秒),自动扩展到当天 00:00:00
|
||||||
|
end_date: 结束时间戳(可选,毫秒),自动扩展到当天 23:59:59
|
||||||
|
episodic_type: 情景类型筛选(可选,默认all)
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含情景记忆分页列表
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- 基础分页查询:GET /episodics?end_user_id=xxx&page=1&pagesize=5
|
||||||
|
返回第1页,每页5条数据
|
||||||
|
- 按时间范围筛选:GET /episodics?end_user_id=xxx&page=1&pagesize=5&start_date=1738684800000&end_date=1738771199000
|
||||||
|
返回指定时间范围内的数据
|
||||||
|
- 按情景类型筛选:GET /episodics?end_user_id=xxx&page=1&pagesize=5&episodic_type=important_event
|
||||||
|
返回类型为"重要事件"的数据
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- start_date 和 end_date 必须同时提供或同时不提供
|
||||||
|
- start_date 不能大于 end_date
|
||||||
|
- episodic_type 可选值:all, conversation, project_work, learning, decision, important_event
|
||||||
|
- total 为该用户情景记忆总数(不受筛选条件影响)
|
||||||
|
- page.total 为筛选后的总条数
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆列表但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"情景记忆分页查询: end_user_id={end_user_id}, "
|
||||||
|
f"start_date={start_date}, end_date={end_date}, episodic_type={episodic_type}, "
|
||||||
|
f"page={page}, pagesize={pagesize}, username={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. 参数校验
|
||||||
|
if page < 1 or pagesize < 1:
|
||||||
|
api_logger.warning(f"分页参数错误: page={page}, pagesize={pagesize}")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "分页参数必须大于0")
|
||||||
|
|
||||||
|
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
|
||||||
|
if episodic_type not in valid_episodic_types:
|
||||||
|
api_logger.warning(f"无效的情景类型参数: {episodic_type}")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
|
||||||
|
|
||||||
|
# 时间戳参数校验
|
||||||
|
if (start_date is not None and end_date is None) or (end_date is not None and start_date is None):
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "start_date和end_date必须同时提供")
|
||||||
|
|
||||||
|
if start_date is not None and end_date is not None and start_date > end_date:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "start_date不能大于end_date")
|
||||||
|
|
||||||
|
# 2. 执行查询
|
||||||
|
try:
|
||||||
|
result = await memory_explicit_service.get_episodic_memory_list(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
episodic_type=episodic_type,
|
||||||
|
)
|
||||||
|
api_logger.info(
|
||||||
|
f"情景记忆分页查询成功: end_user_id={end_user_id}, "
|
||||||
|
f"total={result['total']}, 返回={len(result['items'])}条"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"情景记忆分页查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "情景记忆分页查询失败", str(e))
|
||||||
|
|
||||||
|
# 3. 返回结构化响应
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
|
||||||
|
@router.get("/semantics", response_model=ApiResponse)
|
||||||
|
async def get_semantic_memory_list_api(
|
||||||
|
end_user_id: str = Query(..., description="终端用户ID"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
获取语义记忆列表
|
||||||
|
|
||||||
|
返回指定用户的全量语义记忆列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID(必填)
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含语义记忆全量列表
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试查询语义记忆列表但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"语义记忆列表查询: end_user_id={end_user_id}, username={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await memory_explicit_service.get_semantic_memory_list(
|
||||||
|
end_user_id=end_user_id
|
||||||
|
)
|
||||||
|
api_logger.info(
|
||||||
|
f"语义记忆列表查询成功: end_user_id={end_user_id}, total={len(result)}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"语义记忆列表查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "语义记忆列表查询失败", str(e))
|
||||||
|
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/details", response_model=ApiResponse)
|
@router.post("/details", response_model=ApiResponse)
|
||||||
async def get_explicit_memory_details_api(
|
async def get_explicit_memory_details_api(
|
||||||
request: ExplicitMemoryDetailsRequest,
|
request: ExplicitMemoryDetailsRequest,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from . import (
|
|||||||
rag_api_document_controller,
|
rag_api_document_controller,
|
||||||
rag_api_file_controller,
|
rag_api_file_controller,
|
||||||
rag_api_knowledge_controller,
|
rag_api_knowledge_controller,
|
||||||
|
user_memory_api_controller,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 V1 API 路由器
|
# 创建 V1 API 路由器
|
||||||
@@ -28,5 +29,6 @@ service_router.include_router(rag_api_chunk_controller.router)
|
|||||||
service_router.include_router(memory_api_controller.router)
|
service_router.include_router(memory_api_controller.router)
|
||||||
service_router.include_router(end_user_api_controller.router)
|
service_router.include_router(end_user_api_controller.router)
|
||||||
service_router.include_router(memory_config_api_controller.router)
|
service_router.include_router(memory_config_api_controller.router)
|
||||||
|
service_router.include_router(user_memory_api_controller.router)
|
||||||
|
|
||||||
__all__ = ["service_router"]
|
__all__ = ["service_router"]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from fastapi import APIRouter, Body, Depends, Query, Request
|
from fastapi import APIRouter, Body, Depends, Query, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.celery_task_scheduler import scheduler
|
||||||
from app.core.api_key_auth import require_api_key
|
from app.core.api_key_auth import require_api_key
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.quota_stub import check_end_user_quota
|
from app.core.quota_stub import check_end_user_quota
|
||||||
@@ -86,7 +87,7 @@ async def write_memory(
|
|||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
|
logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}")
|
||||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
||||||
|
|
||||||
|
|
||||||
@@ -105,8 +106,7 @@ async def get_write_task_status(
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Write task status check - task_id: {task_id}")
|
logger.info(f"Write task status check - task_id: {task_id}")
|
||||||
|
|
||||||
from app.services.task_service import get_task_memory_write_result
|
result = scheduler.get_task_status(task_id)
|
||||||
result = get_task_memory_write_result(task_id)
|
|
||||||
|
|
||||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||||
|
|
||||||
|
|||||||
@@ -113,6 +113,33 @@ async def create_chunk(
|
|||||||
current_user=current_user)
|
current_user=current_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
|
||||||
|
@require_api_key(scopes=["rag"])
|
||||||
|
async def create_chunks_batch(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
items: list = Body(..., description="chunk items list"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Batch create chunks (max 8)
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
batch_data = chunk_schema.ChunkBatchCreate(**body)
|
||||||
|
# 0. Obtain the creator of the api key
|
||||||
|
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||||
|
current_user = api_key.creator
|
||||||
|
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||||
|
|
||||||
|
return await chunk_controller.create_chunks_batch(kb_id=kb_id,
|
||||||
|
document_id=document_id,
|
||||||
|
batch_data=batch_data,
|
||||||
|
db=db,
|
||||||
|
current_user=current_user)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||||
@require_api_key(scopes=["rag"])
|
@require_api_key(scopes=["rag"])
|
||||||
async def get_chunk(
|
async def get_chunk(
|
||||||
@@ -176,6 +203,7 @@ async def delete_chunk(
|
|||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
delete document chunk
|
delete document chunk
|
||||||
@@ -188,6 +216,7 @@ async def delete_chunk(
|
|||||||
return await chunk_controller.delete_chunk(kb_id=kb_id,
|
return await chunk_controller.delete_chunk(kb_id=kb_id,
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
|
force_refresh=force_refresh,
|
||||||
db=db,
|
db=db,
|
||||||
current_user=current_user)
|
current_user=current_user)
|
||||||
|
|
||||||
|
|||||||
230
api/app/controllers/service/user_memory_api_controller.py
Normal file
230
api/app/controllers/service/user_memory_api_controller.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""User Memory 服务接口 — 基于 API Key 认证
|
||||||
|
|
||||||
|
包装 user_memory_controllers.py 和 memory_agent_controller.py 中的内部接口,
|
||||||
|
提供基于 API Key 认证的对外服务:
|
||||||
|
1./analytics/graph_data - 知识图谱数据接口
|
||||||
|
2./analytics/community_graph - 社区图谱接口
|
||||||
|
3./analytics/node_statistics - 记忆节点统计接口
|
||||||
|
4./analytics/user_summary - 用户摘要接口
|
||||||
|
5./analytics/memory_insight - 记忆洞察接口
|
||||||
|
6./analytics/interest_distribution - 兴趣分布接口
|
||||||
|
7./analytics/end_user_info - 终端用户信息接口
|
||||||
|
8./analytics/generate_cache - 缓存生成接口
|
||||||
|
|
||||||
|
|
||||||
|
路由前缀: /memory
|
||||||
|
子路径: /analytics/...
|
||||||
|
最终路径: /v1/memory/analytics/...
|
||||||
|
认证方式: API Key (@require_api_key)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, Query, Request, Body
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.api_key_auth import require_api_key
|
||||||
|
from app.core.api_key_utils import get_current_user_from_api_key, validate_end_user_in_workspace
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.db import get_db
|
||||||
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
|
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||||
|
|
||||||
|
# 包装内部服务 controller
|
||||||
|
from app.controllers import user_memory_controllers, memory_agent_controller
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/memory", tags=["V1 - User Memory API"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 知识图谱 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/graph_data")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_graph_data(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
node_types: Optional[str] = Query(None, description="Comma-separated node types filter"),
|
||||||
|
limit: int = Query(100, description="Max nodes to return (auto-capped at 1000 in service layer)"),
|
||||||
|
depth: int = Query(1, description="Graph traversal depth (auto-capped at 3 in service layer)"),
|
||||||
|
center_node_id: Optional[str] = Query(None, description="Center node for subgraph"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get knowledge graph data (nodes + edges) for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_graph_data_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
node_types=node_types,
|
||||||
|
limit=limit,
|
||||||
|
depth=depth,
|
||||||
|
center_node_id=center_node_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/community_graph")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_community_graph(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get community clustering graph for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_community_graph_data_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 节点统计 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/node_statistics")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_node_statistics(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get memory node type statistics for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_node_statistics_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 用户摘要 & 洞察 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/user_summary")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_user_summary(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get cached user summary for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_user_summary_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
language_type=language_type,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/memory_insight")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_memory_insight(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get cached memory insight report for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_memory_insight_report_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 兴趣分布 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/interest_distribution")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_interest_distribution(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
limit: int = Query(5, le=5, description="Max interest tags to return"),
|
||||||
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get interest distribution tags for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await memory_agent_controller.get_interest_distribution_by_user_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
language_type=language_type,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 终端用户信息 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/end_user_info")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_end_user_info(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get end user basic information (name, aliases, metadata)."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_end_user_info(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 缓存生成 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/analytics/generate_cache")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def generate_cache(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
|
):
|
||||||
|
"""Trigger cache generation (user summary + memory insight) for an end user or all workspace users."""
|
||||||
|
body = await request.json()
|
||||||
|
cache_request = GenerateCacheRequest(**body)
|
||||||
|
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
|
||||||
|
if cache_request.end_user_id:
|
||||||
|
validate_end_user_in_workspace(db, cache_request.end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.generate_cache_api(
|
||||||
|
request=cache_request,
|
||||||
|
language_type=language_type,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -173,6 +173,8 @@ async def delete_tool(
|
|||||||
return success(msg="工具删除成功")
|
return success(msg="工具删除成功")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -249,6 +251,8 @@ async def parse_openapi_schema(
|
|||||||
if result["success"] is False:
|
if result["success"] is False:
|
||||||
raise HTTPException(status_code=400, detail=result["message"])
|
raise HTTPException(status_code=400, detail=result["message"])
|
||||||
return success(data=result, msg="Schema解析完成")
|
return success(data=result, msg="Schema解析完成")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,15 @@
|
|||||||
"""API Key 工具函数"""
|
"""API Key 工具函数"""
|
||||||
import secrets
|
import secrets
|
||||||
|
import uuid as _uuid
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session as _Session
|
||||||
|
from app.core.error_codes import BizCode as _BizCode
|
||||||
|
from app.core.exceptions import BusinessException as _BusinessException
|
||||||
|
from app.models.end_user_model import EndUser as _EndUser
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository as _EndUserRepository
|
||||||
|
|
||||||
from app.models.api_key_model import ApiKeyType
|
from app.models.api_key_model import ApiKeyType
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
@@ -65,3 +72,72 @@ def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return int(dt.timestamp() * 1000)
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_from_api_key(db: _Session, api_key_auth):
|
||||||
|
"""通过 API Key 构造 current_user 对象。
|
||||||
|
|
||||||
|
从 API Key 反查创建者(管理员用户),并设置其 workspace 上下文。
|
||||||
|
与内部接口的 Depends(get_current_user) (JWT) 等价。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
api_key_auth: API Key 认证信息(ApiKeyAuth)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User ORM 对象,已设置 current_workspace_id
|
||||||
|
"""
|
||||||
|
from app.services import api_key_service
|
||||||
|
|
||||||
|
api_key = api_key_service.ApiKeyService.get_api_key(
|
||||||
|
db, api_key_auth.api_key_id, api_key_auth.workspace_id
|
||||||
|
)
|
||||||
|
current_user = api_key.creator
|
||||||
|
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def validate_end_user_in_workspace(
|
||||||
|
db: _Session,
|
||||||
|
end_user_id: str,
|
||||||
|
workspace_id,
|
||||||
|
) -> _EndUser:
|
||||||
|
"""校验 end_user 是否存在且属于指定 workspace。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
end_user_id: 终端用户 ID
|
||||||
|
workspace_id: 工作空间 ID(UUID 或字符串均可)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EndUser ORM 对象(校验通过时)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException(INVALID_PARAMETER): end_user_id 格式无效
|
||||||
|
BusinessException(USER_NOT_FOUND): end_user 不存在
|
||||||
|
BusinessException(PERMISSION_DENIED): end_user 不属于该 workspace
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_uuid.UUID(end_user_id)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
raise _BusinessException(
|
||||||
|
f"Invalid end_user_id format: {end_user_id}",
|
||||||
|
_BizCode.INVALID_PARAMETER,
|
||||||
|
)
|
||||||
|
|
||||||
|
end_user_repo = _EndUserRepository(db)
|
||||||
|
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||||
|
|
||||||
|
if end_user is None:
|
||||||
|
raise _BusinessException(
|
||||||
|
"End user not found",
|
||||||
|
_BizCode.USER_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
if str(end_user.workspace_id) != str(workspace_id):
|
||||||
|
raise _BusinessException(
|
||||||
|
"End user does not belong to this workspace",
|
||||||
|
_BizCode.PERMISSION_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
|
return end_user
|
||||||
@@ -98,6 +98,7 @@ class Settings:
|
|||||||
# File Upload
|
# File Upload
|
||||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||||
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||||
|
MAX_CHUNK_BATCH_SIZE: int = int(os.getenv("MAX_CHUNK_BATCH_SIZE", "8"))
|
||||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from app.celery_task_scheduler import scheduler
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||||
@@ -12,8 +13,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.task_service import get_task_memory_write_result
|
|
||||||
from app.tasks import write_message_task
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
@@ -86,16 +85,28 @@ async def write(
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
write_id = write_message_task.delay(
|
# write_id = write_message_task.delay(
|
||||||
actual_end_user_id, # end_user_id: User ID
|
# actual_end_user_id, # end_user_id: User ID
|
||||||
structured_messages, # message: JSON string format message list
|
# structured_messages, # message: JSON string format message list
|
||||||
str(actual_config_id), # config_id: Configuration ID string
|
# str(actual_config_id), # config_id: Configuration ID string
|
||||||
storage_type, # storage_type: "neo4j"
|
# storage_type, # storage_type: "neo4j"
|
||||||
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
# user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||||
|
# )
|
||||||
|
scheduler.push_task(
|
||||||
|
"app.core.memory.agent.write_message",
|
||||||
|
str(actual_end_user_id),
|
||||||
|
{
|
||||||
|
"end_user_id": str(actual_end_user_id),
|
||||||
|
"message": structured_messages,
|
||||||
|
"config_id": str(actual_config_id),
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id or ""
|
||||||
|
}
|
||||||
)
|
)
|
||||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
|
||||||
write_status = get_task_memory_write_result(str(write_id))
|
# logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
# write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
# logger.info(f'[WRITE] Task result - user={actual_end_user_id}')
|
||||||
|
|
||||||
|
|
||||||
async def term_memory_save(end_user_id, strategy_type, scope):
|
async def term_memory_save(end_user_id, strategy_type, scope):
|
||||||
@@ -164,13 +175,24 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
|||||||
else:
|
else:
|
||||||
config_id = memory_config
|
config_id = memory_config
|
||||||
|
|
||||||
write_message_task.delay(
|
scheduler.push_task(
|
||||||
end_user_id, # end_user_id: User ID
|
"app.core.memory.agent.write_message",
|
||||||
redis_messages, # message: JSON string format message list
|
str(end_user_id),
|
||||||
config_id, # config_id: Configuration ID string
|
{
|
||||||
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
"end_user_id": str(end_user_id),
|
||||||
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
"message": redis_messages,
|
||||||
|
"config_id": str(config_id),
|
||||||
|
"storage_type": AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||||
|
"user_rag_memory_id": ""
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
# write_message_task.delay(
|
||||||
|
# end_user_id, # end_user_id: User ID
|
||||||
|
# redis_messages, # message: JSON string format message list
|
||||||
|
# config_id, # config_id: Configuration ID string
|
||||||
|
# AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
||||||
|
# "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||||
|
# )
|
||||||
count_store.update_sessions_count(end_user_id, 0, [])
|
count_store.update_sessions_count(end_user_id, 0, [])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from app.core.memory.enums import SearchStrategy, StorageType
|
from app.core.memory.enums import SearchStrategy, StorageType
|
||||||
from app.core.memory.models.service_models import MemorySearchResult
|
from app.core.memory.models.service_models import MemorySearchResult
|
||||||
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
|
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
|
||||||
from app.core.memory.read_services.content_search import Neo4jSearchService, RAGSearchService
|
from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
|
||||||
from app.core.memory.read_services.query_preprocessor import QueryPreprocessor
|
from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor
|
||||||
|
|
||||||
|
|
||||||
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
||||||
|
|||||||
@@ -8,4 +8,4 @@ class RetrievalSummaryProcessor:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify(content: str, llm_client: RedBearLLM):
|
def verify(content: str, llm_client: RedBearLLM):
|
||||||
return
|
return
|
||||||
@@ -8,7 +8,7 @@ from neo4j import Session
|
|||||||
from app.core.memory.enums import Neo4jNodeType
|
from app.core.memory.enums import Neo4jNodeType
|
||||||
from app.core.memory.memory_service import MemoryContext
|
from app.core.memory.memory_service import MemoryContext
|
||||||
from app.core.memory.models.service_models import Memory, MemorySearchResult
|
from app.core.memory.models.service_models import Memory, MemorySearchResult
|
||||||
from app.core.memory.read_services.result_builder import data_builder_factory
|
from app.core.memory.read_services.search_engine.result_builder import data_builder_factory
|
||||||
from app.core.models import RedBearEmbeddings
|
from app.core.models import RedBearEmbeddings
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
@@ -46,7 +46,10 @@ async def run_graphrag(
|
|||||||
start = trio.current_time()
|
start = trio.current_time()
|
||||||
workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"]
|
workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"]
|
||||||
chunks = []
|
chunks = []
|
||||||
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id"], sort_by_position=True):
|
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id", "chunk_type"], sort_by_position=True):
|
||||||
|
# 跳过 QA chunks,只用原文 chunks 构建图谱
|
||||||
|
if d.get("chunk_type") == "qa":
|
||||||
|
continue
|
||||||
chunks.append(d["page_content"])
|
chunks.append(d["page_content"])
|
||||||
|
|
||||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||||
@@ -150,6 +153,9 @@ async def run_graphrag_for_kb(
|
|||||||
|
|
||||||
total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True)
|
total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True)
|
||||||
for doc in items:
|
for doc in items:
|
||||||
|
# 跳过 QA chunks,只用原文 chunks 构建图谱
|
||||||
|
if (doc.metadata or {}).get("chunk_type") == "qa":
|
||||||
|
continue
|
||||||
content = doc.page_content
|
content = doc.page_content
|
||||||
if num_tokens_from_string(current_chunk + content) < 1024:
|
if num_tokens_from_string(current_chunk + content) < 1024:
|
||||||
current_chunk += content
|
current_chunk += content
|
||||||
|
|||||||
@@ -131,18 +131,52 @@ def keyword_extraction(chat_mdl, content, topn=3):
|
|||||||
|
|
||||||
|
|
||||||
def question_proposal(chat_mdl, content, topn=3):
|
def question_proposal(chat_mdl, content, topn=3):
|
||||||
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
|
"""生成问题(向后兼容,返回纯文本问题列表)"""
|
||||||
rendered_prompt = template.render(content=content, topn=topn)
|
pairs = qa_proposal(chat_mdl, content, topn)
|
||||||
|
if not pairs:
|
||||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
|
||||||
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
|
|
||||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
|
||||||
if isinstance(kwd, tuple):
|
|
||||||
kwd = kwd[0]
|
|
||||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
|
||||||
if kwd.find("**ERROR**") >= 0:
|
|
||||||
return ""
|
return ""
|
||||||
return kwd
|
return "\n".join([p["question"] for p in pairs])
|
||||||
|
|
||||||
|
|
||||||
|
def qa_proposal(chat_mdl, content, topn=3, custom_prompt=None):
|
||||||
|
"""生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_mdl: LLM 模型
|
||||||
|
content: 文本内容
|
||||||
|
topn: 生成 QA 对数量
|
||||||
|
custom_prompt: 自定义 prompt 模板(支持 Jinja2,可用变量: content, topn)
|
||||||
|
"""
|
||||||
|
if custom_prompt:
|
||||||
|
template = PROMPT_JINJA_ENV.from_string(custom_prompt)
|
||||||
|
sys_prompt = template.render(topn=topn)
|
||||||
|
else:
|
||||||
|
sys_prompt = QUESTION_PROMPT_TEMPLATE
|
||||||
|
msg = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": content}]
|
||||||
|
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
|
||||||
|
raw = chat_mdl.chat(sys_prompt, msg[1:], {"temperature": 0.2})
|
||||||
|
if isinstance(raw, tuple):
|
||||||
|
raw = raw[0]
|
||||||
|
raw = re.sub(r"^.*</think>", "", raw, flags=re.DOTALL)
|
||||||
|
if raw.find("**ERROR**") >= 0:
|
||||||
|
return []
|
||||||
|
return parse_qa_pairs(raw)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_qa_pairs(text: str) -> list:
|
||||||
|
"""解析 LLM 返回的 QA 对文本,格式: Q: xxx A: xxx"""
|
||||||
|
pairs = []
|
||||||
|
for line in text.strip().split("\n"):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
# 匹配 Q: ... A: ... 格式
|
||||||
|
match = re.match(r'^Q:\s*(.+?)\s+A:\s*(.+)$', line, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
q, a = match.group(1).strip(), match.group(2).strip()
|
||||||
|
if q and a:
|
||||||
|
pairs.append({"question": q, "answer": a})
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
def graph_entity_types(chat_mdl, scenario):
|
def graph_entity_types(chat_mdl, scenario):
|
||||||
|
|||||||
@@ -1,19 +1,20 @@
|
|||||||
## Role
|
## Role
|
||||||
You are a text analyzer.
|
You are a text analyzer and knowledge extraction expert.
|
||||||
|
|
||||||
## Task
|
## Task
|
||||||
Propose {{ topn }} questions about a given piece of text content.
|
Generate question-answer pairs from the given text content.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
- Understand and summarize the text content, and propose the top {{ topn }} important questions.
|
- Understand and summarize the text content, then generate up to {{ topn }} important question-answer pairs.
|
||||||
|
- Each question-answer pair MUST be on a single line, formatted as: Q: <question> A: <answer>
|
||||||
- The questions SHOULD NOT have overlapping meanings.
|
- The questions SHOULD NOT have overlapping meanings.
|
||||||
- The questions SHOULD cover the main content of the text as much as possible.
|
- The questions SHOULD cover the main content of the text as much as possible.
|
||||||
- The questions MUST be in the same language as the given piece of text content.
|
- The answers MUST be concise, accurate, and directly derived from the text content.
|
||||||
- One question per line.
|
- The answers SHOULD be self-contained and understandable without additional context.
|
||||||
- Output questions ONLY.
|
- Both questions and answers MUST be in the same language as the given text content.
|
||||||
|
- If the text is too short or lacks substantive content, generate fewer pairs rather than padding.
|
||||||
---
|
- Output question-answer pairs ONLY, no extra explanation or commentary.
|
||||||
|
|
||||||
## Text Content
|
|
||||||
{{ content }}
|
|
||||||
|
|
||||||
|
## Example Output
|
||||||
|
Q: What is the capital of France? A: The capital of France is Paris.
|
||||||
|
Q: When was the Eiffel Tower built? A: The Eiffel Tower was built in 1889.
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ Transcribe the content from the provided PDF page image into clean Markdown form
|
|||||||
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
||||||
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
||||||
8. Preserve the original language, information, and order exactly as shown in the image.
|
8. Preserve the original language, information, and order exactly as shown in the image.
|
||||||
|
9. Your output language MUST match the language of the content in the image. If the image contains Chinese text, output in Chinese. If English, output in English. Never translate.
|
||||||
|
|
||||||
{% if page %}
|
{% if page %}
|
||||||
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.
|
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from elasticsearch import Elasticsearch, helpers
|
from elasticsearch import Elasticsearch, helpers, NotFoundError
|
||||||
from elasticsearch.helpers import BulkIndexError
|
from elasticsearch.helpers import BulkIndexError
|
||||||
from packaging.version import parse as parse_version
|
from packaging.version import parse as parse_version
|
||||||
# langchain-community
|
# langchain-community
|
||||||
@@ -53,13 +53,30 @@ class ElasticSearchVector(BaseVector):
|
|||||||
return "elasticsearch"
|
return "elasticsearch"
|
||||||
|
|
||||||
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
||||||
# 实现 Elasticsearch 保存向量
|
# QA chunks: embedding 只对 question 字段做;source chunks: 不做 embedding
|
||||||
texts = [chunk.page_content for chunk in chunks]
|
texts_for_embedding = []
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_type = (chunk.metadata or {}).get("chunk_type", "chunk")
|
||||||
|
if chunk_type == "source":
|
||||||
|
# source chunk 不需要向量索引
|
||||||
|
texts_for_embedding.append("")
|
||||||
|
elif chunk_type == "qa":
|
||||||
|
# QA chunk: 用 question 字段做 embedding
|
||||||
|
texts_for_embedding.append((chunk.metadata or {}).get("question", chunk.page_content))
|
||||||
|
else:
|
||||||
|
# 普通 chunk: 用 page_content 做 embedding
|
||||||
|
texts_for_embedding.append(chunk.page_content)
|
||||||
|
|
||||||
if self.is_multimodal_embedding:
|
if self.is_multimodal_embedding:
|
||||||
# 火山引擎多模态 Embedding
|
embeddings = self.embeddings.embed_batch(texts_for_embedding)
|
||||||
embeddings = self.embeddings.embed_batch(texts)
|
|
||||||
else:
|
else:
|
||||||
embeddings = self.embeddings.embed_documents(list(texts))
|
embeddings = self.embeddings.embed_documents(texts_for_embedding)
|
||||||
|
|
||||||
|
# source chunk 的向量置空
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
if (chunk.metadata or {}).get("chunk_type") == "source":
|
||||||
|
embeddings[i] = None
|
||||||
|
|
||||||
self.create(chunks, embeddings, **kwargs)
|
self.create(chunks, embeddings, **kwargs)
|
||||||
|
|
||||||
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
||||||
@@ -72,13 +89,25 @@ class ElasticSearchVector(BaseVector):
|
|||||||
uuids = self._get_uuids(chunks)
|
uuids = self._get_uuids(chunks)
|
||||||
actions = []
|
actions = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
|
source = {
|
||||||
|
Field.CONTENT_KEY.value: chunk.page_content,
|
||||||
|
Field.METADATA_KEY.value: chunk.metadata or {},
|
||||||
|
Field.VECTOR.value: embeddings[i] or None
|
||||||
|
}
|
||||||
|
# 写入 QA 相关字段
|
||||||
|
meta = chunk.metadata or {}
|
||||||
|
if meta.get("chunk_type"):
|
||||||
|
source[Field.CHUNK_TYPE.value] = meta["chunk_type"]
|
||||||
|
if meta.get("question"):
|
||||||
|
source[Field.QUESTION.value] = meta["question"]
|
||||||
|
if meta.get("answer"):
|
||||||
|
source[Field.ANSWER.value] = meta["answer"]
|
||||||
|
if meta.get("source_chunk_id"):
|
||||||
|
source[Field.SOURCE_CHUNK_ID.value] = meta["source_chunk_id"]
|
||||||
|
|
||||||
action = {
|
action = {
|
||||||
"_index": self._collection_name,
|
"_index": self._collection_name,
|
||||||
"_source": {
|
"_source": source
|
||||||
Field.CONTENT_KEY.value: chunk.page_content,
|
|
||||||
Field.METADATA_KEY.value: chunk.metadata or {},
|
|
||||||
Field.VECTOR.value: embeddings[i] or None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
actions.append(action)
|
actions.append(action)
|
||||||
# using bulk mode
|
# using bulk mode
|
||||||
@@ -113,7 +142,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]):
|
def delete_by_ids(self, ids: list[str], *, refresh: bool = False):
|
||||||
if not ids:
|
if not ids:
|
||||||
return
|
return
|
||||||
if not self._client.indices.exists(index=self._collection_name):
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
@@ -134,6 +163,8 @@ class ElasticSearchVector(BaseVector):
|
|||||||
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
||||||
try:
|
try:
|
||||||
helpers.bulk(self._client, actions)
|
helpers.bulk(self._client, actions)
|
||||||
|
if refresh:
|
||||||
|
self._client.indices.refresh(index=self._collection_name)
|
||||||
except BulkIndexError as e:
|
except BulkIndexError as e:
|
||||||
for error in e.errors:
|
for error in e.errors:
|
||||||
delete_error = error.get('delete', {})
|
delete_error = error.get('delete', {})
|
||||||
@@ -153,7 +184,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str):
|
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False):
|
||||||
if not self._client.indices.exists(index=self._collection_name):
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
return False
|
return False
|
||||||
actual_ids = self.get_ids_by_metadata_field(key, value)
|
actual_ids = self.get_ids_by_metadata_field(key, value)
|
||||||
@@ -162,6 +193,8 @@ class ElasticSearchVector(BaseVector):
|
|||||||
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
||||||
try:
|
try:
|
||||||
helpers.bulk(self._client, actions)
|
helpers.bulk(self._client, actions)
|
||||||
|
if refresh:
|
||||||
|
self._client.indices.refresh(index=self._collection_name)
|
||||||
except BulkIndexError as e:
|
except BulkIndexError as e:
|
||||||
for error in e.errors:
|
for error in e.errors:
|
||||||
delete_error = error.get('delete', {})
|
delete_error = error.get('delete', {})
|
||||||
@@ -192,6 +225,8 @@ class ElasticSearchVector(BaseVector):
|
|||||||
List of DocumentChunk objects that match the query.
|
List of DocumentChunk objects that match the query.
|
||||||
"""
|
"""
|
||||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3"
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3"
|
||||||
|
if not self._client.indices.exists(index=indices):
|
||||||
|
return 0, []
|
||||||
|
|
||||||
# Calculate the start position for the current page
|
# Calculate the start position for the current page
|
||||||
from_ = pagesize * (page-1)
|
from_ = pagesize * (page-1)
|
||||||
@@ -226,12 +261,15 @@ class ElasticSearchVector(BaseVector):
|
|||||||
})
|
})
|
||||||
|
|
||||||
# For simplicity, we use from/size here which has a limit (usually up to 10,000).
|
# For simplicity, we use from/size here which has a limit (usually up to 10,000).
|
||||||
result = self._client.search(
|
try:
|
||||||
index=indices,
|
result = self._client.search(
|
||||||
from_=from_, # Only use from_ for the first page (simplified)
|
index=indices,
|
||||||
size=pagesize,
|
from_=from_, # Only use from_ for the first page (simplified)
|
||||||
body=query_str,
|
size=pagesize,
|
||||||
)
|
body=query_str,
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
return 0, []
|
||||||
|
|
||||||
if "errors" in result:
|
if "errors" in result:
|
||||||
raise ValueError(f"Error during query: {result['errors']}")
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
@@ -241,10 +279,19 @@ class ElasticSearchVector(BaseVector):
|
|||||||
for res in result["hits"]["hits"]:
|
for res in result["hits"]["hits"]:
|
||||||
source = res["_source"]
|
source = res["_source"]
|
||||||
page_content = source.get(Field.CONTENT_KEY.value)
|
page_content = source.get(Field.CONTENT_KEY.value)
|
||||||
# vector = source.get(Field.VECTOR.value)
|
|
||||||
vector = None
|
vector = None
|
||||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||||
|
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
||||||
score = res["_score"]
|
score = res["_score"]
|
||||||
|
|
||||||
|
# 将 QA 字段注入 metadata 供前端展示
|
||||||
|
if chunk_type:
|
||||||
|
metadata["chunk_type"] = chunk_type
|
||||||
|
if chunk_type == "qa":
|
||||||
|
metadata["question"] = source.get(Field.QUESTION.value, "")
|
||||||
|
metadata["answer"] = source.get(Field.ANSWER.value, "")
|
||||||
|
page_content = f"Q: {metadata['question']}\nA: {metadata['answer']}"
|
||||||
|
|
||||||
docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score))
|
docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score))
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
@@ -267,13 +314,18 @@ class ElasticSearchVector(BaseVector):
|
|||||||
List of DocumentChunk objects that match the query.
|
List of DocumentChunk objects that match the query.
|
||||||
"""
|
"""
|
||||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||||
|
if not self._client.indices.exists(index=indices):
|
||||||
|
return 0, []
|
||||||
query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}}
|
query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}}
|
||||||
result = self._client.search(
|
try:
|
||||||
index=indices,
|
result = self._client.search(
|
||||||
from_=0, # Only use from_ for the first page (simplified)
|
index=indices,
|
||||||
size=1,
|
from_=0, # Only use from_ for the first page (simplified)
|
||||||
body=query_str,
|
size=1,
|
||||||
)
|
body=query_str,
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
return 0, []
|
||||||
# print(result)
|
# print(result)
|
||||||
if "errors" in result:
|
if "errors" in result:
|
||||||
raise ValueError(f"Error during query: {result['errors']}")
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
@@ -308,27 +360,43 @@ class ElasticSearchVector(BaseVector):
|
|||||||
Returns:
|
Returns:
|
||||||
updated count.
|
updated count.
|
||||||
"""
|
"""
|
||||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
indices = kwargs.get("indices", self._collection_name)
|
||||||
if self.is_multimodal_embedding:
|
chunk_type = (chunk.metadata or {}).get("chunk_type")
|
||||||
# 火山引擎多模态 Embedding
|
|
||||||
chunk.vector = self.embeddings.embed_text(chunk.page_content)
|
# QA chunk: embedding 基于 question;source chunk: 不更新向量
|
||||||
|
if chunk_type == "source":
|
||||||
|
embed_text = ""
|
||||||
|
elif chunk_type == "qa":
|
||||||
|
embed_text = (chunk.metadata or {}).get("question", chunk.page_content)
|
||||||
else:
|
else:
|
||||||
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
embed_text = chunk.page_content
|
||||||
|
|
||||||
|
if chunk_type != "source":
|
||||||
|
if self.is_multimodal_embedding:
|
||||||
|
chunk.vector = self.embeddings.embed_text(embed_text)
|
||||||
|
else:
|
||||||
|
chunk.vector = self.embeddings.embed_query(embed_text)
|
||||||
|
|
||||||
|
script_source = "ctx._source.page_content = params.new_content; ctx._source.vector = params.new_vector;"
|
||||||
|
params = {
|
||||||
|
"new_content": chunk.page_content,
|
||||||
|
"new_vector": chunk.vector if chunk_type != "source" else None
|
||||||
|
}
|
||||||
|
|
||||||
|
# QA chunk: 同时更新 question/answer 字段
|
||||||
|
if chunk_type == "qa":
|
||||||
|
script_source += " ctx._source.question = params.new_question; ctx._source.answer = params.new_answer;"
|
||||||
|
params["new_question"] = (chunk.metadata or {}).get("question", "")
|
||||||
|
params["new_answer"] = (chunk.metadata or {}).get("answer", "")
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
"script": {
|
"script": {
|
||||||
"source": """
|
"source": script_source,
|
||||||
ctx._source.page_content = params.new_content;
|
"params": params
|
||||||
ctx._source.vector = params.new_vector;
|
|
||||||
""",
|
|
||||||
"params": {
|
|
||||||
"new_content": chunk.page_content,
|
|
||||||
"new_vector": chunk.vector
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"query": {
|
"query": {
|
||||||
"term": {
|
"term": {
|
||||||
Field.DOC_ID.value: chunk.metadata["doc_id"] # exact match doc_id
|
Field.DOC_ID.value: chunk.metadata["doc_id"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -336,9 +404,6 @@ class ElasticSearchVector(BaseVector):
|
|||||||
index=indices,
|
index=indices,
|
||||||
body=body,
|
body=body,
|
||||||
)
|
)
|
||||||
# Remove debug printing and use logging instead
|
|
||||||
# print(result)
|
|
||||||
# print(f"Update successful, number of affected documents: {result['updated']}")
|
|
||||||
return result['updated']
|
return result['updated']
|
||||||
|
|
||||||
def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str:
|
def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str:
|
||||||
@@ -397,11 +462,11 @@ class ElasticSearchVector(BaseVector):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": { # Add the filter condition of status=1
|
"filter": [
|
||||||
"term": {
|
{"term": {"metadata.status": 1}},
|
||||||
"metadata.status": 1
|
# 排除 source chunk(仅供 GraphRAG 使用,不参与检索)
|
||||||
}
|
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
||||||
}
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# If file_names_filter is passed in, merge the filtering conditions
|
# If file_names_filter is passed in, merge the filtering conditions
|
||||||
@@ -415,22 +480,14 @@ class ElasticSearchVector(BaseVector):
|
|||||||
},
|
},
|
||||||
"script": {
|
"script": {
|
||||||
"source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0",
|
"source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0",
|
||||||
# The script_score query calculates the cosine similarity between the embedding field of each document and the query vector. The addition of +1.0 is to ensure that the scores returned by the script are non-negative, as the range of cosine similarity is [-1, 1]
|
|
||||||
"params": {"query_vector": query_vector}
|
"params": {"query_vector": query_vector}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": [
|
"filter": [
|
||||||
{
|
{"term": {"metadata.status": 1}},
|
||||||
"term": {
|
{"terms": {"metadata.file_name": file_names_filter}},
|
||||||
"metadata.status": 1
|
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"terms": {
|
|
||||||
"metadata.file_name": file_names_filter # Additional file_name filtering
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -451,8 +508,19 @@ class ElasticSearchVector(BaseVector):
|
|||||||
source = res["_source"]
|
source = res["_source"]
|
||||||
page_content = source.get(Field.CONTENT_KEY.value)
|
page_content = source.get(Field.CONTENT_KEY.value)
|
||||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||||
|
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
||||||
score = res["_score"]
|
score = res["_score"]
|
||||||
score = score / 2 # Normalized [0-1]
|
score = score / 2 # Normalized [0-1]
|
||||||
|
|
||||||
|
# QA chunk: 返回 Q+A 拼接作为上下文
|
||||||
|
if chunk_type == "qa":
|
||||||
|
question = source.get(Field.QUESTION.value, "")
|
||||||
|
answer = source.get(Field.ANSWER.value, "")
|
||||||
|
page_content = f"Q: {question}\nA: {answer}"
|
||||||
|
metadata["chunk_type"] = "qa"
|
||||||
|
metadata["question"] = question
|
||||||
|
metadata["answer"] = answer
|
||||||
|
|
||||||
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score))
|
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score))
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
@@ -491,11 +559,10 @@ class ElasticSearchVector(BaseVector):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": { # Add the filter condition of status=1
|
"filter": [
|
||||||
"term": {
|
{"term": {"metadata.status": 1}},
|
||||||
"metadata.status": 1
|
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
||||||
}
|
]
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -512,16 +579,9 @@ class ElasticSearchVector(BaseVector):
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": [
|
"filter": [
|
||||||
{
|
{"term": {"metadata.status": 1}},
|
||||||
"term": {
|
{"terms": {"metadata.file_name": file_names_filter}},
|
||||||
"metadata.status": 1
|
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"terms": {
|
|
||||||
"metadata.file_name": file_names_filter # Additional file_name filtering
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -543,6 +603,17 @@ class ElasticSearchVector(BaseVector):
|
|||||||
source = res["_source"]
|
source = res["_source"]
|
||||||
page_content = source.get(Field.CONTENT_KEY.value)
|
page_content = source.get(Field.CONTENT_KEY.value)
|
||||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||||
|
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
||||||
|
|
||||||
|
# QA chunk: 返回 Q+A 拼接作为上下文
|
||||||
|
if chunk_type == "qa":
|
||||||
|
question = source.get(Field.QUESTION.value, "")
|
||||||
|
answer = source.get(Field.ANSWER.value, "")
|
||||||
|
page_content = f"Q: {question}\nA: {answer}"
|
||||||
|
metadata["chunk_type"] = "qa"
|
||||||
|
metadata["question"] = question
|
||||||
|
metadata["answer"] = answer
|
||||||
|
|
||||||
# Normalize the score to the [0,1] interval
|
# Normalize the score to the [0,1] interval
|
||||||
normalized_score = res["_score"] / max_score
|
normalized_score = res["_score"] / max_score
|
||||||
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score))
|
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score))
|
||||||
@@ -652,7 +723,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
},
|
},
|
||||||
Field.VECTOR.value: {
|
Field.VECTOR.value: {
|
||||||
"type": "dense_vector",
|
"type": "dense_vector",
|
||||||
"dims": len(embeddings[0]), # Make sure the dimension is correct here,The dimension size of the vector. When index is true, it cannot exceed 1024; when index is false or not specified, it cannot exceed 2048, which can improve retrieval efficiency
|
"dims": len(next((e for e in embeddings if e is not None), [0]*768)), # 跳过 None 获取向量维度,fallback 768
|
||||||
"index": True,
|
"index": True,
|
||||||
"similarity": "cosine"
|
"similarity": "cosine"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,3 +14,8 @@ class Field(StrEnum):
|
|||||||
DOCUMENT_ID = "metadata.document_id"
|
DOCUMENT_ID = "metadata.document_id"
|
||||||
KNOWLEDGE_ID = "metadata.knowledge_id"
|
KNOWLEDGE_ID = "metadata.knowledge_id"
|
||||||
SORT_ID = "metadata.sort_id"
|
SORT_ID = "metadata.sort_id"
|
||||||
|
# QA fields
|
||||||
|
CHUNK_TYPE = "chunk_type" # "chunk" | "source" | "qa"
|
||||||
|
QUESTION = "question"
|
||||||
|
ANSWER = "answer"
|
||||||
|
SOURCE_CHUNK_ID = "source_chunk_id"
|
||||||
|
|||||||
@@ -27,14 +27,14 @@ class BaseVector(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_ids(self, ids: list[str]):
|
def delete_by_ids(self, ids: list[str], *, refresh: bool = False):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_metadata_field(self, key: str, value: str):
|
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ class CustomTool(BaseTool):
|
|||||||
# 添加通用参数(基于第一个操作的参数)
|
# 添加通用参数(基于第一个操作的参数)
|
||||||
if self._parsed_operations:
|
if self._parsed_operations:
|
||||||
first_operation = next(iter(self._parsed_operations.values()))
|
first_operation = next(iter(self._parsed_operations.values()))
|
||||||
|
# path/query 参数
|
||||||
for param_name, param_info in first_operation.get("parameters", {}).items():
|
for param_name, param_info in first_operation.get("parameters", {}).items():
|
||||||
params.append(ToolParameter(
|
params.append(ToolParameter(
|
||||||
name=param_name,
|
name=param_name,
|
||||||
@@ -85,6 +86,23 @@ class CustomTool(BaseTool):
|
|||||||
maximum=param_info.get("maximum"),
|
maximum=param_info.get("maximum"),
|
||||||
pattern=param_info.get("pattern")
|
pattern=param_info.get("pattern")
|
||||||
))
|
))
|
||||||
|
# requestBody 参数 — 将 body 字段平铺为独立参数暴露给模型
|
||||||
|
request_body = first_operation.get("request_body")
|
||||||
|
if request_body:
|
||||||
|
body_schema = request_body.get("properties", {})
|
||||||
|
required_fields = request_body.get("required", [])
|
||||||
|
for prop_name, prop_schema in body_schema.items():
|
||||||
|
params.append(ToolParameter(
|
||||||
|
name=prop_name,
|
||||||
|
type=self._convert_openapi_type(prop_schema.get("type", "string")),
|
||||||
|
description=prop_schema.get("description", ""),
|
||||||
|
required=prop_name in required_fields,
|
||||||
|
default=prop_schema.get("default"),
|
||||||
|
enum=prop_schema.get("enum"),
|
||||||
|
minimum=prop_schema.get("minimum"),
|
||||||
|
maximum=prop_schema.get("maximum"),
|
||||||
|
pattern=prop_schema.get("pattern")
|
||||||
|
))
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from app.core.workflow.engine.runtime_schema import ExecutionContext
|
|||||||
from app.core.workflow.engine.state_manager import WorkflowStateManager
|
from app.core.workflow.engine.state_manager import WorkflowStateManager
|
||||||
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
|
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
|
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
|
||||||
|
from app.core.workflow.nodes.base_node import NodeExecutionError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -326,10 +327,43 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||||
exc_info=True)
|
exc_info=True)
|
||||||
|
|
||||||
|
# 1) 尝试从 checkpoint 回补已成功节点的 node_outputs
|
||||||
|
recovered: dict[str, Any] = {}
|
||||||
|
try:
|
||||||
|
if self.graph is not None:
|
||||||
|
recovered = self.graph.get_state(
|
||||||
|
self.execution_context.checkpoint_config
|
||||||
|
).values or {}
|
||||||
|
except Exception as recover_err:
|
||||||
|
logger.warning(
|
||||||
|
f"Recover state on failure failed: {recover_err}, "
|
||||||
|
f"execution_id={self.execution_context.execution_id}"
|
||||||
|
)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
result = {"error": str(e)}
|
result = dict(recovered) if recovered else {}
|
||||||
else:
|
else:
|
||||||
result["error"] = str(e)
|
# 已有 result 与 recovered 合并,node_outputs 深度合并
|
||||||
|
for k, v in recovered.items():
|
||||||
|
if k == "node_outputs" and isinstance(v, dict):
|
||||||
|
existing = result.get("node_outputs") or {}
|
||||||
|
result["node_outputs"] = {**v, **existing}
|
||||||
|
else:
|
||||||
|
result.setdefault(k, v)
|
||||||
|
|
||||||
|
# 2) 如果是节点抛出的 NodeExecutionError,把失败节点的 node_output 注入 node_outputs
|
||||||
|
failed_node_id: str | None = None
|
||||||
|
if isinstance(e, NodeExecutionError):
|
||||||
|
failed_node_id = e.node_id
|
||||||
|
node_outputs = result.setdefault("node_outputs", {})
|
||||||
|
# 不覆盖已有(理论上不会有),保底写入失败节点记录
|
||||||
|
node_outputs.setdefault(e.node_id, e.node_output)
|
||||||
|
|
||||||
|
result["error"] = str(e)
|
||||||
|
if failed_node_id:
|
||||||
|
result["error_node"] = failed_node_id
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": self.result_builder.build_final_output(
|
"data": self.result_builder.build_final_output(
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -22,6 +23,20 @@ from app.services.multimodal_service import MultimodalService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NodeExecutionError(Exception):
|
||||||
|
"""节点执行失败异常。
|
||||||
|
|
||||||
|
携带失败节点的完整 node_output,供 executor 兜底注入 node_outputs,
|
||||||
|
保证 workflow_executions.output_data 里能看到失败节点的日志记录。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node_id: str, node_output: dict[str, Any], error_message: str):
|
||||||
|
super().__init__(f"Node {node_id} execution failed: {error_message}")
|
||||||
|
self.node_id = node_id
|
||||||
|
self.node_output = node_output
|
||||||
|
self.error_message = error_message
|
||||||
|
|
||||||
|
|
||||||
class BaseNode(ABC):
|
class BaseNode(ABC):
|
||||||
"""Base class for workflow nodes.
|
"""Base class for workflow nodes.
|
||||||
|
|
||||||
@@ -396,6 +411,8 @@ class BaseNode(ABC):
|
|||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"token_usage": token_usage,
|
"token_usage": token_usage,
|
||||||
"error": None,
|
"error": None,
|
||||||
|
# 单调递增序号,用于日志按执行顺序排序(JSONB 不保证 key 顺序)
|
||||||
|
"execution_order": time.monotonic_ns(),
|
||||||
**self._extract_extra_fields(business_result),
|
**self._extract_extra_fields(business_result),
|
||||||
}
|
}
|
||||||
final_output = {
|
final_output = {
|
||||||
@@ -444,7 +461,9 @@ class BaseNode(ABC):
|
|||||||
"output": None,
|
"output": None,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"token_usage": None,
|
"token_usage": None,
|
||||||
"error": error_message
|
"error": error_message,
|
||||||
|
# 单调递增序号,用于日志按执行顺序排序
|
||||||
|
"execution_order": time.monotonic_ns(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# if error_edge:
|
# if error_edge:
|
||||||
@@ -466,7 +485,12 @@ class BaseNode(ABC):
|
|||||||
**node_output
|
**node_output
|
||||||
})
|
})
|
||||||
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
||||||
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
# 抛出自定义异常,把 node_output 带给 executor,供其写入 node_outputs
|
||||||
|
raise NodeExecutionError(
|
||||||
|
node_id=self.node_id,
|
||||||
|
node_output=node_output,
|
||||||
|
error_message=error_message,
|
||||||
|
)
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""Extracts the input data for this node (used for logging or audit).
|
"""Extracts the input data for this node (used for logging or audit).
|
||||||
|
|||||||
@@ -174,12 +174,18 @@ class IterationRuntime:
|
|||||||
continue
|
continue
|
||||||
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
||||||
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
|
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
|
||||||
|
node_cfg = next(
|
||||||
|
(n for n in self.cycle_nodes if n.get("id") == node_name), None
|
||||||
|
)
|
||||||
self.event_write({
|
self.event_write({
|
||||||
"type": "cycle_item",
|
"type": "cycle_item",
|
||||||
"data": {
|
"data": {
|
||||||
"cycle_id": self.node_id,
|
"cycle_id": self.node_id,
|
||||||
"cycle_idx": idx,
|
"cycle_idx": idx,
|
||||||
"node_id": node_name,
|
"node_id": node_name,
|
||||||
|
"node_type": node_type,
|
||||||
|
"node_name": node_cfg.get("data", {}).get("label") if node_cfg else node_name,
|
||||||
|
"status": result.get("node_outputs", {}).get(node_name, {}).get("status", "completed"),
|
||||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||||
if not cycle_variable else cycle_variable,
|
if not cycle_variable else cycle_variable,
|
||||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||||
|
|||||||
@@ -210,6 +210,9 @@ class LoopRuntime:
|
|||||||
"cycle_id": self.node_id,
|
"cycle_id": self.node_id,
|
||||||
"cycle_idx": idx,
|
"cycle_idx": idx,
|
||||||
"node_id": node_name,
|
"node_id": node_name,
|
||||||
|
"node_type": node_type,
|
||||||
|
"node_name": node_name,
|
||||||
|
"status": result.get("node_outputs", {}).get(node_name, {}).get("status", "completed"),
|
||||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||||
if not cycle_variable else cycle_variable,
|
if not cycle_variable else cycle_variable,
|
||||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||||
|
|||||||
@@ -272,6 +272,11 @@ class HttpRequestNodeOutput(BaseModel):
|
|||||||
description="HTTP response body",
|
description="HTTP response body",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
process_data: dict = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Raw HTTP request details for debugging",
|
||||||
|
)
|
||||||
|
|
||||||
# files: list[File] = Field(
|
# files: list[File] = Field(
|
||||||
# ...
|
# ...
|
||||||
# )
|
# )
|
||||||
|
|||||||
@@ -160,7 +160,6 @@ class HttpRequestNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: HttpRequestNodeConfig | None = None
|
self.typed_config: HttpRequestNodeConfig | None = None
|
||||||
self.last_request: str = ""
|
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {
|
return {
|
||||||
@@ -171,47 +170,6 @@ class HttpRequestNode(BaseNode):
|
|||||||
"output": VariableType.STRING
|
"output": VariableType.STRING
|
||||||
}
|
}
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> Any:
|
|
||||||
if isinstance(business_result, dict):
|
|
||||||
result = {k: v for k, v in business_result.items() if k != "request"}
|
|
||||||
return result
|
|
||||||
return business_result
|
|
||||||
|
|
||||||
def _extract_extra_fields(self, business_result: Any) -> dict[str, Any]:
|
|
||||||
if isinstance(business_result, dict) and "request" in business_result:
|
|
||||||
return {
|
|
||||||
"process": {
|
|
||||||
"request": business_result.get("request", "")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _wrap_error(
|
|
||||||
self,
|
|
||||||
error_message: str,
|
|
||||||
elapsed_time: float,
|
|
||||||
state: WorkflowState,
|
|
||||||
variable_pool: VariablePool
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
input_data = self._extract_input(state, variable_pool)
|
|
||||||
node_output = {
|
|
||||||
"node_id": self.node_id,
|
|
||||||
"node_type": self.node_type,
|
|
||||||
"node_name": self.node_name,
|
|
||||||
"status": "failed",
|
|
||||||
"input": input_data,
|
|
||||||
"output": None,
|
|
||||||
"process": {"request": self.last_request} if self.last_request else None,
|
|
||||||
"elapsed_time": elapsed_time,
|
|
||||||
"token_usage": None,
|
|
||||||
"error": error_message
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
"node_outputs": {self.node_id: node_output},
|
|
||||||
"error": error_message,
|
|
||||||
"error_node": self.node_id
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_timeout(self) -> Timeout:
|
def _build_timeout(self) -> Timeout:
|
||||||
"""
|
"""
|
||||||
Build httpx Timeout configuration.
|
Build httpx Timeout configuration.
|
||||||
@@ -297,13 +255,18 @@ class HttpRequestNode(BaseNode):
|
|||||||
case HttpContentType.NONE:
|
case HttpContentType.NONE:
|
||||||
return {}
|
return {}
|
||||||
case HttpContentType.JSON:
|
case HttpContentType.JSON:
|
||||||
rendered_body = self._render_template(
|
rendered = self._render_template(
|
||||||
self.typed_config.body.data, variable_pool
|
self.typed_config.body.data, variable_pool
|
||||||
).strip()
|
)
|
||||||
if not rendered_body:
|
if not rendered or not rendered.strip():
|
||||||
content["json"] = {}
|
# 第三方导入的工作流可能出现 content_type=json 但 data 为空的情况,视为无 body
|
||||||
else:
|
return {}
|
||||||
content["json"] = json.loads(rendered_body)
|
try:
|
||||||
|
content["json"] = json.loads(rendered)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Invalid JSON body for HTTP request node: {e.msg} (data={rendered!r})"
|
||||||
|
)
|
||||||
case HttpContentType.FROM_DATA:
|
case HttpContentType.FROM_DATA:
|
||||||
data = {}
|
data = {}
|
||||||
files = []
|
files = []
|
||||||
@@ -371,61 +334,15 @@ class HttpRequestNode(BaseNode):
|
|||||||
case _:
|
case _:
|
||||||
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
|
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
|
||||||
|
|
||||||
def _generate_raw_request(
|
def _extract_output(self, business_result: Any) -> Any:
|
||||||
self,
|
if isinstance(business_result, dict):
|
||||||
variable_pool: VariablePool,
|
return {k: v for k, v in business_result.items() if k != "process_data"}
|
||||||
url: str,
|
return business_result
|
||||||
headers: dict[str, str],
|
|
||||||
params: dict[str, str],
|
|
||||||
content: dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Generate raw HTTP request format for debugging.
|
|
||||||
|
|
||||||
Args:
|
def _extract_extra_fields(self, business_result: Any) -> dict:
|
||||||
variable_pool: Variable Pool
|
if isinstance(business_result, dict) and "process_data" in business_result:
|
||||||
url: Rendered URL
|
return {"process": business_result["process_data"]}
|
||||||
headers: Request headers
|
return {}
|
||||||
params: Query parameters
|
|
||||||
content: Request body content
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Raw HTTP request string
|
|
||||||
"""
|
|
||||||
method = self.typed_config.method.value
|
|
||||||
|
|
||||||
if params:
|
|
||||||
param_str = "&".join([f"{k}={v}" for k, v in params.items()])
|
|
||||||
full_url = f"{url}?{param_str}" if "?" not in url else f"{url}&{param_str}"
|
|
||||||
else:
|
|
||||||
full_url = url
|
|
||||||
|
|
||||||
lines = [f"{method} {full_url} HTTP/1.1"]
|
|
||||||
|
|
||||||
for key, value in headers.items():
|
|
||||||
lines.append(f"{key}: {value}")
|
|
||||||
|
|
||||||
if "json" in content and content["json"]:
|
|
||||||
json_body = json.dumps(content["json"], ensure_ascii=False)
|
|
||||||
lines.append(f"Content-Length: {len(json_body)}")
|
|
||||||
lines.append("")
|
|
||||||
lines.append(json_body)
|
|
||||||
elif "data" in content and "files" not in content:
|
|
||||||
if isinstance(content["data"], dict):
|
|
||||||
body_str = "&".join([f"{k}={v}" for k, v in content["data"].items()])
|
|
||||||
lines.append(f"Content-Length: {len(body_str)}")
|
|
||||||
lines.append("")
|
|
||||||
lines.append(body_str)
|
|
||||||
elif "content" in content:
|
|
||||||
lines.append(f"Content-Length: {len(content['content'])}")
|
|
||||||
lines.append("")
|
|
||||||
lines.append(content["content"])
|
|
||||||
elif "files" in content:
|
|
||||||
lines.append("Content-Length: 0")
|
|
||||||
lines.append("")
|
|
||||||
lines.append("# Note: This request includes file uploads")
|
|
||||||
|
|
||||||
return "\r\n".join(lines)
|
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
|
||||||
"""
|
"""
|
||||||
@@ -445,47 +362,42 @@ class HttpRequestNode(BaseNode):
|
|||||||
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
||||||
"""
|
"""
|
||||||
self.typed_config = HttpRequestNodeConfig(**self.config)
|
self.typed_config = HttpRequestNodeConfig(**self.config)
|
||||||
|
rendered_url = self._render_template(self.typed_config.url, variable_pool)
|
||||||
# Build request components
|
built_headers = self._build_header(variable_pool) | self._build_auth(variable_pool)
|
||||||
headers = self._build_header(variable_pool) | self._build_auth(variable_pool)
|
built_params = self._build_params(variable_pool)
|
||||||
params = self._build_params(variable_pool)
|
|
||||||
content = await self._build_content(variable_pool)
|
|
||||||
url = self._render_template(self.typed_config.url, variable_pool)
|
|
||||||
|
|
||||||
logger.info(f"Node {self.node_id}: headers={headers}, params={params}, content keys={list(content.keys())}")
|
|
||||||
|
|
||||||
# Generate raw HTTP request for debugging
|
|
||||||
raw_request = self._generate_raw_request(variable_pool, url, headers, params, content)
|
|
||||||
self.last_request = raw_request
|
|
||||||
logger.info(f"Node {self.node_id}: Generated HTTP request:\n{raw_request}")
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
verify=self.typed_config.verify_ssl,
|
verify=self.typed_config.verify_ssl,
|
||||||
timeout=self._build_timeout(),
|
timeout=self._build_timeout(),
|
||||||
headers=headers,
|
headers=built_headers,
|
||||||
params=params,
|
params=built_params,
|
||||||
follow_redirects=True
|
follow_redirects=True
|
||||||
) as client:
|
) as client:
|
||||||
retries = self.typed_config.retry.max_attempts
|
retries = self.typed_config.retry.max_attempts
|
||||||
while retries > 0:
|
while retries > 0:
|
||||||
try:
|
try:
|
||||||
request_func = self._get_client_method(client)
|
request_func = self._get_client_method(client)
|
||||||
|
built_content = await self._build_content(variable_pool)
|
||||||
resp = await request_func(
|
resp = await request_func(
|
||||||
url=url,
|
url=rendered_url,
|
||||||
**content
|
**built_content
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||||
response = HttpResponse(resp)
|
response = HttpResponse(resp)
|
||||||
return {
|
# Build raw request summary for process_data
|
||||||
**HttpRequestNodeOutput(
|
raw_request = (
|
||||||
body=response.body,
|
f"{self.typed_config.method.upper()} {resp.request.url} HTTP/1.1\r\n"
|
||||||
status_code=resp.status_code,
|
+ "".join(f"{k}: {v}\r\n" for k, v in resp.request.headers.items())
|
||||||
headers=resp.headers,
|
+ "\r\n"
|
||||||
files=response.files
|
+ (resp.request.content.decode(errors="replace") if resp.request.content else "")
|
||||||
).model_dump(),
|
)
|
||||||
"request": raw_request
|
return HttpRequestNodeOutput(
|
||||||
}
|
body=response.body,
|
||||||
|
status_code=resp.status_code,
|
||||||
|
headers=resp.headers,
|
||||||
|
files=response.files,
|
||||||
|
process_data={"request": raw_request},
|
||||||
|
).model_dump()
|
||||||
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
||||||
logger.error(f"HTTP request node exception: {e}")
|
logger.error(f"HTTP request node exception: {e}")
|
||||||
retries -= 1
|
retries -= 1
|
||||||
@@ -501,19 +413,10 @@ class HttpRequestNode(BaseNode):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||||
)
|
)
|
||||||
error_result = self.typed_config.error_handle.default.model_dump()
|
return self.typed_config.error_handle.default.model_dump()
|
||||||
error_result["request"] = raw_request
|
|
||||||
return error_result
|
|
||||||
case HttpErrorHandle.BRANCH:
|
case HttpErrorHandle.BRANCH:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||||
)
|
)
|
||||||
return {
|
return {"output": "ERROR"}
|
||||||
"output": "ERROR",
|
|
||||||
"body": "",
|
|
||||||
"status_code": 500,
|
|
||||||
"headers": {},
|
|
||||||
"files": [],
|
|
||||||
"request": raw_request
|
|
||||||
}
|
|
||||||
raise RuntimeError("http request failed")
|
raise RuntimeError("http request failed")
|
||||||
|
|||||||
@@ -334,7 +334,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
for kb_config in knowledge_bases:
|
for kb_config in knowledge_bases:
|
||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
logger.warning("The knowledge base does not exist or access is denied.")
|
||||||
|
continue
|
||||||
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
|
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
|
||||||
if tasks:
|
if tasks:
|
||||||
result = await asyncio.gather(*tasks)
|
result = await asyncio.gather(*tasks)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from app.celery_task_scheduler import scheduler
|
||||||
from app.core.memory.enums import SearchStrategy
|
from app.core.memory.enums import SearchStrategy
|
||||||
from app.core.memory.memory_service import MemoryService
|
from app.core.memory.memory_service import MemoryService
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
@@ -11,7 +12,6 @@ from app.core.workflow.variable.base_variable import VariableType
|
|||||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||||
from app.db import get_db_read
|
from app.db import get_db_read
|
||||||
from app.schemas import FileInput
|
from app.schemas import FileInput
|
||||||
from app.tasks import write_message_task
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryReadNode(BaseNode):
|
class MemoryReadNode(BaseNode):
|
||||||
@@ -126,12 +126,23 @@ class MemoryWriteNode(BaseNode):
|
|||||||
"files": file_info
|
"files": file_info
|
||||||
})
|
})
|
||||||
|
|
||||||
write_message_task.delay(
|
scheduler.push_task(
|
||||||
end_user_id=end_user_id,
|
"app.core.memory.agent.write_message",
|
||||||
message=messages,
|
end_user_id,
|
||||||
config_id=str(self.typed_config.config_id),
|
{
|
||||||
storage_type=state["memory_storage_type"],
|
"end_user_id": end_user_id,
|
||||||
user_rag_memory_id=state["user_rag_memory_id"]
|
"message": messages,
|
||||||
|
"config_id": str(self.typed_config.config_id),
|
||||||
|
"storage_type": state["memory_storage_type"],
|
||||||
|
"user_rag_memory_id": state["user_rag_memory_id"]
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
# write_message_task.delay(
|
||||||
|
# end_user_id=end_user_id,
|
||||||
|
# message=messages,
|
||||||
|
# config_id=str(self.typed_config.config_id),
|
||||||
|
# storage_type=state["memory_storage_type"],
|
||||||
|
# user_rag_memory_id=state["user_rag_memory_id"]
|
||||||
|
# )
|
||||||
|
|
||||||
return "success"
|
return "success"
|
||||||
|
|||||||
@@ -15,4 +15,5 @@ class File(Base):
|
|||||||
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
|
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
|
||||||
file_size = Column(Integer, default=0, comment="file size(byte)")
|
file_size = Column(Integer, default=0, comment="file size(byte)")
|
||||||
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
|
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
|
||||||
|
file_key = Column(String(512), nullable=True, index=True, comment="storage file key for FileStorageService")
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||||
@@ -1,13 +1,15 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import select, desc, func
|
from sqlalchemy import select, desc, func, or_, cast, Text
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.exceptions import ResourceNotFoundException
|
from app.core.exceptions import ResourceNotFoundException
|
||||||
from app.core.logging_config import get_db_logger
|
from app.core.logging_config import get_db_logger
|
||||||
from app.models import Conversation, Message
|
from app.models import Conversation, Message
|
||||||
|
from app.models.app_model import AppType
|
||||||
from app.models.conversation_model import ConversationDetail
|
from app.models.conversation_model import ConversationDetail
|
||||||
|
from app.models.workflow_model import WorkflowExecution
|
||||||
|
|
||||||
logger = get_db_logger()
|
logger = get_db_logger()
|
||||||
|
|
||||||
@@ -206,7 +208,8 @@ class ConversationRepository:
|
|||||||
is_draft: Optional[bool] = None,
|
is_draft: Optional[bool] = None,
|
||||||
keyword: Optional[str] = None,
|
keyword: Optional[str] = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
pagesize: int = 20
|
pagesize: int = 20,
|
||||||
|
app_type: Optional[str] = None,
|
||||||
) -> tuple[list[Conversation], int]:
|
) -> tuple[list[Conversation], int]:
|
||||||
"""
|
"""
|
||||||
查询应用日志会话列表(带分页和过滤)
|
查询应用日志会话列表(带分页和过滤)
|
||||||
@@ -218,6 +221,9 @@ class ConversationRepository:
|
|||||||
keyword: 搜索关键词(匹配消息内容)
|
keyword: 搜索关键词(匹配消息内容)
|
||||||
page: 页码(从 1 开始)
|
page: 页码(从 1 开始)
|
||||||
pagesize: 每页数量
|
pagesize: 每页数量
|
||||||
|
app_type: 应用类型。WORKFLOW 类型改用 workflow_executions 的
|
||||||
|
input_data/output_data 做关键词过滤(因为失败的工作流不会写入 messages 表);
|
||||||
|
其他类型仍走 messages 表。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[List[Conversation], int]: (会话列表,总数)
|
Tuple[List[Conversation], int]: (会话列表,总数)
|
||||||
@@ -234,12 +240,28 @@ class ConversationRepository:
|
|||||||
|
|
||||||
# 如果有关键词搜索,通过子查询过滤包含该关键词的 conversation
|
# 如果有关键词搜索,通过子查询过滤包含该关键词的 conversation
|
||||||
if keyword:
|
if keyword:
|
||||||
# 查找包含关键词的 conversation_id 列表
|
kw_pattern = f"%{keyword}%"
|
||||||
keyword_stmt = (
|
if app_type == AppType.WORKFLOW:
|
||||||
select(Message.conversation_id)
|
# 工作流:从 workflow_executions 的 input_data / output_data 匹配
|
||||||
.where(Message.content.ilike(f"%{keyword}%"))
|
# (messages 表只存开场白 assistant 消息,失败的工作流也不会写入)
|
||||||
.distinct()
|
keyword_stmt = (
|
||||||
)
|
select(WorkflowExecution.conversation_id)
|
||||||
|
.where(
|
||||||
|
WorkflowExecution.conversation_id.is_not(None),
|
||||||
|
or_(
|
||||||
|
cast(WorkflowExecution.input_data, Text).ilike(kw_pattern),
|
||||||
|
cast(WorkflowExecution.output_data, Text).ilike(kw_pattern),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.distinct()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Agent 等其他类型:仍走 messages 表(user + assistant 内容)
|
||||||
|
keyword_stmt = (
|
||||||
|
select(Message.conversation_id)
|
||||||
|
.where(Message.content.ilike(kw_pattern))
|
||||||
|
.distinct()
|
||||||
|
)
|
||||||
base_stmt = base_stmt.where(Conversation.id.in_(keyword_stmt))
|
base_stmt = base_stmt.where(Conversation.id.in_(keyword_stmt))
|
||||||
|
|
||||||
# Calculate total number of records
|
# Calculate total number of records
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ class AppLogMessage(BaseModel):
|
|||||||
conversation_id: uuid.UUID
|
conversation_id: uuid.UUID
|
||||||
role: str = Field(description="角色: user / assistant / system")
|
role: str = Field(description="角色: user / assistant / system")
|
||||||
content: str
|
content: str
|
||||||
|
status: Optional[str] = Field(default=None, description="执行状态(工作流专用): completed / failed")
|
||||||
meta_data: Optional[Dict[str, Any]] = None
|
meta_data: Optional[Dict[str, Any]] = None
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
|
|
||||||
@@ -58,6 +59,7 @@ class AppLogNodeExecution(BaseModel):
|
|||||||
input: Optional[Any] = None
|
input: Optional[Any] = None
|
||||||
process: Optional[Any] = None
|
process: Optional[Any] = None
|
||||||
output: Optional[Any] = None
|
output: Optional[Any] = None
|
||||||
|
cycle_items: Optional[List[Any]] = None
|
||||||
elapsed_time: Optional[float] = None
|
elapsed_time: Optional[float] = None
|
||||||
token_usage: Optional[Dict[str, Any]] = None
|
token_usage: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import uuid
|
|||||||
from typing import Optional, Any, List, Dict, Union
|
from typing import Optional, Any, List, Dict, Union
|
||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator, model_serializer
|
||||||
|
|
||||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||||
|
|
||||||
@@ -661,9 +661,11 @@ class DraftRunResponse(BaseModel):
|
|||||||
suggested_questions: List[str] = Field(default_factory=list, description="下一步建议问题")
|
suggested_questions: List[str] = Field(default_factory=list, description="下一步建议问题")
|
||||||
citations: List[Dict[str, Any]] = Field(default_factory=list, description="引用来源")
|
citations: List[Dict[str, Any]] = Field(default_factory=list, description="引用来源")
|
||||||
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
|
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
|
||||||
|
audio_status: Optional[str] = Field(default=None, description="TTS 语音状态")
|
||||||
|
|
||||||
def model_dump(self, **kwargs):
|
@model_serializer(mode="wrap")
|
||||||
data = super().model_dump(**kwargs)
|
def _serialize(self, handler):
|
||||||
|
data = handler(self)
|
||||||
if not data.get("reasoning_content"):
|
if not data.get("reasoning_content"):
|
||||||
data.pop("reasoning_content", None)
|
data.pop("reasoning_content", None)
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -20,13 +20,26 @@ class ChunkCreate(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def chunk_content(self) -> str:
|
def chunk_content(self) -> str:
|
||||||
"""
|
"""Get the actual content string regardless of input type"""
|
||||||
Get the actual content string regardless of input type
|
|
||||||
"""
|
|
||||||
if isinstance(self.content, QAChunk):
|
if isinstance(self.content, QAChunk):
|
||||||
return f"question: {self.content.question} answer: {self.content.answer}"
|
return self.content.question # QA 模式下 page_content 存 question
|
||||||
return self.content
|
return self.content
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_qa(self) -> bool:
|
||||||
|
return isinstance(self.content, QAChunk)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def qa_metadata(self) -> dict:
|
||||||
|
"""返回 QA 相关的 metadata 字段"""
|
||||||
|
if isinstance(self.content, QAChunk):
|
||||||
|
return {
|
||||||
|
"chunk_type": "qa",
|
||||||
|
"question": self.content.question,
|
||||||
|
"answer": self.content.answer,
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class ChunkUpdate(BaseModel):
|
class ChunkUpdate(BaseModel):
|
||||||
content: Union[str, QAChunk] = Field(
|
content: Union[str, QAChunk] = Field(
|
||||||
@@ -35,13 +48,26 @@ class ChunkUpdate(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def chunk_content(self) -> str:
|
def chunk_content(self) -> str:
|
||||||
"""
|
"""Get the actual content string regardless of input type"""
|
||||||
Get the actual content string regardless of input type
|
|
||||||
"""
|
|
||||||
if isinstance(self.content, QAChunk):
|
if isinstance(self.content, QAChunk):
|
||||||
return f"question: {self.content.question} answer: {self.content.answer}"
|
return self.content.question # QA 模式下 page_content 存 question
|
||||||
return self.content
|
return self.content
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_qa(self) -> bool:
|
||||||
|
return isinstance(self.content, QAChunk)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def qa_metadata(self) -> dict:
|
||||||
|
"""返回 QA 相关的 metadata 字段"""
|
||||||
|
if isinstance(self.content, QAChunk):
|
||||||
|
return {
|
||||||
|
"chunk_type": "qa",
|
||||||
|
"question": self.content.question,
|
||||||
|
"answer": self.content.answer,
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class ChunkRetrieve(BaseModel):
|
class ChunkRetrieve(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
@@ -51,3 +77,8 @@ class ChunkRetrieve(BaseModel):
|
|||||||
vector_similarity_weight: float | None = Field(None)
|
vector_similarity_weight: float | None = Field(None)
|
||||||
top_k: int | None = Field(None)
|
top_k: int | None = Field(None)
|
||||||
retrieve_type: RetrieveType | None = Field(None)
|
retrieve_type: RetrieveType | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkBatchCreate(BaseModel):
|
||||||
|
"""批量创建 chunk"""
|
||||||
|
items: list[ChunkCreate] = Field(..., min_length=1, description="chunk 列表")
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer, model_serializer
|
||||||
|
|
||||||
# 导入 FileInput(用于体验运行)
|
# 导入 FileInput(用于体验运行)
|
||||||
from app.schemas.app_schema import FileInput
|
from app.schemas.app_schema import FileInput
|
||||||
@@ -94,6 +94,18 @@ class ChatResponse(BaseModel):
|
|||||||
message_id: str
|
message_id: str
|
||||||
usage: Optional[Dict[str, Any]] = None
|
usage: Optional[Dict[str, Any]] = None
|
||||||
elapsed_time: Optional[float] = None
|
elapsed_time: Optional[float] = None
|
||||||
|
reasoning_content: Optional[str] = None
|
||||||
|
suggested_questions: Optional[List[str]] = None
|
||||||
|
citations: Optional[List[Dict[str, Any]]] = None
|
||||||
|
audio_url: Optional[str] = None
|
||||||
|
audio_status: Optional[str] = None
|
||||||
|
|
||||||
|
@model_serializer(mode="wrap")
|
||||||
|
def _serialize(self, handler):
|
||||||
|
data = handler(self)
|
||||||
|
if not data.get("reasoning_content"):
|
||||||
|
data.pop("reasoning_content", None)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
# ---------- Conversation Summary Schemas ----------
|
# ---------- Conversation Summary Schemas ----------
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ class FileBase(BaseModel):
|
|||||||
file_ext: str
|
file_ext: str
|
||||||
file_size: int
|
file_size: int
|
||||||
file_url: str | None = None
|
file_url: str | None = None
|
||||||
|
file_key: str | None = None
|
||||||
created_at: datetime.datetime | None = None
|
created_at: datetime.datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -112,12 +112,12 @@ class MemoryWriteResponse(BaseModel):
|
|||||||
"""Response schema for memory write operation.
|
"""Response schema for memory write operation.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
task_id: Celery task ID for status polling
|
task_id: task ID for status polling
|
||||||
status: Initial task status (PENDING)
|
status: Initial task status (QUEUED)
|
||||||
end_user_id: End user ID the write was submitted for
|
end_user_id: End user ID the write was submitted for
|
||||||
"""
|
"""
|
||||||
task_id: str = Field(..., description="Celery task ID for polling")
|
task_id: str = Field(..., description="task ID for polling")
|
||||||
status: str = Field(..., description="Task status: PENDING")
|
status: str = Field(..., description="Task status: QUEUED")
|
||||||
end_user_id: str = Field(..., description="End user ID")
|
end_user_id: str = Field(..., description="End user ID")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.aioRedis import aio_redis
|
from app.aioRedis import aio_redis
|
||||||
from app.models.api_key_model import ApiKey
|
from app.models.api_key_model import ApiKey, ApiKeyType
|
||||||
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
|
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
|
||||||
from app.schemas import api_key_schema
|
from app.schemas import api_key_schema
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
@@ -65,6 +65,12 @@ class ApiKeyService:
|
|||||||
BizCode.BAD_REQUEST
|
BizCode.BAD_REQUEST
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# SERVICE 类型的 resource_id 指向 workspace,非应用,跳过应用发布校验
|
||||||
|
if data.resource_id and data.type != ApiKeyType.SERVICE.value:
|
||||||
|
app = db.get(App, data.resource_id)
|
||||||
|
if not app or not app.current_release_id:
|
||||||
|
raise BusinessException("该应用未发布", BizCode.APP_NOT_PUBLISHED)
|
||||||
|
|
||||||
# 生成 API Key
|
# 生成 API Key
|
||||||
api_key = generate_api_key(data.type)
|
api_key = generate_api_key(data.type)
|
||||||
|
|
||||||
@@ -447,9 +453,12 @@ class ApiKeyAuthService:
|
|||||||
def check_app_published(db: Session, api_key_obj: ApiKey) -> None:
|
def check_app_published(db: Session, api_key_obj: ApiKey) -> None:
|
||||||
"""
|
"""
|
||||||
检查应用是否已发布,未发布则抛出异常
|
检查应用是否已发布,未发布则抛出异常
|
||||||
|
SERVICE 类型的 api_key 不绑定应用(resource_id 指向 workspace),跳过校验
|
||||||
"""
|
"""
|
||||||
if not api_key_obj.resource_id:
|
if not api_key_obj.resource_id:
|
||||||
return
|
return
|
||||||
|
if api_key_obj.type == ApiKeyType.SERVICE.value:
|
||||||
|
return
|
||||||
app = db.get(App, api_key_obj.resource_id)
|
app = db.get(App, api_key_obj.resource_id)
|
||||||
if not app or not app.current_release_id:
|
if not app or not app.current_release_id:
|
||||||
raise BusinessException("应用未发布,不可用", BizCode.APP_NOT_PUBLISHED)
|
raise BusinessException("应用未发布,不可用", BizCode.APP_NOT_PUBLISHED)
|
||||||
|
|||||||
@@ -107,23 +107,6 @@ class AppChatService:
|
|||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
|
||||||
agent = LangChainAgent(
|
|
||||||
model_name=api_key_obj.model_name,
|
|
||||||
api_key=api_key_obj.api_key,
|
|
||||||
provider=api_key_obj.provider,
|
|
||||||
api_base=api_key_obj.api_base,
|
|
||||||
is_omni=api_key_obj.is_omni,
|
|
||||||
temperature=model_parameters.get("temperature", 0.7),
|
|
||||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
|
||||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
|
||||||
json_output=model_parameters.get("json_output", False),
|
|
||||||
capability=api_key_obj.capability or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
model_info = ModelInfo(
|
model_info = ModelInfo(
|
||||||
model_name=api_key_obj.model_name,
|
model_name=api_key_obj.model_name,
|
||||||
provider=api_key_obj.provider,
|
provider=api_key_obj.provider,
|
||||||
@@ -177,16 +160,27 @@ class AppChatService:
|
|||||||
if doc_img_recognition and "vision" in (api_key_obj.capability or []) and any(
|
if doc_img_recognition and "vision" in (api_key_obj.capability or []) and any(
|
||||||
f.type == FileType.DOCUMENT for f in files
|
f.type == FileType.DOCUMENT for f in files
|
||||||
):
|
):
|
||||||
from langchain.agents import create_agent
|
system_prompt += (
|
||||||
agent.system_prompt += (
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
|
||||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
|
||||||
)
|
|
||||||
agent.agent = create_agent(
|
|
||||||
model=agent.llm,
|
|
||||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
|
||||||
system_prompt=agent.system_prompt
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 创建 LangChain Agent
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_obj.model_name,
|
||||||
|
api_key=api_key_obj.api_key,
|
||||||
|
provider=api_key_obj.provider,
|
||||||
|
api_base=api_key_obj.api_base,
|
||||||
|
is_omni=api_key_obj.is_omni,
|
||||||
|
temperature=model_parameters.get("temperature", 0.7),
|
||||||
|
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||||
|
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||||
|
json_output=model_parameters.get("json_output", False),
|
||||||
|
capability=api_key_obj.capability or [],
|
||||||
|
)
|
||||||
|
|
||||||
# 为需要运行时上下文的工具注入上下文
|
# 为需要运行时上下文的工具注入上下文
|
||||||
for t in tools:
|
for t in tools:
|
||||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||||
@@ -323,7 +317,7 @@ class AppChatService:
|
|||||||
"suggested_questions": suggested_questions,
|
"suggested_questions": suggested_questions,
|
||||||
"citations": filtered_citations,
|
"citations": filtered_citations,
|
||||||
"audio_url": audio_url,
|
"audio_url": audio_url,
|
||||||
"audio_status": "pending"
|
"audio_status": "pending" if audio_url else None
|
||||||
}
|
}
|
||||||
|
|
||||||
async def agnet_chat_stream(
|
async def agnet_chat_stream(
|
||||||
@@ -399,24 +393,6 @@ class AppChatService:
|
|||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
|
||||||
agent = LangChainAgent(
|
|
||||||
model_name=api_key_obj.model_name,
|
|
||||||
api_key=api_key_obj.api_key,
|
|
||||||
provider=api_key_obj.provider,
|
|
||||||
api_base=api_key_obj.api_base,
|
|
||||||
is_omni=api_key_obj.is_omni,
|
|
||||||
temperature=model_parameters.get("temperature", 0.7),
|
|
||||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
streaming=True,
|
|
||||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
|
||||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
|
||||||
json_output=model_parameters.get("json_output", False),
|
|
||||||
capability=api_key_obj.capability or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
model_info = ModelInfo(
|
model_info = ModelInfo(
|
||||||
model_name=api_key_obj.model_name,
|
model_name=api_key_obj.model_name,
|
||||||
provider=api_key_obj.provider,
|
provider=api_key_obj.provider,
|
||||||
@@ -471,16 +447,28 @@ class AppChatService:
|
|||||||
f.type == FileType.DOCUMENT for f in files
|
f.type == FileType.DOCUMENT for f in files
|
||||||
):
|
):
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
agent.system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
|
||||||
)
|
|
||||||
agent.agent = create_agent(
|
|
||||||
model=agent.llm,
|
|
||||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
|
||||||
system_prompt=agent.system_prompt
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 创建 LangChain Agent
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_obj.model_name,
|
||||||
|
api_key=api_key_obj.api_key,
|
||||||
|
provider=api_key_obj.provider,
|
||||||
|
api_base=api_key_obj.api_base,
|
||||||
|
is_omni=api_key_obj.is_omni,
|
||||||
|
temperature=model_parameters.get("temperature", 0.7),
|
||||||
|
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
streaming=True,
|
||||||
|
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||||
|
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||||
|
json_output=model_parameters.get("json_output", False),
|
||||||
|
capability=api_key_obj.capability or [],
|
||||||
|
)
|
||||||
|
|
||||||
# 为需要运行时上下文的工具注入上下文
|
# 为需要运行时上下文的工具注入上下文
|
||||||
for t in tools:
|
for t in tools:
|
||||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
"""应用日志服务层"""
|
"""应用日志服务层"""
|
||||||
import uuid
|
import uuid
|
||||||
|
import datetime as dt
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.models.app_model import AppType
|
||||||
from app.models.conversation_model import Conversation, Message
|
from app.models.conversation_model import Conversation, Message
|
||||||
from app.models.workflow_model import WorkflowExecution
|
from app.models.workflow_model import WorkflowExecution
|
||||||
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
||||||
from app.schemas.app_log_schema import AppLogNodeExecution
|
from app.schemas.app_log_schema import AppLogMessage, AppLogNodeExecution
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -31,6 +32,7 @@ class AppLogService:
|
|||||||
pagesize: int = 20,
|
pagesize: int = 20,
|
||||||
is_draft: Optional[bool] = None,
|
is_draft: Optional[bool] = None,
|
||||||
keyword: Optional[str] = None,
|
keyword: Optional[str] = None,
|
||||||
|
app_type: Optional[str] = None,
|
||||||
) -> Tuple[list[Conversation], int]:
|
) -> Tuple[list[Conversation], int]:
|
||||||
"""
|
"""
|
||||||
查询应用日志会话列表
|
查询应用日志会话列表
|
||||||
@@ -42,6 +44,7 @@ class AppLogService:
|
|||||||
pagesize: 每页数量
|
pagesize: 每页数量
|
||||||
is_draft: 是否草稿会话(None表示返回全部)
|
is_draft: 是否草稿会话(None表示返回全部)
|
||||||
keyword: 搜索关键词(匹配消息内容)
|
keyword: 搜索关键词(匹配消息内容)
|
||||||
|
app_type: 应用类型(WORKFLOW 时关键词将从 workflow_executions 搜索)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[list[Conversation], int]: (会话列表,总数)
|
Tuple[list[Conversation], int]: (会话列表,总数)
|
||||||
@@ -54,7 +57,8 @@ class AppLogService:
|
|||||||
"page": page,
|
"page": page,
|
||||||
"pagesize": pagesize,
|
"pagesize": pagesize,
|
||||||
"is_draft": is_draft,
|
"is_draft": is_draft,
|
||||||
"keyword": keyword
|
"keyword": keyword,
|
||||||
|
"app_type": app_type,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,7 +69,8 @@ class AppLogService:
|
|||||||
is_draft=is_draft,
|
is_draft=is_draft,
|
||||||
keyword=keyword,
|
keyword=keyword,
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=pagesize
|
pagesize=pagesize,
|
||||||
|
app_type=app_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -83,51 +88,40 @@ class AppLogService:
|
|||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
workspace_id: uuid.UUID
|
workspace_id: uuid.UUID,
|
||||||
) -> Tuple[Conversation, dict[str, list[AppLogNodeExecution]]]:
|
app_type: str = AppType.AGENT
|
||||||
|
) -> Tuple[Conversation, list, dict[str, list[AppLogNodeExecution]]]:
|
||||||
"""
|
"""
|
||||||
查询会话详情(包含消息和工作流节点执行记录)
|
查询会话详情
|
||||||
|
|
||||||
Args:
|
|
||||||
app_id: 应用 ID
|
|
||||||
conversation_id: 会话 ID
|
|
||||||
workspace_id: 工作空间 ID
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Conversation, dict[str, list[AppLogNodeExecution]]]:
|
Tuple[Conversation, list[AppLogMessage|Message], dict[str, list[AppLogNodeExecution]]]
|
||||||
(包含消息的会话对象, 按消息ID分组的节点执行记录)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ResourceNotFoundException: 当会话不存在时
|
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
"查询应用日志会话详情",
|
"查询应用日志会话详情",
|
||||||
extra={
|
extra={
|
||||||
"app_id": str(app_id),
|
"app_id": str(app_id),
|
||||||
"conversation_id": str(conversation_id),
|
"conversation_id": str(conversation_id),
|
||||||
"workspace_id": str(workspace_id)
|
"workspace_id": str(workspace_id),
|
||||||
|
"app_type": app_type
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 查询会话
|
|
||||||
conversation = self.conversation_repository.get_conversation_for_app_log(
|
conversation = self.conversation_repository.get_conversation_for_app_log(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 查询消息(按时间正序)
|
if app_type == AppType.WORKFLOW:
|
||||||
messages = self.message_repository.get_messages_by_conversation(
|
messages, node_executions_map = self._get_workflow_messages_and_nodes(conversation_id)
|
||||||
conversation_id=conversation_id
|
else:
|
||||||
)
|
messages = self.message_repository.get_messages_by_conversation(
|
||||||
|
conversation_id=conversation_id
|
||||||
# 将消息附加到会话对象
|
)
|
||||||
conversation.messages = messages
|
node_executions_map = self._get_workflow_node_executions_with_map(
|
||||||
|
conversation_id, messages
|
||||||
# 查询工作流节点执行记录(按消息分组)
|
)
|
||||||
_, node_executions_map = self._get_workflow_node_executions_with_map(
|
|
||||||
conversation_id, messages
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"查询应用日志会话详情成功",
|
"查询应用日志会话详情成功",
|
||||||
@@ -139,13 +133,129 @@ class AppLogService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return conversation, node_executions_map
|
return conversation, messages, node_executions_map
|
||||||
|
|
||||||
|
def _get_workflow_messages_and_nodes(
|
||||||
|
self,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
) -> Tuple[list[AppLogMessage], dict[str, list[AppLogNodeExecution]]]:
|
||||||
|
"""
|
||||||
|
工作流应用专用:从 workflow_executions 构建 messages 和节点日志。
|
||||||
|
|
||||||
|
每条 WorkflowExecution 对应一轮对话:
|
||||||
|
- user message:来自 execution.input_data(content 取 message 字段,files 放 meta_data)
|
||||||
|
- assistant message:来自 execution.output_data(失败时内容为错误信息)
|
||||||
|
开场白的 suggested_questions 合并到第一条 assistant message 的 meta_data 里。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(messages 列表, node_executions_map)
|
||||||
|
"""
|
||||||
|
stmt = (
|
||||||
|
select(WorkflowExecution)
|
||||||
|
.where(
|
||||||
|
WorkflowExecution.conversation_id == conversation_id,
|
||||||
|
WorkflowExecution.status.in_(["completed", "failed"])
|
||||||
|
)
|
||||||
|
.order_by(WorkflowExecution.started_at.asc())
|
||||||
|
)
|
||||||
|
executions = list(self.db.scalars(stmt).all())
|
||||||
|
|
||||||
|
# 查开场白:Message 表里 meta_data 含 suggested_questions 的第一条 assistant 消息
|
||||||
|
opening_stmt = (
|
||||||
|
select(Message)
|
||||||
|
.where(
|
||||||
|
Message.conversation_id == conversation_id,
|
||||||
|
Message.role == "assistant",
|
||||||
|
)
|
||||||
|
.order_by(Message.created_at.asc())
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
early_messages = list(self.db.scalars(opening_stmt).all())
|
||||||
|
suggested_questions: list = []
|
||||||
|
for m in early_messages:
|
||||||
|
if isinstance(m.meta_data, dict) and "suggested_questions" in m.meta_data:
|
||||||
|
suggested_questions = m.meta_data.get("suggested_questions") or []
|
||||||
|
break
|
||||||
|
|
||||||
|
messages: list[AppLogMessage] = []
|
||||||
|
node_executions_map: dict[str, list[AppLogNodeExecution]] = {}
|
||||||
|
|
||||||
|
# 如果有开场白,作为第一条 assistant 消息插入
|
||||||
|
if suggested_questions or early_messages:
|
||||||
|
opening_msg = next(
|
||||||
|
(m for m in early_messages
|
||||||
|
if isinstance(m.meta_data, dict) and "suggested_questions" in m.meta_data),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
if opening_msg:
|
||||||
|
messages.append(AppLogMessage(
|
||||||
|
id=opening_msg.id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=opening_msg.content,
|
||||||
|
status=None,
|
||||||
|
meta_data={"suggested_questions": suggested_questions},
|
||||||
|
created_at=opening_msg.created_at,
|
||||||
|
))
|
||||||
|
|
||||||
|
for execution in executions:
|
||||||
|
started_at = execution.started_at or dt.datetime.now()
|
||||||
|
completed_at = execution.completed_at or started_at
|
||||||
|
|
||||||
|
# assistant message 的 id,同时作为 node_executions_map 的 key
|
||||||
|
assistant_msg_id = uuid.uuid5(execution.id, "assistant")
|
||||||
|
|
||||||
|
# --- user message(输入)---
|
||||||
|
input_data = execution.input_data or {}
|
||||||
|
input_content = input_data.get("message") or _extract_text(input_data)
|
||||||
|
|
||||||
|
# 跳过没有用户输入的 execution(如开场白触发的记录)
|
||||||
|
if not input_content or not input_content.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
files = input_data.get("files") or []
|
||||||
|
user_msg = AppLogMessage(
|
||||||
|
id=uuid.uuid5(execution.id, "user"),
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=input_content,
|
||||||
|
meta_data={"files": files} if files else None,
|
||||||
|
created_at=started_at,
|
||||||
|
)
|
||||||
|
messages.append(user_msg)
|
||||||
|
|
||||||
|
# --- assistant message(输出)---
|
||||||
|
if execution.status == "completed":
|
||||||
|
output_content = _extract_text(execution.output_data)
|
||||||
|
meta = {"usage": execution.token_usage or {}, "elapsed_time": execution.elapsed_time}
|
||||||
|
else:
|
||||||
|
output_content = _extract_text(execution.output_data) or ""
|
||||||
|
meta = {"error": execution.error_message, "error_node_id": execution.error_node_id}
|
||||||
|
|
||||||
|
assistant_msg = AppLogMessage(
|
||||||
|
id=assistant_msg_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=output_content,
|
||||||
|
status=execution.status,
|
||||||
|
meta_data=meta,
|
||||||
|
created_at=completed_at,
|
||||||
|
)
|
||||||
|
messages.append(assistant_msg)
|
||||||
|
|
||||||
|
# --- 节点执行记录,从 workflow_executions.output_data["node_outputs"] 读取 ---
|
||||||
|
execution_nodes = _build_nodes_from_output_data(execution.output_data)
|
||||||
|
|
||||||
|
if execution_nodes:
|
||||||
|
node_executions_map[str(assistant_msg_id)] = execution_nodes
|
||||||
|
|
||||||
|
return messages, node_executions_map
|
||||||
|
|
||||||
def _get_workflow_node_executions_with_map(
|
def _get_workflow_node_executions_with_map(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
messages: list[Message]
|
messages: list[Message]
|
||||||
) -> Tuple[list[AppLogNodeExecution], dict[str, list[AppLogNodeExecution]]]:
|
) -> dict[str, list[AppLogNodeExecution]]:
|
||||||
"""
|
"""
|
||||||
从 workflow_executions 表中提取节点执行记录,并按 assistant message 分组
|
从 workflow_executions 表中提取节点执行记录,并按 assistant message 分组
|
||||||
|
|
||||||
@@ -157,13 +267,12 @@ class AppLogService:
|
|||||||
Tuple[list[AppLogNodeExecution], dict[str, list[AppLogNodeExecution]]]:
|
Tuple[list[AppLogNodeExecution], dict[str, list[AppLogNodeExecution]]]:
|
||||||
(所有节点执行记录列表, 按 message_id 分组的节点执行记录字典)
|
(所有节点执行记录列表, 按 message_id 分组的节点执行记录字典)
|
||||||
"""
|
"""
|
||||||
node_executions = []
|
|
||||||
node_executions_map: dict[str, list[AppLogNodeExecution]] = {}
|
node_executions_map: dict[str, list[AppLogNodeExecution]] = {}
|
||||||
|
|
||||||
# 查询该会话关联的所有工作流执行记录(按时间正序)
|
# 查询该会话关联的所有工作流执行记录(按时间正序)
|
||||||
stmt = select(WorkflowExecution).where(
|
stmt = select(WorkflowExecution).where(
|
||||||
WorkflowExecution.conversation_id == conversation_id,
|
WorkflowExecution.conversation_id == conversation_id,
|
||||||
WorkflowExecution.status == "completed"
|
WorkflowExecution.status.in_(["completed", "failed"])
|
||||||
).order_by(WorkflowExecution.started_at.asc())
|
).order_by(WorkflowExecution.started_at.asc())
|
||||||
|
|
||||||
executions = self.db.scalars(stmt).all()
|
executions = self.db.scalars(stmt).all()
|
||||||
@@ -188,10 +297,18 @@ class AppLogService:
|
|||||||
used_message_ids: set[str] = set()
|
used_message_ids: set[str] = set()
|
||||||
|
|
||||||
for execution in executions:
|
for execution in executions:
|
||||||
if not execution.output_data:
|
# 构建节点执行记录列表,从 workflow_executions.output_data["node_outputs"] 读取
|
||||||
|
execution_nodes = _build_nodes_from_output_data(execution.output_data)
|
||||||
|
|
||||||
|
if not execution_nodes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 找到该 execution 对应的 assistant message
|
# 失败的执行没有 assistant message,直接用 execution id 作为 key
|
||||||
|
if execution.status == "failed":
|
||||||
|
node_executions_map[f"execution_{str(execution.id)}"] = execution_nodes
|
||||||
|
continue
|
||||||
|
|
||||||
|
# completed:通过时序匹配关联到对应的 assistant message
|
||||||
# 逻辑:找 execution.started_at 之后最近的、未使用的 assistant message
|
# 逻辑:找 execution.started_at 之后最近的、未使用的 assistant message
|
||||||
best_msg = None
|
best_msg = None
|
||||||
best_dt = None
|
best_dt = None
|
||||||
@@ -200,9 +317,9 @@ class AppLogService:
|
|||||||
if msg_id_str in used_message_ids:
|
if msg_id_str in used_message_ids:
|
||||||
continue
|
continue
|
||||||
if msg.created_at and msg.created_at >= execution.started_at:
|
if msg.created_at and msg.created_at >= execution.started_at:
|
||||||
dt = (msg.created_at - execution.started_at).total_seconds()
|
delta = (msg.created_at - execution.started_at).total_seconds()
|
||||||
if best_dt is None or dt < best_dt:
|
if best_dt is None or delta < best_dt:
|
||||||
best_dt = dt
|
best_dt = delta
|
||||||
best_msg = msg
|
best_msg = msg
|
||||||
|
|
||||||
if not best_msg:
|
if not best_msg:
|
||||||
@@ -210,31 +327,86 @@ class AppLogService:
|
|||||||
|
|
||||||
msg_id_str = str(best_msg.id)
|
msg_id_str = str(best_msg.id)
|
||||||
used_message_ids.add(msg_id_str)
|
used_message_ids.add(msg_id_str)
|
||||||
|
node_executions_map[msg_id_str] = execution_nodes
|
||||||
|
|
||||||
# 提取节点输出
|
return node_executions_map
|
||||||
output_data = execution.output_data
|
|
||||||
if isinstance(output_data, dict):
|
|
||||||
node_outputs = output_data.get("node_outputs", {})
|
|
||||||
execution_nodes = []
|
|
||||||
for node_id, node_data in node_outputs.items():
|
|
||||||
if not isinstance(node_data, dict):
|
|
||||||
continue
|
|
||||||
node_execution = AppLogNodeExecution(
|
|
||||||
node_id=node_data.get("node_id", node_id),
|
|
||||||
node_type=node_data.get("node_type", "unknown"),
|
|
||||||
node_name=node_data.get("node_name"),
|
|
||||||
status=node_data.get("status", "unknown"),
|
|
||||||
error=node_data.get("error"),
|
|
||||||
input=node_data.get("input"),
|
|
||||||
process=node_data.get("process"),
|
|
||||||
output=node_data.get("output"),
|
|
||||||
elapsed_time=node_data.get("elapsed_time"),
|
|
||||||
token_usage=node_data.get("token_usage"),
|
|
||||||
)
|
|
||||||
node_executions.append(node_execution)
|
|
||||||
execution_nodes.append(node_execution)
|
|
||||||
|
|
||||||
# 将节点记录关联到 message_id
|
|
||||||
node_executions_map[msg_id_str] = execution_nodes
|
|
||||||
|
|
||||||
return node_executions, node_executions_map
|
def _extract_text(data: Optional[dict]) -> str:
|
||||||
|
"""从 workflow execution 的 input_data / output_data 中提取可读文本。
|
||||||
|
|
||||||
|
优先取 'text'、'content'、'output' 字段;若都没有则 JSON 序列化整个 dict。
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
return ""
|
||||||
|
for key in ("message", "text", "content", "output", "result", "answer"):
|
||||||
|
if key in data and isinstance(data[key], str):
|
||||||
|
return data[key]
|
||||||
|
import json
|
||||||
|
return json.dumps(data, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_nodes_from_output_data(output_data: Optional[dict]) -> list[AppLogNodeExecution]:
|
||||||
|
"""从 workflow_executions.output_data["node_outputs"] 构建节点执行记录列表。
|
||||||
|
|
||||||
|
output_data 结构:
|
||||||
|
{
|
||||||
|
"node_outputs": {
|
||||||
|
"<node_id>": {
|
||||||
|
"node_type": ...,
|
||||||
|
"node_name": ...,
|
||||||
|
"status": ...,
|
||||||
|
"input": ...,
|
||||||
|
"output": ...,
|
||||||
|
"elapsed_time": ...,
|
||||||
|
"token_usage": ...,
|
||||||
|
"error": ...,
|
||||||
|
"cycle_items": [...],
|
||||||
|
...
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"error": ...,
|
||||||
|
...
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if not output_data:
|
||||||
|
return []
|
||||||
|
node_outputs: dict = output_data.get("node_outputs") or {}
|
||||||
|
# 按 execution_order(节点执行时写入的单调递增序号)排序。
|
||||||
|
# PostgreSQL JSONB 不保证 key 顺序,不能依赖 dict 插入顺序;
|
||||||
|
# 缺失 execution_order 的历史数据退化到 0,保持在最前。
|
||||||
|
ordered_items = sorted(
|
||||||
|
node_outputs.items(),
|
||||||
|
key=lambda kv: (kv[1] or {}).get("execution_order", 0)
|
||||||
|
if isinstance(kv[1], dict) else 0
|
||||||
|
)
|
||||||
|
result = []
|
||||||
|
for node_id, node_data in ordered_items:
|
||||||
|
if not isinstance(node_data, dict):
|
||||||
|
continue
|
||||||
|
output = dict(node_data)
|
||||||
|
cycle_items = output.pop("cycle_items", None)
|
||||||
|
# 把已知的顶层字段剥离,剩余的作为 output
|
||||||
|
node_type = output.pop("node_type", "unknown")
|
||||||
|
node_name = output.pop("node_name", None)
|
||||||
|
status = output.pop("status", "completed")
|
||||||
|
error = output.pop("error", None)
|
||||||
|
inp = output.pop("input", None)
|
||||||
|
elapsed_time = output.pop("elapsed_time", None)
|
||||||
|
token_usage = output.pop("token_usage", None)
|
||||||
|
# execution_order 仅用于排序,不返回给前端
|
||||||
|
output.pop("execution_order", None)
|
||||||
|
result.append(AppLogNodeExecution(
|
||||||
|
node_id=node_id,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
status=status,
|
||||||
|
error=error,
|
||||||
|
input=inp,
|
||||||
|
process=None,
|
||||||
|
output=output if output else None,
|
||||||
|
cycle_items=cycle_items,
|
||||||
|
elapsed_time=elapsed_time,
|
||||||
|
token_usage=token_usage,
|
||||||
|
))
|
||||||
|
return result
|
||||||
|
|||||||
@@ -595,23 +595,6 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
tools.extend(memory_tools)
|
tools.extend(memory_tools)
|
||||||
|
|
||||||
# 4. 创建 LangChain Agent
|
|
||||||
agent = LangChainAgent(
|
|
||||||
model_name=api_key_config["model_name"],
|
|
||||||
api_key=api_key_config["api_key"],
|
|
||||||
provider=api_key_config.get("provider", "openai"),
|
|
||||||
api_base=api_key_config.get("api_base"),
|
|
||||||
is_omni=api_key_config.get("is_omni", False),
|
|
||||||
temperature=effective_params.get("temperature", 0.7),
|
|
||||||
max_tokens=effective_params.get("max_tokens", 2000),
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
deep_thinking=effective_params.get("deep_thinking", False),
|
|
||||||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
|
||||||
json_output=effective_params.get("json_output", False),
|
|
||||||
capability=api_key_config.get("capability", []),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||||
is_new_conversation = not conversation_id
|
is_new_conversation = not conversation_id
|
||||||
opening, suggested_questions = None, None
|
opening, suggested_questions = None, None
|
||||||
@@ -666,16 +649,26 @@ class AgentRunService:
|
|||||||
and any(f.type == FileType.DOCUMENT for f in files)
|
and any(f.type == FileType.DOCUMENT for f in files)
|
||||||
)
|
)
|
||||||
if has_doc_with_images:
|
if has_doc_with_images:
|
||||||
agent.system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
|
||||||
)
|
|
||||||
# 重建 agent graph 以使新 system_prompt 生效
|
|
||||||
agent.agent = create_agent(
|
|
||||||
model=agent.llm,
|
|
||||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
|
||||||
system_prompt=agent.system_prompt
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_config["model_name"],
|
||||||
|
api_key=api_key_config["api_key"],
|
||||||
|
provider=api_key_config.get("provider", "openai"),
|
||||||
|
api_base=api_key_config.get("api_base"),
|
||||||
|
is_omni=api_key_config.get("is_omni", False),
|
||||||
|
temperature=effective_params.get("temperature", 0.7),
|
||||||
|
max_tokens=effective_params.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
deep_thinking=effective_params.get("deep_thinking", False),
|
||||||
|
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||||||
|
json_output=effective_params.get("json_output", False),
|
||||||
|
capability=api_key_config.get("capability", []),
|
||||||
|
)
|
||||||
|
|
||||||
# 为需要运行时上下文的工具注入上下文
|
# 为需要运行时上下文的工具注入上下文
|
||||||
for t in tools:
|
for t in tools:
|
||||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||||
@@ -761,7 +754,7 @@ class AgentRunService:
|
|||||||
) if not sub_agent else [],
|
) if not sub_agent else [],
|
||||||
"citations": filtered_citations,
|
"citations": filtered_citations,
|
||||||
"audio_url": audio_url,
|
"audio_url": audio_url,
|
||||||
"audio_status": "pending"
|
"audio_status": "pending" if audio_url else None
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -875,24 +868,6 @@ class AgentRunService:
|
|||||||
user_rag_memory_id)
|
user_rag_memory_id)
|
||||||
tools.extend(memory_tools)
|
tools.extend(memory_tools)
|
||||||
|
|
||||||
# 4. 创建 LangChain Agent
|
|
||||||
agent = LangChainAgent(
|
|
||||||
model_name=api_key_config["model_name"],
|
|
||||||
api_key=api_key_config["api_key"],
|
|
||||||
provider=api_key_config.get("provider", "openai"),
|
|
||||||
api_base=api_key_config.get("api_base"),
|
|
||||||
is_omni=api_key_config.get("is_omni", False),
|
|
||||||
temperature=effective_params.get("temperature", 0.7),
|
|
||||||
max_tokens=effective_params.get("max_tokens", 2000),
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
streaming=True,
|
|
||||||
deep_thinking=effective_params.get("deep_thinking", False),
|
|
||||||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
|
||||||
json_output=effective_params.get("json_output", False),
|
|
||||||
capability=api_key_config.get("capability", []),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||||
is_new_conversation = not conversation_id
|
is_new_conversation = not conversation_id
|
||||||
opening, suggested_questions = None, None
|
opening, suggested_questions = None, None
|
||||||
@@ -948,18 +923,28 @@ class AgentRunService:
|
|||||||
and any(f.type == FileType.DOCUMENT for f in files)
|
and any(f.type == FileType.DOCUMENT for f in files)
|
||||||
)
|
)
|
||||||
if has_doc_with_images:
|
if has_doc_with_images:
|
||||||
agent.system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档中包含图片,图片位置已在文本中以 [图片 第N页 第M张图片]: URL 标记。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
|
||||||
"**规则1:图片URL必须原封不动、一字不差地复制,禁止修改、禁止省略任何字符**"
|
|
||||||
"**规则2:禁止修改URL中UUID里的任何数字和字母**"
|
|
||||||
"**规则3:直接使用  格式输出**"
|
|
||||||
)
|
|
||||||
agent.agent = create_agent(
|
|
||||||
model=agent.llm,
|
|
||||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
|
||||||
system_prompt=agent.system_prompt
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 创建 LangChain Agent
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_config["model_name"],
|
||||||
|
api_key=api_key_config["api_key"],
|
||||||
|
provider=api_key_config.get("provider", "openai"),
|
||||||
|
api_base=api_key_config.get("api_base"),
|
||||||
|
is_omni=api_key_config.get("is_omni", False),
|
||||||
|
temperature=effective_params.get("temperature", 0.7),
|
||||||
|
max_tokens=effective_params.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
streaming=True,
|
||||||
|
deep_thinking=effective_params.get("deep_thinking", False),
|
||||||
|
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||||||
|
json_output=effective_params.get("json_output", False),
|
||||||
|
capability=api_key_config.get("capability", []),
|
||||||
|
)
|
||||||
|
|
||||||
# 为需要运行时上下文的工具注入上下文
|
# 为需要运行时上下文的工具注入上下文
|
||||||
for t in tools:
|
for t in tools:
|
||||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||||
|
|||||||
@@ -34,26 +34,7 @@ def generate_file_key(
|
|||||||
Generate a unique file key for storage.
|
Generate a unique file key for storage.
|
||||||
|
|
||||||
The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext}
|
The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext}
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id: The tenant UUID.
|
|
||||||
workspace_id: The workspace UUID.
|
|
||||||
file_id: The file UUID.
|
|
||||||
file_ext: The file extension (e.g., '.pdf', '.txt').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A unique file key string.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> generate_file_key(
|
|
||||||
... uuid.UUID('550e8400-e29b-41d4-a716-446655440000'),
|
|
||||||
... uuid.UUID('660e8400-e29b-41d4-a716-446655440001'),
|
|
||||||
... uuid.UUID('770e8400-e29b-41d4-a716-446655440002'),
|
|
||||||
... '.pdf'
|
|
||||||
... )
|
|
||||||
'550e8400-e29b-41d4-a716-446655440000/660e8400-e29b-41d4-a716-446655440001/770e8400-e29b-41d4-a716-446655440002.pdf'
|
|
||||||
"""
|
"""
|
||||||
# Ensure file_ext starts with a dot
|
|
||||||
if file_ext and not file_ext.startswith('.'):
|
if file_ext and not file_ext.startswith('.'):
|
||||||
file_ext = f'.{file_ext}'
|
file_ext = f'.{file_ext}'
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
@@ -61,6 +42,21 @@ def generate_file_key(
|
|||||||
return f"{tenant_id}/{file_id}{file_ext}"
|
return f"{tenant_id}/{file_id}{file_ext}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_kb_file_key(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
file_ext: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a file key for knowledge base files.
|
||||||
|
|
||||||
|
Format: kb/{kb_id}/{file_id}{file_ext}
|
||||||
|
"""
|
||||||
|
if file_ext and not file_ext.startswith('.'):
|
||||||
|
file_ext = f'.{file_ext}'
|
||||||
|
return f"kb/{kb_id}/{file_id}{file_ext}"
|
||||||
|
|
||||||
|
|
||||||
class FileStorageService:
|
class FileStorageService:
|
||||||
"""
|
"""
|
||||||
High-level service for file storage operations.
|
High-level service for file storage operations.
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.celery_task_scheduler import scheduler
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
@@ -166,20 +167,31 @@ class MemoryAPIService:
|
|||||||
# Convert to message list format expected by write_message_task
|
# Convert to message list format expected by write_message_task
|
||||||
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
||||||
|
|
||||||
from app.tasks import write_message_task
|
# from app.tasks import write_message_task
|
||||||
task = write_message_task.delay(
|
# task = write_message_task.delay(
|
||||||
|
# end_user_id,
|
||||||
|
# messages,
|
||||||
|
# config_id,
|
||||||
|
# storage_type,
|
||||||
|
# user_rag_memory_id or "",
|
||||||
|
# )
|
||||||
|
task_id = scheduler.push_task(
|
||||||
|
"app.core.memory.agent.write_message",
|
||||||
end_user_id,
|
end_user_id,
|
||||||
messages,
|
{
|
||||||
config_id,
|
"end_user_id": end_user_id,
|
||||||
storage_type,
|
"message": messages,
|
||||||
user_rag_memory_id or "",
|
"config_id": config_id,
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id or ""
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write task submitted: task_id={task.id}, end_user_id={end_user_id}")
|
logger.info(f"Memory write task submitted, task_id={task_id} end_user_id={end_user_id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"task_id": task.id,
|
"task_id": task_id,
|
||||||
"status": "PENDING",
|
"status": "QUEUED",
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
处理显性记忆相关的业务逻辑,包括情景记忆和语义记忆的查询。
|
处理显性记忆相关的业务逻辑,包括情景记忆和语义记忆的查询。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
from app.services.memory_base_service import MemoryBaseService
|
from app.services.memory_base_service import MemoryBaseService
|
||||||
@@ -104,7 +104,7 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
e.description AS core_definition
|
e.description AS core_definition
|
||||||
ORDER BY e.name ASC
|
ORDER BY e.name ASC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
semantic_result = await self.neo4j_connector.execute_query(
|
semantic_result = await self.neo4j_connector.execute_query(
|
||||||
semantic_query,
|
semantic_query,
|
||||||
end_user_id=end_user_id
|
end_user_id=end_user_id
|
||||||
@@ -146,6 +146,209 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
logger.error(f"获取显性记忆总览时出错: {str(e)}", exc_info=True)
|
logger.error(f"获取显性记忆总览时出错: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def get_episodic_memory_list(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
page: int,
|
||||||
|
pagesize: int,
|
||||||
|
start_date: Optional[int] = None,
|
||||||
|
end_date: Optional[int] = None,
|
||||||
|
episodic_type: str = "all",
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取情景记忆分页列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
page: 页码
|
||||||
|
pagesize: 每页数量
|
||||||
|
start_date: 开始时间戳(毫秒),可选
|
||||||
|
end_date: 结束时间戳(毫秒),可选
|
||||||
|
episodic_type: 情景类型筛选
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"total": int, # 该用户情景记忆总数(不受筛选影响)
|
||||||
|
"items": [...], # 当前页数据
|
||||||
|
"page": {
|
||||||
|
"page": int,
|
||||||
|
"pagesize": int,
|
||||||
|
"total": int, # 筛选后总数
|
||||||
|
"hasnext": bool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"情景记忆分页查询: end_user_id={end_user_id}, "
|
||||||
|
f"start_date={start_date}, end_date={end_date}, "
|
||||||
|
f"episodic_type={episodic_type}, page={page}, pagesize={pagesize}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. 查询情景记忆总数(不受筛选条件限制)
|
||||||
|
total_all_query = """
|
||||||
|
MATCH (s:MemorySummary)
|
||||||
|
WHERE s.end_user_id = $end_user_id
|
||||||
|
RETURN count(s) AS total
|
||||||
|
"""
|
||||||
|
total_all_result = await self.neo4j_connector.execute_query(
|
||||||
|
total_all_query, end_user_id=end_user_id
|
||||||
|
)
|
||||||
|
total_all = total_all_result[0]["total"] if total_all_result else 0
|
||||||
|
|
||||||
|
# 2. 构建筛选条件
|
||||||
|
where_clauses = ["s.end_user_id = $end_user_id"]
|
||||||
|
params = {"end_user_id": end_user_id}
|
||||||
|
|
||||||
|
# 时间戳筛选(毫秒时间戳转为 UTC ISO 字符串,使用 Neo4j datetime() 精确比较)
|
||||||
|
if start_date is not None and end_date is not None:
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
start_dt = datetime.fromtimestamp(start_date / 1000, tz=timezone.utc)
|
||||||
|
end_dt = datetime.fromtimestamp(end_date / 1000, tz=timezone.utc)
|
||||||
|
# 开始时间取当天 UTC 00:00:00,结束时间取当天 UTC 23:59:59.999999
|
||||||
|
start_iso = start_dt.strftime("%Y-%m-%dT") + "00:00:00.000000"
|
||||||
|
end_iso = end_dt.strftime("%Y-%m-%dT") + "23:59:59.999999"
|
||||||
|
|
||||||
|
where_clauses.append("datetime(s.created_at) >= datetime($start_iso) AND datetime(s.created_at) <= datetime($end_iso)")
|
||||||
|
params["start_iso"] = start_iso
|
||||||
|
params["end_iso"] = end_iso
|
||||||
|
|
||||||
|
# 类型筛选下推到 Cypher(兼容中英文)
|
||||||
|
if episodic_type != "all":
|
||||||
|
type_mapping = {
|
||||||
|
"conversation": "对话",
|
||||||
|
"project_work": "项目/工作",
|
||||||
|
"learning": "学习",
|
||||||
|
"decision": "决策",
|
||||||
|
"important_event": "重要事件"
|
||||||
|
}
|
||||||
|
chinese_type = type_mapping.get(episodic_type)
|
||||||
|
if chinese_type:
|
||||||
|
where_clauses.append(
|
||||||
|
"(s.memory_type = $episodic_type OR s.memory_type = $chinese_type)"
|
||||||
|
)
|
||||||
|
params["episodic_type"] = episodic_type
|
||||||
|
params["chinese_type"] = chinese_type
|
||||||
|
else:
|
||||||
|
where_clauses.append("s.memory_type = $episodic_type")
|
||||||
|
params["episodic_type"] = episodic_type
|
||||||
|
|
||||||
|
where_str = " AND ".join(where_clauses)
|
||||||
|
|
||||||
|
# 3. 查询筛选后的总数
|
||||||
|
count_query = f"""
|
||||||
|
MATCH (s:MemorySummary)
|
||||||
|
WHERE {where_str}
|
||||||
|
RETURN count(s) AS total
|
||||||
|
"""
|
||||||
|
count_result = await self.neo4j_connector.execute_query(count_query, **params)
|
||||||
|
filtered_total = count_result[0]["total"] if count_result else 0
|
||||||
|
|
||||||
|
# 4. 查询分页数据
|
||||||
|
skip = (page - 1) * pagesize
|
||||||
|
data_query = f"""
|
||||||
|
MATCH (s:MemorySummary)
|
||||||
|
WHERE {where_str}
|
||||||
|
RETURN elementId(s) AS id,
|
||||||
|
s.name AS title,
|
||||||
|
s.memory_type AS memory_type,
|
||||||
|
s.content AS content,
|
||||||
|
s.created_at AS created_at
|
||||||
|
ORDER BY s.created_at DESC
|
||||||
|
SKIP $skip LIMIT $limit
|
||||||
|
"""
|
||||||
|
params["skip"] = skip
|
||||||
|
params["limit"] = pagesize
|
||||||
|
|
||||||
|
result = await self.neo4j_connector.execute_query(data_query, **params)
|
||||||
|
|
||||||
|
# 5. 处理结果
|
||||||
|
items = []
|
||||||
|
if result:
|
||||||
|
for record in result:
|
||||||
|
raw_created_at = record.get("created_at")
|
||||||
|
created_at_timestamp = self.parse_timestamp(raw_created_at)
|
||||||
|
items.append({
|
||||||
|
"id": record["id"],
|
||||||
|
"title": record.get("title") or "未命名",
|
||||||
|
"memory_type": record.get("memory_type") or "其他",
|
||||||
|
"content": record.get("content") or "",
|
||||||
|
"created_at": created_at_timestamp
|
||||||
|
})
|
||||||
|
|
||||||
|
# 6. 构建返回结果
|
||||||
|
return {
|
||||||
|
"total": total_all,
|
||||||
|
"items": items,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"total": filtered_total,
|
||||||
|
"hasnext": (page * pagesize) < filtered_total
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"情景记忆分页查询出错: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_semantic_memory_list(
|
||||||
|
self,
|
||||||
|
end_user_id: str
|
||||||
|
) -> list:
|
||||||
|
"""
|
||||||
|
获取语义记忆全量列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": str,
|
||||||
|
"name": str,
|
||||||
|
"entity_type": str,
|
||||||
|
"core_definition": str
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"语义记忆列表查询: end_user_id={end_user_id}")
|
||||||
|
|
||||||
|
semantic_query = """
|
||||||
|
MATCH (e:ExtractedEntity)
|
||||||
|
WHERE e.end_user_id = $end_user_id
|
||||||
|
AND e.is_explicit_memory = true
|
||||||
|
RETURN elementId(e) AS id,
|
||||||
|
e.name AS name,
|
||||||
|
e.entity_type AS entity_type,
|
||||||
|
e.description AS core_definition
|
||||||
|
ORDER BY e.name ASC
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await self.neo4j_connector.execute_query(
|
||||||
|
semantic_query, end_user_id=end_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
items = []
|
||||||
|
if result:
|
||||||
|
for record in result:
|
||||||
|
items.append({
|
||||||
|
"id": record["id"],
|
||||||
|
"name": record.get("name") or "未命名",
|
||||||
|
"entity_type": record.get("entity_type") or "未分类",
|
||||||
|
"core_definition": record.get("core_definition") or ""
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"语义记忆列表查询成功: end_user_id={end_user_id}, total={len(items)}")
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"语义记忆列表查询出错: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
async def get_explicit_memory_details(
|
async def get_explicit_memory_details(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
|||||||
"""通义千问文档格式"""
|
"""通义千问文档格式"""
|
||||||
return True, {
|
return True, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
"text": f"<document name=\"{file_name}\">\n文档内容:\n{text}\n</document>"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def format_audio(
|
async def format_audio(
|
||||||
@@ -167,6 +167,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
|||||||
async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]:
|
async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
||||||
# Bedrock 文档需要 base64 编码
|
# Bedrock 文档需要 base64 编码
|
||||||
|
text = f"文档内容:\n{text}\n"
|
||||||
text_bytes = text.encode('utf-8')
|
text_bytes = text.encode('utf-8')
|
||||||
base64_text = base64.b64encode(text_bytes).decode('utf-8')
|
base64_text = base64.b64encode(text_bytes).decode('utf-8')
|
||||||
|
|
||||||
@@ -223,7 +224,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
|||||||
"""OpenAI 文档格式"""
|
"""OpenAI 文档格式"""
|
||||||
return True, {
|
return True, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
"text": f"<document name=\"{file_name}\">\n文档内容:\n{text}\n</document>"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def format_audio(
|
async def format_audio(
|
||||||
@@ -388,13 +389,14 @@ class MultimodalService:
|
|||||||
from app.models.workspace_model import Workspace as WorkspaceModel
|
from app.models.workspace_model import Workspace as WorkspaceModel
|
||||||
ws = self.db.query(WorkspaceModel).filter(WorkspaceModel.id == workspace_id).first()
|
ws = self.db.query(WorkspaceModel).filter(WorkspaceModel.id == workspace_id).first()
|
||||||
tenant_id = ws.tenant_id if ws else None
|
tenant_id = ws.tenant_id if ws else None
|
||||||
|
img_result = []
|
||||||
for img_info in img_infos:
|
for img_info in img_infos:
|
||||||
page = img_info["page"]
|
page = img_info["page"]
|
||||||
index = img_info["index"]
|
index = img_info["index"]
|
||||||
ext = img_info.get("ext", "png")
|
ext = img_info.get("ext", "png")
|
||||||
try:
|
try:
|
||||||
_, img_url = await self._save_doc_image_to_storage(img_info["bytes"], ext, tenant_id, workspace_id)
|
_, img_url = await self._save_doc_image_to_storage(img_info["bytes"], ext, tenant_id, workspace_id)
|
||||||
placeholder = f"第{page}页 第{index + 1}张图片" if page > 0 else f"第{index + 1}张图片"
|
placeholder = f"第{page}页 第{index + 1}张" if page > 0 else f"第{index + 1}张"
|
||||||
# 在文本内容中追加图片位置标记
|
# 在文本内容中追加图片位置标记
|
||||||
if result and result[-1].get("type") in ("text", "document"):
|
if result and result[-1].get("type") in ("text", "document"):
|
||||||
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
|
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
|
||||||
@@ -407,9 +409,10 @@ class MultimodalService:
|
|||||||
file_type="image/png",
|
file_type="image/png",
|
||||||
)
|
)
|
||||||
_, img_content = await self._process_image(img_file, strategy_class(img_file))
|
_, img_content = await self._process_image(img_file, strategy_class(img_file))
|
||||||
result.append(img_content)
|
img_result.append(img_content)
|
||||||
except Exception as img_err:
|
except Exception as img_err:
|
||||||
logger.warning(f"文档图片处理失败: {img_err}")
|
logger.warning(f"文档图片处理失败: {img_err}")
|
||||||
|
result.extend(img_result)
|
||||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||||
is_support, content = await self._process_audio(file, strategy)
|
is_support, content = await self._process_audio(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
|
|||||||
@@ -815,11 +815,12 @@ class ToolService:
|
|||||||
"default": param_info.get("default")
|
"default": param_info.get("default")
|
||||||
})
|
})
|
||||||
|
|
||||||
# 请求体参数
|
# 请求体参数 — _extract_request_body 返回 {"schema": {...}, "required": bool, ...}
|
||||||
request_body = operation.get("request_body")
|
request_body = operation.get("request_body")
|
||||||
if request_body:
|
if request_body:
|
||||||
schema_props = request_body.get("schema", {}).get("properties", {})
|
body_schema = request_body.get("schema", {})
|
||||||
required_props = request_body.get("schema", {}).get("required", [])
|
schema_props = body_schema.get("properties", {})
|
||||||
|
required_props = body_schema.get("required", [])
|
||||||
|
|
||||||
for prop_name, prop_schema in schema_props.items():
|
for prop_name, prop_schema in schema_props.items():
|
||||||
parameters.append({
|
parameters.append({
|
||||||
|
|||||||
@@ -17,8 +17,9 @@ from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
|||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.core.workflow.validator import validate_workflow_config
|
from app.core.workflow.validator import validate_workflow_config
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
|
from sqlalchemy import select
|
||||||
from app.models import App
|
from app.models import App
|
||||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
from app.models.workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.repositories.workflow_repository import (
|
from app.repositories.workflow_repository import (
|
||||||
WorkflowConfigRepository,
|
WorkflowConfigRepository,
|
||||||
@@ -918,6 +919,7 @@ class WorkflowService:
|
|||||||
input_data["conv_messages"] = conv_messages
|
input_data["conv_messages"] = conv_messages
|
||||||
init_message_length = len(input_data.get("conv_messages", []))
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
message_id = uuid.uuid4()
|
message_id = uuid.uuid4()
|
||||||
|
_cycle_items: dict[str, list] = {}
|
||||||
|
|
||||||
# 新会话时写入开场白
|
# 新会话时写入开场白
|
||||||
is_new_conversation = init_message_length == 0
|
is_new_conversation = init_message_length == 0
|
||||||
@@ -948,6 +950,15 @@ class WorkflowService:
|
|||||||
memory_storage_type=storage_type,
|
memory_storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
):
|
):
|
||||||
|
event_type = event.get("event")
|
||||||
|
event_data = event.get("data", {})
|
||||||
|
|
||||||
|
if event_type == "cycle_item":
|
||||||
|
cycle_id = event_data.get("cycle_id")
|
||||||
|
if cycle_id not in _cycle_items:
|
||||||
|
_cycle_items[cycle_id] = []
|
||||||
|
_cycle_items[cycle_id].append(event_data)
|
||||||
|
|
||||||
if event.get("event") == "workflow_end":
|
if event.get("event") == "workflow_end":
|
||||||
status = event.get("data", {}).get("status")
|
status = event.get("data", {}).get("status")
|
||||||
token_usage = event.get("data", {}).get("token_usage", {}) or {}
|
token_usage = event.get("data", {}).get("token_usage", {}) or {}
|
||||||
@@ -1019,6 +1030,18 @@ class WorkflowService:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f"unexpect workflow run status, status: {status}")
|
logger.error(f"unexpect workflow run status, status: {status}")
|
||||||
|
# 把积累的 cycle_item 写入 workflow_executions.output_data["node_outputs"]
|
||||||
|
if _cycle_items and execution.output_data:
|
||||||
|
import copy
|
||||||
|
new_output_data = copy.deepcopy(execution.output_data)
|
||||||
|
node_outputs = new_output_data.setdefault("node_outputs", {})
|
||||||
|
for cycle_node_id, items in _cycle_items.items():
|
||||||
|
if cycle_node_id in node_outputs:
|
||||||
|
node_outputs[cycle_node_id]["cycle_items"] = items
|
||||||
|
else:
|
||||||
|
node_outputs[cycle_node_id] = {"cycle_items": items}
|
||||||
|
execution.output_data = new_output_data
|
||||||
|
self.db.commit()
|
||||||
elif event.get("event") == "workflow_start":
|
elif event.get("event") == "workflow_start":
|
||||||
event["data"]["message_id"] = str(message_id)
|
event["data"]["message_id"] = str(message_id)
|
||||||
event = self._emit(public, event)
|
event = self._emit(public, event)
|
||||||
|
|||||||
355
api/app/tasks.py
355
api/app/tasks.py
@@ -30,11 +30,11 @@ from app.core.rag.llm.cv_model import QWenCV
|
|||||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||||
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
||||||
from app.core.rag.models.chunk import DocumentChunk
|
from app.core.rag.models.chunk import DocumentChunk
|
||||||
from app.core.rag.prompts.generator import question_proposal
|
from app.core.rag.prompts.generator import question_proposal, qa_proposal
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||||
ElasticSearchVectorFactory,
|
ElasticSearchVectorFactory,
|
||||||
)
|
)
|
||||||
from app.db import get_db, get_db_context
|
from app.db import get_db_context
|
||||||
from app.models import Document, File, Knowledge
|
from app.models import Document, File, Knowledge
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.schemas import document_schema, file_schema
|
from app.schemas import document_schema, file_schema
|
||||||
@@ -210,9 +210,14 @@ def _build_vision_model(file_path: str, db_knowledge):
|
|||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.rag.tasks.parse_document")
|
@celery_app.task(name="app.core.rag.tasks.parse_document")
|
||||||
def parse_document(file_path: str, document_id: uuid.UUID):
|
def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
||||||
"""
|
"""
|
||||||
Document parsing, vectorization, and storage
|
Document parsing, vectorization, and storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_key: Storage key for FileStorageService (e.g. "kb/{kb_id}/{file_id}.docx")
|
||||||
|
document_id: Document UUID
|
||||||
|
file_name: Original file name (used for extension detection in chunk())
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db_document = None
|
db_document = None
|
||||||
@@ -223,7 +228,6 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
try:
|
try:
|
||||||
# Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确
|
|
||||||
if not isinstance(document_id, uuid.UUID):
|
if not isinstance(document_id, uuid.UUID):
|
||||||
document_id = uuid.UUID(str(document_id))
|
document_id = uuid.UUID(str(document_id))
|
||||||
|
|
||||||
@@ -234,7 +238,11 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
if db_knowledge is None:
|
if db_knowledge is None:
|
||||||
raise ValueError(f"Knowledge {db_document.kb_id} not found")
|
raise ValueError(f"Knowledge {db_document.kb_id} not found")
|
||||||
|
|
||||||
# 1. Document parsing & segmentation
|
# Use file_name from argument or fall back to document record
|
||||||
|
if not file_name:
|
||||||
|
file_name = db_document.file_name
|
||||||
|
|
||||||
|
# 1. Download file from storage backend
|
||||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.")
|
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
db_document.progress = 0.0
|
db_document.progress = 0.0
|
||||||
@@ -245,45 +253,36 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_document)
|
db.refresh(db_document)
|
||||||
|
|
||||||
|
# Read file content from storage backend (no NFS dependency)
|
||||||
|
from app.services.file_storage_service import FileStorageService
|
||||||
|
import asyncio
|
||||||
|
storage_service = FileStorageService()
|
||||||
|
|
||||||
|
async def _download():
|
||||||
|
return await storage_service.download_file(file_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_binary = asyncio.run(_download())
|
||||||
|
except RuntimeError:
|
||||||
|
# If there's already a running loop (e.g. in some worker configurations)
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
file_binary = loop.run_until_complete(_download())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
if not file_binary:
|
||||||
|
raise IOError(f"Downloaded empty file from storage: {file_key}")
|
||||||
|
logger.info(f"[ParseDoc] Downloaded {len(file_binary)} bytes from storage key: {file_key}")
|
||||||
|
|
||||||
def progress_callback(prog=None, msg=None):
|
def progress_callback(prog=None, msg=None):
|
||||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.")
|
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.")
|
||||||
|
|
||||||
# Prepare vision_model for parsing
|
# Prepare vision_model for parsing
|
||||||
vision_model = _build_vision_model(file_path, db_knowledge)
|
vision_model = _build_vision_model(file_name, db_knowledge)
|
||||||
|
|
||||||
# 先将文件读入内存,避免解析过程中依赖 NFS 文件持续可访问
|
|
||||||
# python-docx 等库在 binary=None 时会用路径直接打开文件,
|
|
||||||
# 在 NFS/共享存储上可能因缓存失效导致 "Package not found"
|
|
||||||
max_wait_seconds = 30
|
|
||||||
wait_interval = 2
|
|
||||||
waited = 0
|
|
||||||
file_binary = None
|
|
||||||
while waited <= max_wait_seconds:
|
|
||||||
# os.listdir 强制 NFS 客户端刷新目录缓存
|
|
||||||
parent_dir = os.path.dirname(file_path)
|
|
||||||
try:
|
|
||||||
os.listdir(parent_dir)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
file_binary = f.read()
|
|
||||||
if not file_binary:
|
|
||||||
# NFS 上文件存在但内容为空(可能还在同步中)
|
|
||||||
raise IOError(f"File is empty (0 bytes), NFS may still be syncing: {file_path}")
|
|
||||||
break
|
|
||||||
except (FileNotFoundError, IOError) as e:
|
|
||||||
if waited >= max_wait_seconds:
|
|
||||||
raise type(e)(
|
|
||||||
f"File not accessible at '{file_path}' after waiting {max_wait_seconds}s: {e}"
|
|
||||||
)
|
|
||||||
logger.warning(f"File not ready on this node, retrying in {wait_interval}s: {file_path} ({e})")
|
|
||||||
time.sleep(wait_interval)
|
|
||||||
waited += wait_interval
|
|
||||||
|
|
||||||
from app.core.rag.app.naive import chunk
|
from app.core.rag.app.naive import chunk
|
||||||
logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}")
|
logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}")
|
||||||
res = chunk(filename=file_path,
|
res = chunk(filename=file_name,
|
||||||
binary=file_binary,
|
binary=file_binary,
|
||||||
from_page=0,
|
from_page=0,
|
||||||
to_page=DEFAULT_PARSE_TO_PAGE,
|
to_page=DEFAULT_PARSE_TO_PAGE,
|
||||||
@@ -312,6 +311,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||||
# 2.2 Vectorize and import batch documents
|
# 2.2 Vectorize and import batch documents
|
||||||
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
|
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
|
||||||
|
qa_prompt = db_document.parser_config.get("qa_prompt", None)
|
||||||
chat_model = None
|
chat_model = None
|
||||||
if auto_questions_topn:
|
if auto_questions_topn:
|
||||||
chat_model = Base(
|
chat_model = Base(
|
||||||
@@ -319,62 +319,123 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||||
)
|
)
|
||||||
|
logger.info(f"[QA] LLM model: {db_knowledge.llm.api_keys[0].model_name}, base_url: {db_knowledge.llm.api_keys[0].api_base}")
|
||||||
|
if qa_prompt:
|
||||||
|
logger.info(f"[QA] Using custom prompt ({len(qa_prompt)} chars)")
|
||||||
|
|
||||||
# 预先构建所有 batch 的 chunks,保证 sort_id 全局有序
|
# 预先构建所有 batch 的 chunks,保证 sort_id 全局有序
|
||||||
all_batch_chunks: list[list[DocumentChunk]] = []
|
all_batch_chunks: list[list[DocumentChunk]] = []
|
||||||
|
|
||||||
if auto_questions_topn:
|
if auto_questions_topn:
|
||||||
# auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组
|
# QA 模式(FastGPT 方案):
|
||||||
# 构建 (global_idx, item) 列表
|
# 1. 原 chunk 标记为 source(保留供 GraphRAG 使用,不参与检索)
|
||||||
|
# 2. LLM 生成 QA 对,每个 QA 对独立存储为 qa chunk
|
||||||
indexed_items = list(enumerate(res))
|
indexed_items = list(enumerate(res))
|
||||||
|
|
||||||
def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]:
|
def _generate_qa(idx_item: tuple[int, dict]) -> tuple[int, list]:
|
||||||
"""为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)"""
|
"""为单个 chunk 生成 QA 对(带缓存),返回 (global_idx, qa_pairs)"""
|
||||||
global_idx, item = idx_item
|
global_idx, item = idx_item
|
||||||
content = item["content_with_weight"]
|
content = item["content_with_weight"]
|
||||||
cached = get_llm_cache(chat_model.model_name, content, "question",
|
cache_params = {"topn": auto_questions_topn}
|
||||||
{"topn": auto_questions_topn})
|
if qa_prompt:
|
||||||
|
import hashlib
|
||||||
|
cache_params["prompt_hash"] = hashlib.md5(qa_prompt.encode()).hexdigest()[:8]
|
||||||
|
cached = get_llm_cache(chat_model.model_name, content, "qa", cache_params)
|
||||||
if not cached:
|
if not cached:
|
||||||
cached = question_proposal(chat_model, content, auto_questions_topn)
|
logger.info(f"[QA] Cache miss for chunk {global_idx}, calling LLM. cache_params={cache_params}")
|
||||||
set_llm_cache(chat_model.model_name, content, cached, "question",
|
try:
|
||||||
{"topn": auto_questions_topn})
|
pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt)
|
||||||
return global_idx, cached
|
except Exception as e:
|
||||||
|
logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}")
|
||||||
|
return global_idx, []
|
||||||
|
logger.info(f"[QA] Chunk {global_idx} generated {len(pairs)} QA pairs")
|
||||||
|
# 缓存存 JSON 字符串
|
||||||
|
set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa",
|
||||||
|
cache_params)
|
||||||
|
return global_idx, pairs
|
||||||
|
logger.info(f"[QA] Cache hit for chunk {global_idx}, cache_params={cache_params}, cached_type={type(cached).__name__}")
|
||||||
|
# 从缓存读取:可能是 JSON 字符串或旧格式纯文本
|
||||||
|
if isinstance(cached, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(cached)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
logger.info(f"[QA] Chunk {global_idx} loaded {len(parsed)} QA pairs from cache")
|
||||||
|
return global_idx, parsed
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
# 旧缓存格式(纯文本问题),尝试解析
|
||||||
|
from app.core.rag.prompts.generator import parse_qa_pairs
|
||||||
|
return global_idx, parse_qa_pairs(cached) if cached else []
|
||||||
|
return global_idx, cached if isinstance(cached, list) else []
|
||||||
|
|
||||||
# 并发调用 LLM 生成问题
|
# 并发调用 LLM 生成 QA 对
|
||||||
question_map: dict[int, str] = {}
|
qa_map: dict[int, list] = {}
|
||||||
with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor:
|
with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor:
|
||||||
futures = {q_executor.submit(_generate_question, item): item[0]
|
futures = {q_executor.submit(_generate_qa, item): item[0]
|
||||||
for item in indexed_items}
|
for item in indexed_items}
|
||||||
for future in futures:
|
for future in futures:
|
||||||
global_idx, cached = future.result()
|
global_idx, pairs = future.result()
|
||||||
question_map[global_idx] = cached
|
qa_map[global_idx] = pairs
|
||||||
|
|
||||||
progress_lines.append(
|
progress_lines.append(
|
||||||
f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks "
|
f"{datetime.now().strftime('%H:%M:%S')} QA pairs generated for {total_chunks} chunks "
|
||||||
f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).")
|
f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).")
|
||||||
|
|
||||||
# 按 batch 分组组装 DocumentChunk
|
# 组装 chunks:source chunks + qa chunks
|
||||||
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
source_chunks = []
|
||||||
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks)
|
qa_chunks = []
|
||||||
chunks = []
|
qa_sort_id = 0
|
||||||
for global_idx in range(batch_start, batch_end):
|
|
||||||
item = res[global_idx]
|
for global_idx in range(total_chunks):
|
||||||
metadata = {
|
item = res[global_idx]
|
||||||
|
source_chunk_id = uuid.uuid4().hex
|
||||||
|
|
||||||
|
# source chunk:保留原文,供 GraphRAG 使用,不参与向量检索
|
||||||
|
source_meta = {
|
||||||
|
"doc_id": source_chunk_id,
|
||||||
|
"file_id": str(db_document.file_id),
|
||||||
|
"file_name": db_document.file_name,
|
||||||
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||||
|
"document_id": str(db_document.id),
|
||||||
|
"knowledge_id": str(db_document.kb_id),
|
||||||
|
"sort_id": global_idx,
|
||||||
|
"status": 1,
|
||||||
|
"chunk_type": "source",
|
||||||
|
}
|
||||||
|
source_chunks.append(
|
||||||
|
DocumentChunk(page_content=item["content_with_weight"], metadata=source_meta))
|
||||||
|
|
||||||
|
# qa chunks:每个 QA 对独立存储
|
||||||
|
pairs = qa_map.get(global_idx, [])
|
||||||
|
for pair in pairs:
|
||||||
|
qa_meta = {
|
||||||
"doc_id": uuid.uuid4().hex,
|
"doc_id": uuid.uuid4().hex,
|
||||||
"file_id": str(db_document.file_id),
|
"file_id": str(db_document.file_id),
|
||||||
"file_name": db_document.file_name,
|
"file_name": db_document.file_name,
|
||||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||||
"document_id": str(db_document.id),
|
"document_id": str(db_document.id),
|
||||||
"knowledge_id": str(db_document.kb_id),
|
"knowledge_id": str(db_document.kb_id),
|
||||||
"sort_id": global_idx,
|
"sort_id": qa_sort_id,
|
||||||
"status": 1,
|
"status": 1,
|
||||||
|
"chunk_type": "qa",
|
||||||
|
"question": pair["question"],
|
||||||
|
"answer": pair["answer"],
|
||||||
|
"source_chunk_id": source_chunk_id,
|
||||||
}
|
}
|
||||||
cached = question_map[global_idx]
|
# page_content 存 question,用于向量索引
|
||||||
chunks.append(
|
qa_chunks.append(
|
||||||
DocumentChunk(
|
DocumentChunk(page_content=pair["question"], metadata=qa_meta))
|
||||||
page_content=f"question: {cached} answer: {item['content_with_weight']}",
|
qa_sort_id += 1
|
||||||
metadata=metadata))
|
|
||||||
all_batch_chunks.append(chunks)
|
# 按 batch 分组(source + qa 一起)
|
||||||
|
all_chunks = source_chunks + qa_chunks
|
||||||
|
for batch_start in range(0, len(all_chunks), EMBEDDING_BATCH_SIZE):
|
||||||
|
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, len(all_chunks))
|
||||||
|
all_batch_chunks.append(all_chunks[batch_start:batch_end])
|
||||||
|
|
||||||
|
progress_lines.append(
|
||||||
|
f"{datetime.now().strftime('%H:%M:%S')} QA mode: {len(source_chunks)} source chunks + "
|
||||||
|
f"{len(qa_chunks)} QA chunks prepared.")
|
||||||
else:
|
else:
|
||||||
# 无 auto_questions:直接构建 chunks
|
# 无 auto_questions:直接构建 chunks
|
||||||
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
||||||
@@ -636,6 +697,136 @@ def build_graphrag_for_document(document_id: str, knowledge_id: str):
|
|||||||
return f"build_graphrag_for_document '{document_id}' failed: {e}"
|
return f"build_graphrag_for_document '{document_id}' failed: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(name="app.core.rag.tasks.import_qa_chunks", queue="qa_import")
|
||||||
|
def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: bytes):
|
||||||
|
"""
|
||||||
|
异步导入 QA 问答对(CSV/Excel)
|
||||||
|
|
||||||
|
文件格式:第一行标题(跳过),第一列问题,第二列答案
|
||||||
|
"""
|
||||||
|
import csv as csv_module
|
||||||
|
import io
|
||||||
|
|
||||||
|
db = None
|
||||||
|
try:
|
||||||
|
from app.db import get_db_context
|
||||||
|
with get_db_context() as db:
|
||||||
|
db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
|
||||||
|
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first()
|
||||||
|
if not db_document or not db_knowledge:
|
||||||
|
logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found")
|
||||||
|
return {"error": "document or knowledge not found", "imported": 0}
|
||||||
|
|
||||||
|
# 1. 解析文件
|
||||||
|
qa_pairs = []
|
||||||
|
failed_rows = []
|
||||||
|
|
||||||
|
if filename.endswith(".csv"):
|
||||||
|
try:
|
||||||
|
text = contents.decode("utf-8-sig")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
text = contents.decode("gbk", errors="ignore")
|
||||||
|
|
||||||
|
sniffer = csv_module.Sniffer()
|
||||||
|
try:
|
||||||
|
dialect = sniffer.sniff(text[:2048])
|
||||||
|
delimiter = dialect.delimiter
|
||||||
|
except csv_module.Error:
|
||||||
|
delimiter = "," if "," in text[:500] else "\t"
|
||||||
|
|
||||||
|
reader = csv_module.reader(io.StringIO(text), delimiter=delimiter)
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
if len(row) >= 2 and row[0].strip() and row[1].strip():
|
||||||
|
qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()})
|
||||||
|
elif len(row) >= 1 and row[0].strip():
|
||||||
|
failed_rows.append(i + 1)
|
||||||
|
|
||||||
|
elif filename.endswith(".xlsx") or filename.endswith(".xls"):
|
||||||
|
try:
|
||||||
|
import openpyxl
|
||||||
|
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
|
||||||
|
for sheet in wb.worksheets:
|
||||||
|
for i, row in enumerate(sheet.iter_rows(values_only=True)):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
if len(row) >= 2 and row[0] and row[1]:
|
||||||
|
q = str(row[0]).strip()
|
||||||
|
a = str(row[1]).strip()
|
||||||
|
if q and a:
|
||||||
|
qa_pairs.append({"question": q, "answer": a})
|
||||||
|
elif len(row) >= 1 and row[0]:
|
||||||
|
failed_rows.append(i + 1)
|
||||||
|
wb.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ImportQA] Excel parse failed: {e}")
|
||||||
|
return {"error": f"Excel parse failed: {e}", "imported": 0}
|
||||||
|
|
||||||
|
if not qa_pairs:
|
||||||
|
logger.warning(f"[ImportQA] No valid QA pairs found in {filename}")
|
||||||
|
return {"error": "No valid QA pairs found", "imported": 0}
|
||||||
|
|
||||||
|
logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}")
|
||||||
|
|
||||||
|
# 2. 写入 ES
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
|
||||||
|
sort_id = 0
|
||||||
|
total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False)
|
||||||
|
if items:
|
||||||
|
sort_id = items[0].metadata["sort_id"]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for pair in qa_pairs:
|
||||||
|
sort_id += 1
|
||||||
|
doc_id = uuid.uuid4().hex
|
||||||
|
metadata = {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"file_id": str(db_document.file_id),
|
||||||
|
"file_name": db_document.file_name,
|
||||||
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||||
|
"document_id": document_id,
|
||||||
|
"knowledge_id": kb_id,
|
||||||
|
"sort_id": sort_id,
|
||||||
|
"status": 1,
|
||||||
|
"chunk_type": "qa",
|
||||||
|
"question": pair["question"],
|
||||||
|
"answer": pair["answer"],
|
||||||
|
}
|
||||||
|
chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata))
|
||||||
|
|
||||||
|
batch_size = 50
|
||||||
|
for i in range(0, len(chunks), batch_size):
|
||||||
|
batch = chunks[i:i + batch_size]
|
||||||
|
vector_service.add_chunks(batch)
|
||||||
|
|
||||||
|
# 3. 更新 chunk_num 和 progress
|
||||||
|
db_document.chunk_num += len(chunks)
|
||||||
|
db_document.progress = 1.0
|
||||||
|
db_document.progress_msg = f"QA 导入完成: {len(chunks)} 条"
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
result = {"imported": len(chunks), "failed_rows": failed_rows}
|
||||||
|
logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ImportQA] Failed: {e}", exc_info=True)
|
||||||
|
# 尝试更新文档状态为失败
|
||||||
|
try:
|
||||||
|
from app.db import get_db_context
|
||||||
|
with get_db_context() as err_db:
|
||||||
|
doc = err_db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
|
||||||
|
if doc:
|
||||||
|
doc.progress = -1.0
|
||||||
|
doc.progress_msg = f"QA 导入失败: {str(e)[:200]}"
|
||||||
|
err_db.commit()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {"error": str(e), "imported": 0}
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
||||||
def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||||
"""
|
"""
|
||||||
@@ -2025,7 +2216,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
end_users = db.query(EndUser).all()
|
end_users = db.query(EndUser).all()
|
||||||
if not end_users:
|
if not end_users:
|
||||||
logger.info("没有终端用户,跳过遗忘周期")
|
logger.info("没有终端用户,跳过遗忘周期")
|
||||||
return {"status": "SUCCESS", "message": "没有终端用户",
|
return {"status": "SUCCESS", "message": "没有终端用户",
|
||||||
"report": {"merged_count": 0, "failed_count": 0, "processed_users": 0},
|
"report": {"merged_count": 0, "failed_count": 0, "processed_users": 0},
|
||||||
"duration_seconds": time.time() - start_time}
|
"duration_seconds": time.time() - start_time}
|
||||||
|
|
||||||
@@ -2039,7 +2230,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
# 获取用户配置(自动回退到工作空间默认配置)
|
# 获取用户配置(自动回退到工作空间默认配置)
|
||||||
connected_config = get_end_user_connected_config(str(end_user.id), db)
|
connected_config = get_end_user_connected_config(str(end_user.id), db)
|
||||||
user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db)
|
user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db)
|
||||||
|
|
||||||
if not user_config_id:
|
if not user_config_id:
|
||||||
failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"})
|
failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"})
|
||||||
continue
|
continue
|
||||||
@@ -2048,13 +2239,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
report = await forget_service.trigger_forgetting_cycle(
|
report = await forget_service.trigger_forgetting_cycle(
|
||||||
db=db, end_user_id=str(end_user.id), config_id=user_config_id
|
db=db, end_user_id=str(end_user.id), config_id=user_config_id
|
||||||
)
|
)
|
||||||
|
|
||||||
total_merged += report.get('merged_count', 0)
|
total_merged += report.get('merged_count', 0)
|
||||||
total_failed += report.get('failed_count', 0)
|
total_failed += report.get('failed_count', 0)
|
||||||
processed_users += 1
|
processed_users += 1
|
||||||
|
|
||||||
logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点")
|
logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True)
|
logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True)
|
||||||
failed_users.append({"end_user_id": str(end_user.id), "error": str(e)})
|
failed_users.append({"end_user_id": str(end_user.id), "error": str(e)})
|
||||||
@@ -2801,18 +2992,18 @@ def run_incremental_clustering(
|
|||||||
包含任务执行结果的字典
|
包含任务执行结果的字典
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
async def _run() -> Dict[str, Any]:
|
async def _run() -> Dict[str, Any]:
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, "
|
f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, "
|
||||||
f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}"
|
f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
engine = LabelPropagationEngine(
|
engine = LabelPropagationEngine(
|
||||||
@@ -2820,12 +3011,12 @@ def run_incremental_clustering(
|
|||||||
llm_model_id=llm_model_id,
|
llm_model_id=llm_model_id,
|
||||||
embedding_model_id=embedding_model_id,
|
embedding_model_id=embedding_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 执行增量聚类
|
# 执行增量聚类
|
||||||
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
||||||
|
|
||||||
logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}")
|
logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
@@ -2836,18 +3027,18 @@ def run_incremental_clustering(
|
|||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop = set_asyncio_event_loop()
|
loop = set_asyncio_event_loop()
|
||||||
result = loop.run_until_complete(_run())
|
result = loop.run_until_complete(_run())
|
||||||
result["elapsed_time"] = time.time() - start_time
|
result["elapsed_time"] = time.time() - start_time
|
||||||
result["task_id"] = self.request.id
|
result["task_id"] = self.request.id
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, "
|
f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, "
|
||||||
f"elapsed_time={result['elapsed_time']:.2f}s"
|
f"elapsed_time={result['elapsed_time']:.2f}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|||||||
@@ -63,6 +63,23 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- celery
|
- celery
|
||||||
|
|
||||||
|
celery-task-scheduler:
|
||||||
|
image: redbear-mem-open:latest
|
||||||
|
container_name: celery-task-scheduler
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
volumes:
|
||||||
|
- /etc/localtime:/etc/localtime:ro
|
||||||
|
command: python -m app.celery_task_scheduler
|
||||||
|
restart: unless-stopped
|
||||||
|
healthcheck:
|
||||||
|
test: CMD curl -f 127.0.0.1:8001 || exit 1
|
||||||
|
interval: 30s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 3
|
||||||
|
networks:
|
||||||
|
- celery
|
||||||
|
|
||||||
# Celery Beat - scheduler
|
# Celery Beat - scheduler
|
||||||
beat:
|
beat:
|
||||||
image: redbear-mem-open:latest
|
image: redbear-mem-open:latest
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
import { type FC, useRef } from 'react';
|
import { type FC, useRef } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useParams } from 'react-router-dom';
|
import { useParams } from 'react-router-dom';
|
||||||
import { Flex, Button } from 'antd';
|
import { Flex, Button, Form } from 'antd';
|
||||||
import type { ColumnsType } from 'antd/es/table';
|
import type { ColumnsType } from 'antd/es/table';
|
||||||
|
|
||||||
import { getAppLogsUrl } from '@/api/application';
|
import { getAppLogsUrl } from '@/api/application';
|
||||||
@@ -15,11 +15,14 @@ import Table from '@/components/Table'
|
|||||||
import { formatDateTime } from '@/utils/format';
|
import { formatDateTime } from '@/utils/format';
|
||||||
import type { LogItem, LogDetailModalRef } from './types'
|
import type { LogItem, LogDetailModalRef } from './types'
|
||||||
import LogDetailModal from './components/LogDetailModal'
|
import LogDetailModal from './components/LogDetailModal'
|
||||||
|
import SearchInput from '@/components/SearchInput'
|
||||||
|
|
||||||
const Statistics: FC = () => {
|
const Statistics: FC = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { id } = useParams();
|
const { id } = useParams();
|
||||||
const logDetailRef = useRef<LogDetailModalRef>(null);
|
const logDetailRef = useRef<LogDetailModalRef>(null);
|
||||||
|
const [form] = Form.useForm();
|
||||||
|
const values = Form.useWatch([], form);
|
||||||
|
|
||||||
const handleViewDetail = (item: LogItem) => {
|
const handleViewDetail = (item: LogItem) => {
|
||||||
logDetailRef.current?.handleOpen(item);
|
logDetailRef.current?.handleOpen(item);
|
||||||
@@ -62,15 +65,26 @@ const Statistics: FC = () => {
|
|||||||
];
|
];
|
||||||
return (
|
return (
|
||||||
<div className="rb:bg-white rb:rounded-lg rb:pt-3 rb:px-3">
|
<div className="rb:bg-white rb:rounded-lg rb:pt-3 rb:px-3">
|
||||||
|
<Flex justify="flex-end" className="rb:mb-3!">
|
||||||
|
<Form form={form}>
|
||||||
|
<Form.Item name="keyword" noStyle>
|
||||||
|
<SearchInput
|
||||||
|
placeholder={t('application.logSearchPlaceholder')}
|
||||||
|
variant="outlined"
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Flex>
|
||||||
<Table<LogItem>
|
<Table<LogItem>
|
||||||
apiUrl={getAppLogsUrl(id || '')}
|
apiUrl={getAppLogsUrl(id || '')}
|
||||||
apiParams={{
|
apiParams={{
|
||||||
is_draft: false,
|
is_draft: false,
|
||||||
|
...(values ?? {})
|
||||||
}}
|
}}
|
||||||
columns={columns}
|
columns={columns}
|
||||||
rowKey="id"
|
rowKey="id"
|
||||||
isScroll={true}
|
isScroll={true}
|
||||||
scrollY="calc(100vh - 214px)"
|
scrollY="calc(100vh - 242px)"
|
||||||
/>
|
/>
|
||||||
<LogDetailModal ref={logDetailRef} />
|
<LogDetailModal ref={logDetailRef} />
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-03-13 17:27:52
|
* @Date: 2026-03-13 17:27:52
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-07 21:48:30
|
* @Last Modified time: 2026-04-24 18:14:25
|
||||||
*/
|
*/
|
||||||
import { type FC, useState, useRef, useEffect } from 'react'
|
import { type FC, useState, useRef, useEffect } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
@@ -59,6 +59,7 @@ interface NodeData {
|
|||||||
node_type?: string;
|
node_type?: string;
|
||||||
input?: any;
|
input?: any;
|
||||||
output?: any;
|
output?: any;
|
||||||
|
process?: any;
|
||||||
elapsed_time?: string;
|
elapsed_time?: string;
|
||||||
error?: any;
|
error?: any;
|
||||||
state: Record<string, any>;
|
state: Record<string, any>;
|
||||||
@@ -485,7 +486,7 @@ const TestChat: FC<TestChatProps> = ({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const updateWorkflowNodeEndMessage = (data: NodeData) => {
|
const updateWorkflowNodeEndMessage = (data: NodeData) => {
|
||||||
const { node_id, input, output, error, elapsed_time, status } = data;
|
const { node_id, input, output, process, error, elapsed_time, status } = data;
|
||||||
setChatList(prev => {
|
setChatList(prev => {
|
||||||
const newList = [...prev]
|
const newList = [...prev]
|
||||||
const lastIndex = newList.length - 1
|
const lastIndex = newList.length - 1
|
||||||
@@ -498,6 +499,7 @@ const TestChat: FC<TestChatProps> = ({
|
|||||||
content: {
|
content: {
|
||||||
input,
|
input,
|
||||||
output,
|
output,
|
||||||
|
process,
|
||||||
error,
|
error,
|
||||||
},
|
},
|
||||||
status: status || 'completed',
|
status: status || 'completed',
|
||||||
@@ -514,7 +516,7 @@ const TestChat: FC<TestChatProps> = ({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const updateWorkflowCycleMessage = (data: NodeData) => {
|
const updateWorkflowCycleMessage = (data: NodeData) => {
|
||||||
const { node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = data;
|
const { node_id, cycle_id, cycle_idx, input, output, process, error, elapsed_time, status } = data;
|
||||||
const { nodes } = config as WorkflowConfig
|
const { nodes } = config as WorkflowConfig
|
||||||
const node = nodes.find(n => n.id === node_id);
|
const node = nodes.find(n => n.id === node_id);
|
||||||
const { name, type } = node || {}
|
const { name, type } = node || {}
|
||||||
@@ -538,6 +540,7 @@ const TestChat: FC<TestChatProps> = ({
|
|||||||
cycle_idx,
|
cycle_idx,
|
||||||
input,
|
input,
|
||||||
output,
|
output,
|
||||||
|
process,
|
||||||
error,
|
error,
|
||||||
},
|
},
|
||||||
status: status || 'completed',
|
status: status || 'completed',
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-03-24 16:31:24
|
* @Date: 2026-03-24 16:31:24
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-24 16:31:24
|
* @Last Modified time: 2026-04-24 17:49:58
|
||||||
*/
|
*/
|
||||||
import { forwardRef, useImperativeHandle, useState, useEffect } from 'react';
|
import { forwardRef, useImperativeHandle, useState, useEffect } from 'react';
|
||||||
import { Flex, Button, Empty, Skeleton } from 'antd';
|
import { Flex, Button, Empty, Skeleton } from 'antd';
|
||||||
@@ -14,6 +14,12 @@ import { getAppLogDetail } from '@/api/application'
|
|||||||
import ChatContent from '@/components/Chat/ChatContent'
|
import ChatContent from '@/components/Chat/ChatContent'
|
||||||
import { formatDateTime } from '@/utils/format'
|
import { formatDateTime } from '@/utils/format'
|
||||||
import type { ChatItem } from '@/components/Chat/types'
|
import type { ChatItem } from '@/components/Chat/types'
|
||||||
|
import Runtime from '@/views/Workflow/components/Chat/Runtime'
|
||||||
|
import { nodeLibrary } from '@/views/Workflow/constant'
|
||||||
|
|
||||||
|
const nodeIconMap = Object.fromEntries(
|
||||||
|
nodeLibrary.flatMap(c => c.nodes.map(n => [n.type, n.icon]))
|
||||||
|
)
|
||||||
|
|
||||||
/** Log detail data with conversation messages */
|
/** Log detail data with conversation messages */
|
||||||
type Data = LogItem & {
|
type Data = LogItem & {
|
||||||
@@ -54,7 +60,30 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
|
|||||||
if (!vo) return
|
if (!vo) return
|
||||||
setLoading(true)
|
setLoading(true)
|
||||||
getAppLogDetail(vo.app_id, vo.id).then(res => {
|
getAppLogDetail(vo.app_id, vo.id).then(res => {
|
||||||
setData(res as Data)
|
const { node_executions_map, messages, ...rest } = res as Data;
|
||||||
|
let hasSubContentMessages = messages
|
||||||
|
if (messages && messages.length > 0 && node_executions_map && Object.keys(node_executions_map).length > 0) {
|
||||||
|
hasSubContentMessages = messages.map(item => {
|
||||||
|
if (item.id && node_executions_map[item.id]) {
|
||||||
|
item.subContent = node_executions_map[item.id]?.map(({ input, output, cycle_items = [], error, process, ...node }: any) => {
|
||||||
|
const converted: any = { ...node, icon: nodeIconMap[node.node_type], content: { input, output, process, error } }
|
||||||
|
if (node.node_type === 'loop' && Array.isArray(cycle_items) && cycle_items.length > 0) {
|
||||||
|
converted.subContent = cycle_items.map(({ input: cInput, output: cOutput, error: cError, process: cProcess, ...cNode }: any) => ({
|
||||||
|
...cNode,
|
||||||
|
icon: nodeIconMap[cNode.node_type],
|
||||||
|
content: { input: cInput, output: cOutput, process: cProcess, error: cError }
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
return converted
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return { ...item }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
setData({
|
||||||
|
...rest,
|
||||||
|
messages: hasSubContentMessages
|
||||||
|
})
|
||||||
})
|
})
|
||||||
.finally(() => {
|
.finally(() => {
|
||||||
setLoading(false)
|
setLoading(false)
|
||||||
@@ -66,6 +95,8 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
|
|||||||
handleClose
|
handleClose
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
console.log('data', data)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<RbModal
|
<RbModal
|
||||||
title={<>
|
title={<>
|
||||||
@@ -92,6 +123,7 @@ const LogDetailModal = forwardRef<LogDetailModalRef>((_props, ref) => {
|
|||||||
data={data.messages || []}
|
data={data.messages || []}
|
||||||
streamLoading={false}
|
streamLoading={false}
|
||||||
labelFormat={(item) => formatDateTime(item.created_at)}
|
labelFormat={(item) => formatDateTime(item.created_at)}
|
||||||
|
renderRuntime={(item, index) => <Runtime item={item} index={index} />}
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-06 21:10:56
|
* @Date: 2026-02-06 21:10:56
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-21 14:59:13
|
* @Last Modified time: 2026-04-24 18:13:22
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Workflow Chat Component
|
* Workflow Chat Component
|
||||||
@@ -66,7 +66,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
const [fileList, setFileList] = useState<any[]>([])
|
const [fileList, setFileList] = useState<any[]>([])
|
||||||
const [message, setMessage] = useState<string | undefined>(undefined)
|
const [message, setMessage] = useState<string | undefined>(undefined)
|
||||||
|
|
||||||
console.log('abortRef', abortRef)
|
console.log('abortRef', abortRef, chatList)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Opens the chat drawer and loads workflow variables from the start node
|
* Opens the chat drawer and loads workflow variables from the start node
|
||||||
@@ -185,7 +185,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
*/
|
*/
|
||||||
const handleStreamMessage = (data: SSEMessage[]) => {
|
const handleStreamMessage = (data: SSEMessage[]) => {
|
||||||
data.forEach(item => {
|
data.forEach(item => {
|
||||||
const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status, citations } = item.data as {
|
const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, process, error, elapsed_time, status, citations } = item.data as {
|
||||||
content: string;
|
content: string;
|
||||||
conversation_id: string | null;
|
conversation_id: string | null;
|
||||||
cycle_id: string;
|
cycle_id: string;
|
||||||
@@ -193,6 +193,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
node_id: string;
|
node_id: string;
|
||||||
node_name?: string;
|
node_name?: string;
|
||||||
node_type?: string;
|
node_type?: string;
|
||||||
|
process?: any;
|
||||||
input?: any;
|
input?: any;
|
||||||
output?: any;
|
output?: any;
|
||||||
elapsed_time?: string;
|
elapsed_time?: string;
|
||||||
@@ -277,6 +278,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
content: {
|
content: {
|
||||||
input,
|
input,
|
||||||
output,
|
output,
|
||||||
|
process,
|
||||||
error,
|
error,
|
||||||
},
|
},
|
||||||
status: status || 'completed',
|
status: status || 'completed',
|
||||||
@@ -305,13 +307,14 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
cycle_id,
|
cycle_id,
|
||||||
cycle_idx,
|
cycle_idx,
|
||||||
node_id,
|
node_id,
|
||||||
node_name: name,
|
node_name: type === 'cycle-start' ? t('workflow.cycle-start') : name,
|
||||||
node_type: type,
|
node_type: type,
|
||||||
icon,
|
icon,
|
||||||
content: {
|
content: {
|
||||||
cycle_idx,
|
cycle_idx,
|
||||||
input,
|
input,
|
||||||
output,
|
output,
|
||||||
|
process,
|
||||||
error,
|
error,
|
||||||
},
|
},
|
||||||
status: status || 'completed',
|
status: status || 'completed',
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-24 17:57:08
|
* @Date: 2026-02-24 17:57:08
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-20 15:33:48
|
* @Last Modified time: 2026-04-24 18:04:31
|
||||||
*/
|
*/
|
||||||
/*
|
/*
|
||||||
* Runtime Component
|
* Runtime Component
|
||||||
@@ -184,27 +184,30 @@ const Runtime: FC<{ item: ChatItem; index: number;}> = ({
|
|||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
{/* Display input and output data as JSON code blocks */}
|
{/* Display input and output data as JSON code blocks */}
|
||||||
{['input', 'output'].map(key => (
|
{['input', 'process', 'output'].map(key => {
|
||||||
<div key={key} className="rb:bg-[#EBEBEB] rb:rounded-lg">
|
if (vo.node_type !== 'http-request' && key === 'process') return null
|
||||||
<div className="rb:py-2 rb:px-3 rb:flex rb:justify-between rb:items-center rb:text-[12px]">
|
return (
|
||||||
{isLoop ? t(`workflow.runtime.${key}_cycle_vars`) : t(`workflow.${key}_result`)}
|
<div key={key} className="rb:bg-[#EBEBEB] rb:rounded-lg">
|
||||||
<Button
|
<div className="rb:py-2 rb:px-3 rb:flex rb:justify-between rb:items-center rb:text-[12px]">
|
||||||
className="rb:py-0! rb:px-1! rb:text-[12px]!"
|
{isLoop ? t(`workflow.runtime.${key}_cycle_vars`) : t(`workflow.${key}_result`)}
|
||||||
size="small"
|
<Button
|
||||||
onClick={() => handleCopy(typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}')}
|
className="rb:py-0! rb:px-1! rb:text-[12px]!"
|
||||||
>{t('common.copy')}</Button>
|
size="small"
|
||||||
|
onClick={() => handleCopy(typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}')}
|
||||||
|
>{t('common.copy')}</Button>
|
||||||
|
</div>
|
||||||
|
<div className="rb:max-h-40 rb:overflow-auto">
|
||||||
|
<CodeBlock
|
||||||
|
size="small"
|
||||||
|
value={typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}'}
|
||||||
|
needCopy={false}
|
||||||
|
showLineNumbers={true}
|
||||||
|
background="#EBEBEB"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="rb:max-h-40 rb:overflow-auto">
|
)
|
||||||
<CodeBlock
|
})}
|
||||||
size="small"
|
|
||||||
value={typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}'}
|
|
||||||
needCopy={false}
|
|
||||||
showLineNumbers={true}
|
|
||||||
background="#EBEBEB"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
)
|
)
|
||||||
}]}
|
}]}
|
||||||
|
|||||||
@@ -65,8 +65,8 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
||||||
'rb:border-[#171719]!': data.isSelected,
|
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
|
||||||
'rb:border-[#FCFCFD]': !data.isSelected,
|
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus,
|
||||||
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
||||||
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
||||||
})}>
|
})}>
|
||||||
|
|||||||
@@ -131,8 +131,8 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
||||||
'rb:border-[#171719]': data.isSelected,
|
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
|
||||||
'rb:border-[#FCFCFD]': !data.isSelected,
|
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus,
|
||||||
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
||||||
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
||||||
})}>
|
})}>
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ const NormalNode: ReactShapeConfig['component'] = ({ node }) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
||||||
'rb:border-[#171719]!': data.isSelected,
|
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
|
||||||
'rb:border-[#FCFCFD]': !data.isSelected,
|
'rb:border-[#FCFCFD]': !data.isSelected && !data.executionStatus,
|
||||||
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
'rb:border-[#369F21]!': !data.isSelected && data.executionStatus === 'completed',
|
||||||
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
'rb:border-[#FF5D34]!': !data.isSelected && data.executionStatus === 'failed',
|
||||||
})}>
|
})}>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 15:17:48
|
* @Date: 2026-02-03 15:17:48
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-20 16:00:26
|
* @Last Modified time: 2026-04-24 17:21:09
|
||||||
*/
|
*/
|
||||||
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
|
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
|
||||||
import type { HistoryCommand as Command } from '@antv/x6/lib/plugin/history/type';
|
import type { HistoryCommand as Command } from '@antv/x6/lib/plugin/history/type';
|
||||||
@@ -1492,7 +1492,7 @@ export const useWorkflowGraph = ({
|
|||||||
// Reset all node execution status first
|
// Reset all node execution status first
|
||||||
nodes.forEach(node => {
|
nodes.forEach(node => {
|
||||||
const data = node.getData();
|
const data = node.getData();
|
||||||
if (typeof data.status === 'string') {
|
if (typeof data.executionStatus === 'string') {
|
||||||
node.setData({ ...data, executionStatus: undefined });
|
node.setData({ ...data, executionStatus: undefined });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user