Compare commits
3 Commits
refactor/w
...
release/v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a7eaf563d7 | ||
|
|
4c7809ce4a | ||
|
|
51847955cd |
164
.github/workflows/release-notify-wechat.yml
vendored
164
.github/workflows/release-notify-wechat.yml
vendored
@@ -1,164 +0,0 @@
|
|||||||
name: Release Notify Workflow
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
types: [closed]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
notify:
|
|
||||||
if: >
|
|
||||||
github.event.pull_request.merged == true &&
|
|
||||||
startsWith(github.event.pull_request.base.ref, 'release')
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
# 防止 GitHub HEAD 未同步
|
|
||||||
- run: sleep 3
|
|
||||||
|
|
||||||
# 1️⃣ 获取分支 HEAD
|
|
||||||
- name: Get HEAD
|
|
||||||
id: head
|
|
||||||
run: |
|
|
||||||
HEAD_SHA=$(curl -s \
|
|
||||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
|
||||||
https://api.github.com/repos/${{ github.repository }}/git/ref/heads/${{ github.event.pull_request.base.ref }} \
|
|
||||||
| jq -r '.object.sha')
|
|
||||||
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
# 2️⃣ 判断是否最终PR
|
|
||||||
- name: Check Latest
|
|
||||||
id: check
|
|
||||||
run: |
|
|
||||||
if [ "${{ github.event.pull_request.merge_commit_sha }}" = "${{ steps.head.outputs.head_sha }}" ]; then
|
|
||||||
echo "ok=true" >> $GITHUB_OUTPUT
|
|
||||||
else
|
|
||||||
echo "ok=false" >> $GITHUB_OUTPUT
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 3️⃣ 尝试从 PR body 提取 Sourcery 摘要
|
|
||||||
- name: Extract Sourcery Summary
|
|
||||||
if: steps.check.outputs.ok == 'true'
|
|
||||||
id: sourcery
|
|
||||||
env:
|
|
||||||
PR_BODY: ${{ github.event.pull_request.body }}
|
|
||||||
run: |
|
|
||||||
python3 << 'PYEOF'
|
|
||||||
import os, re
|
|
||||||
|
|
||||||
body = os.environ.get("PR_BODY", "") or ""
|
|
||||||
match = re.search(
|
|
||||||
r"## Summary by Sourcery\s*\n(.*?)(?=\n## |\Z)",
|
|
||||||
body,
|
|
||||||
re.DOTALL
|
|
||||||
)
|
|
||||||
|
|
||||||
if match:
|
|
||||||
summary = match.group(1).strip()
|
|
||||||
found = "true"
|
|
||||||
else:
|
|
||||||
summary = ""
|
|
||||||
found = "false"
|
|
||||||
|
|
||||||
with open("sourcery_summary.txt", "w", encoding="utf-8") as f:
|
|
||||||
f.write(summary)
|
|
||||||
|
|
||||||
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
|
||||||
gh.write(f"found={found}\n")
|
|
||||||
gh.write("summary<<EOF\n")
|
|
||||||
gh.write(summary + "\n")
|
|
||||||
gh.write("EOF\n")
|
|
||||||
PYEOF
|
|
||||||
|
|
||||||
# 4️⃣ Fallback: 获取 commits + 通义千问总结
|
|
||||||
- name: Get Commits
|
|
||||||
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
|
||||||
run: |
|
|
||||||
curl -s \
|
|
||||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
|
||||||
${{ github.event.pull_request.commits_url }} \
|
|
||||||
| jq -r '.[].commit.message' | head -n 20 > commits.txt
|
|
||||||
|
|
||||||
- name: AI Summary (Qwen Fallback)
|
|
||||||
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
|
||||||
id: qwen
|
|
||||||
env:
|
|
||||||
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
|
||||||
run: |
|
|
||||||
python3 << 'PYEOF'
|
|
||||||
import json, os, urllib.request
|
|
||||||
|
|
||||||
with open("commits.txt", "r") as f:
|
|
||||||
commits = f.read().strip()
|
|
||||||
|
|
||||||
prompt = "请用中文总结以下代码提交,输出3-5条要点,面向测试人员。直接输出编号列表,不要输出标题或前言:\n" + commits
|
|
||||||
payload = {"model": "qwen-plus", "input": {"prompt": prompt}}
|
|
||||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
||||||
|
|
||||||
req = urllib.request.Request(
|
|
||||||
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
|
|
||||||
data=data,
|
|
||||||
headers={
|
|
||||||
"Authorization": "Bearer " + os.environ["DASHSCOPE_API_KEY"],
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
resp = urllib.request.urlopen(req)
|
|
||||||
result = json.loads(resp.read().decode())
|
|
||||||
summary = result.get("output", {}).get("text", "AI 摘要生成失败")
|
|
||||||
|
|
||||||
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
|
||||||
gh.write("summary<<EOF\n")
|
|
||||||
gh.write(summary + "\n")
|
|
||||||
gh.write("EOF\n")
|
|
||||||
PYEOF
|
|
||||||
|
|
||||||
# 5️⃣ 企业微信通知(Markdown)
|
|
||||||
- name: Notify WeChat
|
|
||||||
if: steps.check.outputs.ok == 'true'
|
|
||||||
env:
|
|
||||||
WECHAT_WEBHOOK: ${{ secrets.WECHAT_WEBHOOK }}
|
|
||||||
BRANCH: ${{ github.event.pull_request.base.ref }}
|
|
||||||
AUTHOR: ${{ github.event.pull_request.user.login }}
|
|
||||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
|
||||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
|
||||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
|
||||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
|
||||||
SOURCERY_FOUND: ${{ steps.sourcery.outputs.found }}
|
|
||||||
SOURCERY_SUMMARY: ${{ steps.sourcery.outputs.summary }}
|
|
||||||
QWEN_SUMMARY: ${{ steps.qwen.outputs.summary }}
|
|
||||||
run: |
|
|
||||||
python3 << 'PYEOF'
|
|
||||||
import json, os, urllib.request
|
|
||||||
|
|
||||||
if os.environ.get("SOURCERY_FOUND") == "true":
|
|
||||||
label = "Summary by Sourcery"
|
|
||||||
summary = os.environ.get("SOURCERY_SUMMARY", "")
|
|
||||||
else:
|
|
||||||
label = "AI变更摘要"
|
|
||||||
summary = os.environ.get("QWEN_SUMMARY", "AI 摘要生成失败")
|
|
||||||
|
|
||||||
pr_number = os.environ.get("PR_NUMBER", "")
|
|
||||||
short_sha = os.environ.get("MERGE_SHA", "")[:7]
|
|
||||||
|
|
||||||
content = (
|
|
||||||
"## 🚀 Release 发布通知\n"
|
|
||||||
"> <20> **分支**: " + os.environ["BRANCH"] + "\n"
|
|
||||||
"> 👤 **提交人**: " + os.environ["AUTHOR"] + "\n"
|
|
||||||
"> 📝 **标题**: " + os.environ["PR_TITLE"] + "\n"
|
|
||||||
"> 🔢 **PR编号**: #" + pr_number + "\n"
|
|
||||||
"> 🔖 **Commit**: " + short_sha + "\n\n"
|
|
||||||
"### 🧠 " + label + "\n" +
|
|
||||||
summary + "\n\n"
|
|
||||||
"---\n"
|
|
||||||
"🔗 [查看PR详情](" + os.environ["PR_URL"] + ")"
|
|
||||||
)
|
|
||||||
payload = {"msgtype": "markdown", "markdown": {"content": content}}
|
|
||||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
||||||
req = urllib.request.Request(
|
|
||||||
os.environ["WECHAT_WEBHOOK"],
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/json"}
|
|
||||||
)
|
|
||||||
resp = urllib.request.urlopen(req)
|
|
||||||
print(resp.read().decode())
|
|
||||||
PYEOF
|
|
||||||
33
.github/workflows/sync-to-gitee.yml
vendored
33
.github/workflows/sync-to-gitee.yml
vendored
@@ -1,33 +0,0 @@
|
|||||||
name: Sync to Gitee
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- '**' # All branchs
|
|
||||||
tags:
|
|
||||||
- '**' # All version tags (v1.0.0, etc.)
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
sync:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout Source Code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Sync to Gitee
|
|
||||||
run: |
|
|
||||||
GITEE_URL="https://${{ secrets.GITEE_USERNAME }}:${{ secrets.GITEE_TOKEN }}@gitee.com/hangzhou-hongxiong-intelligent_1/MemoryBear.git"
|
|
||||||
git remote add gitee "$GITEE_URL"
|
|
||||||
|
|
||||||
# 遍历并推送所有分支
|
|
||||||
for branch in $(git branch -r | grep -v HEAD | sed 's/origin\///'); do
|
|
||||||
echo "Syncing branch: $branch"
|
|
||||||
git push -f gitee "origin/$branch:refs/heads/$branch"
|
|
||||||
done
|
|
||||||
|
|
||||||
# 推送所有标签
|
|
||||||
echo "Syncing tags..."
|
|
||||||
git push gitee --tags --force
|
|
||||||
12
.gitignore
vendored
12
.gitignore
vendored
@@ -18,22 +18,16 @@ examples/
|
|||||||
.kiro
|
.kiro
|
||||||
.vscode
|
.vscode
|
||||||
.idea
|
.idea
|
||||||
.claude
|
|
||||||
|
|
||||||
# Temporary outputs
|
# Temporary outputs
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.hypothesis/
|
|
||||||
time.log
|
time.log
|
||||||
celerybeat-schedule.db
|
celerybeat-schedule.db
|
||||||
search_results.json
|
search_results.json
|
||||||
redbear-mem-metrics/
|
|
||||||
redbear-mem-benchmark/
|
|
||||||
pitch-deck/
|
|
||||||
|
|
||||||
api/migrations/versions
|
api/migrations/versions
|
||||||
tmp
|
tmp
|
||||||
files
|
files
|
||||||
powers/
|
|
||||||
|
|
||||||
# Exclude dep files
|
# Exclude dep files
|
||||||
huggingface.co/
|
huggingface.co/
|
||||||
@@ -42,7 +36,5 @@ tika-server*.jar*
|
|||||||
cl100k_base.tiktoken
|
cl100k_base.tiktoken
|
||||||
libssl*.deb
|
libssl*.deb
|
||||||
|
|
||||||
sandbox/lib/seccomp_redbear/target
|
sandbox/lib/seccomp_python/target
|
||||||
|
sandbox/lib/seccomp_nodejs/target
|
||||||
# Qoder repowiki generated content
|
|
||||||
.qoder/repowiki/zh/
|
|
||||||
|
|||||||
@@ -2,10 +2,6 @@
|
|||||||
|
|
||||||
# MemoryBear empowers AI with human-like memory capabilities
|
# MemoryBear empowers AI with human-like memory capabilities
|
||||||
|
|
||||||
[](LICENSE)
|
|
||||||
[](https://www.python.org/)
|
|
||||||
[](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
|
|
||||||
|
|
||||||
[中文](./README_CN.md) | English
|
[中文](./README_CN.md) | English
|
||||||
|
|
||||||
### [Installation Guide](#memorybear-installation-guide)
|
### [Installation Guide](#memorybear-installation-guide)
|
||||||
@@ -230,8 +226,8 @@ REDIS_PORT=6379
|
|||||||
REDIS_DB=1
|
REDIS_DB=1
|
||||||
|
|
||||||
# Celery (Using Redis as broker)
|
# Celery (Using Redis as broker)
|
||||||
REDIS_DB_CELERY_BROKER=1
|
BROKER_URL=redis://127.0.0.1:6379/0
|
||||||
REDIS_DB_CELERY_BACKEND=2
|
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||||
|
|
||||||
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
||||||
SECRET_KEY=your-secret-key-here
|
SECRET_KEY=your-secret-key-here
|
||||||
|
|||||||
@@ -2,10 +2,6 @@
|
|||||||
|
|
||||||
# MemoryBear 让AI拥有如同人类一样的记忆
|
# MemoryBear 让AI拥有如同人类一样的记忆
|
||||||
|
|
||||||
[](LICENSE)
|
|
||||||
[](https://www.python.org/)
|
|
||||||
[](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
|
|
||||||
|
|
||||||
中文 | [English](./README.md)
|
中文 | [English](./README.md)
|
||||||
|
|
||||||
### [安装教程](#memorybear安装教程)
|
### [安装教程](#memorybear安装教程)
|
||||||
@@ -205,8 +201,8 @@ REDIS_PORT=6379
|
|||||||
REDIS_DB=1
|
REDIS_DB=1
|
||||||
|
|
||||||
# Celery (使用Redis作为broker)
|
# Celery (使用Redis作为broker)
|
||||||
REDIS_DB_CELERY_BROKER=1
|
BROKER_URL=redis://127.0.0.1:6379/0
|
||||||
REDIS_DB_CELERY_BACKEND=2
|
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||||
|
|
||||||
# JWT密钥 (生成方式: openssl rand -hex 32)
|
# JWT密钥 (生成方式: openssl rand -hex 32)
|
||||||
SECRET_KEY=your-secret-key-here
|
SECRET_KEY=your-secret-key-here
|
||||||
|
|||||||
@@ -45,8 +45,7 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
|||||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||||
apt install -y libjemalloc-dev && \
|
apt install -y libjemalloc-dev && \
|
||||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||||
apt install -y ghostscript && \
|
apt install -y ghostscript
|
||||||
apt install -y libmagic1
|
|
||||||
|
|
||||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||||
|
|||||||
@@ -60,12 +60,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
|||||||
# are written from script.py.mako
|
# are written from script.py.mako
|
||||||
# output_encoding = utf-8
|
# output_encoding = utf-8
|
||||||
|
|
||||||
# Database connection URL - DO NOT hardcode credentials here!
|
sqlalchemy.url = postgresql://user:password@localhost/dbname
|
||||||
# Connection string is set dynamically from environment variables in migrations/env.py
|
|
||||||
# Required env vars: DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME
|
|
||||||
# Example: postgresql://user:password@localhost:5432/dbname
|
|
||||||
; sqlalchemy.url = postgresql://user:password@host:port/dbname
|
|
||||||
sqlalchemy.url = driver://user:password@host:port/dbname
|
|
||||||
|
|
||||||
|
|
||||||
[post_write_hooks]
|
[post_write_hooks]
|
||||||
|
|||||||
@@ -1,18 +1,16 @@
|
|||||||
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
from redis.asyncio import ConnectionPool
|
from redis.asyncio import ConnectionPool
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
# 设置日志记录器
|
# 设置日志记录器
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# 创建连接池
|
# 创建连接池
|
||||||
pool = ConnectionPool.from_url(
|
pool = ConnectionPool.from_url(
|
||||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
|
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
|
||||||
@@ -23,51 +21,6 @@ pool = ConnectionPool.from_url(
|
|||||||
)
|
)
|
||||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||||
|
|
||||||
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
|
|
||||||
|
|
||||||
# Thread-local storage for connection pools.
|
|
||||||
# Each thread (and each forked process) gets its own pool to avoid
|
|
||||||
# "Future attached to a different loop" errors in Celery --pool=threads
|
|
||||||
# and stale connections after fork in --pool=prefork.
|
|
||||||
_thread_local = threading.local()
|
|
||||||
|
|
||||||
|
|
||||||
def get_thread_safe_redis() -> redis.StrictRedis:
|
|
||||||
"""Return a Redis client whose connection pool is bound to the current
|
|
||||||
thread, process **and** event loop.
|
|
||||||
|
|
||||||
The pool is recreated when:
|
|
||||||
- The PID changes (fork, Celery --pool=prefork)
|
|
||||||
- The thread has no pool yet (Celery --pool=threads)
|
|
||||||
- The previously-cached event loop has been closed (Celery tasks call
|
|
||||||
``_shutdown_loop_gracefully`` which closes the loop after each run)
|
|
||||||
"""
|
|
||||||
current_pid = os.getpid()
|
|
||||||
cached_loop = getattr(_thread_local, "loop", None)
|
|
||||||
loop_stale = cached_loop is not None and cached_loop.is_closed()
|
|
||||||
|
|
||||||
if not hasattr(_thread_local, "pool") \
|
|
||||||
or getattr(_thread_local, "pid", None) != current_pid \
|
|
||||||
or loop_stale:
|
|
||||||
_thread_local.pid = current_pid
|
|
||||||
# Python 3.10+: get_event_loop() raises RuntimeError in threads
|
|
||||||
# where no loop has been set yet (e.g. Celery --pool=threads).
|
|
||||||
try:
|
|
||||||
_thread_local.loop = asyncio.get_event_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
_thread_local.loop = None
|
|
||||||
_thread_local.pool = ConnectionPool.from_url(
|
|
||||||
_REDIS_URL,
|
|
||||||
db=settings.REDIS_DB,
|
|
||||||
password=settings.REDIS_PASSWORD,
|
|
||||||
decode_responses=True,
|
|
||||||
max_connections=5,
|
|
||||||
health_check_interval=30,
|
|
||||||
)
|
|
||||||
|
|
||||||
return redis.StrictRedis(connection_pool=_thread_local.pool)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_redis_connection():
|
async def get_redis_connection():
|
||||||
"""获取Redis连接"""
|
"""获取Redis连接"""
|
||||||
try:
|
try:
|
||||||
@@ -76,7 +29,6 @@ async def get_redis_connection():
|
|||||||
logger.error(f"Redis连接失败: {str(e)}")
|
logger.error(f"Redis连接失败: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||||
"""设置Redis键值
|
"""设置Redis键值
|
||||||
|
|
||||||
@@ -90,13 +42,14 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
|||||||
val = json.dumps(val, ensure_ascii=False)
|
val = json.dumps(val, ensure_ascii=False)
|
||||||
|
|
||||||
if expire is not None:
|
if expire is not None:
|
||||||
|
# 设置带过期时间的键值
|
||||||
await aio_redis.set(key, val, ex=expire)
|
await aio_redis.set(key, val, ex=expire)
|
||||||
else:
|
else:
|
||||||
|
# 设置永久键值
|
||||||
await aio_redis.set(key, val)
|
await aio_redis.set(key, val)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Redis set错误: {str(e)}")
|
logger.error(f"Redis set错误: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
async def aio_redis_get(key: str):
|
async def aio_redis_get(key: str):
|
||||||
"""获取Redis键值"""
|
"""获取Redis键值"""
|
||||||
try:
|
try:
|
||||||
@@ -105,7 +58,6 @@ async def aio_redis_get(key: str):
|
|||||||
logger.error(f"Redis get错误: {str(e)}")
|
logger.error(f"Redis get错误: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def aio_redis_delete(key: str):
|
async def aio_redis_delete(key: str):
|
||||||
"""删除Redis键"""
|
"""删除Redis键"""
|
||||||
try:
|
try:
|
||||||
@@ -114,7 +66,6 @@ async def aio_redis_delete(key: str):
|
|||||||
logger.error(f"Redis delete错误: {str(e)}")
|
logger.error(f"Redis delete错误: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||||
"""发布消息到Redis频道"""
|
"""发布消息到Redis频道"""
|
||||||
try:
|
try:
|
||||||
@@ -127,7 +78,6 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
|||||||
logger.error(f"Redis发布错误: {str(e)}")
|
logger.error(f"Redis发布错误: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class RedisSubscriber:
|
class RedisSubscriber:
|
||||||
"""Redis订阅器"""
|
"""Redis订阅器"""
|
||||||
|
|
||||||
@@ -213,7 +163,6 @@ class RedisSubscriber:
|
|||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
|
|
||||||
|
|
||||||
class RedisPubSubManager:
|
class RedisPubSubManager:
|
||||||
"""Redis发布订阅管理器"""
|
"""Redis发布订阅管理器"""
|
||||||
|
|
||||||
@@ -247,6 +196,6 @@ class RedisPubSubManager:
|
|||||||
self.subscribers.clear()
|
self.subscribers.clear()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
# 全局实例
|
# 全局实例
|
||||||
pubsub_manager = RedisPubSubManager()
|
pubsub_manager = RedisPubSubManager()
|
||||||
|
|
||||||
|
|||||||
5
api/app/cache/__init__.py
vendored
5
api/app/cache/__init__.py
vendored
@@ -3,8 +3,9 @@ Cache 缓存模块
|
|||||||
|
|
||||||
提供各种缓存功能的统一入口
|
提供各种缓存功能的统一入口
|
||||||
"""
|
"""
|
||||||
from .memory import InterestMemoryCache
|
from .memory import EmotionMemoryCache, ImplicitMemoryCache
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InterestMemoryCache",
|
"EmotionMemoryCache",
|
||||||
|
"ImplicitMemoryCache",
|
||||||
]
|
]
|
||||||
|
|||||||
8
api/app/cache/memory/__init__.py
vendored
8
api/app/cache/memory/__init__.py
vendored
@@ -3,10 +3,10 @@ Memory 缓存模块
|
|||||||
|
|
||||||
提供记忆系统相关的缓存功能
|
提供记忆系统相关的缓存功能
|
||||||
"""
|
"""
|
||||||
from .interest_memory import InterestMemoryCache
|
from .emotion_memory import EmotionMemoryCache
|
||||||
from .activity_stats_cache import ActivityStatsCache
|
from .implicit_memory import ImplicitMemoryCache
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InterestMemoryCache",
|
"EmotionMemoryCache",
|
||||||
"ActivityStatsCache",
|
"ImplicitMemoryCache",
|
||||||
]
|
]
|
||||||
|
|||||||
124
api/app/cache/memory/activity_stats_cache.py
vendored
124
api/app/cache/memory/activity_stats_cache.py
vendored
@@ -1,124 +0,0 @@
|
|||||||
"""
|
|
||||||
Recent Activity Stats Cache
|
|
||||||
|
|
||||||
记忆提取活动统计缓存模块
|
|
||||||
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放
|
|
||||||
查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.aioRedis import get_thread_safe_redis
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 缓存过期时间:24小时
|
|
||||||
ACTIVITY_STATS_CACHE_EXPIRE = 86400
|
|
||||||
|
|
||||||
|
|
||||||
class ActivityStatsCache:
|
|
||||||
"""记忆提取活动统计缓存类"""
|
|
||||||
|
|
||||||
PREFIX = "cache:memory:activity_stats"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_key(cls, workspace_id: str) -> str:
|
|
||||||
"""生成 Redis key
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
完整的 Redis key
|
|
||||||
"""
|
|
||||||
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def set_activity_stats(
|
|
||||||
cls,
|
|
||||||
workspace_id: str,
|
|
||||||
stats: Dict[str, Any],
|
|
||||||
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
|
|
||||||
) -> bool:
|
|
||||||
"""设置记忆提取活动统计缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
stats: 统计数据,格式:
|
|
||||||
{
|
|
||||||
"chunk_count": int,
|
|
||||||
"statements_count": int,
|
|
||||||
"triplet_entities_count": int,
|
|
||||||
"triplet_relations_count": int,
|
|
||||||
"temporal_count": int,
|
|
||||||
}
|
|
||||||
expire: 过期时间(秒),默认24小时
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
是否设置成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = cls._get_key(workspace_id)
|
|
||||||
payload = {
|
|
||||||
"stats": stats,
|
|
||||||
"generated_at": datetime.now().isoformat(),
|
|
||||||
"workspace_id": workspace_id,
|
|
||||||
"cached": True,
|
|
||||||
}
|
|
||||||
value = json.dumps(payload, ensure_ascii=False)
|
|
||||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
|
||||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_activity_stats(
|
|
||||||
cls,
|
|
||||||
workspace_id: str,
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""获取记忆提取活动统计缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
统计数据字典,缓存不存在或已过期返回 None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = cls._get_key(workspace_id)
|
|
||||||
value = await get_thread_safe_redis().get(key)
|
|
||||||
if value:
|
|
||||||
payload = json.loads(value)
|
|
||||||
logger.info(f"命中活动统计缓存: {key}")
|
|
||||||
return payload
|
|
||||||
logger.info(f"活动统计缓存不存在或已过期: {key}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def delete_activity_stats(
|
|
||||||
cls,
|
|
||||||
workspace_id: str,
|
|
||||||
) -> bool:
|
|
||||||
"""删除记忆提取活动统计缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
是否删除成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = cls._get_key(workspace_id)
|
|
||||||
result = await get_thread_safe_redis().delete(key)
|
|
||||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
|
||||||
return result > 0
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
134
api/app/cache/memory/emotion_memory.py
vendored
Normal file
134
api/app/cache/memory/emotion_memory.py
vendored
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""
|
||||||
|
Emotion Suggestions Cache
|
||||||
|
|
||||||
|
情绪个性化建议缓存模块
|
||||||
|
用于缓存用户的情绪个性化建议数据
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from app.aioRedis import aio_redis
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmotionMemoryCache:
|
||||||
|
"""情绪建议缓存类"""
|
||||||
|
|
||||||
|
# Key 前缀
|
||||||
|
PREFIX = "cache:memory:emotion_memory"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_key(cls, *parts: str) -> str:
|
||||||
|
"""生成 Redis key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*parts: key 的各个部分
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的 Redis key
|
||||||
|
"""
|
||||||
|
return ":".join([cls.PREFIX] + list(parts))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def set_emotion_suggestions(
|
||||||
|
cls,
|
||||||
|
user_id: str,
|
||||||
|
suggestions_data: Dict[str, Any],
|
||||||
|
expire: int = 86400
|
||||||
|
) -> bool:
|
||||||
|
"""设置用户情绪建议缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(end_user_id)
|
||||||
|
suggestions_data: 建议数据字典,包含:
|
||||||
|
- health_summary: 健康状态摘要
|
||||||
|
- suggestions: 建议列表
|
||||||
|
- generated_at: 生成时间(可选)
|
||||||
|
expire: 过期时间(秒),默认24小时(86400秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否设置成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key("suggestions", user_id)
|
||||||
|
|
||||||
|
# 添加生成时间戳
|
||||||
|
if "generated_at" not in suggestions_data:
|
||||||
|
suggestions_data["generated_at"] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
# 添加缓存标记
|
||||||
|
suggestions_data["cached"] = True
|
||||||
|
|
||||||
|
value = json.dumps(suggestions_data, ensure_ascii=False)
|
||||||
|
await aio_redis.set(key, value, ex=expire)
|
||||||
|
logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取用户情绪建议缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(end_user_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
建议数据字典,如果不存在或已过期返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key("suggestions", user_id)
|
||||||
|
value = await aio_redis.get(key)
|
||||||
|
|
||||||
|
if value:
|
||||||
|
data = json.loads(value)
|
||||||
|
logger.info(f"成功获取情绪建议缓存: {key}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
logger.info(f"情绪建议缓存不存在或已过期: {key}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_emotion_suggestions(cls, user_id: str) -> bool:
|
||||||
|
"""删除用户情绪建议缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(end_user_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否删除成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key("suggestions", user_id)
|
||||||
|
result = await aio_redis.delete(key)
|
||||||
|
logger.info(f"删除情绪建议缓存: {key}, 结果: {result}")
|
||||||
|
return result > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_suggestions_ttl(cls, user_id: str) -> int:
|
||||||
|
"""获取情绪建议缓存的剩余过期时间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(end_user_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key("suggestions", user_id)
|
||||||
|
ttl = await aio_redis.ttl(key)
|
||||||
|
logger.debug(f"情绪建议缓存TTL: {key} = {ttl}秒")
|
||||||
|
return ttl
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取情绪建议缓存TTL失败: {e}")
|
||||||
|
return -2
|
||||||
136
api/app/cache/memory/implicit_memory.py
vendored
Normal file
136
api/app/cache/memory/implicit_memory.py
vendored
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Implicit Memory Profile Cache
|
||||||
|
|
||||||
|
隐式记忆用户画像缓存模块
|
||||||
|
用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯)
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from app.aioRedis import aio_redis
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ImplicitMemoryCache:
|
||||||
|
"""隐式记忆用户画像缓存类"""
|
||||||
|
|
||||||
|
# Key 前缀
|
||||||
|
PREFIX = "cache:memory:implicit_memory"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_key(cls, *parts: str) -> str:
|
||||||
|
"""生成 Redis key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*parts: key 的各个部分
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的 Redis key
|
||||||
|
"""
|
||||||
|
return ":".join([cls.PREFIX] + list(parts))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def set_user_profile(
|
||||||
|
cls,
|
||||||
|
user_id: str,
|
||||||
|
profile_data: Dict[str, Any],
|
||||||
|
expire: int = 86400
|
||||||
|
) -> bool:
|
||||||
|
"""设置用户完整画像缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(end_user_id)
|
||||||
|
profile_data: 画像数据字典,包含:
|
||||||
|
- preferences: 偏好标签列表
|
||||||
|
- portrait: 四维画像对象
|
||||||
|
- interest_areas: 兴趣领域分布对象
|
||||||
|
- habits: 行为习惯列表
|
||||||
|
- generated_at: 生成时间(可选)
|
||||||
|
expire: 过期时间(秒),默认24小时(86400秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否设置成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key("profile", user_id)
|
||||||
|
|
||||||
|
# 添加生成时间戳
|
||||||
|
if "generated_at" not in profile_data:
|
||||||
|
profile_data["generated_at"] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
# 添加缓存标记
|
||||||
|
profile_data["cached"] = True
|
||||||
|
|
||||||
|
value = json.dumps(profile_data, ensure_ascii=False)
|
||||||
|
await aio_redis.set(key, value, ex=expire)
|
||||||
|
logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"设置用户画像缓存失败: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取用户完整画像缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(end_user_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
画像数据字典,如果不存在或已过期返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key("profile", user_id)
|
||||||
|
value = await aio_redis.get(key)
|
||||||
|
|
||||||
|
if value:
|
||||||
|
data = json.loads(value)
|
||||||
|
logger.info(f"成功获取用户画像缓存: {key}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
logger.info(f"用户画像缓存不存在或已过期: {key}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取用户画像缓存失败: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_user_profile(cls, user_id: str) -> bool:
|
||||||
|
"""删除用户完整画像缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(end_user_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否删除成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key("profile", user_id)
|
||||||
|
result = await aio_redis.delete(key)
|
||||||
|
logger.info(f"删除用户画像缓存: {key}, 结果: {result}")
|
||||||
|
return result > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除用户画像缓存失败: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_profile_ttl(cls, user_id: str) -> int:
|
||||||
|
"""获取用户画像缓存的剩余过期时间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(end_user_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key("profile", user_id)
|
||||||
|
ttl = await aio_redis.ttl(key)
|
||||||
|
logger.debug(f"用户画像缓存TTL: {key} = {ttl}秒")
|
||||||
|
return ttl
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取用户画像缓存TTL失败: {e}")
|
||||||
|
return -2
|
||||||
122
api/app/cache/memory/interest_memory.py
vendored
122
api/app/cache/memory/interest_memory.py
vendored
@@ -1,122 +0,0 @@
|
|||||||
"""
|
|
||||||
Interest Distribution Cache
|
|
||||||
|
|
||||||
兴趣分布缓存模块
|
|
||||||
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Optional, List, Dict, Any
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.aioRedis import get_thread_safe_redis
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 缓存过期时间:24小时
|
|
||||||
INTEREST_CACHE_EXPIRE = 86400
|
|
||||||
|
|
||||||
|
|
||||||
class InterestMemoryCache:
|
|
||||||
"""兴趣分布缓存类"""
|
|
||||||
|
|
||||||
PREFIX = "cache:memory:interest_distribution"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_key(cls, end_user_id: str, language: str) -> str:
|
|
||||||
"""生成 Redis key
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 用户ID
|
|
||||||
language: 语言类型
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
完整的 Redis key
|
|
||||||
"""
|
|
||||||
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def set_interest_distribution(
|
|
||||||
cls,
|
|
||||||
end_user_id: str,
|
|
||||||
language: str,
|
|
||||||
data: List[Dict[str, Any]],
|
|
||||||
expire: int = INTEREST_CACHE_EXPIRE,
|
|
||||||
) -> bool:
|
|
||||||
"""设置用户兴趣分布缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 用户ID
|
|
||||||
language: 语言类型
|
|
||||||
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
|
|
||||||
expire: 过期时间(秒),默认24小时
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
是否设置成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = cls._get_key(end_user_id, language)
|
|
||||||
payload = {
|
|
||||||
"data": data,
|
|
||||||
"generated_at": datetime.now().isoformat(),
|
|
||||||
"cached": True,
|
|
||||||
}
|
|
||||||
value = json.dumps(payload, ensure_ascii=False)
|
|
||||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
|
||||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_interest_distribution(
|
|
||||||
cls,
|
|
||||||
end_user_id: str,
|
|
||||||
language: str,
|
|
||||||
) -> Optional[List[Dict[str, Any]]]:
|
|
||||||
"""获取用户兴趣分布缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 用户ID
|
|
||||||
language: 语言类型
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
兴趣分布列表,缓存不存在或已过期返回 None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = cls._get_key(end_user_id, language)
|
|
||||||
value = await get_thread_safe_redis().get(key)
|
|
||||||
if value:
|
|
||||||
payload = json.loads(value)
|
|
||||||
logger.info(f"命中兴趣分布缓存: {key}")
|
|
||||||
return payload.get("data")
|
|
||||||
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def delete_interest_distribution(
|
|
||||||
cls,
|
|
||||||
end_user_id: str,
|
|
||||||
language: str,
|
|
||||||
) -> bool:
|
|
||||||
"""删除用户兴趣分布缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 用户ID
|
|
||||||
language: 语言类型
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
是否删除成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = cls._get_key(end_user_id, language)
|
|
||||||
result = await get_thread_safe_redis().delete(key)
|
|
||||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
|
||||||
return result > 0
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
@@ -1,59 +1,25 @@
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
from celery.schedules import crontab
|
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging_config import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _mask_url(url: str) -> str:
|
|
||||||
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
|
|
||||||
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
|
|
||||||
|
|
||||||
|
|
||||||
# macOS fork() safety - must be set before any Celery initialization
|
# macOS fork() safety - must be set before any Celery initialization
|
||||||
if platform.system() == 'Darwin':
|
if platform.system() == 'Darwin':
|
||||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
|
|
||||||
# 创建 Celery 应用实例
|
# 创建 Celery 应用实例
|
||||||
# broker: 优先使用环境变量 CELERY_BROKER_URL(支持 amqp:// 等任意协议),
|
# broker: 任务队列(使用 Redis DB 0)
|
||||||
# 未配置则回退到 Redis 方案
|
# backend: 结果存储(使用 Redis DB 10)
|
||||||
# backend: 结果存储(使用 Redis)
|
|
||||||
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
|
||||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
|
||||||
|
|
||||||
_broker_url = os.getenv("CELERY_BROKER_URL") or \
|
|
||||||
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
|
||||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
|
||||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
|
||||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
|
||||||
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
|
|
||||||
# integration and accidentally override our canonical URLs.
|
|
||||||
os.environ.pop("BROKER_URL", None)
|
|
||||||
os.environ.pop("RESULT_BACKEND", None)
|
|
||||||
os.environ.pop("CELERY_BROKER", None)
|
|
||||||
os.environ.pop("CELERY_BACKEND", None)
|
|
||||||
|
|
||||||
celery_app = Celery(
|
celery_app = Celery(
|
||||||
"redbear_tasks",
|
"redbear_tasks",
|
||||||
broker=_broker_url,
|
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
||||||
backend=_backend_url,
|
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Celery app initialized",
|
|
||||||
extra={
|
|
||||||
"broker": _mask_url(_broker_url),
|
|
||||||
"backend": _mask_url(_backend_url),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Default queue for unrouted tasks
|
# Default queue for unrouted tasks
|
||||||
celery_app.conf.task_default_queue = 'memory_tasks'
|
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||||
|
|
||||||
@@ -68,21 +34,20 @@ celery_app.conf.update(
|
|||||||
accept_content=['json'],
|
accept_content=['json'],
|
||||||
result_serializer='json',
|
result_serializer='json',
|
||||||
|
|
||||||
# # 时区
|
# 时区
|
||||||
# timezone='Asia/Shanghai',
|
timezone='Asia/Shanghai',
|
||||||
# enable_utc=False,
|
enable_utc=True,
|
||||||
|
|
||||||
# 任务追踪
|
# 任务追踪
|
||||||
task_track_started=True,
|
task_track_started=True,
|
||||||
task_ignore_result=False,
|
task_ignore_result=False,
|
||||||
|
|
||||||
# 超时设置
|
# 超时设置
|
||||||
task_time_limit=3600, # 60分钟硬超时
|
task_time_limit=1800, # 30分钟硬超时
|
||||||
task_soft_time_limit=3000, # 50分钟软超时
|
task_soft_time_limit=1500, # 25分钟软超时
|
||||||
|
|
||||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||||
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
|
|
||||||
|
|
||||||
# 结果过期时间
|
# 结果过期时间
|
||||||
result_expires=3600, # 结果保存1小时
|
result_expires=3600, # 结果保存1小时
|
||||||
@@ -108,38 +73,15 @@ celery_app.conf.update(
|
|||||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
# Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题)
|
|
||||||
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
|
|
||||||
|
|
||||||
# Metadata extraction → memory_tasks queue
|
|
||||||
'app.tasks.extract_user_metadata': {'queue': 'memory_tasks'},
|
|
||||||
|
|
||||||
# Async emotion extraction → memory_tasks queue (IO-bound LLM calls)
|
|
||||||
'app.tasks.extract_emotion_batch': {'queue': 'memory_tasks'},
|
|
||||||
|
|
||||||
# Post-store dedup + alias merge → memory_tasks queue
|
|
||||||
'app.tasks.post_store_dedup_and_alias_merge': {'queue': 'memory_tasks'},
|
|
||||||
|
|
||||||
# Async metadata extraction → memory_tasks queue
|
|
||||||
'app.tasks.extract_metadata_batch': {'queue': 'memory_tasks'},
|
|
||||||
|
|
||||||
# Document tasks → document_tasks queue (prefork worker)
|
# Document tasks → document_tasks queue (prefork worker)
|
||||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||||
|
|
||||||
# GraphRAG tasks → graphrag_tasks queue (独立队列,避免阻塞文档解析)
|
|
||||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'graphrag_tasks'},
|
|
||||||
'app.core.rag.tasks.build_graphrag_for_document': {'queue': 'graphrag_tasks'},
|
|
||||||
|
|
||||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
|
||||||
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
|
||||||
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
|
||||||
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -147,44 +89,40 @@ celery_app.conf.update(
|
|||||||
celery_app.autodiscover_tasks(['app'])
|
celery_app.autodiscover_tasks(['app'])
|
||||||
|
|
||||||
# Celery Beat schedule for periodic tasks
|
# Celery Beat schedule for periodic tasks
|
||||||
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
|
# memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
# memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||||
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
|
# workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||||
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
|
# forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||||
implicit_emotions_update_schedule = crontab(
|
|
||||||
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
|
|
||||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建定时任务配置
|
# 构建定时任务配置
|
||||||
beat_schedule_config = {
|
# beat_schedule_config = {
|
||||||
"run-workspace-reflection": {
|
# "run-workspace-reflection": {
|
||||||
"task": "app.tasks.workspace_reflection_task",
|
# "task": "app.tasks.workspace_reflection_task",
|
||||||
"schedule": workspace_reflection_schedule,
|
# "schedule": workspace_reflection_schedule,
|
||||||
"args": (),
|
# "args": (),
|
||||||
},
|
# },
|
||||||
"regenerate-memory-cache": {
|
# "regenerate-memory-cache": {
|
||||||
"task": "app.tasks.regenerate_memory_cache",
|
# "task": "app.tasks.regenerate_memory_cache",
|
||||||
"schedule": memory_cache_regeneration_schedule,
|
# "schedule": memory_cache_regeneration_schedule,
|
||||||
"args": (),
|
# "args": (),
|
||||||
},
|
# },
|
||||||
"run-forgetting-cycle": {
|
# "run-forgetting-cycle": {
|
||||||
"task": "app.tasks.run_forgetting_cycle_task",
|
# "task": "app.tasks.run_forgetting_cycle_task",
|
||||||
"schedule": forgetting_cycle_schedule,
|
# "schedule": forgetting_cycle_schedule,
|
||||||
"kwargs": {
|
# "kwargs": {
|
||||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
# "config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||||
},
|
# },
|
||||||
},
|
# },
|
||||||
"write-all-workspaces-memory": {
|
# }
|
||||||
"task": "app.tasks.write_all_workspaces_memory_task",
|
|
||||||
"schedule": memory_increment_schedule,
|
|
||||||
"args": (),
|
|
||||||
},
|
|
||||||
"update-implicit-emotions-storage": {
|
|
||||||
"task": "app.tasks.update_implicit_emotions_storage",
|
|
||||||
"schedule": implicit_emotions_update_schedule,
|
|
||||||
"args": (),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
celery_app.conf.beat_schedule = beat_schedule_config
|
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||||
|
# if settings.DEFAULT_WORKSPACE_ID:
|
||||||
|
# beat_schedule_config["write-total-memory"] = {
|
||||||
|
# "task": "app.controllers.memory_storage_controller.search_all",
|
||||||
|
# "schedule": memory_increment_schedule,
|
||||||
|
# "kwargs": {
|
||||||
|
# "workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||||
|
# },
|
||||||
|
# }
|
||||||
|
|
||||||
|
# celery_app.conf.beat_schedule = beat_schedule_config
|
||||||
|
|||||||
@@ -1,500 +0,0 @@
|
|||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
import redis
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.logging_config import get_named_logger
|
|
||||||
from app.celery_app import celery_app
|
|
||||||
|
|
||||||
logger = get_named_logger("task_scheduler")
|
|
||||||
|
|
||||||
# per-user queue scheduler:uq:{user_id}
|
|
||||||
USER_QUEUE_PREFIX = "scheduler:uq:"
|
|
||||||
# User Collection of Pending Messages
|
|
||||||
ACTIVE_USERS = "scheduler:active_users"
|
|
||||||
# Set of users that can dispatch (ready signal)
|
|
||||||
READY_SET = "scheduler:ready_users"
|
|
||||||
# Metadata of tasks that have been dispatched and are pending completion
|
|
||||||
PENDING_HASH = "scheduler:pending_tasks"
|
|
||||||
# Dynamic Sharding: Instance Registry
|
|
||||||
REGISTRY_KEY = "scheduler:instances"
|
|
||||||
|
|
||||||
TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded
|
|
||||||
HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds)
|
|
||||||
INSTANCE_TTL = 30 # Instance timeout (seconds)
|
|
||||||
|
|
||||||
LUA_ATOMIC_LOCK = """
|
|
||||||
local dispatch_lock = KEYS[1]
|
|
||||||
local lock_key = KEYS[2]
|
|
||||||
local instance_id = ARGV[1]
|
|
||||||
local dispatch_ttl = tonumber(ARGV[2])
|
|
||||||
local lock_ttl = tonumber(ARGV[3])
|
|
||||||
|
|
||||||
if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then
|
|
||||||
return 0
|
|
||||||
end
|
|
||||||
|
|
||||||
if redis.call('EXISTS', lock_key) == 1 then
|
|
||||||
redis.call('DEL', dispatch_lock)
|
|
||||||
return -1
|
|
||||||
end
|
|
||||||
|
|
||||||
redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl)
|
|
||||||
return 1
|
|
||||||
"""
|
|
||||||
|
|
||||||
LUA_SAFE_DELETE = """
|
|
||||||
if redis.call('GET', KEYS[1]) == ARGV[1] then
|
|
||||||
return redis.call('DEL', KEYS[1])
|
|
||||||
end
|
|
||||||
return 0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def stable_hash(value: str) -> int:
|
|
||||||
return int.from_bytes(
|
|
||||||
hashlib.md5(value.encode("utf-8")).digest(),
|
|
||||||
"big"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def health_check_server(scheduler_ref):
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
health_app = FastAPI()
|
|
||||||
|
|
||||||
@health_app.get("/")
|
|
||||||
def health():
|
|
||||||
return scheduler_ref.health()
|
|
||||||
|
|
||||||
port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001"))
|
|
||||||
threading.Thread(
|
|
||||||
target=uvicorn.run,
|
|
||||||
kwargs={
|
|
||||||
"app": health_app,
|
|
||||||
"host": "0.0.0.0",
|
|
||||||
"port": port,
|
|
||||||
"log_config": None,
|
|
||||||
},
|
|
||||||
daemon=True,
|
|
||||||
).start()
|
|
||||||
logger.info("[Health] Server started at http://0.0.0.0:%s", port)
|
|
||||||
|
|
||||||
|
|
||||||
class RedisTaskScheduler:
|
|
||||||
def __init__(self):
|
|
||||||
self.redis = redis.Redis(
|
|
||||||
host=settings.REDIS_HOST,
|
|
||||||
port=settings.REDIS_PORT,
|
|
||||||
db=settings.REDIS_DB_CELERY_BACKEND,
|
|
||||||
password=settings.REDIS_PASSWORD,
|
|
||||||
decode_responses=True,
|
|
||||||
)
|
|
||||||
self.running = False
|
|
||||||
self.dispatched = 0
|
|
||||||
self.errors = 0
|
|
||||||
|
|
||||||
self.instance_id = f"{socket.gethostname()}-{os.getpid()}"
|
|
||||||
self._shard_index = 0
|
|
||||||
self._shard_count = 1
|
|
||||||
self._last_heartbeat = 0.0
|
|
||||||
|
|
||||||
def push_task(self, task_name, user_id, params):
|
|
||||||
try:
|
|
||||||
msg_id = str(uuid.uuid4())
|
|
||||||
msg = json.dumps({
|
|
||||||
"msg_id": msg_id,
|
|
||||||
"task_name": task_name,
|
|
||||||
"user_id": user_id,
|
|
||||||
"params": json.dumps(params),
|
|
||||||
})
|
|
||||||
|
|
||||||
lock_key = f"{task_name}:{user_id}"
|
|
||||||
queue_key = f"{USER_QUEUE_PREFIX}{user_id}"
|
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
pipe.rpush(queue_key, msg)
|
|
||||||
pipe.sadd(ACTIVE_USERS, user_id)
|
|
||||||
pipe.set(
|
|
||||||
f"task_tracker:{msg_id}",
|
|
||||||
json.dumps({"status": "QUEUED", "task_id": None}),
|
|
||||||
ex=86400,
|
|
||||||
)
|
|
||||||
pipe.execute()
|
|
||||||
|
|
||||||
if not self.redis.exists(lock_key):
|
|
||||||
self.redis.sadd(READY_SET, user_id)
|
|
||||||
|
|
||||||
logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id)
|
|
||||||
return msg_id
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Push task exception %s", e, exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_task_status(self, msg_id: str) -> dict:
|
|
||||||
raw = self.redis.get(f"task_tracker:{msg_id}")
|
|
||||||
if raw is None:
|
|
||||||
return {"status": "NOT_FOUND"}
|
|
||||||
|
|
||||||
tracker = json.loads(raw)
|
|
||||||
status = tracker["status"]
|
|
||||||
task_id = tracker.get("task_id")
|
|
||||||
result_content = tracker.get("result") or {}
|
|
||||||
|
|
||||||
if status == "DISPATCHED" and task_id:
|
|
||||||
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
|
|
||||||
if result_raw:
|
|
||||||
result_data = json.loads(result_raw)
|
|
||||||
status = result_data.get("status", status)
|
|
||||||
result_content = result_data.get("result")
|
|
||||||
|
|
||||||
return {"status": status, "task_id": task_id, "result": result_content}
|
|
||||||
|
|
||||||
def _cleanup_finished(self):
|
|
||||||
pending = self.redis.hgetall(PENDING_HASH)
|
|
||||||
if not pending:
|
|
||||||
return
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
task_ids = list(pending.keys())
|
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for task_id in task_ids:
|
|
||||||
pipe.get(f"celery-task-meta-{task_id}")
|
|
||||||
results = pipe.execute()
|
|
||||||
|
|
||||||
cleanup_pipe = self.redis.pipeline()
|
|
||||||
has_cleanup = False
|
|
||||||
ready_user_ids = set()
|
|
||||||
|
|
||||||
for task_id, raw_result in zip(task_ids, results):
|
|
||||||
try:
|
|
||||||
meta = json.loads(pending[task_id])
|
|
||||||
lock_key = meta["lock_key"]
|
|
||||||
dispatched_at = meta.get("dispatched_at", 0)
|
|
||||||
age = now - dispatched_at
|
|
||||||
|
|
||||||
should_cleanup = False
|
|
||||||
result_data = {}
|
|
||||||
|
|
||||||
if raw_result is not None:
|
|
||||||
result_data = json.loads(raw_result)
|
|
||||||
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
|
|
||||||
should_cleanup = True
|
|
||||||
logger.info(
|
|
||||||
"Task finished: %s state=%s", task_id,
|
|
||||||
result_data.get("status"),
|
|
||||||
)
|
|
||||||
elif age > TASK_TIMEOUT:
|
|
||||||
should_cleanup = True
|
|
||||||
logger.warning(
|
|
||||||
"Task expired or lost: %s age=%.0fs, force cleanup",
|
|
||||||
task_id, age,
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_cleanup:
|
|
||||||
final_status = (
|
|
||||||
result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id)
|
|
||||||
|
|
||||||
cleanup_pipe.hdel(PENDING_HASH, task_id)
|
|
||||||
|
|
||||||
tracker_msg_id = meta.get("msg_id")
|
|
||||||
if tracker_msg_id:
|
|
||||||
cleanup_pipe.set(
|
|
||||||
f"task_tracker:{tracker_msg_id}",
|
|
||||||
json.dumps({
|
|
||||||
"status": final_status,
|
|
||||||
"task_id": task_id,
|
|
||||||
"result": result_data.get("result") or {},
|
|
||||||
}),
|
|
||||||
ex=86400,
|
|
||||||
)
|
|
||||||
has_cleanup = True
|
|
||||||
|
|
||||||
parts = lock_key.split(":", 1)
|
|
||||||
if len(parts) == 2:
|
|
||||||
ready_user_ids.add(parts[1])
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
|
|
||||||
self.errors += 1
|
|
||||||
|
|
||||||
if has_cleanup:
|
|
||||||
cleanup_pipe.execute()
|
|
||||||
|
|
||||||
if ready_user_ids:
|
|
||||||
self.redis.sadd(READY_SET, *ready_user_ids)
|
|
||||||
|
|
||||||
def _heartbeat(self):
|
|
||||||
now = time.time()
|
|
||||||
if now - self._last_heartbeat < HEARTBEAT_INTERVAL:
|
|
||||||
return
|
|
||||||
self._last_heartbeat = now
|
|
||||||
|
|
||||||
self.redis.hset(REGISTRY_KEY, self.instance_id, str(now))
|
|
||||||
|
|
||||||
all_instances = self.redis.hgetall(REGISTRY_KEY)
|
|
||||||
|
|
||||||
alive = []
|
|
||||||
dead = []
|
|
||||||
for iid, ts in all_instances.items():
|
|
||||||
if now - float(ts) < INSTANCE_TTL:
|
|
||||||
alive.append(iid)
|
|
||||||
else:
|
|
||||||
dead.append(iid)
|
|
||||||
|
|
||||||
if dead:
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for iid in dead:
|
|
||||||
pipe.hdel(REGISTRY_KEY, iid)
|
|
||||||
pipe.execute()
|
|
||||||
logger.info("Cleaned dead instances: %s", dead)
|
|
||||||
|
|
||||||
alive.sort()
|
|
||||||
self._shard_count = max(len(alive), 1)
|
|
||||||
self._shard_index = (
|
|
||||||
alive.index(self.instance_id) if self.instance_id in alive else 0
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Shard: %s/%s (instance=%s, alive=%d)",
|
|
||||||
self._shard_index, self._shard_count,
|
|
||||||
self.instance_id, len(alive),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _is_mine(self, user_id: str) -> bool:
|
|
||||||
if self._shard_count <= 1:
|
|
||||||
return True
|
|
||||||
return stable_hash(user_id) % self._shard_count == self._shard_index
|
|
||||||
|
|
||||||
def _dispatch(self, msg_id, msg_data) -> bool:
|
|
||||||
user_id = msg_data["user_id"]
|
|
||||||
task_name = msg_data["task_name"]
|
|
||||||
params = json.loads(msg_data.get("params", "{}"))
|
|
||||||
|
|
||||||
lock_key = f"{task_name}:{user_id}"
|
|
||||||
dispatch_lock = f"dispatch:{msg_id}"
|
|
||||||
|
|
||||||
result = self.redis.eval(
|
|
||||||
LUA_ATOMIC_LOCK, 2,
|
|
||||||
dispatch_lock, lock_key,
|
|
||||||
self.instance_id, str(300), str(3600),
|
|
||||||
)
|
|
||||||
|
|
||||||
if result == 0:
|
|
||||||
return False
|
|
||||||
if result == -1:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
task = celery_app.send_task(task_name, kwargs=params)
|
|
||||||
except Exception as e:
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
pipe.delete(dispatch_lock)
|
|
||||||
pipe.delete(lock_key)
|
|
||||||
pipe.execute()
|
|
||||||
self.errors += 1
|
|
||||||
logger.error(
|
|
||||||
"send_task failed for %s:%s msg=%s: %s",
|
|
||||||
task_name, user_id, msg_id, e, exc_info=True,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
pipe.set(lock_key, task.id, ex=3600)
|
|
||||||
pipe.hset(PENDING_HASH, task.id, json.dumps({
|
|
||||||
"lock_key": lock_key,
|
|
||||||
"dispatched_at": time.time(),
|
|
||||||
"msg_id": msg_id,
|
|
||||||
}))
|
|
||||||
pipe.delete(dispatch_lock)
|
|
||||||
pipe.set(
|
|
||||||
f"task_tracker:{msg_id}",
|
|
||||||
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
|
|
||||||
ex=86400,
|
|
||||||
)
|
|
||||||
pipe.execute()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"Post-dispatch state update failed for %s: %s",
|
|
||||||
task.id, e, exc_info=True,
|
|
||||||
)
|
|
||||||
self.errors += 1
|
|
||||||
|
|
||||||
self.dispatched += 1
|
|
||||||
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _process_batch(self, user_ids):
|
|
||||||
if not user_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for uid in user_ids:
|
|
||||||
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
|
|
||||||
heads = pipe.execute()
|
|
||||||
|
|
||||||
candidates = [] # (user_id, msg_dict)
|
|
||||||
empty_users = []
|
|
||||||
|
|
||||||
for uid, head in zip(user_ids, heads):
|
|
||||||
if head is None:
|
|
||||||
empty_users.append(uid)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
candidates.append((uid, json.loads(head)))
|
|
||||||
except (json.JSONDecodeError, TypeError) as e:
|
|
||||||
logger.error("Bad message in queue for user %s: %s", uid, e)
|
|
||||||
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
|
|
||||||
|
|
||||||
if empty_users:
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for uid in empty_users:
|
|
||||||
pipe.srem(ACTIVE_USERS, uid)
|
|
||||||
pipe.execute()
|
|
||||||
|
|
||||||
if not candidates:
|
|
||||||
return
|
|
||||||
|
|
||||||
for uid, msg in candidates:
|
|
||||||
if self._dispatch(msg["msg_id"], msg):
|
|
||||||
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
|
|
||||||
|
|
||||||
def schedule_loop(self):
|
|
||||||
self._heartbeat()
|
|
||||||
self._cleanup_finished()
|
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
pipe.smembers(READY_SET)
|
|
||||||
pipe.delete(READY_SET)
|
|
||||||
results = pipe.execute()
|
|
||||||
ready_users = results[0] or set()
|
|
||||||
|
|
||||||
my_users = [uid for uid in ready_users if self._is_mine(uid)]
|
|
||||||
|
|
||||||
if not my_users:
|
|
||||||
time.sleep(0.5)
|
|
||||||
return
|
|
||||||
|
|
||||||
self._process_batch(my_users)
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
def _full_scan(self):
|
|
||||||
cursor = 0
|
|
||||||
ready_batch = []
|
|
||||||
while True:
|
|
||||||
cursor, user_ids = self.redis.sscan(
|
|
||||||
ACTIVE_USERS, cursor=cursor, count=1000,
|
|
||||||
)
|
|
||||||
if user_ids:
|
|
||||||
my_users = [uid for uid in user_ids if self._is_mine(uid)]
|
|
||||||
if my_users:
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for uid in my_users:
|
|
||||||
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
|
|
||||||
heads = pipe.execute()
|
|
||||||
|
|
||||||
for uid, head in zip(my_users, heads):
|
|
||||||
if head is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
msg = json.loads(head)
|
|
||||||
lock_key = f"{msg['task_name']}:{uid}"
|
|
||||||
ready_batch.append((uid, lock_key))
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if cursor == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not ready_batch:
|
|
||||||
return
|
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
for _, lock_key in ready_batch:
|
|
||||||
pipe.exists(lock_key)
|
|
||||||
lock_exists = pipe.execute()
|
|
||||||
|
|
||||||
ready_uids = [
|
|
||||||
uid for (uid, _), locked in zip(ready_batch, lock_exists)
|
|
||||||
if not locked
|
|
||||||
]
|
|
||||||
|
|
||||||
if ready_uids:
|
|
||||||
self.redis.sadd(READY_SET, *ready_uids)
|
|
||||||
logger.info("Full scan found %d ready users", len(ready_uids))
|
|
||||||
|
|
||||||
def run_server(self):
|
|
||||||
health_check_server(self)
|
|
||||||
self.running = True
|
|
||||||
|
|
||||||
last_full_scan = 0.0
|
|
||||||
full_scan_interval = 30.0
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Scheduler started: instance=%s", self.instance_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
self.schedule_loop()
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
if now - last_full_scan > full_scan_interval:
|
|
||||||
self._full_scan()
|
|
||||||
last_full_scan = now
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Scheduler exception %s", e, exc_info=True)
|
|
||||||
self.errors += 1
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
def health(self) -> dict:
|
|
||||||
return {
|
|
||||||
"running": self.running,
|
|
||||||
"active_users": self.redis.scard(ACTIVE_USERS),
|
|
||||||
"ready_users": self.redis.scard(READY_SET),
|
|
||||||
"pending_tasks": self.redis.hlen(PENDING_HASH),
|
|
||||||
"dispatched": self.dispatched,
|
|
||||||
"errors": self.errors,
|
|
||||||
"shard": f"{self._shard_index}/{self._shard_count}",
|
|
||||||
"instance": self.instance_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
logger.info("Scheduler shutting down: instance=%s", self.instance_id)
|
|
||||||
self.running = False
|
|
||||||
try:
|
|
||||||
self.redis.hdel(REGISTRY_KEY, self.instance_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Shutdown cleanup error: %s", e)
|
|
||||||
|
|
||||||
|
|
||||||
scheduler: RedisTaskScheduler | None = None
|
|
||||||
if scheduler is None:
|
|
||||||
scheduler = RedisTaskScheduler()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
def _signal_handler(signum, frame):
|
|
||||||
scheduler.shutdown()
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, _signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, _signal_handler)
|
|
||||||
|
|
||||||
scheduler.run_server()
|
|
||||||
@@ -2,9 +2,6 @@
|
|||||||
Celery Worker 入口点
|
Celery Worker 入口点
|
||||||
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
|
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
|
||||||
"""
|
"""
|
||||||
# 必须在导入任何使用 DashScope SDK 的模块之前应用补丁
|
|
||||||
import app.plugins.dashscope_patch # noqa: F401
|
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.logging_config import LoggingConfig, get_logger
|
from app.core.logging_config import LoggingConfig, get_logger
|
||||||
|
|
||||||
@@ -16,39 +13,4 @@ logger.info("Celery worker logging initialized")
|
|||||||
# 导入任务模块以注册任务
|
# 导入任务模块以注册任务
|
||||||
import app.tasks
|
import app.tasks
|
||||||
|
|
||||||
|
|
||||||
@worker_process_init.connect
|
|
||||||
def _reinit_db_pool(**kwargs):
|
|
||||||
"""
|
|
||||||
prefork 子进程启动时重建被 fork 污染的资源。
|
|
||||||
|
|
||||||
fork() 后子进程继承了父进程的:
|
|
||||||
1. SQLAlchemy 连接池 — 多进程共享 TCP socket 导致 DB 连接损坏
|
|
||||||
2. ThreadPoolExecutor — fork 后线程状态不确定,第二个任务会死锁
|
|
||||||
"""
|
|
||||||
# 重建 DB 连接池
|
|
||||||
from app.db import engine
|
|
||||||
engine.dispose()
|
|
||||||
logger.info("DB connection pool disposed for forked worker process")
|
|
||||||
|
|
||||||
# 重建模块级 ThreadPoolExecutor(fork 后线程池不可用)
|
|
||||||
try:
|
|
||||||
from app.core.rag.deepdoc.parser import figure_parser
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
figure_parser.shared_executor = ThreadPoolExecutor(max_workers=10)
|
|
||||||
logger.info("figure_parser.shared_executor recreated")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to recreate figure_parser.shared_executor: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.rag.utils import libre_office
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import os
|
|
||||||
max_workers = os.cpu_count() * 2 if os.cpu_count() else 4
|
|
||||||
libre_office.executor = ThreadPoolExecutor(max_workers=max_workers)
|
|
||||||
logger.info("libre_office.executor recreated")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to recreate libre_office.executor: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['celery_app']
|
__all__ = ['celery_app']
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Configuration module for application settings."""
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""
|
|
||||||
社区版默认免费套餐配置
|
|
||||||
当无法从 SaaS 版获取 premium 模块时,使用此配置作为兜底
|
|
||||||
|
|
||||||
可通过环境变量覆盖配额配置,格式:QUOTA_<QUOTA_NAME>
|
|
||||||
例如:QUOTA_END_USER_QUOTA=100
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def _get_quota_from_env():
|
|
||||||
"""从环境变量获取配额配置"""
|
|
||||||
quota_keys = [
|
|
||||||
"workspace_quota",
|
|
||||||
"skill_quota",
|
|
||||||
"app_quota",
|
|
||||||
"knowledge_capacity_quota",
|
|
||||||
"memory_engine_quota",
|
|
||||||
"end_user_quota",
|
|
||||||
"ontology_project_quota",
|
|
||||||
"model_quota",
|
|
||||||
"api_ops_rate_limit",
|
|
||||||
]
|
|
||||||
quotas = {}
|
|
||||||
for key in quota_keys:
|
|
||||||
env_key = f"QUOTA_{key.upper()}"
|
|
||||||
env_value = os.getenv(env_key)
|
|
||||||
if env_value is not None:
|
|
||||||
try:
|
|
||||||
quotas[key] = float(env_value) if '.' in env_value else int(env_value)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
return quotas
|
|
||||||
|
|
||||||
|
|
||||||
def _build_default_free_plan():
|
|
||||||
"""构建默认免费套餐配置"""
|
|
||||||
base = {
|
|
||||||
"name": "记忆体验版",
|
|
||||||
"name_en": "Memory Experience",
|
|
||||||
"category": "saas_personal",
|
|
||||||
"tier_level": 0,
|
|
||||||
"version": "1.0",
|
|
||||||
"status": True,
|
|
||||||
"price": 0,
|
|
||||||
"billing_cycle": "permanent_free",
|
|
||||||
"core_value": "感受永久记忆",
|
|
||||||
"core_value_en": "Experience Permanent Memory",
|
|
||||||
"tech_support": "社群交流",
|
|
||||||
"tech_support_en": "Community Support",
|
|
||||||
"sla_compliance": "无",
|
|
||||||
"sla_compliance_en": "None",
|
|
||||||
"page_customization": "无",
|
|
||||||
"page_customization_en": "None",
|
|
||||||
"theme_color": "#64748B",
|
|
||||||
"quotas": {
|
|
||||||
"workspace_quota": 1,
|
|
||||||
"skill_quota": 5,
|
|
||||||
"app_quota": 2,
|
|
||||||
"knowledge_capacity_quota": 0.3,
|
|
||||||
"memory_engine_quota": 1,
|
|
||||||
"end_user_quota": 10,
|
|
||||||
"ontology_project_quota": 3,
|
|
||||||
"model_quota": 1,
|
|
||||||
"api_ops_rate_limit": 50,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
env_quotas = _get_quota_from_env()
|
|
||||||
if env_quotas:
|
|
||||||
base["quotas"].update(env_quotas)
|
|
||||||
|
|
||||||
return base
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_FREE_PLAN = _build_default_free_plan()
|
|
||||||
@@ -1,239 +0,0 @@
|
|||||||
"""默认本体场景配置
|
|
||||||
|
|
||||||
本模块定义系统预设的本体场景和实体类型配置。
|
|
||||||
这些配置用于在工作空间创建时自动初始化默认场景。
|
|
||||||
支持中英文双语配置,根据用户语言偏好创建对应语言的场景。
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 在线教育场景配置
|
|
||||||
ONLINE_EDUCATION_SCENE = {
|
|
||||||
"name_chinese": "在线教育",
|
|
||||||
"name_english": "Online Education",
|
|
||||||
"description_chinese": "适用于在线教育平台的本体建模,包含学生、教师、课程等核心实体类型",
|
|
||||||
"description_english": "Ontology modeling for online education platforms, including core entity types such as students, teachers, and courses",
|
|
||||||
"types": [
|
|
||||||
{
|
|
||||||
"name_chinese": "学生",
|
|
||||||
"name_english": "Student",
|
|
||||||
"description_chinese": "在教育系统中接受教育的个体,包含姓名、学号、年级、班级等属性",
|
|
||||||
"description_english": "Individuals receiving education in the education system, including attributes such as name, student ID, grade, and class"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "教师",
|
|
||||||
"name_english": "Teacher",
|
|
||||||
"description_chinese": "在教育系统中提供教学服务的个体,包含姓名、工号、任教学科、职称等属性",
|
|
||||||
"description_english": "Individuals providing teaching services in the education system, including attributes such as name, employee ID, teaching subject, and title"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "课程",
|
|
||||||
"name_english": "Course",
|
|
||||||
"description_chinese": "教育系统中的教学内容单元,包含课程名称、课程代码、学分、学时等属性",
|
|
||||||
"description_english": "Teaching content units in the education system, including attributes such as course name, course code, credits, and class hours"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "作业",
|
|
||||||
"name_english": "Assignment",
|
|
||||||
"description_chinese": "课程中布置的学习任务,包含作业标题、截止日期、所属课程、提交状态等属性",
|
|
||||||
"description_english": "Learning tasks assigned in courses, including attributes such as assignment title, deadline, course, and submission status"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "成绩",
|
|
||||||
"name_english": "Grade",
|
|
||||||
"description_chinese": "学生学习成果的评价结果,包含分数、评级、考试类型、所属课程等属性",
|
|
||||||
"description_english": "Evaluation results of student learning outcomes, including attributes such as score, rating, exam type, and course"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "考试",
|
|
||||||
"name_english": "Exam",
|
|
||||||
"description_chinese": "评估学生学习成果的测试活动,包含考试名称、时间、地点、科目等属性",
|
|
||||||
"description_english": "Test activities to assess student learning outcomes, including attributes such as exam name, time, location, and subject"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "教室",
|
|
||||||
"name_english": "Classroom",
|
|
||||||
"description_chinese": "进行教学活动的物理或虚拟空间,包含教室编号、容量、设备等属性",
|
|
||||||
"description_english": "Physical or virtual spaces for teaching activities, including attributes such as classroom number, capacity, and equipment"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "学科",
|
|
||||||
"name_english": "Subject",
|
|
||||||
"description_chinese": "知识的分类领域,包含学科名称、代码、所属院系等属性",
|
|
||||||
"description_english": "Classification domains of knowledge, including attributes such as subject name, code, and department"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "教材",
|
|
||||||
"name_english": "Textbook",
|
|
||||||
"description_chinese": "教学使用的书籍或资料,包含书名、作者、出版社、ISBN等属性",
|
|
||||||
"description_english": "Books or materials used for teaching, including attributes such as title, author, publisher, and ISBN"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "班级",
|
|
||||||
"name_english": "Class",
|
|
||||||
"description_chinese": "学生的组织单位,包含班级名称、年级、人数、班主任等属性",
|
|
||||||
"description_english": "Organizational units of students, including attributes such as class name, grade, number of students, and class teacher"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "学期",
|
|
||||||
"name_english": "Semester",
|
|
||||||
"description_chinese": "教学时间的划分单位,包含学期名称、开始时间、结束时间等属性",
|
|
||||||
"description_english": "Time division units for teaching, including attributes such as semester name, start time, and end time"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "课时",
|
|
||||||
"name_english": "Class Hour",
|
|
||||||
"description_chinese": "课程的时间单位,包含上课时间、地点、教师、课程等属性",
|
|
||||||
"description_english": "Time units of courses, including attributes such as class time, location, teacher, and course"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "教学计划",
|
|
||||||
"name_english": "Teaching Plan",
|
|
||||||
"description_chinese": "课程的教学安排,包含教学目标、内容安排、进度计划等属性",
|
|
||||||
"description_english": "Teaching arrangements for courses, including attributes such as teaching objectives, content arrangement, and progress plan"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
# 情感陪伴场景配置
|
|
||||||
EMOTIONAL_COMPANION_SCENE = {
|
|
||||||
"name_chinese": "情感陪伴",
|
|
||||||
"name_english": "Emotional Companion",
|
|
||||||
"description_chinese": "适用于情感陪伴应用的本体建模,包含用户、情绪、活动等核心实体类型",
|
|
||||||
"description_english": "Ontology modeling for emotional companion applications, including core entity types such as users, emotions, and activities",
|
|
||||||
"types": [
|
|
||||||
{
|
|
||||||
"name_chinese": "用户",
|
|
||||||
"name_english": "User",
|
|
||||||
"description_chinese": "使用情感陪伴服务的个体,包含姓名、昵称、性格特征、偏好等属性",
|
|
||||||
"description_english": "Individuals using emotional companion services, including attributes such as name, nickname, personality traits, and preferences"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "情绪",
|
|
||||||
"name_english": "Emotion",
|
|
||||||
"description_chinese": "用户的情感状态,包含情绪类型、强度、触发原因、持续时间等属性",
|
|
||||||
"description_english": "Emotional states of users, including attributes such as emotion type, intensity, trigger cause, and duration"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "活动",
|
|
||||||
"name_english": "Activity",
|
|
||||||
"description_chinese": "用户参与的各类活动,包含活动名称、类型、参与者、时间地点等属性",
|
|
||||||
"description_english": "Various activities users participate in, including attributes such as activity name, type, participants, time, and location"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "对话",
|
|
||||||
"name_english": "Conversation",
|
|
||||||
"description_chinese": "用户之间的交流记录,包含对话主题、参与者、时间、关键内容等属性",
|
|
||||||
"description_english": "Communication records between users, including attributes such as conversation topic, participants, time, and key content"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "兴趣爱好",
|
|
||||||
"name_english": "Hobby",
|
|
||||||
"description_chinese": "用户的兴趣和爱好,包含爱好名称、类别、熟练程度、相关活动等属性",
|
|
||||||
"description_english": "User interests and hobbies, including attributes such as hobby name, category, proficiency level, and related activities"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "日常事件",
|
|
||||||
"name_english": "Daily Event",
|
|
||||||
"description_chinese": "用户日常生活中的事件,包含事件描述、时间、地点、相关人物等属性",
|
|
||||||
"description_english": "Events in users' daily lives, including attributes such as event description, time, location, and related people"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "关系",
|
|
||||||
"name_english": "Relationship",
|
|
||||||
"description_chinese": "用户之间的社会关系,包含关系类型、亲密度、建立时间等属性",
|
|
||||||
"description_english": "Social relationships between users, including attributes such as relationship type, intimacy, and establishment time"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "回忆",
|
|
||||||
"name_english": "Memory",
|
|
||||||
"description_chinese": "用户的重要记忆片段,包含回忆内容、时间、地点、相关人物等属性",
|
|
||||||
"description_english": "Important memory fragments of users, including attributes such as memory content, time, location, and related people"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "地点",
|
|
||||||
"name_english": "Location",
|
|
||||||
"description_chinese": "用户活动的地理位置,包含地点名称、地址、类型、相关事件等属性",
|
|
||||||
"description_english": "Geographic locations of user activities, including attributes such as location name, address, type, and related events"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "时间节点",
|
|
||||||
"name_english": "Time Point",
|
|
||||||
"description_chinese": "重要的时间标记,包含日期、事件、意义等属性",
|
|
||||||
"description_english": "Important time markers, including attributes such as date, event, and significance"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "目标",
|
|
||||||
"name_english": "Goal",
|
|
||||||
"description_chinese": "用户设定的目标,包含目标描述、截止时间、完成状态、相关活动等属性",
|
|
||||||
"description_english": "Goals set by users, including attributes such as goal description, deadline, completion status, and related activities"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name_chinese": "成就",
|
|
||||||
"name_english": "Achievement",
|
|
||||||
"description_chinese": "用户获得的成就,包含成就名称、获得时间、描述、相关目标等属性",
|
|
||||||
"description_english": "Achievements obtained by users, including attributes such as achievement name, acquisition time, description, and related goals"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
# 导出默认场景列表
|
|
||||||
DEFAULT_SCENES = [ONLINE_EDUCATION_SCENE, EMOTIONAL_COMPANION_SCENE]
|
|
||||||
|
|
||||||
|
|
||||||
def get_scene_name(scene_config: dict, language: str = "zh") -> str:
|
|
||||||
"""获取场景名称(根据语言)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_config: 场景配置字典
|
|
||||||
language: 语言类型 ("zh" 或 "en")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
对应语言的场景名称
|
|
||||||
"""
|
|
||||||
if language == "en":
|
|
||||||
return scene_config.get("name_english", scene_config.get("name_chinese"))
|
|
||||||
return scene_config.get("name_chinese")
|
|
||||||
|
|
||||||
|
|
||||||
def get_scene_description(scene_config: dict, language: str = "zh") -> str:
|
|
||||||
"""获取场景描述(根据语言)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_config: 场景配置字典
|
|
||||||
language: 语言类型 ("zh" 或 "en")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
对应语言的场景描述
|
|
||||||
"""
|
|
||||||
if language == "en":
|
|
||||||
return scene_config.get("description_english", scene_config.get("description_chinese"))
|
|
||||||
return scene_config.get("description_chinese")
|
|
||||||
|
|
||||||
|
|
||||||
def get_type_name(type_config: dict, language: str = "zh") -> str:
|
|
||||||
"""获取类型名称(根据语言)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
type_config: 类型配置字典
|
|
||||||
language: 语言类型 ("zh" 或 "en")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
对应语言的类型名称
|
|
||||||
"""
|
|
||||||
if language == "en":
|
|
||||||
return type_config.get("name_english", type_config.get("name_chinese"))
|
|
||||||
return type_config.get("name_chinese")
|
|
||||||
|
|
||||||
|
|
||||||
def get_type_description(type_config: dict, language: str = "zh") -> str:
|
|
||||||
"""获取类型描述(根据语言)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
type_config: 类型配置字典
|
|
||||||
language: 语言类型 ("zh" 或 "en")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
对应语言的类型描述
|
|
||||||
"""
|
|
||||||
if language == "en":
|
|
||||||
return type_config.get("description_english", type_config.get("description_chinese"))
|
|
||||||
return type_config.get("description_chinese")
|
|
||||||
@@ -1,249 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""默认本体场景初始化器
|
|
||||||
|
|
||||||
本模块提供默认本体场景和类型的自动初始化功能。
|
|
||||||
在工作空间创建时,自动添加预设的本体场景和实体类型。
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
DefaultOntologyInitializer: 默认本体场景初始化器
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.config.default_ontology_config import (
|
|
||||||
DEFAULT_SCENES,
|
|
||||||
get_scene_name,
|
|
||||||
get_scene_description,
|
|
||||||
get_type_name,
|
|
||||||
get_type_description,
|
|
||||||
)
|
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
|
||||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultOntologyInitializer:
|
|
||||||
"""默认本体场景初始化器
|
|
||||||
|
|
||||||
负责在工作空间创建时自动初始化默认的本体场景和类型。
|
|
||||||
遵循最小侵入原则,确保初始化失败不阻止工作空间创建。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
db: 数据库会话
|
|
||||||
scene_repo: 场景Repository
|
|
||||||
class_repo: 类型Repository
|
|
||||||
logger: 业务日志记录器
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, db: Session):
|
|
||||||
"""初始化
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
"""
|
|
||||||
self.db = db
|
|
||||||
self.scene_repo = OntologySceneRepository(db)
|
|
||||||
self.class_repo = OntologyClassRepository(db)
|
|
||||||
self.logger = get_business_logger()
|
|
||||||
|
|
||||||
def initialize_default_scenes(
|
|
||||||
self,
|
|
||||||
workspace_id: UUID,
|
|
||||||
language: str = "zh"
|
|
||||||
) -> Tuple[bool, str]:
|
|
||||||
"""为工作空间初始化默认场景
|
|
||||||
|
|
||||||
创建两个默认场景(在线教育、情感陪伴)及其对应的实体类型。
|
|
||||||
如果创建失败,记录错误日志但不抛出异常。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
language: 语言类型 ("zh" 或 "en"),默认为 "zh"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, str]: (是否成功, 错误信息)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.logger.info(
|
|
||||||
f"开始初始化默认本体场景 - workspace_id={workspace_id}, language={language}"
|
|
||||||
)
|
|
||||||
|
|
||||||
scenes_created = 0
|
|
||||||
total_types_created = 0
|
|
||||||
|
|
||||||
# 遍历默认场景配置
|
|
||||||
for scene_config in DEFAULT_SCENES:
|
|
||||||
scene_name = get_scene_name(scene_config, language)
|
|
||||||
|
|
||||||
# 创建场景及其类型
|
|
||||||
scene_id = self._create_scene_with_types(workspace_id, scene_config, language)
|
|
||||||
|
|
||||||
if scene_id:
|
|
||||||
scenes_created += 1
|
|
||||||
# 统计类型数量
|
|
||||||
types_count = len(scene_config.get("types", []))
|
|
||||||
total_types_created += types_count
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"场景创建成功 - scene_name={scene_name}, "
|
|
||||||
f"scene_id={scene_id}, types_count={types_count}, language={language}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.warning(
|
|
||||||
f"场景创建失败 - scene_name={scene_name}, "
|
|
||||||
f"workspace_id={workspace_id}, language={language}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 记录总体结果
|
|
||||||
self.logger.info(
|
|
||||||
f"默认场景初始化完成 - workspace_id={workspace_id}, "
|
|
||||||
f"language={language}, scenes_created={scenes_created}, "
|
|
||||||
f"total_types_created={total_types_created}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果至少创建了一个场景,视为成功
|
|
||||||
if scenes_created > 0:
|
|
||||||
return True, ""
|
|
||||||
else:
|
|
||||||
error_msg = "所有默认场景创建失败"
|
|
||||||
self.logger.error(
|
|
||||||
f"默认场景初始化失败 - workspace_id={workspace_id}, "
|
|
||||||
f"language={language}, error={error_msg}"
|
|
||||||
)
|
|
||||||
return False, error_msg
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"默认场景初始化异常: {str(e)}"
|
|
||||||
self.logger.error(
|
|
||||||
f"默认场景初始化异常 - workspace_id={workspace_id}, "
|
|
||||||
f"language={language}, error={str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
return False, error_msg
|
|
||||||
|
|
||||||
def _create_scene_with_types(
|
|
||||||
self,
|
|
||||||
workspace_id: UUID,
|
|
||||||
scene_config: dict,
|
|
||||||
language: str = "zh"
|
|
||||||
) -> Optional[UUID]:
|
|
||||||
"""创建场景及其类型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
scene_config: 场景配置字典
|
|
||||||
language: 语言类型 ("zh" 或 "en")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[UUID]: 创建的场景ID,失败返回None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
scene_name = get_scene_name(scene_config, language)
|
|
||||||
scene_description = get_scene_description(scene_config, language)
|
|
||||||
|
|
||||||
# 检查是否已存在同名场景(支持向后兼容)
|
|
||||||
existing_scene = self.scene_repo.get_by_name(scene_name, workspace_id)
|
|
||||||
if existing_scene:
|
|
||||||
self.logger.info(
|
|
||||||
f"场景已存在,跳过创建 - scene_name={scene_name}, "
|
|
||||||
f"workspace_id={workspace_id}, scene_id={existing_scene.scene_id}, "
|
|
||||||
f"language={language}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 创建场景记录,设置 is_system_default=true
|
|
||||||
scene_data = {
|
|
||||||
"scene_name": scene_name,
|
|
||||||
"scene_description": scene_description
|
|
||||||
}
|
|
||||||
|
|
||||||
scene = self.scene_repo.create(scene_data, workspace_id)
|
|
||||||
|
|
||||||
# 设置系统默认标识
|
|
||||||
scene.is_system_default = True
|
|
||||||
self.db.flush()
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"场景创建成功 - scene_name={scene_name}, "
|
|
||||||
f"scene_id={scene.scene_id}, is_system_default=True, language={language}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 批量创建类型
|
|
||||||
types_config = scene_config.get("types", [])
|
|
||||||
types_created = self._batch_create_types(scene.scene_id, types_config, language)
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"场景类型创建完成 - scene_id={scene.scene_id}, "
|
|
||||||
f"types_created={types_created}/{len(types_config)}, language={language}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return scene.scene_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
scene_name = get_scene_name(scene_config, language)
|
|
||||||
self.logger.error(
|
|
||||||
f"场景创建失败 - scene_name={scene_name}, "
|
|
||||||
f"workspace_id={workspace_id}, language={language}, error={str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _batch_create_types(
|
|
||||||
self,
|
|
||||||
scene_id: UUID,
|
|
||||||
types_config: List[dict],
|
|
||||||
language: str = "zh"
|
|
||||||
) -> int:
|
|
||||||
"""批量创建实体类型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_id: 场景ID
|
|
||||||
types_config: 类型配置列表
|
|
||||||
language: 语言类型 ("zh" 或 "en")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 成功创建的类型数量
|
|
||||||
"""
|
|
||||||
created_count = 0
|
|
||||||
|
|
||||||
for type_config in types_config:
|
|
||||||
try:
|
|
||||||
type_name = get_type_name(type_config, language)
|
|
||||||
type_description = get_type_description(type_config, language)
|
|
||||||
|
|
||||||
# 创建类型数据
|
|
||||||
class_data = {
|
|
||||||
"class_name": type_name,
|
|
||||||
"class_description": type_description
|
|
||||||
}
|
|
||||||
|
|
||||||
# 创建类型
|
|
||||||
ontology_class = self.class_repo.create(class_data, scene_id)
|
|
||||||
|
|
||||||
# 设置系统默认标识
|
|
||||||
ontology_class.is_system_default = True
|
|
||||||
self.db.flush()
|
|
||||||
|
|
||||||
created_count += 1
|
|
||||||
|
|
||||||
self.logger.debug(
|
|
||||||
f"类型创建成功 - class_name={type_name}, "
|
|
||||||
f"class_id={ontology_class.class_id}, "
|
|
||||||
f"scene_id={scene_id}, is_system_default=True, language={language}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
type_name = get_type_name(type_config, language)
|
|
||||||
self.logger.warning(
|
|
||||||
f"单个类型创建失败,继续创建其他类型 - "
|
|
||||||
f"class_name={type_name}, scene_id={scene_id}, "
|
|
||||||
f"language={language}, error={str(e)}"
|
|
||||||
)
|
|
||||||
# 继续创建其他类型
|
|
||||||
continue
|
|
||||||
|
|
||||||
return created_count
|
|
||||||
@@ -8,7 +8,6 @@ from fastapi import APIRouter
|
|||||||
from . import (
|
from . import (
|
||||||
api_key_controller,
|
api_key_controller,
|
||||||
app_controller,
|
app_controller,
|
||||||
app_log_controller,
|
|
||||||
auth_controller,
|
auth_controller,
|
||||||
chunk_controller,
|
chunk_controller,
|
||||||
document_controller,
|
document_controller,
|
||||||
@@ -17,22 +16,17 @@ from . import (
|
|||||||
file_controller,
|
file_controller,
|
||||||
file_storage_controller,
|
file_storage_controller,
|
||||||
home_page_controller,
|
home_page_controller,
|
||||||
i18n_controller,
|
|
||||||
implicit_memory_controller,
|
implicit_memory_controller,
|
||||||
knowledge_controller,
|
knowledge_controller,
|
||||||
knowledgeshare_controller,
|
knowledgeshare_controller,
|
||||||
mcp_market_controller,
|
|
||||||
mcp_market_config_controller,
|
|
||||||
memory_agent_controller,
|
memory_agent_controller,
|
||||||
memory_dashboard_controller,
|
memory_dashboard_controller,
|
||||||
memory_episodic_controller,
|
memory_episodic_controller,
|
||||||
memory_explicit_controller,
|
memory_explicit_controller,
|
||||||
memory_forget_controller,
|
memory_forget_controller,
|
||||||
memory_perceptual_controller,
|
|
||||||
memory_reflection_controller,
|
memory_reflection_controller,
|
||||||
memory_short_term_controller,
|
memory_short_term_controller,
|
||||||
memory_storage_controller,
|
memory_storage_controller,
|
||||||
memory_working_controller,
|
|
||||||
model_controller,
|
model_controller,
|
||||||
multi_agent_controller,
|
multi_agent_controller,
|
||||||
prompt_optimizer_controller,
|
prompt_optimizer_controller,
|
||||||
@@ -45,10 +39,13 @@ from . import (
|
|||||||
upload_controller,
|
upload_controller,
|
||||||
user_controller,
|
user_controller,
|
||||||
user_memory_controllers,
|
user_memory_controllers,
|
||||||
|
workflow_controller,
|
||||||
workspace_controller,
|
workspace_controller,
|
||||||
|
memory_forget_controller,
|
||||||
|
home_page_controller,
|
||||||
|
memory_perceptual_controller,
|
||||||
|
memory_working_controller,
|
||||||
ontology_controller,
|
ontology_controller,
|
||||||
skill_controller,
|
|
||||||
tenant_subscription_controller,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建管理端 API 路由器
|
# 创建管理端 API 路由器
|
||||||
@@ -65,13 +62,10 @@ manager_router.include_router(model_controller.router)
|
|||||||
manager_router.include_router(file_controller.router)
|
manager_router.include_router(file_controller.router)
|
||||||
manager_router.include_router(document_controller.router)
|
manager_router.include_router(document_controller.router)
|
||||||
manager_router.include_router(knowledge_controller.router)
|
manager_router.include_router(knowledge_controller.router)
|
||||||
manager_router.include_router(mcp_market_controller.router)
|
|
||||||
manager_router.include_router(mcp_market_config_controller.router)
|
|
||||||
manager_router.include_router(chunk_controller.router)
|
manager_router.include_router(chunk_controller.router)
|
||||||
manager_router.include_router(test_controller.router)
|
manager_router.include_router(test_controller.router)
|
||||||
manager_router.include_router(knowledgeshare_controller.router)
|
manager_router.include_router(knowledgeshare_controller.router)
|
||||||
manager_router.include_router(app_controller.router)
|
manager_router.include_router(app_controller.router)
|
||||||
manager_router.include_router(app_log_controller.router)
|
|
||||||
manager_router.include_router(upload_controller.router)
|
manager_router.include_router(upload_controller.router)
|
||||||
manager_router.include_router(memory_agent_controller.router)
|
manager_router.include_router(memory_agent_controller.router)
|
||||||
manager_router.include_router(memory_dashboard_controller.router)
|
manager_router.include_router(memory_dashboard_controller.router)
|
||||||
@@ -84,6 +78,7 @@ manager_router.include_router(release_share_controller.router)
|
|||||||
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||||
manager_router.include_router(memory_dashboard_controller.router)
|
manager_router.include_router(memory_dashboard_controller.router)
|
||||||
manager_router.include_router(multi_agent_controller.router)
|
manager_router.include_router(multi_agent_controller.router)
|
||||||
|
manager_router.include_router(workflow_controller.router)
|
||||||
manager_router.include_router(emotion_controller.router)
|
manager_router.include_router(emotion_controller.router)
|
||||||
manager_router.include_router(emotion_config_controller.router)
|
manager_router.include_router(emotion_config_controller.router)
|
||||||
manager_router.include_router(prompt_optimizer_controller.router)
|
manager_router.include_router(prompt_optimizer_controller.router)
|
||||||
@@ -97,9 +92,5 @@ manager_router.include_router(memory_perceptual_controller.router)
|
|||||||
manager_router.include_router(memory_working_controller.router)
|
manager_router.include_router(memory_working_controller.router)
|
||||||
manager_router.include_router(file_storage_controller.router)
|
manager_router.include_router(file_storage_controller.router)
|
||||||
manager_router.include_router(ontology_controller.router)
|
manager_router.include_router(ontology_controller.router)
|
||||||
manager_router.include_router(skill_controller.router)
|
|
||||||
manager_router.include_router(i18n_controller.router)
|
|
||||||
manager_router.include_router(tenant_subscription_controller.router)
|
|
||||||
manager_router.include_router(tenant_subscription_controller.public_router)
|
|
||||||
|
|
||||||
__all__ = ["manager_router"]
|
__all__ = ["manager_router"]
|
||||||
|
|||||||
@@ -167,8 +167,6 @@ def update_api_key(
|
|||||||
|
|
||||||
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
||||||
|
|
||||||
except BusinessException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"未知错误: {str(e)}", extra={
|
logger.error(f"未知错误: {str(e)}", extra={
|
||||||
"api_key_id": str(api_key_id),
|
"api_key_id": str(api_key_id),
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
import uuid
|
import uuid
|
||||||
import io
|
|
||||||
from typing import Optional, Annotated
|
from typing import Optional, Annotated
|
||||||
|
|
||||||
import yaml
|
from fastapi import APIRouter, Depends, Path
|
||||||
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
@@ -20,15 +17,11 @@ from app.repositories.end_user_repository import EndUserRepository
|
|||||||
from app.schemas import app_schema
|
from app.schemas import app_schema
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||||
from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave
|
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||||
from app.services import app_service, workspace_service
|
from app.services import app_service, workspace_service
|
||||||
from app.services.agent_config_helper import enrich_agent_config
|
from app.services.agent_config_helper import enrich_agent_config
|
||||||
from app.services.app_service import AppService
|
from app.services.app_service import AppService
|
||||||
from app.services.app_statistics_service import AppStatisticsService
|
|
||||||
from app.services.workflow_import_service import WorkflowImportService
|
|
||||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||||
from app.services.app_dsl_service import AppDslService
|
|
||||||
from app.core.quota_stub import check_app_quota
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -36,7 +29,6 @@ logger = get_business_logger()
|
|||||||
|
|
||||||
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
@check_app_quota
|
|
||||||
def create_app(
|
def create_app(
|
||||||
payload: app_schema.AppCreate,
|
payload: app_schema.AppCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -55,7 +47,6 @@ def list_apps(
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
include_shared: bool = True,
|
include_shared: bool = True,
|
||||||
shared_only: bool = False,
|
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
pagesize: int = 10,
|
pagesize: int = 10,
|
||||||
ids: Optional[str] = None,
|
ids: Optional[str] = None,
|
||||||
@@ -67,42 +58,16 @@ def list_apps(
|
|||||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||||
- search 参数支持:应用名称模糊搜索、API Key 精确搜索
|
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select as sa_select
|
|
||||||
from app.models.api_key_model import ApiKey
|
|
||||||
|
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
service = app_service.AppService(db)
|
service = app_service.AppService(db)
|
||||||
|
|
||||||
# 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索
|
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||||
if search:
|
|
||||||
search = search.strip()
|
|
||||||
# 尝试作为 API Key 精确匹配(API Key 通常较长)
|
|
||||||
if len(search) >= 10:
|
|
||||||
matched_id = db.execute(
|
|
||||||
sa_select(ApiKey.resource_id).where(
|
|
||||||
ApiKey.workspace_id == workspace_id,
|
|
||||||
ApiKey.api_key == search,
|
|
||||||
ApiKey.resource_id.isnot(None),
|
|
||||||
)
|
|
||||||
).scalar_one_or_none()
|
|
||||||
if matched_id:
|
|
||||||
# 找到 API Key,直接返回关联的应用
|
|
||||||
ids = str(matched_id)
|
|
||||||
|
|
||||||
# 当 ids 存在时,根据 ids 获取应用(不分页)
|
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
||||||
if app_ids:
|
|
||||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||||
# 返回标准分页格式
|
return success(data=items)
|
||||||
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
|
|
||||||
return success(data=PageData(page=meta, items=items))
|
|
||||||
# ids 为空时,返回空列表
|
|
||||||
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
|
|
||||||
return success(data=PageData(page=meta, items=[]))
|
|
||||||
|
|
||||||
# 正常分页查询
|
# 正常分页查询
|
||||||
items_orm, total = app_service.list_apps(
|
items_orm, total = app_service.list_apps(
|
||||||
@@ -113,7 +78,6 @@ def list_apps(
|
|||||||
status=status,
|
status=status,
|
||||||
search=search,
|
search=search,
|
||||||
include_shared=include_shared,
|
include_shared=include_shared,
|
||||||
shared_only=shared_only,
|
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=pagesize,
|
pagesize=pagesize,
|
||||||
)
|
)
|
||||||
@@ -123,37 +87,6 @@ def list_apps(
|
|||||||
return success(data=PageData(page=meta, items=items))
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/my-shared-out", summary="列出本工作空间主动分享出去的记录")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
def list_my_shared_out(
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""列出本工作空间主动分享给其他工作空间的所有记录(我的共享)"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
service = app_service.AppService(db)
|
|
||||||
shares = service.list_my_shared_out(workspace_id=workspace_id)
|
|
||||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
|
||||||
return success(data=data)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/share/{target_workspace_id}", summary="取消对某工作空间的所有应用分享")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
def unshare_all_apps_to_workspace(
|
|
||||||
target_workspace_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""Cancel all app shares from current workspace to a target workspace."""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
service = app_service.AppService(db)
|
|
||||||
count = service.unshare_all_apps_to_workspace(
|
|
||||||
target_workspace_id=target_workspace_id,
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
return success(msg=f"已取消 {count} 个应用的分享", data={"count": count})
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}", summary="获取应用详情")
|
@router.get("/{app_id}", summary="获取应用详情")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_app(
|
def get_app(
|
||||||
@@ -219,11 +152,9 @@ def delete_app(
|
|||||||
|
|
||||||
@router.post("/{app_id}/copy", summary="复制应用")
|
@router.post("/{app_id}/copy", summary="复制应用")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
@check_app_quota
|
|
||||||
def copy_app(
|
def copy_app(
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
new_name: Optional[str] = None,
|
new_name: Optional[str] = None,
|
||||||
payload: app_schema.CopyAppRequest = None,
|
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user=Depends(get_current_user),
|
current_user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
@@ -235,8 +166,6 @@ def copy_app(
|
|||||||
- 不影响原应用
|
- 不影响原应用
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
# body takes precedence over query param for backward compatibility
|
|
||||||
new_name = (payload.new_name if payload else None) or new_name
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"用户请求复制应用",
|
"用户请求复制应用",
|
||||||
extra={
|
extra={
|
||||||
@@ -272,19 +201,6 @@ def update_agent_config(
|
|||||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/model/parameters/default", summary="获取 Agent 模型参数默认配置")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
def get_agent_model_parameters(
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
service = AppService(db)
|
|
||||||
model_parameters = service.get_default_model_parameters(app_id=app_id)
|
|
||||||
return success(data=model_parameters, msg="获取 Agent 模型参数默认配置")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_agent_config(
|
def get_agent_config(
|
||||||
@@ -299,36 +215,6 @@ def get_agent_config(
|
|||||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
def get_opening(
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
# 根据应用类型获取 features
|
|
||||||
from app.models.app_model import App as AppModel
|
|
||||||
app = db.get(AppModel, app_id)
|
|
||||||
if app and app.type == "workflow":
|
|
||||||
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
|
|
||||||
features = cfg.features or {}
|
|
||||||
else:
|
|
||||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
|
||||||
features = cfg.features or {}
|
|
||||||
if hasattr(features, "model_dump"):
|
|
||||||
features = features.model_dump()
|
|
||||||
|
|
||||||
opening = features.get("opening_statement", {})
|
|
||||||
return success(data=app_schema.OpeningResponse(
|
|
||||||
enabled=opening.get("enabled", False),
|
|
||||||
statement=opening.get("statement"),
|
|
||||||
suggested_questions=opening.get("suggested_questions", []),
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def publish_app(
|
def publish_app(
|
||||||
@@ -410,8 +296,7 @@ def share_app(
|
|||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
target_workspace_ids=payload.target_workspace_ids,
|
target_workspace_ids=payload.target_workspace_ids,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id
|
||||||
permission=payload.permission
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||||
@@ -442,32 +327,6 @@ def unshare_app(
|
|||||||
return success(msg="应用分享已取消")
|
return success(msg="应用分享已取消")
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{app_id}/share/{target_workspace_id}", summary="更新共享权限")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
def update_share_permission(
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
target_workspace_id: uuid.UUID,
|
|
||||||
payload: app_schema.UpdateSharePermissionRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""更新共享权限(readonly <-> editable)
|
|
||||||
|
|
||||||
- 只能修改自己工作空间应用的共享权限
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
service = app_service.AppService(db)
|
|
||||||
share = service.update_share_permission(
|
|
||||||
app_id=app_id,
|
|
||||||
target_workspace_id=target_workspace_id,
|
|
||||||
permission=payload.permission,
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data=app_schema.AppShare.model_validate(share))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def list_app_shares(
|
def list_app_shares(
|
||||||
@@ -491,46 +350,6 @@ def list_app_shares(
|
|||||||
return success(data=data)
|
return success(data=data)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/shared/{source_workspace_id}", summary="批量移除某来源工作空间的所有共享应用")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
def remove_all_shared_apps_from_workspace(
|
|
||||||
source_workspace_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""Remove all shared apps from a specific source workspace (recipient operation)."""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
service = app_service.AppService(db)
|
|
||||||
count = service.remove_all_shared_apps_from_workspace(
|
|
||||||
source_workspace_id=source_workspace_id,
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
return success(msg=f"已移除 {count} 个共享应用", data={"count": count})
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{app_id}/shared", summary="移除共享给我的应用")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
def remove_shared_app(
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""被共享者从自己的工作空间移除共享应用
|
|
||||||
|
|
||||||
- 不会删除源应用,只删除共享记录
|
|
||||||
- 只能移除共享给自己工作空间的应用
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
service = app_service.AppService(db)
|
|
||||||
service.remove_shared_app(
|
|
||||||
app_id=app_id,
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(msg="已移除共享应用")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def draft_run(
|
async def draft_run(
|
||||||
@@ -571,13 +390,13 @@ async def draft_run(
|
|||||||
# 提前验证和准备(在流式响应开始前完成)
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
from app.services.app_service import AppService
|
from app.services.app_service import AppService
|
||||||
from app.services.multi_agent_service import MultiAgentService
|
from app.services.multi_agent_service import MultiAgentService
|
||||||
from app.models import AgentConfig, ModelConfig, AppRelease
|
from app.models import AgentConfig, ModelConfig
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import DraftRunService
|
||||||
|
|
||||||
service = AppService(db)
|
service = AppService(db)
|
||||||
draft_service = AgentRunService(db)
|
draft_service = DraftRunService(db)
|
||||||
|
|
||||||
# 1. 验证应用
|
# 1. 验证应用
|
||||||
app = service._get_app_or_404(app_id)
|
app = service._get_app_or_404(app_id)
|
||||||
@@ -588,12 +407,11 @@ async def draft_run(
|
|||||||
service._validate_app_accessible(app, workspace_id)
|
service._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
if payload.user_id is None:
|
if payload.user_id is None:
|
||||||
# 先获取 app 的 workspace_id
|
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
workspace_id=app.workspace_id,
|
|
||||||
other_id=str(current_user.id),
|
other_id=str(current_user.id),
|
||||||
|
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||||
)
|
)
|
||||||
payload.user_id = str(new_end_user.id)
|
payload.user_id = str(new_end_user.id)
|
||||||
|
|
||||||
@@ -610,17 +428,6 @@ async def draft_run(
|
|||||||
service._check_agent_config(app_id)
|
service._check_agent_config(app_id)
|
||||||
|
|
||||||
# 2. 获取 Agent 配置
|
# 2. 获取 Agent 配置
|
||||||
# 共享应用:从最新发布版本读配置快照,而非草稿
|
|
||||||
is_shared = app.workspace_id != workspace_id
|
|
||||||
if is_shared:
|
|
||||||
if not app.current_release_id:
|
|
||||||
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
|
||||||
release = db.get(AppRelease, app.current_release_id)
|
|
||||||
if not release:
|
|
||||||
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
|
||||||
agent_cfg = service._agent_config_from_release(release)
|
|
||||||
model_config = db.get(ModelConfig, release.default_model_config_id) if release.default_model_config_id else None
|
|
||||||
else:
|
|
||||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||||
agent_cfg = db.scalars(stmt).first()
|
agent_cfg = db.scalars(stmt).first()
|
||||||
if not agent_cfg:
|
if not agent_cfg:
|
||||||
@@ -647,8 +454,7 @@ async def draft_run(
|
|||||||
user_id=payload.user_id or str(current_user.id),
|
user_id=payload.user_id or str(current_user.id),
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id
|
||||||
files=payload.files # 传递多模态文件
|
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -669,13 +475,12 @@ async def draft_run(
|
|||||||
"app_id": str(app_id),
|
"app_id": str(app_id),
|
||||||
"message_length": len(payload.message),
|
"message_length": len(payload.message),
|
||||||
"has_conversation_id": bool(payload.conversation_id),
|
"has_conversation_id": bool(payload.conversation_id),
|
||||||
"has_variables": bool(payload.variables),
|
"has_variables": bool(payload.variables)
|
||||||
"has_files": bool(payload.files)
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import DraftRunService
|
||||||
draft_service = AgentRunService(db)
|
draft_service = DraftRunService(db)
|
||||||
result = await draft_service.run(
|
result = await draft_service.run(
|
||||||
agent_config=agent_cfg,
|
agent_config=agent_cfg,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@@ -685,8 +490,7 @@ async def draft_run(
|
|||||||
user_id=payload.user_id or str(current_user.id),
|
user_id=payload.user_id or str(current_user.id),
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id
|
||||||
files=payload.files # 传递多模态文件
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -788,16 +592,6 @@ async def draft_run(
|
|||||||
msg="多 Agent 任务执行成功"
|
msg="多 Agent 任务执行成功"
|
||||||
)
|
)
|
||||||
elif app.type == AppType.WORKFLOW: # 工作流
|
elif app.type == AppType.WORKFLOW: # 工作流
|
||||||
# 共享应用:从最新发布版本读配置快照,而非草稿
|
|
||||||
is_shared = app.workspace_id != workspace_id
|
|
||||||
if is_shared:
|
|
||||||
if not app.current_release_id:
|
|
||||||
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
|
||||||
release = db.get(AppRelease, app.current_release_id)
|
|
||||||
if not release:
|
|
||||||
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
|
||||||
config = service._workflow_config_from_release(release)
|
|
||||||
else:
|
|
||||||
config = workflow_service.check_config(app_id)
|
config = workflow_service.check_config(app_id)
|
||||||
# 3. 流式返回
|
# 3. 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
@@ -941,16 +735,6 @@ async def draft_run_compare(
|
|||||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
service._validate_app_accessible(app, workspace_id)
|
service._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
if payload.user_id is None:
|
|
||||||
# 先获取 app 的 workspace_id
|
|
||||||
end_user_repo = EndUserRepository(db)
|
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
|
||||||
app_id=app_id,
|
|
||||||
workspace_id=app.workspace_id,
|
|
||||||
other_id=str(current_user.id),
|
|
||||||
)
|
|
||||||
payload.user_id = str(new_end_user.id)
|
|
||||||
|
|
||||||
# 2. 获取 Agent 配置
|
# 2. 获取 Agent 配置
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.models import AgentConfig
|
from app.models import AgentConfig
|
||||||
@@ -996,33 +780,25 @@ async def draft_run_compare(
|
|||||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||||
})
|
})
|
||||||
|
|
||||||
# 从 features 中读取功能开关(与 draft_run 保持一致)
|
|
||||||
features_config: dict = agent_cfg.features or {}
|
|
||||||
if hasattr(features_config, 'model_dump'):
|
|
||||||
features_config = features_config.model_dump()
|
|
||||||
web_search_feature = features_config.get("web_search", {})
|
|
||||||
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
|
|
||||||
|
|
||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import DraftRunService
|
||||||
draft_service = AgentRunService(db)
|
draft_service = DraftRunService(db)
|
||||||
async for event in draft_service.run_compare_stream(
|
async for event in draft_service.run_compare_stream(
|
||||||
agent_config=agent_cfg,
|
agent_config=agent_cfg,
|
||||||
models=model_configs,
|
models=model_configs,
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
conversation_id=payload.conversation_id,
|
conversation_id=payload.conversation_id,
|
||||||
user_id=payload.user_id,
|
user_id=payload.user_id or str(current_user.id),
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
web_search=web_search,
|
web_search=True,
|
||||||
memory=True,
|
memory=True,
|
||||||
parallel=payload.parallel,
|
parallel=payload.parallel,
|
||||||
timeout=payload.timeout or 60,
|
timeout=payload.timeout or 60
|
||||||
files=payload.files
|
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -1037,23 +813,22 @@ async def draft_run_compare(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 非流式返回
|
# 非流式返回
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import DraftRunService
|
||||||
draft_service = AgentRunService(db)
|
draft_service = DraftRunService(db)
|
||||||
result = await draft_service.run_compare(
|
result = await draft_service.run_compare(
|
||||||
agent_config=agent_cfg,
|
agent_config=agent_cfg,
|
||||||
models=model_configs,
|
models=model_configs,
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
conversation_id=payload.conversation_id,
|
conversation_id=payload.conversation_id,
|
||||||
user_id=payload.user_id,
|
user_id=payload.user_id or str(current_user.id),
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
web_search=web_search,
|
web_search=True,
|
||||||
memory=True,
|
memory=True,
|
||||||
parallel=payload.parallel,
|
parallel=payload.parallel,
|
||||||
timeout=payload.timeout or 60,
|
timeout=payload.timeout or 60
|
||||||
files=payload.files
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -1095,73 +870,10 @@ async def update_workflow_config(
|
|||||||
current_user: Annotated[User, Depends(get_current_user)]
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
):
|
):
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
if payload.variables:
|
|
||||||
from app.services.workflow_service import WorkflowService
|
|
||||||
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
|
|
||||||
[v.model_dump() for v in payload.variables]
|
|
||||||
)
|
|
||||||
# Patch default values back into VariableDefinition objects
|
|
||||||
for var_def, resolved_def in zip(payload.variables, resolved):
|
|
||||||
var_def.default = resolved_def.get("default", var_def.default)
|
|
||||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/workflow/export")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
async def export_workflow_config(
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
db: Annotated[Session, Depends(get_db)],
|
|
||||||
current_user: Annotated[User, Depends(get_current_user)]
|
|
||||||
):
|
|
||||||
"""导出工作流配置为YAML文件"""
|
|
||||||
workflow_service = WorkflowService(db)
|
|
||||||
|
|
||||||
return success(data={
|
|
||||||
"content": workflow_service.export_workflow_dsl(app_id=app_id),
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/workflow/import")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
async def import_workflow_config(
|
|
||||||
file: UploadFile = File(...),
|
|
||||||
platform: str = Form(...),
|
|
||||||
app_id: str = Form(None),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
|
|
||||||
):
|
|
||||||
"""从YAML内容导入工作流配置"""
|
|
||||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
|
||||||
return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST)
|
|
||||||
|
|
||||||
raw_text = (await file.read()).decode("utf-8")
|
|
||||||
import_service = WorkflowImportService(db)
|
|
||||||
config = yaml.safe_load(raw_text)
|
|
||||||
result = await import_service.upload_config(platform, config)
|
|
||||||
return success(data=result)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/workflow/import/save")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
@check_app_quota
|
|
||||||
async def save_workflow_import(
|
|
||||||
data: WorkflowImportSave,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
import_service = WorkflowImportService(db)
|
|
||||||
app = await import_service.save_workflow(
|
|
||||||
user_id=current_user.id,
|
|
||||||
workspace_id=current_user.current_workspace_id,
|
|
||||||
temp_id=data.temp_id,
|
|
||||||
name=data.name,
|
|
||||||
description=data.description,
|
|
||||||
)
|
|
||||||
return success(data=app_schema.App.model_validate(app))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_app_statistics(
|
def get_app_statistics(
|
||||||
@@ -1177,8 +889,6 @@ def get_app_statistics(
|
|||||||
app_id: 应用ID
|
app_id: 应用ID
|
||||||
start_date: 开始时间戳(毫秒)
|
start_date: 开始时间戳(毫秒)
|
||||||
end_date: 结束时间戳(毫秒)
|
end_date: 结束时间戳(毫秒)
|
||||||
db: 数据库连接
|
|
||||||
current_user: 当前用户
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- daily_conversations: 每日会话数统计
|
- daily_conversations: 每日会话数统计
|
||||||
@@ -1191,6 +901,8 @@ def get_app_statistics(
|
|||||||
- total_tokens: 总token消耗
|
- total_tokens: 总token消耗
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
from app.services.app_statistics_service import AppStatisticsService
|
||||||
stats_service = AppStatisticsService(db)
|
stats_service = AppStatisticsService(db)
|
||||||
|
|
||||||
result = stats_service.get_app_statistics(
|
result = stats_service.get_app_statistics(
|
||||||
@@ -1201,143 +913,3 @@ def get_app_statistics(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return success(data=result)
|
return success(data=result)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/workspace/api-statistics", summary="工作空间API调用统计")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
def get_workspace_api_statistics(
|
|
||||||
start_date: int,
|
|
||||||
end_date: int,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""获取工作空间API调用统计
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start_date: 开始时间戳(毫秒)
|
|
||||||
end_date: 结束时间戳(毫秒)
|
|
||||||
db: 数据库连接
|
|
||||||
current_user: 当前用户
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
每日统计数据列表,每项包含:
|
|
||||||
- date: 日期
|
|
||||||
- total_calls: 当日总调用次数
|
|
||||||
- app_calls: 当日应用调用次数
|
|
||||||
- service_calls: 当日服务调用次数
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
stats_service = AppStatisticsService(db)
|
|
||||||
|
|
||||||
result = stats_service.get_workspace_api_statistics(
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data=result)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
async def export_app(
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
db: Annotated[Session, Depends(get_db)],
|
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
|
||||||
release_id: Optional[uuid.UUID] = None
|
|
||||||
):
|
|
||||||
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
|
|
||||||
release_id: 指定发布版本id,不传则导出当前草稿配置。
|
|
||||||
"""
|
|
||||||
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
|
|
||||||
encoded = quote(filename, safe=".")
|
|
||||||
yaml_bytes = yaml_str.encode("utf-8")
|
|
||||||
file_stream = io.BytesIO(yaml_bytes)
|
|
||||||
file_stream.seek(0)
|
|
||||||
return StreamingResponse(
|
|
||||||
file_stream,
|
|
||||||
media_type="application/octet-stream; charset=utf-8",
|
|
||||||
headers={"Content-Disposition": f"attachment; filename={encoded}",
|
|
||||||
"Content-Length": str(len(yaml_bytes))}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/import", summary="从 YAML 文件导入应用")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
async def import_app(
|
|
||||||
file: UploadFile = File(...),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
app_id: Optional[str] = Form(None),
|
|
||||||
):
|
|
||||||
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
|
||||||
传入 app_id 时覆盖该应用的配置(类型必须一致),否则创建新应用。
|
|
||||||
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
|
||||||
"""
|
|
||||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
|
||||||
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
|
|
||||||
|
|
||||||
raw = (await file.read()).decode("utf-8")
|
|
||||||
dsl = yaml.safe_load(raw)
|
|
||||||
if not dsl or "app" not in dsl:
|
|
||||||
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
|
||||||
|
|
||||||
target_app_id = uuid.UUID(app_id) if app_id else None
|
|
||||||
# 仅新建应用时检查配额,覆盖已有应用时跳过
|
|
||||||
if target_app_id is None:
|
|
||||||
from app.core.quota_manager import _check_quota
|
|
||||||
_check_quota(db, current_user.tenant_id, "app_quota", "app", workspace_id=current_user.current_workspace_id)
|
|
||||||
result_app, warnings = AppDslService(db).import_dsl(
|
|
||||||
dsl=dsl,
|
|
||||||
workspace_id=current_user.current_workspace_id,
|
|
||||||
tenant_id=current_user.tenant_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
app_id=target_app_id,
|
|
||||||
)
|
|
||||||
return success(
|
|
||||||
data={"app": app_schema.App.model_validate(result_app), "warnings": warnings},
|
|
||||||
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/citations/{document_id}/download", summary="下载引用文档原始文件")
|
|
||||||
async def download_citation_file(
|
|
||||||
document_id: uuid.UUID = Path(..., description="引用文档ID"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
下载引用文档的原始文件。
|
|
||||||
仅当应用功能特性 citation.allow_download=true 时,前端才会展示此下载链接。
|
|
||||||
路由本身不做权限校验,由业务层通过 allow_download 开关控制入口。
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
from fastapi import HTTPException, status as http_status
|
|
||||||
from fastapi.responses import FileResponse
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.models.document_model import Document
|
|
||||||
from app.models.file_model import File as FileModel
|
|
||||||
|
|
||||||
doc = db.query(Document).filter(Document.id == document_id).first()
|
|
||||||
if not doc:
|
|
||||||
raise HTTPException(status_code=http_status.HTTP_404_NOT_FOUND, detail="文档不存在")
|
|
||||||
|
|
||||||
file_record = db.query(FileModel).filter(FileModel.id == doc.file_id).first()
|
|
||||||
if not file_record:
|
|
||||||
raise HTTPException(status_code=http_status.HTTP_404_NOT_FOUND, detail="原始文件不存在")
|
|
||||||
|
|
||||||
file_path = os.path.join(
|
|
||||||
settings.FILE_PATH,
|
|
||||||
str(file_record.kb_id),
|
|
||||||
str(file_record.parent_id),
|
|
||||||
f"{file_record.id}{file_record.file_ext}"
|
|
||||||
)
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise HTTPException(status_code=http_status.HTTP_404_NOT_FOUND, detail="文件未找到")
|
|
||||||
|
|
||||||
encoded_name = quote(doc.file_name)
|
|
||||||
return FileResponse(
|
|
||||||
path=file_path,
|
|
||||||
filename=doc.file_name,
|
|
||||||
media_type="application/octet-stream",
|
|
||||||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"}
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,110 +0,0 @@
|
|||||||
"""应用日志(消息记录)接口"""
|
|
||||||
import uuid
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
from app.core.response_utils import success
|
|
||||||
from app.db import get_db
|
|
||||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
|
||||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
|
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
|
||||||
from app.services.app_service import AppService
|
|
||||||
from app.services.app_log_service import AppLogService
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/apps", tags=["App Logs"])
|
|
||||||
logger = get_business_logger()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/logs", summary="应用日志 - 会话列表")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
def list_app_logs(
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
page: int = Query(1, ge=1),
|
|
||||||
pagesize: int = Query(20, ge=1, le=100),
|
|
||||||
is_draft: Optional[bool] = Query(None, description="是否草稿会话(不传则返回全部)"),
|
|
||||||
keyword: Optional[str] = Query(None, description="搜索关键词(匹配消息内容)"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""查看应用下所有会话记录(分页)
|
|
||||||
|
|
||||||
- is_draft 不传则返回所有会话(草稿 + 正式)
|
|
||||||
- is_draft=True 只返回草稿会话
|
|
||||||
- is_draft=False 只返回发布会话
|
|
||||||
- 支持按 keyword 搜索(匹配消息内容)
|
|
||||||
- 按最新更新时间倒序排列
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
# 验证应用访问权限
|
|
||||||
app_service = AppService(db)
|
|
||||||
app = app_service.get_app(app_id, workspace_id)
|
|
||||||
|
|
||||||
# 使用 Service 层查询
|
|
||||||
log_service = AppLogService(db)
|
|
||||||
conversations, total = log_service.list_conversations(
|
|
||||||
app_id=app_id,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
page=page,
|
|
||||||
pagesize=pagesize,
|
|
||||||
is_draft=is_draft,
|
|
||||||
keyword=keyword,
|
|
||||||
app_type=app.type,
|
|
||||||
)
|
|
||||||
|
|
||||||
items = [AppLogConversation.model_validate(c) for c in conversations]
|
|
||||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
|
||||||
|
|
||||||
return success(data=PageData(page=meta, items=items))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/logs/{conversation_id}", summary="应用日志 - 会话消息详情")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
def get_app_log_detail(
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
conversation_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""查看某会话的完整消息记录
|
|
||||||
|
|
||||||
- 返回会话基本信息 + 所有消息(按时间正序)
|
|
||||||
- 消息 meta_data 包含模型名、token 用量等信息
|
|
||||||
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
# 验证应用访问权限
|
|
||||||
app_service = AppService(db)
|
|
||||||
app = app_service.get_app(app_id, workspace_id)
|
|
||||||
|
|
||||||
# 使用 Service 层查询
|
|
||||||
log_service = AppLogService(db)
|
|
||||||
conversation, messages, node_executions_map = log_service.get_conversation_detail(
|
|
||||||
app_id=app_id,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
app_type=app.type
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建基础会话信息(不经过 ORM relationship)
|
|
||||||
base = AppLogConversation.model_validate(conversation)
|
|
||||||
|
|
||||||
# 单独处理 messages,避免触发 SQLAlchemy relationship 校验
|
|
||||||
if messages and isinstance(messages[0], AppLogMessage):
|
|
||||||
# 工作流:已经是 AppLogMessage 实例
|
|
||||||
msg_list = messages
|
|
||||||
else:
|
|
||||||
# Agent:ORM Message 对象逐个转换
|
|
||||||
msg_list = [AppLogMessage.model_validate(m) for m in messages]
|
|
||||||
|
|
||||||
detail = AppLogConversationDetail(
|
|
||||||
**base.model_dump(),
|
|
||||||
messages=msg_list,
|
|
||||||
node_executions_map=node_executions_map,
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data=detail)
|
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Callable
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -17,7 +16,6 @@ from app.core.exceptions import BusinessException
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.dependencies import get_current_user, oauth2_scheme
|
from app.dependencies import get_current_user, oauth2_scheme
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.i18n.dependencies import get_translator
|
|
||||||
|
|
||||||
# 获取专用日志器
|
# 获取专用日志器
|
||||||
auth_logger = get_auth_logger()
|
auth_logger = get_auth_logger()
|
||||||
@@ -28,8 +26,7 @@ router = APIRouter(tags=["Authentication"])
|
|||||||
@router.post("/token", response_model=ApiResponse)
|
@router.post("/token", response_model=ApiResponse)
|
||||||
async def login_for_access_token(
|
async def login_for_access_token(
|
||||||
form_data: TokenRequest,
|
form_data: TokenRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db)
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""用户登录获取token"""
|
"""用户登录获取token"""
|
||||||
auth_logger.info(f"用户登录请求: {form_data.email}")
|
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||||||
@@ -43,22 +40,20 @@ async def login_for_access_token(
|
|||||||
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||||||
|
|
||||||
if not invite_info.is_valid:
|
if not invite_info.is_valid:
|
||||||
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
|
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
if invite_info.email != form_data.email:
|
if invite_info.email != form_data.email:
|
||||||
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
|
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
|
||||||
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||||||
try:
|
try:
|
||||||
# 尝试认证用户
|
# 尝试认证用户
|
||||||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||||
if form_data.invite:
|
if form_data.invite:
|
||||||
auth_service.bind_workspace_with_invite(
|
auth_service.bind_workspace_with_invite(db=db,
|
||||||
db=db,
|
|
||||||
user=user,
|
user=user,
|
||||||
invite_token=form_data.invite,
|
invite_token=form_data.invite,
|
||||||
workspace_id=invite_info.workspace_id
|
workspace_id=invite_info.workspace_id)
|
||||||
)
|
|
||||||
except BusinessException as e:
|
except BusinessException as e:
|
||||||
# 用户不存在且有邀请码,尝试注册
|
# 用户不存在且有邀请码,尝试注册
|
||||||
if e.code == BizCode.USER_NOT_FOUND:
|
if e.code == BizCode.USER_NOT_FOUND:
|
||||||
@@ -66,7 +61,6 @@ async def login_for_access_token(
|
|||||||
user = auth_service.register_user_with_invite(
|
user = auth_service.register_user_with_invite(
|
||||||
db=db,
|
db=db,
|
||||||
email=form_data.email,
|
email=form_data.email,
|
||||||
username=form_data.username,
|
|
||||||
password=form_data.password,
|
password=form_data.password,
|
||||||
invite_token=form_data.invite,
|
invite_token=form_data.invite,
|
||||||
workspace_id=invite_info.workspace_id
|
workspace_id=invite_info.workspace_id
|
||||||
@@ -74,7 +68,7 @@ async def login_for_access_token(
|
|||||||
elif e.code == BizCode.PASSWORD_ERROR:
|
elif e.code == BizCode.PASSWORD_ERROR:
|
||||||
# 用户存在但密码错误
|
# 用户存在但密码错误
|
||||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||||
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
|
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
|
||||||
else:
|
else:
|
||||||
# 其他认证失败情况,直接抛出
|
# 其他认证失败情况,直接抛出
|
||||||
raise
|
raise
|
||||||
@@ -115,15 +109,14 @@ async def login_for_access_token(
|
|||||||
expires_at=access_expires_at,
|
expires_at=access_expires_at,
|
||||||
refresh_expires_at=refresh_expires_at
|
refresh_expires_at=refresh_expires_at
|
||||||
),
|
),
|
||||||
msg=t("auth.login.success")
|
msg="登录成功"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=ApiResponse)
|
@router.post("/refresh", response_model=ApiResponse)
|
||||||
async def refresh_token(
|
async def refresh_token(
|
||||||
refresh_request: RefreshTokenRequest,
|
refresh_request: RefreshTokenRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db)
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""刷新token"""
|
"""刷新token"""
|
||||||
auth_logger.info("收到token刷新请求")
|
auth_logger.info("收到token刷新请求")
|
||||||
@@ -131,18 +124,18 @@ async def refresh_token(
|
|||||||
# 验证 refresh token
|
# 验证 refresh token
|
||||||
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||||||
if not userId:
|
if not userId:
|
||||||
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
|
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
|
||||||
|
|
||||||
# 检查用户是否存在
|
# 检查用户是否存在
|
||||||
user = auth_service.get_user_by_id(db, userId)
|
user = auth_service.get_user_by_id(db, userId)
|
||||||
if not user:
|
if not user:
|
||||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NO_ACCESS)
|
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||||
|
|
||||||
# 检查 refresh token 黑名单
|
# 检查 refresh token 黑名单
|
||||||
if settings.ENABLE_SINGLE_SESSION:
|
if settings.ENABLE_SINGLE_SESSION:
|
||||||
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||||
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||||||
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
|
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
|
||||||
|
|
||||||
# 生成新 tokens
|
# 生成新 tokens
|
||||||
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||||||
@@ -173,7 +166,7 @@ async def refresh_token(
|
|||||||
expires_at=access_expires_at,
|
expires_at=access_expires_at,
|
||||||
refresh_expires_at=refresh_expires_at
|
refresh_expires_at=refresh_expires_at
|
||||||
),
|
),
|
||||||
msg=t("auth.token.refresh_success")
|
msg="token刷新成功"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -181,15 +174,14 @@ async def refresh_token(
|
|||||||
async def logout(
|
async def logout(
|
||||||
token: str = Depends(oauth2_scheme),
|
token: str = Depends(oauth2_scheme),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db)
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""登出当前用户:加入token黑名单并清理会话"""
|
"""登出当前用户:加入token黑名单并清理会话"""
|
||||||
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||||||
|
|
||||||
token_id = security.get_token_id(token)
|
token_id = security.get_token_id(token)
|
||||||
if not token_id:
|
if not token_id:
|
||||||
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
|
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
|
||||||
|
|
||||||
# 加入黑名单
|
# 加入黑名单
|
||||||
await SessionService.blacklist_token(token_id)
|
await SessionService.blacklist_token(token_id)
|
||||||
@@ -199,5 +191,5 @@ async def logout(
|
|||||||
await SessionService.clear_user_session(current_user.username)
|
await SessionService.clear_user_session(current_user.username)
|
||||||
|
|
||||||
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||||||
return success(msg=t("auth.logout.success"))
|
return success(msg="登出成功")
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from app.models.user_model import User
|
|||||||
from app.schemas import chunk_schema
|
from app.schemas import chunk_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||||
from app.services.model_service import ModelApiKeyService
|
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -442,14 +441,14 @@ async def retrieve_chunks(
|
|||||||
# 1 participle search, 2 semantic search, 3 hybrid search
|
# 1 participle search, 2 semantic search, 3 hybrid search
|
||||||
match retrieve_data.retrieve_type:
|
match retrieve_data.retrieve_type:
|
||||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
return success(data=rs, msg="retrieval successful")
|
||||||
case chunk_schema.RetrieveType.SEMANTIC:
|
case chunk_schema.RetrieveType.SEMANTIC:
|
||||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
return success(data=rs, msg="retrieval successful")
|
||||||
case _:
|
case _:
|
||||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||||
# Efficient deduplication
|
# Efficient deduplication
|
||||||
seen_ids = set()
|
seen_ids = set()
|
||||||
unique_rs = []
|
unique_rs = []
|
||||||
@@ -457,22 +456,20 @@ async def retrieve_chunks(
|
|||||||
if doc.metadata["doc_id"] not in seen_ids:
|
if doc.metadata["doc_id"] not in seen_ids:
|
||||||
seen_ids.add(doc.metadata["doc_id"])
|
seen_ids.add(doc.metadata["doc_id"])
|
||||||
unique_rs.append(doc)
|
unique_rs.append(doc)
|
||||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) if unique_rs else []
|
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
|
||||||
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
||||||
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
||||||
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
||||||
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
|
|
||||||
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
|
|
||||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||||
chat_model = Base(
|
chat_model = Base(
|
||||||
key=llm_key.api_key,
|
key=db_knowledge.llm.api_keys[0].api_key,
|
||||||
model_name=llm_key.model_name,
|
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||||
base_url=llm_key.api_base
|
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||||
)
|
)
|
||||||
embedding_model = OpenAIEmbed(
|
embedding_model = OpenAIEmbed(
|
||||||
key=emb_key.api_key,
|
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||||
model_name=emb_key.model_name,
|
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||||
base_url=emb_key.api_base
|
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||||
)
|
)
|
||||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||||
if doc:
|
if doc:
|
||||||
|
|||||||
@@ -314,10 +314,8 @@ async def parse_documents(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 4. Check if the file exists
|
# 4. Check if the file exists
|
||||||
api_logger.debug(f"Constructed file path: {file_path}")
|
|
||||||
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="File not found (possibly deleted)"
|
detail="File not found (possibly deleted)"
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ Routes:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.language_utils import get_language_from_header
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import fail, success
|
from app.core.response_utils import fail, success
|
||||||
from app.dependencies import get_current_user, get_db
|
from app.dependencies import get_current_user, get_db
|
||||||
@@ -46,14 +45,11 @@ emotion_service = EmotionAnalyticsService()
|
|||||||
@router.post("/tags", response_model=ApiResponse)
|
@router.post("/tags", response_model=ApiResponse)
|
||||||
async def get_emotion_tags(
|
async def get_emotion_tags(
|
||||||
request: EmotionTagsRequest,
|
request: EmotionTagsRequest,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用集中化的语言校验
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 请求获取情绪标签统计",
|
f"用户 {current_user.username} 请求获取情绪标签统计",
|
||||||
extra={
|
extra={
|
||||||
@@ -61,8 +57,7 @@ async def get_emotion_tags(
|
|||||||
"emotion_type": request.emotion_type,
|
"emotion_type": request.emotion_type,
|
||||||
"start_date": request.start_date,
|
"start_date": request.start_date,
|
||||||
"end_date": request.end_date,
|
"end_date": request.end_date,
|
||||||
"limit": request.limit,
|
"limit": request.limit
|
||||||
"language_type": language
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -72,8 +67,7 @@ async def get_emotion_tags(
|
|||||||
emotion_type=request.emotion_type,
|
emotion_type=request.emotion_type,
|
||||||
start_date=request.start_date,
|
start_date=request.start_date,
|
||||||
end_date=request.end_date,
|
end_date=request.end_date,
|
||||||
limit=request.limit,
|
limit=request.limit
|
||||||
language=language
|
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
@@ -103,14 +97,11 @@ async def get_emotion_tags(
|
|||||||
@router.post("/wordcloud", response_model=ApiResponse)
|
@router.post("/wordcloud", response_model=ApiResponse)
|
||||||
async def get_emotion_wordcloud(
|
async def get_emotion_wordcloud(
|
||||||
request: EmotionWordcloudRequest,
|
request: EmotionWordcloudRequest,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用集中化的语言校验
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 请求获取情绪词云数据",
|
f"用户 {current_user.username} 请求获取情绪词云数据",
|
||||||
extra={
|
extra={
|
||||||
@@ -153,14 +144,11 @@ async def get_emotion_wordcloud(
|
|||||||
@router.post("/health", response_model=ApiResponse)
|
@router.post("/health", response_model=ApiResponse)
|
||||||
async def get_emotion_health(
|
async def get_emotion_health(
|
||||||
request: EmotionHealthRequest,
|
request: EmotionHealthRequest,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用集中化的语言校验
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
# 验证时间范围参数
|
# 验证时间范围参数
|
||||||
if request.time_range not in ["7d", "30d", "90d"]:
|
if request.time_range not in ["7d", "30d", "90d"]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -186,7 +174,7 @@ async def get_emotion_health(
|
|||||||
"情绪健康指数获取成功",
|
"情绪健康指数获取成功",
|
||||||
extra={
|
extra={
|
||||||
"end_user_id": request.end_user_id,
|
"end_user_id": request.end_user_id,
|
||||||
"health_score": data.get("health_score") or 0,
|
"health_score": data.get("health_score", 0),
|
||||||
"level": data.get("level", "未知")
|
"level": data.get("level", "未知")
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -208,64 +196,14 @@ async def get_emotion_health(
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
# @router.post("/check-data", response_model=ApiResponse)
|
|
||||||
# async def check_emotion_data_exists(
|
|
||||||
# request: EmotionSuggestionsRequest,
|
|
||||||
# db: Session = Depends(get_db),
|
|
||||||
# current_user: User = Depends(get_current_user),
|
|
||||||
# ):
|
|
||||||
# """检查用户情绪建议数据是否存在
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# request: 包含 end_user_id
|
|
||||||
# db: 数据库会话
|
|
||||||
# current_user: 当前用户
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# 数据存在状态
|
|
||||||
# """
|
|
||||||
# try:
|
|
||||||
# api_logger.info(
|
|
||||||
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
|
|
||||||
# extra={"end_user_id": request.end_user_id}
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # 从数据库获取建议
|
|
||||||
# data = await emotion_service.get_cached_suggestions(
|
|
||||||
# end_user_id=request.end_user_id,
|
|
||||||
# db=db
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if data is None:
|
|
||||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
|
|
||||||
# return fail(
|
|
||||||
# BizCode.NOT_FOUND,
|
|
||||||
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
|
|
||||||
# {"exists": False}
|
|
||||||
# )
|
|
||||||
|
|
||||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
|
|
||||||
# return success(data={"exists": True}, msg="情绪建议数据已存在")
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# api_logger.error(
|
|
||||||
# f"检查情绪建议数据失败: {str(e)}",
|
|
||||||
# extra={"end_user_id": request.end_user_id},
|
|
||||||
# exc_info=True
|
|
||||||
# )
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
# detail=f"检查情绪建议数据失败: {str(e)}"
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/suggestions", response_model=ApiResponse)
|
@router.post("/suggestions", response_model=ApiResponse)
|
||||||
async def get_emotion_suggestions(
|
async def get_emotion_suggestions(
|
||||||
request: EmotionSuggestionsRequest,
|
request: EmotionSuggestionsRequest,
|
||||||
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""获取个性化情绪建议(从数据库读取)
|
"""获取个性化情绪建议(从缓存读取)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: 包含 end_user_id 和可选的 config_id
|
request: 包含 end_user_id 和可选的 config_id
|
||||||
@@ -273,42 +211,44 @@ async def get_emotion_suggestions(
|
|||||||
current_user: 当前用户
|
current_user: 当前用户
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
存储的个性化情绪建议响应
|
缓存的个性化情绪建议响应
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||||
extra={
|
extra={
|
||||||
"end_user_id": request.end_user_id,
|
"end_user_id": request.end_user_id,
|
||||||
"config_id": request.config_id
|
"config_id": request.config_id
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 从数据库获取建议
|
# 从缓存获取建议
|
||||||
data = await emotion_service.get_cached_suggestions(
|
data = await emotion_service.get_cached_suggestions(
|
||||||
end_user_id=request.end_user_id,
|
end_user_id=request.end_user_id,
|
||||||
db=db
|
db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
|
# 缓存不存在或已过期
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {request.end_user_id} 的建议数据不存在",
|
f"用户 {request.end_user_id} 的建议缓存不存在或已过期",
|
||||||
extra={"end_user_id": request.end_user_id}
|
extra={"end_user_id": request.end_user_id}
|
||||||
)
|
)
|
||||||
return success(
|
return fail(
|
||||||
data={"exists": False},
|
BizCode.NOT_FOUND,
|
||||||
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
|
"建议缓存不存在或已过期,请右上角刷新生成新建议",
|
||||||
|
""
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
"个性化建议获取成功",
|
"个性化建议获取成功(缓存)",
|
||||||
extra={
|
extra={
|
||||||
"end_user_id": request.end_user_id,
|
"end_user_id": request.end_user_id,
|
||||||
"suggestions_count": len(data.get("suggestions", []))
|
"suggestions_count": len(data.get("suggestions", []))
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return success(data=data, msg="个性化建议获取成功")
|
return success(data=data, msg="个性化建议获取成功(缓存)")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(
|
api_logger.error(
|
||||||
@@ -325,11 +265,11 @@ async def get_emotion_suggestions(
|
|||||||
@router.post("/generate_suggestions", response_model=ApiResponse)
|
@router.post("/generate_suggestions", response_model=ApiResponse)
|
||||||
async def generate_emotion_suggestions(
|
async def generate_emotion_suggestions(
|
||||||
request: EmotionGenerateSuggestionsRequest,
|
request: EmotionGenerateSuggestionsRequest,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""生成个性化情绪建议(调用LLM并保存到数据库)
|
"""生成个性化情绪建议(调用LLM并缓存)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: 包含 end_user_id
|
request: 包含 end_user_id
|
||||||
@@ -340,9 +280,6 @@ async def generate_emotion_suggestions(
|
|||||||
新生成的个性化情绪建议响应
|
新生成的个性化情绪建议响应
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 使用集中化的语言校验
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 请求生成个性化情绪建议",
|
f"用户 {current_user.username} 请求生成个性化情绪建议",
|
||||||
extra={
|
extra={
|
||||||
@@ -353,15 +290,15 @@ async def generate_emotion_suggestions(
|
|||||||
# 调用服务层生成建议
|
# 调用服务层生成建议
|
||||||
data = await emotion_service.generate_emotion_suggestions(
|
data = await emotion_service.generate_emotion_suggestions(
|
||||||
end_user_id=request.end_user_id,
|
end_user_id=request.end_user_id,
|
||||||
db=db,
|
db=db
|
||||||
language=language
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到缓存
|
||||||
await emotion_service.save_suggestions_cache(
|
await emotion_service.save_suggestions_cache(
|
||||||
end_user_id=request.end_user_id,
|
end_user_id=request.end_user_id,
|
||||||
suggestions_data=data,
|
suggestions_data=data,
|
||||||
db=db
|
db=db,
|
||||||
|
expires_hours=24
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from app.models.user_model import User
|
|||||||
from app.schemas import file_schema, document_schema
|
from app.schemas import file_schema, document_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import file_service, document_service
|
from app.services import file_service, document_service
|
||||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
|
||||||
|
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
@@ -132,7 +131,6 @@ async def create_folder(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/file", response_model=ApiResponse)
|
@router.post("/file", response_model=ApiResponse)
|
||||||
@check_knowledge_capacity_quota
|
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
|
|||||||
@@ -14,11 +14,8 @@ Routes:
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import httpx
|
|
||||||
import mimetypes
|
|
||||||
from urllib.parse import urlparse, unquote
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||||
from fastapi.responses import FileResponse, RedirectResponse
|
from fastapi.responses import FileResponse, RedirectResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -32,7 +29,7 @@ from app.core.storage_exceptions import (
|
|||||||
StorageUploadError,
|
StorageUploadError,
|
||||||
)
|
)
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user, get_share_user_id, ShareTokenData
|
from app.dependencies import get_current_user
|
||||||
from app.models.file_metadata_model import FileMetadata
|
from app.models.file_metadata_model import FileMetadata
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
@@ -50,19 +47,6 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _match_scheme(request: Request, url: str) -> str:
|
|
||||||
"""
|
|
||||||
将 presigned URL 的协议替换为与当前请求一致的协议(http/https)。
|
|
||||||
解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。
|
|
||||||
"""
|
|
||||||
incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme
|
|
||||||
if url.startswith("http://") and incoming_scheme == "https":
|
|
||||||
return "https://" + url[7:]
|
|
||||||
if url.startswith("https://") and incoming_scheme == "http":
|
|
||||||
return "http://" + url[8:]
|
|
||||||
return url
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/files", response_model=ApiResponse)
|
@router.post("/files", response_model=ApiResponse)
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
@@ -94,7 +78,7 @@ async def upload_file(
|
|||||||
|
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -159,238 +143,8 @@ async def upload_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/share/files", response_model=ApiResponse)
|
|
||||||
async def upload_file_with_share_token(
|
|
||||||
file: UploadFile = File(...),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Upload a file to the configured storage backend using share_token authentication.
|
|
||||||
"""
|
|
||||||
from app.services.release_share_service import ReleaseShareService
|
|
||||||
from app.models.app_model import App
|
|
||||||
from app.models.workspace_model import Workspace
|
|
||||||
|
|
||||||
# Get share and release info from share_token
|
|
||||||
service = ReleaseShareService(db)
|
|
||||||
|
|
||||||
# Get share object to access app_id
|
|
||||||
share = service.repo.get_by_share_token(share_data.share_token)
|
|
||||||
if not share:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Shared app not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get app to access workspace_id
|
|
||||||
app = db.query(App).filter(
|
|
||||||
App.id == share.app_id,
|
|
||||||
App.is_active.is_(True)
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not app:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="App not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get workspace to access tenant_id
|
|
||||||
workspace = db.query(Workspace).filter(
|
|
||||||
Workspace.id == app.workspace_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not workspace:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Workspace not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
tenant_id = workspace.tenant_id
|
|
||||||
workspace_id = app.workspace_id
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"Storage upload request (share): tenant_id={tenant_id}, workspace_id={workspace_id}, "
|
|
||||||
f"filename={file.filename}, share_token={share_data.share_token}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Read file contents
|
|
||||||
contents = await file.read()
|
|
||||||
file_size = len(contents)
|
|
||||||
|
|
||||||
# Validate file size
|
|
||||||
if file_size == 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="The file is empty."
|
|
||||||
)
|
|
||||||
|
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract file extension
|
|
||||||
_, file_extension = os.path.splitext(file.filename)
|
|
||||||
file_ext = file_extension.lower()
|
|
||||||
|
|
||||||
# Generate file_id and file_key
|
|
||||||
file_id = uuid.uuid4()
|
|
||||||
file_key = generate_file_key(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
file_id=file_id,
|
|
||||||
file_ext=file_ext,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create file metadata record with pending status
|
|
||||||
file_metadata = FileMetadata(
|
|
||||||
id=file_id,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
file_key=file_key,
|
|
||||||
file_name=file.filename,
|
|
||||||
file_ext=file_ext,
|
|
||||||
file_size=file_size,
|
|
||||||
content_type=file.content_type,
|
|
||||||
status="pending",
|
|
||||||
)
|
|
||||||
db.add(file_metadata)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(file_metadata)
|
|
||||||
|
|
||||||
# Upload file to storage backend
|
|
||||||
try:
|
|
||||||
await storage_service.upload_file(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
file_id=file_id,
|
|
||||||
file_ext=file_ext,
|
|
||||||
content=contents,
|
|
||||||
content_type=file.content_type,
|
|
||||||
)
|
|
||||||
# Update status to completed
|
|
||||||
file_metadata.status = "completed"
|
|
||||||
db.commit()
|
|
||||||
api_logger.info(f"File uploaded to storage (share): file_key={file_key}")
|
|
||||||
except StorageUploadError as e:
|
|
||||||
# Update status to failed
|
|
||||||
file_metadata.status = "failed"
|
|
||||||
db.commit()
|
|
||||||
api_logger.error(f"Storage upload failed (share): {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"File storage failed: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"File upload successful (share): {file.filename} (file_id: {file_id})")
|
|
||||||
|
|
||||||
return success(
|
|
||||||
data={"file_id": str(file_id), "file_key": file_key},
|
|
||||||
msg="File upload successful"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/files/info-by-url", response_model=ApiResponse)
|
|
||||||
async def get_file_info_by_url(
|
|
||||||
url: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get file information by network URL (no authentication required).
|
|
||||||
|
|
||||||
Fetches file metadata from a remote URL via HTTP HEAD request.
|
|
||||||
Falls back to GET request if HEAD is not supported.
|
|
||||||
Returns file type, name, and size.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: The network URL of the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse with file information.
|
|
||||||
"""
|
|
||||||
api_logger.info(f"File info by URL request: url={url}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
||||||
# Try HEAD request first
|
|
||||||
response = await client.head(url, follow_redirects=True)
|
|
||||||
|
|
||||||
# If HEAD fails, try GET request (some servers don't support HEAD)
|
|
||||||
if response.status_code != 200:
|
|
||||||
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
|
|
||||||
response = await client.get(url, follow_redirects=True)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Unable to access file: HTTP {response.status_code}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get file size from Content-Length header or actual content
|
|
||||||
file_size = response.headers.get("Content-Length")
|
|
||||||
if file_size:
|
|
||||||
file_size = int(file_size)
|
|
||||||
elif hasattr(response, 'content'):
|
|
||||||
file_size = len(response.content)
|
|
||||||
else:
|
|
||||||
file_size = None
|
|
||||||
|
|
||||||
# Get content type from Content-Type header
|
|
||||||
content_type = response.headers.get("Content-Type", "application/octet-stream")
|
|
||||||
# Remove charset and other parameters from content type
|
|
||||||
content_type = content_type.split(';')[0].strip()
|
|
||||||
|
|
||||||
# Extract filename from Content-Disposition or URL
|
|
||||||
file_name = None
|
|
||||||
content_disposition = response.headers.get("Content-Disposition")
|
|
||||||
if content_disposition and "filename=" in content_disposition:
|
|
||||||
parts = content_disposition.split("filename=")
|
|
||||||
if len(parts) > 1:
|
|
||||||
file_name = parts[1].strip('"').strip("'")
|
|
||||||
|
|
||||||
if not file_name:
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
|
|
||||||
|
|
||||||
# Extract file extension from filename
|
|
||||||
_, file_ext = os.path.splitext(file_name)
|
|
||||||
|
|
||||||
# If no extension found, infer from content type
|
|
||||||
if not file_ext:
|
|
||||||
ext = mimetypes.guess_extension(content_type)
|
|
||||||
if ext:
|
|
||||||
file_ext = ext
|
|
||||||
file_name = f"{file_name}{file_ext}"
|
|
||||||
|
|
||||||
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
|
|
||||||
|
|
||||||
return success(
|
|
||||||
data={
|
|
||||||
"url": url,
|
|
||||||
"file_name": file_name,
|
|
||||||
"file_ext": file_ext.lower() if file_ext else "",
|
|
||||||
"file_size": file_size,
|
|
||||||
"content_type": content_type,
|
|
||||||
},
|
|
||||||
msg="File information retrieved successfully"
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Unexpected error: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to retrieve file information: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/files/{file_id}", response_model=Any)
|
@router.get("/files/{file_id}", response_model=Any)
|
||||||
async def download_file(
|
async def download_file(
|
||||||
request: Request,
|
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -438,7 +192,6 @@ async def download_file(
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||||
presigned_url = _match_scheme(request, presigned_url)
|
|
||||||
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@@ -512,7 +265,6 @@ async def delete_file(
|
|||||||
|
|
||||||
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
||||||
async def get_file_url(
|
async def get_file_url(
|
||||||
request: Request,
|
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
expires: int = None,
|
expires: int = None,
|
||||||
permanent: bool = False,
|
permanent: bool = False,
|
||||||
@@ -574,13 +326,8 @@ async def get_file_url(
|
|||||||
# For local storage, generate signed URL with expiration
|
# For local storage, generate signed URL with expiration
|
||||||
url = generate_signed_url(str(file_id), expires)
|
url = generate_signed_url(str(file_id), expires)
|
||||||
else:
|
else:
|
||||||
# For remote storage (OSS/S3), get presigned URL with forced download
|
# For remote storage (OSS/S3), get presigned URL
|
||||||
url = await storage_service.get_file_url(
|
url = await storage_service.get_file_url(file_key, expires=expires)
|
||||||
file_key,
|
|
||||||
expires=expires,
|
|
||||||
file_name=file_metadata.file_name,
|
|
||||||
)
|
|
||||||
url = _match_scheme(request, url)
|
|
||||||
|
|
||||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||||
return success(
|
return success(
|
||||||
@@ -600,54 +347,8 @@ async def get_file_url(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/files/{file_id}/public-url", response_model=ApiResponse)
|
|
||||||
async def get_permanent_file_url(
|
|
||||||
file_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
获取文件的永久公开 URL(无过期时间)。
|
|
||||||
|
|
||||||
- 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置)
|
|
||||||
- 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限)
|
|
||||||
"""
|
|
||||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
|
||||||
if not file_metadata:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist")
|
|
||||||
|
|
||||||
if file_metadata.status != "completed":
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"File upload not completed, status: {file_metadata.status}")
|
|
||||||
|
|
||||||
file_key = file_metadata.file_key
|
|
||||||
storage = storage_service.storage
|
|
||||||
|
|
||||||
try:
|
|
||||||
if isinstance(storage, LocalStorage):
|
|
||||||
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
|
|
||||||
else:
|
|
||||||
url = await storage.get_permanent_url(file_key)
|
|
||||||
if not url:
|
|
||||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
||||||
detail="Permanent URL not supported for current storage backend")
|
|
||||||
|
|
||||||
api_logger.info(f"Generated permanent URL: file_id={file_id}")
|
|
||||||
return success(
|
|
||||||
data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name},
|
|
||||||
msg="Permanent file URL generated successfully"
|
|
||||||
)
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Failed to generate permanent URL: {e}")
|
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to generate permanent URL: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/public/{file_id}", response_model=Any)
|
@router.get("/public/{file_id}", response_model=Any)
|
||||||
async def public_download_file(
|
async def public_download_file(
|
||||||
request: Request,
|
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
expires: int = 0,
|
expires: int = 0,
|
||||||
signature: str = "",
|
signature: str = "",
|
||||||
@@ -719,7 +420,6 @@ async def public_download_file(
|
|||||||
# For remote storage, redirect to presigned URL
|
# For remote storage, redirect to presigned URL
|
||||||
try:
|
try:
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||||
presigned_url = _match_scheme(request, presigned_url)
|
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||||
@@ -731,7 +431,6 @@ async def public_download_file(
|
|||||||
|
|
||||||
@router.get("/permanent/{file_id}", response_model=Any)
|
@router.get("/permanent/{file_id}", response_model=Any)
|
||||||
async def permanent_download_file(
|
async def permanent_download_file(
|
||||||
request: Request,
|
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
@@ -790,8 +489,7 @@ async def permanent_download_file(
|
|||||||
# For remote storage, redirect to presigned URL with long expiration
|
# For remote storage, redirect to presigned URL with long expiration
|
||||||
try:
|
try:
|
||||||
# Use a very long expiration (7 days max for most cloud providers)
|
# Use a very long expiration (7 days max for most cloud providers)
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
|
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
||||||
presigned_url = _match_scheme(request, presigned_url)
|
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||||
@@ -799,44 +497,3 @@ async def permanent_download_file(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Failed to retrieve file: {str(e)}"
|
detail=f"Failed to retrieve file: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/files/{file_id}/status", response_model=ApiResponse)
|
|
||||||
async def get_file_status(
|
|
||||||
file_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get file upload/processing status (no authentication required).
|
|
||||||
|
|
||||||
This endpoint is used to check if a file (e.g., TTS audio) is ready.
|
|
||||||
Returns status: pending, completed, or failed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_id: The UUID of the file.
|
|
||||||
db: Database session.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse with file status and metadata.
|
|
||||||
"""
|
|
||||||
api_logger.info(f"File status request: file_id={file_id}")
|
|
||||||
|
|
||||||
# Query file metadata from database
|
|
||||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
|
||||||
if not file_metadata:
|
|
||||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The file does not exist"
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(
|
|
||||||
data={
|
|
||||||
"file_id": str(file_id),
|
|
||||||
"status": file_metadata.status,
|
|
||||||
"file_name": file_metadata.file_name,
|
|
||||||
"file_size": file_metadata.file_size,
|
|
||||||
"content_type": file_metadata.content_type,
|
|
||||||
},
|
|
||||||
msg="File status retrieved successfully"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -3,10 +3,9 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db, SessionLocal
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.repositories.home_page_repository import HomePageRepository
|
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.home_page_service import HomePageService
|
from app.services.home_page_service import HomePageService
|
||||||
|
|
||||||
@@ -33,31 +32,8 @@ def get_workspace_list(
|
|||||||
@router.get("/version", response_model=ApiResponse)
|
@router.get("/version", response_model=ApiResponse)
|
||||||
def get_system_version():
|
def get_system_version():
|
||||||
"""获取系统版本号+说明"""
|
"""获取系统版本号+说明"""
|
||||||
current_version = None
|
|
||||||
version_info = None
|
|
||||||
|
|
||||||
# 1️⃣ 优先从数据库获取最新已发布的版本
|
|
||||||
try:
|
|
||||||
db = SessionLocal()
|
|
||||||
try:
|
|
||||||
current_version, version_info = HomePageRepository.get_latest_version_introduction(db)
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 2️⃣ 降级:使用环境变量中的版本号
|
|
||||||
if not current_version:
|
|
||||||
current_version = settings.SYSTEM_VERSION
|
current_version = settings.SYSTEM_VERSION
|
||||||
version_info = HomePageService.load_version_introduction(current_version)
|
version_info = HomePageService.load_version_introduction(current_version)
|
||||||
|
|
||||||
# 3️⃣ 如果数据库和 JSON 都没有,返回基本信息
|
|
||||||
if not version_info:
|
|
||||||
version_info = {
|
|
||||||
"introduction": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []},
|
|
||||||
"introduction_en": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []}
|
|
||||||
}
|
|
||||||
|
|
||||||
return success(
|
return success(
|
||||||
data={
|
data={
|
||||||
"version": current_version,
|
"version": current_version,
|
||||||
|
|||||||
@@ -1,833 +0,0 @@
|
|||||||
"""
|
|
||||||
I18n Management API Controller
|
|
||||||
|
|
||||||
This module provides management APIs for:
|
|
||||||
- Language management (list, get, add, update languages)
|
|
||||||
- Translation management (get, update, reload translations)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
|
||||||
from app.core.response_utils import success
|
|
||||||
from app.db import get_db
|
|
||||||
from app.dependencies import get_current_user, get_current_superuser
|
|
||||||
from app.i18n.dependencies import get_translator
|
|
||||||
from app.i18n.service import get_translation_service
|
|
||||||
from app.models.user_model import User
|
|
||||||
from app.schemas.i18n_schema import (
|
|
||||||
LanguageInfo,
|
|
||||||
LanguageListResponse,
|
|
||||||
LanguageCreateRequest,
|
|
||||||
LanguageUpdateRequest,
|
|
||||||
TranslationResponse,
|
|
||||||
TranslationUpdateRequest,
|
|
||||||
MissingTranslationsResponse,
|
|
||||||
ReloadResponse
|
|
||||||
)
|
|
||||||
from app.schemas.response_schema import ApiResponse
|
|
||||||
|
|
||||||
api_logger = get_api_logger()
|
|
||||||
|
|
||||||
router = APIRouter(
|
|
||||||
prefix="/i18n",
|
|
||||||
tags=["I18n Management"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Language Management APIs
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
@router.get("/languages", response_model=ApiResponse)
|
|
||||||
def get_languages(
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get list of all supported languages.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of language information including code, name, and status
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Get languages request from user: {current_user.username}")
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
|
|
||||||
# Get available locales from translation service
|
|
||||||
available_locales = translation_service.get_available_locales()
|
|
||||||
|
|
||||||
# Build language info list
|
|
||||||
languages = []
|
|
||||||
for locale in available_locales:
|
|
||||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
|
||||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
|
||||||
|
|
||||||
# Get native names
|
|
||||||
native_names = {
|
|
||||||
"zh": "中文(简体)",
|
|
||||||
"en": "English",
|
|
||||||
"ja": "日本語",
|
|
||||||
"ko": "한국어",
|
|
||||||
"fr": "Français",
|
|
||||||
"de": "Deutsch",
|
|
||||||
"es": "Español"
|
|
||||||
}
|
|
||||||
|
|
||||||
language_info = LanguageInfo(
|
|
||||||
code=locale,
|
|
||||||
name=f"{locale.upper()}",
|
|
||||||
native_name=native_names.get(locale, locale),
|
|
||||||
is_enabled=is_enabled,
|
|
||||||
is_default=is_default
|
|
||||||
)
|
|
||||||
languages.append(language_info)
|
|
||||||
|
|
||||||
response = LanguageListResponse(languages=languages)
|
|
||||||
|
|
||||||
api_logger.info(f"Returning {len(languages)} languages")
|
|
||||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/languages/{locale}", response_model=ApiResponse)
|
|
||||||
def get_language(
|
|
||||||
locale: str,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get information about a specific language.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Language code (e.g., 'zh', 'en')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Language information
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
|
|
||||||
# Check if locale exists
|
|
||||||
available_locales = translation_service.get_available_locales()
|
|
||||||
if locale not in available_locales:
|
|
||||||
api_logger.warning(f"Language not found: {locale}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=t("i18n.language.not_found", locale=locale)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build language info
|
|
||||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
|
||||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
|
||||||
|
|
||||||
native_names = {
|
|
||||||
"zh": "中文(简体)",
|
|
||||||
"en": "English",
|
|
||||||
"ja": "日本語",
|
|
||||||
"ko": "한국어",
|
|
||||||
"fr": "Français",
|
|
||||||
"de": "Deutsch",
|
|
||||||
"es": "Español"
|
|
||||||
}
|
|
||||||
|
|
||||||
language_info = LanguageInfo(
|
|
||||||
code=locale,
|
|
||||||
name=f"{locale.upper()}",
|
|
||||||
native_name=native_names.get(locale, locale),
|
|
||||||
is_enabled=is_enabled,
|
|
||||||
is_default=is_default
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"Returning language info for: {locale}")
|
|
||||||
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/languages", response_model=ApiResponse)
|
|
||||||
def add_language(
|
|
||||||
request: LanguageCreateRequest,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Add a new language (admin only).
|
|
||||||
|
|
||||||
Note: This endpoint validates the request but actual language addition
|
|
||||||
requires creating translation files in the locales directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Language creation request
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Success message
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Add language request: code={request.code}, admin={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
|
|
||||||
# Check if language already exists
|
|
||||||
available_locales = translation_service.get_available_locales()
|
|
||||||
if request.code in available_locales:
|
|
||||||
api_logger.warning(f"Language already exists: {request.code}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=t("i18n.language.already_exists", locale=request.code)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: Actual language addition requires creating translation files
|
|
||||||
# This endpoint serves as a validation and documentation point
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"Language addition validated: {request.code}. "
|
|
||||||
"Translation files need to be created manually."
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(
|
|
||||||
msg=t(
|
|
||||||
"i18n.language.add_instructions",
|
|
||||||
locale=request.code,
|
|
||||||
dir=settings.I18N_CORE_LOCALES_DIR
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/languages/{locale}", response_model=ApiResponse)
|
|
||||||
def update_language(
|
|
||||||
locale: str,
|
|
||||||
request: LanguageUpdateRequest,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update language configuration (admin only).
|
|
||||||
|
|
||||||
Note: This endpoint validates the request but actual configuration
|
|
||||||
changes require updating environment variables or config files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Language code
|
|
||||||
request: Language update request
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Success message
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Update language request: locale={locale}, admin={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
|
|
||||||
# Check if language exists
|
|
||||||
available_locales = translation_service.get_available_locales()
|
|
||||||
if locale not in available_locales:
|
|
||||||
api_logger.warning(f"Language not found: {locale}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=t("i18n.language.not_found", locale=locale)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: Actual configuration changes require updating settings
|
|
||||||
# This endpoint serves as a validation and documentation point
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"Language update validated: {locale}. "
|
|
||||||
"Configuration changes require environment variable updates."
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(msg=t("i18n.language.update_instructions", locale=locale))
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Translation Management APIs
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
@router.get("/translations", response_model=ApiResponse)
|
|
||||||
def get_all_translations(
|
|
||||||
locale: Optional[str] = None,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get all translations for all or specific locale.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Optional locale filter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
All translations organized by locale and namespace
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Get all translations request: locale={locale}, user={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
|
|
||||||
if locale:
|
|
||||||
# Get translations for specific locale
|
|
||||||
available_locales = translation_service.get_available_locales()
|
|
||||||
if locale not in available_locales:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=t("i18n.language.not_found", locale=locale)
|
|
||||||
)
|
|
||||||
|
|
||||||
translations = {
|
|
||||||
locale: translation_service._cache.get(locale, {})
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Get all translations
|
|
||||||
translations = translation_service._cache
|
|
||||||
|
|
||||||
response = TranslationResponse(translations=translations)
|
|
||||||
|
|
||||||
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
|
|
||||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/translations/{locale}", response_model=ApiResponse)
|
|
||||||
def get_locale_translations(
|
|
||||||
locale: str,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get all translations for a specific locale.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Language code
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
All translations for the locale organized by namespace
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Get locale translations request: locale={locale}, user={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
|
|
||||||
# Check if locale exists
|
|
||||||
available_locales = translation_service.get_available_locales()
|
|
||||||
if locale not in available_locales:
|
|
||||||
api_logger.warning(f"Language not found: {locale}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=t("i18n.language.not_found", locale=locale)
|
|
||||||
)
|
|
||||||
|
|
||||||
translations = translation_service._cache.get(locale, {})
|
|
||||||
|
|
||||||
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
|
|
||||||
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
|
|
||||||
def get_namespace_translations(
|
|
||||||
locale: str,
|
|
||||||
namespace: str,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get translations for a specific namespace in a locale.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Language code
|
|
||||||
namespace: Translation namespace (e.g., 'common', 'auth')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Translations for the specified namespace
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Get namespace translations request: locale={locale}, "
|
|
||||||
f"namespace={namespace}, user={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
|
|
||||||
# Check if locale exists
|
|
||||||
available_locales = translation_service.get_available_locales()
|
|
||||||
if locale not in available_locales:
|
|
||||||
api_logger.warning(f"Language not found: {locale}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=t("i18n.language.not_found", locale=locale)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get namespace translations
|
|
||||||
locale_translations = translation_service._cache.get(locale, {})
|
|
||||||
namespace_translations = locale_translations.get(namespace, {})
|
|
||||||
|
|
||||||
if not namespace_translations:
|
|
||||||
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"Returning translations for namespace: {namespace} in locale: {locale}"
|
|
||||||
)
|
|
||||||
return success(
|
|
||||||
data={
|
|
||||||
"locale": locale,
|
|
||||||
"namespace": namespace,
|
|
||||||
"translations": namespace_translations
|
|
||||||
},
|
|
||||||
msg=t("common.success.retrieved")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
|
|
||||||
def update_translation(
|
|
||||||
locale: str,
|
|
||||||
key: str,
|
|
||||||
request: TranslationUpdateRequest,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update a single translation (admin only).
|
|
||||||
|
|
||||||
Note: This endpoint validates the request but actual translation updates
|
|
||||||
require modifying translation files in the locales directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Language code
|
|
||||||
key: Translation key (format: "namespace.key.subkey")
|
|
||||||
request: Translation update request
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Success message
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Update translation request: locale={locale}, key={key}, "
|
|
||||||
f"admin={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
|
|
||||||
# Check if locale exists
|
|
||||||
available_locales = translation_service.get_available_locales()
|
|
||||||
if locale not in available_locales:
|
|
||||||
api_logger.warning(f"Language not found: {locale}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=t("i18n.language.not_found", locale=locale)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate key format
|
|
||||||
if "." not in key:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=t("i18n.translation.invalid_key_format", key=key)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: Actual translation updates require modifying JSON files
|
|
||||||
# This endpoint serves as a validation and documentation point
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"Translation update validated: {locale}/{key}. "
|
|
||||||
"Translation files need to be updated manually."
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(
|
|
||||||
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/translations/missing", response_model=ApiResponse)
|
|
||||||
def get_missing_translations(
|
|
||||||
locale: Optional[str] = None,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get list of missing translations.
|
|
||||||
|
|
||||||
Compares translations across locales to find missing keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Optional locale to check (defaults to checking all non-default locales)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of missing translation keys
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Get missing translations request: locale={locale}, user={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
|
|
||||||
default_locale = settings.I18N_DEFAULT_LANGUAGE
|
|
||||||
available_locales = translation_service.get_available_locales()
|
|
||||||
|
|
||||||
# Get default locale translations as reference
|
|
||||||
default_translations = translation_service._cache.get(default_locale, {})
|
|
||||||
|
|
||||||
# Collect all keys from default locale
|
|
||||||
def collect_keys(data, prefix=""):
|
|
||||||
keys = []
|
|
||||||
for key, value in data.items():
|
|
||||||
full_key = f"{prefix}.{key}" if prefix else key
|
|
||||||
if isinstance(value, dict):
|
|
||||||
keys.extend(collect_keys(value, full_key))
|
|
||||||
else:
|
|
||||||
keys.append(full_key)
|
|
||||||
return keys
|
|
||||||
|
|
||||||
default_keys = set()
|
|
||||||
for namespace, translations in default_translations.items():
|
|
||||||
namespace_keys = collect_keys(translations, namespace)
|
|
||||||
default_keys.update(namespace_keys)
|
|
||||||
|
|
||||||
# Find missing keys in target locale(s)
|
|
||||||
missing_by_locale = {}
|
|
||||||
|
|
||||||
target_locales = [locale] if locale else [
|
|
||||||
loc for loc in available_locales if loc != default_locale
|
|
||||||
]
|
|
||||||
|
|
||||||
for target_locale in target_locales:
|
|
||||||
if target_locale not in available_locales:
|
|
||||||
continue
|
|
||||||
|
|
||||||
target_translations = translation_service._cache.get(target_locale, {})
|
|
||||||
target_keys = set()
|
|
||||||
|
|
||||||
for namespace, translations in target_translations.items():
|
|
||||||
namespace_keys = collect_keys(translations, namespace)
|
|
||||||
target_keys.update(namespace_keys)
|
|
||||||
|
|
||||||
missing_keys = default_keys - target_keys
|
|
||||||
if missing_keys:
|
|
||||||
missing_by_locale[target_locale] = sorted(list(missing_keys))
|
|
||||||
|
|
||||||
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
|
|
||||||
|
|
||||||
total_missing = sum(len(keys) for keys in missing_by_locale.values())
|
|
||||||
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
|
|
||||||
|
|
||||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reload", response_model=ApiResponse)
|
|
||||||
def reload_translations(
|
|
||||||
locale: Optional[str] = None,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Trigger hot reload of translation files (admin only).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Optional locale to reload (defaults to reloading all locales)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Reload status and statistics
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Reload translations request: locale={locale or 'all'}, "
|
|
||||||
f"admin={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
if not settings.I18N_ENABLE_HOT_RELOAD:
|
|
||||||
api_logger.warning("Hot reload is disabled in configuration")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=t("i18n.reload.disabled")
|
|
||||||
)
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Reload translations
|
|
||||||
translation_service.reload(locale)
|
|
||||||
|
|
||||||
# Get statistics
|
|
||||||
available_locales = translation_service.get_available_locales()
|
|
||||||
reloaded_locales = [locale] if locale else available_locales
|
|
||||||
|
|
||||||
response = ReloadResponse(
|
|
||||||
success=True,
|
|
||||||
reloaded_locales=reloaded_locales,
|
|
||||||
total_locales=len(available_locales)
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data=response.dict(), msg=t("i18n.reload.success"))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Failed to reload translations: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=t("i18n.reload.failed", error=str(e))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Performance Monitoring APIs
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
@router.get("/metrics", response_model=ApiResponse)
|
|
||||||
def get_metrics(
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get i18n performance metrics (admin only).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Performance metrics including:
|
|
||||||
- Request counts
|
|
||||||
- Missing translations
|
|
||||||
- Timing statistics
|
|
||||||
- Locale usage
|
|
||||||
- Error counts
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Get metrics request: admin={current_user.username}")
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
metrics = translation_service.get_metrics_summary()
|
|
||||||
|
|
||||||
api_logger.info("Returning i18n metrics")
|
|
||||||
return success(data=metrics, msg=t("common.success.retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/metrics/cache", response_model=ApiResponse)
|
|
||||||
def get_cache_stats(
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get cache statistics (admin only).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Cache statistics including:
|
|
||||||
- Hit/miss rates
|
|
||||||
- LRU cache performance
|
|
||||||
- Loaded locales
|
|
||||||
- Memory usage
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Get cache stats request: admin={current_user.username}")
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
cache_stats = translation_service.get_cache_stats()
|
|
||||||
memory_usage = translation_service.get_memory_usage()
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"cache": cache_stats,
|
|
||||||
"memory": memory_usage
|
|
||||||
}
|
|
||||||
|
|
||||||
api_logger.info("Returning cache statistics")
|
|
||||||
return success(data=data, msg=t("common.success.retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/metrics/prometheus")
|
|
||||||
def get_prometheus_metrics(
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get metrics in Prometheus format (admin only).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Prometheus-formatted metrics as plain text
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
|
|
||||||
|
|
||||||
from app.i18n.metrics import get_metrics
|
|
||||||
metrics = get_metrics()
|
|
||||||
prometheus_output = metrics.export_prometheus()
|
|
||||||
|
|
||||||
from fastapi.responses import PlainTextResponse
|
|
||||||
return PlainTextResponse(content=prometheus_output)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/metrics/reset", response_model=ApiResponse)
|
|
||||||
def reset_metrics(
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Reset all metrics (admin only).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Success message
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Reset metrics request: admin={current_user.username}")
|
|
||||||
|
|
||||||
from app.i18n.metrics import get_metrics
|
|
||||||
metrics = get_metrics()
|
|
||||||
metrics.reset()
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
translation_service.cache.reset_stats()
|
|
||||||
|
|
||||||
api_logger.info("Metrics reset completed")
|
|
||||||
return success(msg=t("i18n.metrics.reset_success"))
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Missing Translation Logging and Reporting APIs
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
@router.get("/logs/missing", response_model=ApiResponse)
|
|
||||||
def get_missing_translation_logs(
|
|
||||||
locale: Optional[str] = None,
|
|
||||||
limit: Optional[int] = 100,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get missing translation logs (admin only).
|
|
||||||
|
|
||||||
Returns logged missing translations with context information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Optional locale filter
|
|
||||||
limit: Maximum number of entries to return (default: 100)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Missing translation logs with context
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Get missing translation logs request: locale={locale}, "
|
|
||||||
f"limit={limit}, admin={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
translation_logger = translation_service.translation_logger
|
|
||||||
|
|
||||||
# Get missing translations
|
|
||||||
missing_translations = translation_logger.get_missing_translations(locale)
|
|
||||||
|
|
||||||
# Get missing with context
|
|
||||||
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
|
|
||||||
|
|
||||||
# Get statistics
|
|
||||||
statistics = translation_logger.get_statistics()
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"missing_translations": missing_translations,
|
|
||||||
"recent_context": missing_with_context,
|
|
||||||
"statistics": statistics
|
|
||||||
}
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"Returning {statistics['total_missing']} missing translations"
|
|
||||||
)
|
|
||||||
return success(data=data, msg=t("common.success.retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/logs/missing/report", response_model=ApiResponse)
|
|
||||||
def generate_missing_translation_report(
|
|
||||||
locale: Optional[str] = None,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Generate a comprehensive missing translation report (admin only).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Optional locale filter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Comprehensive report with missing translations and statistics
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Generate missing translation report request: locale={locale}, "
|
|
||||||
f"admin={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
translation_logger = translation_service.translation_logger
|
|
||||||
|
|
||||||
# Generate report
|
|
||||||
report = translation_logger.generate_report(locale)
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"Generated report with {report['total_missing']} missing translations"
|
|
||||||
)
|
|
||||||
return success(data=report, msg=t("common.success.retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logs/missing/export", response_model=ApiResponse)
|
|
||||||
def export_missing_translations(
|
|
||||||
locale: Optional[str] = None,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Export missing translations to JSON file (admin only).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Optional locale filter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Export status and file path
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Export missing translations request: locale={locale}, "
|
|
||||||
f"admin={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
translation_logger = translation_service.translation_logger
|
|
||||||
|
|
||||||
# Generate filename with timestamp
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
locale_suffix = f"_{locale}" if locale else "_all"
|
|
||||||
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
|
|
||||||
|
|
||||||
# Export to file
|
|
||||||
translation_logger.export_to_json(output_file)
|
|
||||||
|
|
||||||
api_logger.info(f"Missing translations exported to: {output_file}")
|
|
||||||
return success(
|
|
||||||
data={"file_path": output_file},
|
|
||||||
msg=t("i18n.logs.export_success", file=output_file)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/logs/missing", response_model=ApiResponse)
|
|
||||||
def clear_missing_translation_logs(
|
|
||||||
locale: Optional[str] = None,
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
current_user: User = Depends(get_current_superuser)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Clear missing translation logs (admin only).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
locale: Optional locale to clear (clears all if not specified)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Success message
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Clear missing translation logs request: locale={locale or 'all'}, "
|
|
||||||
f"admin={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
translation_service = get_translation_service()
|
|
||||||
translation_logger = translation_service.translation_logger
|
|
||||||
|
|
||||||
# Clear logs
|
|
||||||
translation_logger.clear(locale)
|
|
||||||
|
|
||||||
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
|
|
||||||
return success(msg=t("i18n.logs.clear_success"))
|
|
||||||
@@ -122,48 +122,6 @@ def validate_confidence_threshold(threshold: float) -> None:
|
|||||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
async def check_user_data_exists(
|
|
||||||
end_user_id: str,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
) -> ApiResponse:
|
|
||||||
"""
|
|
||||||
检查用户画像数据是否存在
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 目标用户ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
数据存在状态
|
|
||||||
"""
|
|
||||||
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Validate inputs
|
|
||||||
validate_user_id(end_user_id)
|
|
||||||
|
|
||||||
# Create service with user-specific config
|
|
||||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
|
||||||
|
|
||||||
# Get cached profile
|
|
||||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
|
||||||
|
|
||||||
if cached_profile is None:
|
|
||||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
|
||||||
return success(
|
|
||||||
data={"exists": False},
|
|
||||||
msg="画像数据不存在,请点击右上角刷新进行初始化"
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
|
|
||||||
return success(data={"exists": True}, msg="画像数据已存在")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def get_preference_tags(
|
async def get_preference_tags(
|
||||||
@@ -201,8 +159,12 @@ async def get_preference_tags(
|
|||||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||||
|
|
||||||
if cached_profile is None:
|
if cached_profile is None:
|
||||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||||
return fail(BizCode.NOT_FOUND, "", "")
|
return fail(
|
||||||
|
BizCode.NOT_FOUND,
|
||||||
|
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||||
|
""
|
||||||
|
)
|
||||||
|
|
||||||
# Extract preferences from cache
|
# Extract preferences from cache
|
||||||
preferences = cached_profile.get("preferences", [])
|
preferences = cached_profile.get("preferences", [])
|
||||||
@@ -268,8 +230,12 @@ async def get_dimension_portrait(
|
|||||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||||
|
|
||||||
if cached_profile is None:
|
if cached_profile is None:
|
||||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||||
return fail(BizCode.NOT_FOUND, "", "")
|
return fail(
|
||||||
|
BizCode.NOT_FOUND,
|
||||||
|
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||||
|
""
|
||||||
|
)
|
||||||
|
|
||||||
# Extract portrait from cache
|
# Extract portrait from cache
|
||||||
portrait = cached_profile.get("portrait", {})
|
portrait = cached_profile.get("portrait", {})
|
||||||
@@ -312,8 +278,12 @@ async def get_interest_area_distribution(
|
|||||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||||
|
|
||||||
if cached_profile is None:
|
if cached_profile is None:
|
||||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||||
return fail(BizCode.NOT_FOUND, "", "")
|
return fail(
|
||||||
|
BizCode.NOT_FOUND,
|
||||||
|
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||||
|
""
|
||||||
|
)
|
||||||
|
|
||||||
# Extract interest areas from cache
|
# Extract interest areas from cache
|
||||||
interest_areas = cached_profile.get("interest_areas", {})
|
interest_areas = cached_profile.get("interest_areas", {})
|
||||||
@@ -360,8 +330,12 @@ async def get_behavior_habits(
|
|||||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||||
|
|
||||||
if cached_profile is None:
|
if cached_profile is None:
|
||||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||||
return fail(BizCode.NOT_FOUND, "", "")
|
return fail(
|
||||||
|
BizCode.NOT_FOUND,
|
||||||
|
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||||
|
""
|
||||||
|
)
|
||||||
|
|
||||||
# Extract habits from cache
|
# Extract habits from cache
|
||||||
habits = cached_profile.get("habits", [])
|
habits = cached_profile.get("habits", [])
|
||||||
|
|||||||
@@ -9,16 +9,13 @@ from sqlalchemy import or_
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.rag.common import settings
|
from app.core.rag.common import settings
|
||||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
|
||||||
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
|
||||||
from app.core.rag.llm.chat_model import Base
|
from app.core.rag.llm.chat_model import Base
|
||||||
from app.core.rag.nlp import rag_tokenizer, search
|
from app.core.rag.nlp import rag_tokenizer, search
|
||||||
from app.core.rag.prompts.generator import graph_entity_types
|
from app.core.rag.prompts.generator import graph_entity_types
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
from app.core.response_utils import success, fail
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models import knowledge_model
|
from app.models import knowledge_model
|
||||||
@@ -27,7 +24,6 @@ from app.schemas import knowledge_schema
|
|||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import knowledge_service, document_service
|
from app.services import knowledge_service, document_service
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -180,7 +176,6 @@ async def get_knowledges(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/knowledge", response_model=ApiResponse)
|
@router.post("/knowledge", response_model=ApiResponse)
|
||||||
@check_knowledge_capacity_quota
|
|
||||||
async def create_knowledge(
|
async def create_knowledge(
|
||||||
create_data: knowledge_schema.KnowledgeCreate,
|
create_data: knowledge_schema.KnowledgeCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -354,7 +349,6 @@ async def delete_knowledge(
|
|||||||
# 2. Soft-delete knowledge base
|
# 2. Soft-delete knowledge base
|
||||||
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
|
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
|
||||||
db_knowledge.status = 2
|
db_knowledge.status = 2
|
||||||
db_knowledge.updated_at = datetime.datetime.now()
|
|
||||||
db.commit()
|
db.commit()
|
||||||
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
|
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
|
||||||
return success(msg="The knowledge base has been successfully deleted")
|
return success(msg="The knowledge base has been successfully deleted")
|
||||||
@@ -490,99 +484,3 @@ async def rebuild_knowledge_graph(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
|
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@router.get("/check/yuque/auth", response_model=ApiResponse)
|
|
||||||
async def check_yuque_auth(
|
|
||||||
yuque_user_id: str,
|
|
||||||
yuque_token: str,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
check yuque auth info
|
|
||||||
"""
|
|
||||||
api_logger.info(f"check yuque auth info, username: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
api_client = YuqueAPIClient(
|
|
||||||
user_id=yuque_user_id,
|
|
||||||
token=yuque_token
|
|
||||||
)
|
|
||||||
async with api_client as client:
|
|
||||||
repos = await client.get_user_repos()
|
|
||||||
if repos:
|
|
||||||
return success(msg="Successfully auth yuque info")
|
|
||||||
return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect")
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"auth yuque info failed: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/check/feishu/auth", response_model=ApiResponse)
|
|
||||||
async def check_feishu_auth(
|
|
||||||
feishu_app_id: str,
|
|
||||||
feishu_app_secret: str,
|
|
||||||
feishu_folder_token: str,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
check feishu auth info
|
|
||||||
"""
|
|
||||||
api_logger.info(f"check feishu auth info, username: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
api_client = FeishuAPIClient(
|
|
||||||
app_id=feishu_app_id,
|
|
||||||
app_secret=feishu_app_secret
|
|
||||||
)
|
|
||||||
async with api_client as client:
|
|
||||||
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
|
|
||||||
if files:
|
|
||||||
return success(msg="Successfully auth feishu info")
|
|
||||||
return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect")
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"auth feishu info failed: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
|
|
||||||
async def sync_knowledge(
|
|
||||||
knowledge_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
sync knowledge base information based on knowledge_id
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. Query knowledge base information from the database
|
|
||||||
api_logger.debug(f"Query knowledge base: {knowledge_id}")
|
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
|
||||||
if not db_knowledge:
|
|
||||||
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The knowledge base does not exist or access is denied"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. sync knowledge
|
|
||||||
# from app.tasks import sync_knowledge_for_kb
|
|
||||||
# sync_knowledge_for_kb(kb_id)
|
|
||||||
task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id])
|
|
||||||
result = {
|
|
||||||
"task_id": task.id
|
|
||||||
}
|
|
||||||
return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.")
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|||||||
@@ -1,474 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import json
|
|
||||||
from typing import Optional
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
import requests
|
|
||||||
from sqlalchemy import or_
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from modelscope.hub.errors import raise_for_http_status
|
|
||||||
from modelscope.hub.mcp_api import MCPApi
|
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
|
||||||
from app.core.response_utils import success, fail
|
|
||||||
from app.db import get_db
|
|
||||||
from app.dependencies import get_current_user
|
|
||||||
from app.models import mcp_market_config_model
|
|
||||||
from app.models.user_model import User
|
|
||||||
from app.schemas import mcp_market_config_schema
|
|
||||||
from app.schemas.response_schema import ApiResponse
|
|
||||||
from app.services import mcp_market_config_service, mcp_market_service
|
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
|
||||||
api_logger = get_api_logger()
|
|
||||||
|
|
||||||
router = APIRouter(
|
|
||||||
prefix="/mcp_market_configs",
|
|
||||||
tags=["mcp_market_configs"],
|
|
||||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/mcp_servers", response_model=ApiResponse)
|
|
||||||
async def get_mcp_servers(
|
|
||||||
mcp_market_config_id: uuid.UUID,
|
|
||||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
|
||||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
|
||||||
keywords: Optional[str] = Query(None, description="Search keywords (Optional search query string,e.g. Chinese service name, English service name, author/owner username)"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Query the mcp servers list in pages
|
|
||||||
- Support keyword search for name,author,owner
|
|
||||||
- Return paging metadata + mcp server list
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Query mcp server list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
|
||||||
|
|
||||||
# 1. parameter validation
|
|
||||||
if page < 1 or pagesize < 1:
|
|
||||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="The paging parameter must be greater than 0"
|
|
||||||
)
|
|
||||||
if page * pagesize > 100:
|
|
||||||
api_logger.warning(f"Paging parameters exceed ModelScope limit: page={page}, pagesize={pagesize}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"The maximum number of MCP services can view is 100. Please visit the ModelScope MCP Plaza."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Query mcp market config information from the database
|
|
||||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
|
||||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
|
||||||
mcp_market_config_id=mcp_market_config_id,
|
|
||||||
current_user=current_user)
|
|
||||||
if not db_mcp_market_config:
|
|
||||||
api_logger.warning(
|
|
||||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
|
||||||
return success(msg='The mcp market config does not exist or access is denied')
|
|
||||||
|
|
||||||
# 3. Execute paged query
|
|
||||||
token = db_mcp_market_config.token
|
|
||||||
if not token:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="MCP market config token is not configured"
|
|
||||||
)
|
|
||||||
api = MCPApi()
|
|
||||||
api.login(token)
|
|
||||||
|
|
||||||
body = {
|
|
||||||
'filter': {},
|
|
||||||
'page_number': page,
|
|
||||||
'page_size': pagesize,
|
|
||||||
'search': keywords
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
cookies = api.get_cookies(token)
|
|
||||||
headers=api.builder_headers(api.headers)
|
|
||||||
headers['Authorization'] = f'Bearer {token}'
|
|
||||||
r = api.session.put(
|
|
||||||
url=api.mcp_base_url,
|
|
||||||
headers=headers,
|
|
||||||
json=body,
|
|
||||||
cookies=cookies)
|
|
||||||
raise_for_http_status(r)
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to get MCP servers: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
data = api._handle_response(r)
|
|
||||||
total = data.get('total_count', 0)
|
|
||||||
mcp_server_list = data.get('mcp_server_list', [])
|
|
||||||
# items = [{
|
|
||||||
# 'name': item.get('name', ''),
|
|
||||||
# 'id': item.get('id', ''),
|
|
||||||
# 'description': item.get('description', '')
|
|
||||||
# } for item in mcp_server_list]
|
|
||||||
|
|
||||||
# 4. Return structured response
|
|
||||||
result = {
|
|
||||||
"items": mcp_server_list,
|
|
||||||
"page": {
|
|
||||||
"page": page,
|
|
||||||
"pagesize": pagesize,
|
|
||||||
"total": total,
|
|
||||||
"has_next": True if page * pagesize < total else False
|
|
||||||
}
|
|
||||||
}
|
|
||||||
# 5. Update mck_market.mcp_count
|
|
||||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=db_mcp_market_config.mcp_market_id, current_user=current_user)
|
|
||||||
if not db_mcp_market:
|
|
||||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={db_mcp_market_config.mcp_market_id}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The mcp market does not exist or access is denied"
|
|
||||||
)
|
|
||||||
db_mcp_market.mcp_count = total
|
|
||||||
db.commit()
|
|
||||||
db.refresh(db_mcp_market)
|
|
||||||
return success(data=result, msg="Query of mcp servers list successful")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/operational_mcp_servers", response_model=ApiResponse)
|
|
||||||
async def get_operational_mcp_servers(
|
|
||||||
mcp_market_config_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Query the operational mcp servers list in pages
|
|
||||||
- Support keyword search for name,author,owner
|
|
||||||
- Return paging metadata + operational mcp server list
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
|
||||||
|
|
||||||
# 1. Query mcp market config information from the database
|
|
||||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
|
||||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
|
||||||
mcp_market_config_id=mcp_market_config_id,
|
|
||||||
current_user=current_user)
|
|
||||||
if not db_mcp_market_config:
|
|
||||||
api_logger.warning(
|
|
||||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
|
||||||
return success(msg='The mcp market config does not exist or access is denied')
|
|
||||||
|
|
||||||
# 2. Execute paged query
|
|
||||||
token = db_mcp_market_config.token
|
|
||||||
if not token:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="MCP market config token is not configured"
|
|
||||||
)
|
|
||||||
api = MCPApi()
|
|
||||||
api.login(token)
|
|
||||||
|
|
||||||
url = f'{api.mcp_base_url}/operational'
|
|
||||||
headers = api.builder_headers(api.headers)
|
|
||||||
headers['Authorization'] = f'Bearer {token}'
|
|
||||||
|
|
||||||
try:
|
|
||||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
|
||||||
r = api.session.get(url, headers=headers, cookies=cookies)
|
|
||||||
raise_for_http_status(r)
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to get operational MCP servers: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
data = api._handle_response(r)
|
|
||||||
total = data.get('total_count', 0)
|
|
||||||
mcp_server_list = data.get('mcp_server_list', [])
|
|
||||||
# items = [{
|
|
||||||
# 'name': item.get('name', ''),
|
|
||||||
# 'id': item.get('id', ''),
|
|
||||||
# 'description': item.get('description', '')
|
|
||||||
# } for item in mcp_server_list]
|
|
||||||
|
|
||||||
# 3. Return structured response
|
|
||||||
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/mcp_server", response_model=ApiResponse)
|
|
||||||
async def get_mcp_server(
|
|
||||||
mcp_market_config_id: uuid.UUID,
|
|
||||||
server_id: str,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get detailed information for a specific MCP Server
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Query mcp server: tenant_id={current_user.tenant_id}, mcp_market_config_id={mcp_market_config_id}, server_id={server_id}, username: {current_user.username}")
|
|
||||||
|
|
||||||
# 1. Query mcp market config information from the database
|
|
||||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
|
||||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
|
||||||
mcp_market_config_id=mcp_market_config_id,
|
|
||||||
current_user=current_user)
|
|
||||||
if not db_mcp_market_config:
|
|
||||||
api_logger.warning(
|
|
||||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
|
||||||
return success(msg='The mcp market config does not exist or access is denied')
|
|
||||||
|
|
||||||
# 2. Get detailed information for a specific MCP Server
|
|
||||||
token = db_mcp_market_config.token
|
|
||||||
if not token:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="MCP market config token is not configured"
|
|
||||||
)
|
|
||||||
api = MCPApi()
|
|
||||||
api.login(token)
|
|
||||||
|
|
||||||
result = api.get_mcp_server(server_id=server_id)
|
|
||||||
return success(data=result, msg="Query of mcp servers list successful")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/mcp_market_config", response_model=ApiResponse)
|
|
||||||
async def create_mcp_market_config(
|
|
||||||
create_data: mcp_market_config_schema.McpMarketConfigCreate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
create mcp market config
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Request to create a mcp market config: mcp_market_id={create_data.mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
|
|
||||||
# 1. Validate token can access ModelScope MCP market
|
|
||||||
if not create_data.token:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Token is required to access ModelScope MCP market"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
api = MCPApi()
|
|
||||||
api.login(create_data.token)
|
|
||||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
|
||||||
cookies = api.get_cookies(create_data.token)
|
|
||||||
headers = api.builder_headers(api.headers)
|
|
||||||
headers['Authorization'] = f'Bearer {create_data.token}'
|
|
||||||
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
|
||||||
raise_for_http_status(r)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
|
||||||
)
|
|
||||||
# 2. Check if the mcp market name already exists
|
|
||||||
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
|
|
||||||
if db_mcp_market_config_exist:
|
|
||||||
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
|
||||||
)
|
|
||||||
# 2. verify token
|
|
||||||
create_data.status = 1
|
|
||||||
try:
|
|
||||||
api = MCPApi()
|
|
||||||
token = create_data.token
|
|
||||||
api.login(token)
|
|
||||||
|
|
||||||
body = {
|
|
||||||
'filter': {},
|
|
||||||
'page_number': 1,
|
|
||||||
'page_size': 20,
|
|
||||||
'search': ""
|
|
||||||
}
|
|
||||||
cookies = api.get_cookies(token)
|
|
||||||
headers = api.builder_headers(api.headers)
|
|
||||||
headers['Authorization'] = f'Bearer {token}'
|
|
||||||
r = api.session.put(
|
|
||||||
url=api.mcp_base_url,
|
|
||||||
headers=headers,
|
|
||||||
json=body,
|
|
||||||
cookies=cookies)
|
|
||||||
raise_for_http_status(r)
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
|
||||||
create_data.status = 0
|
|
||||||
# 3. create mcp_market_config
|
|
||||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
|
||||||
api_logger.info(
|
|
||||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
|
||||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
|
||||||
msg="The mcp market config has been successfully created")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"The creation of the mcp market config failed: {create_data.mcp_market_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{mcp_market_config_id}", response_model=ApiResponse)
|
|
||||||
async def get_mcp_market_config(
|
|
||||||
mcp_market_config_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve mcp market config information based on mcp_market_config_id
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Obtain details of the mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. Query mcp market config information from the database
|
|
||||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
|
||||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
|
||||||
if not db_mcp_market_config:
|
|
||||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
|
||||||
return success(msg='The mcp market config does not exist or access is denied')
|
|
||||||
|
|
||||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
|
||||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
|
||||||
msg="Successfully obtained mcp market config information")
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"mcp market config query failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/mcp_market_id/{mcp_market_id}", response_model=ApiResponse)
|
|
||||||
async def get_mcp_market_config_by_mcp_market_id(
|
|
||||||
mcp_market_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve mcp market config information based on mcp_market_id
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Request to create a mcp market config: mcp_market_id={mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. Query mcp market config information from the database
|
|
||||||
api_logger.debug(f"Query mcp market config: mcp_market_id={mcp_market_id}")
|
|
||||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
|
||||||
if not db_mcp_market_config:
|
|
||||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
|
||||||
return success(msg='The mcp market config does not exist or access is denied')
|
|
||||||
|
|
||||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
|
||||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
|
||||||
msg="Successfully obtained mcp market config information")
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"mcp market config query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{mcp_market_config_id}", response_model=ApiResponse)
|
|
||||||
async def update_mcp_market_config(
|
|
||||||
mcp_market_config_id: uuid.UUID,
|
|
||||||
update_data: mcp_market_config_schema.McpMarketConfigUpdate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
# 1. Check if the mcp market config exists
|
|
||||||
api_logger.debug(f"Query the mcp market config to be updated: {mcp_market_config_id}")
|
|
||||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
|
||||||
|
|
||||||
if not db_mcp_market_config:
|
|
||||||
api_logger.warning(
|
|
||||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
|
||||||
return success(msg='The mcp market config does not exist or access is denied')
|
|
||||||
|
|
||||||
# 2. Validate new token if provided
|
|
||||||
if update_data.token is not None:
|
|
||||||
try:
|
|
||||||
api = MCPApi()
|
|
||||||
api.login(update_data.token)
|
|
||||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
|
||||||
cookies = api.get_cookies(update_data.token)
|
|
||||||
headers = api.builder_headers(api.headers)
|
|
||||||
headers['Authorization'] = f'Bearer {update_data.token}'
|
|
||||||
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
|
||||||
raise_for_http_status(r)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Update fields (only update non-null fields)
|
|
||||||
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
|
|
||||||
update_dict = update_data.dict(exclude_unset=True)
|
|
||||||
updated_fields = []
|
|
||||||
for field, value in update_dict.items():
|
|
||||||
if hasattr(db_mcp_market_config, field):
|
|
||||||
old_value = getattr(db_mcp_market_config, field)
|
|
||||||
if old_value != value:
|
|
||||||
# update value
|
|
||||||
setattr(db_mcp_market_config, field, value)
|
|
||||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
|
||||||
|
|
||||||
if updated_fields:
|
|
||||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
|
||||||
|
|
||||||
# 4. Save to database
|
|
||||||
try:
|
|
||||||
db.commit()
|
|
||||||
db.refresh(db_mcp_market_config)
|
|
||||||
api_logger.info(f"The mcp market config has been successfully updated: (ID: {db_mcp_market_config.id})")
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
api_logger.error(f"The mcp market config update failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"The mcp market config update failed: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Return the updated mcp market config
|
|
||||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
|
||||||
msg="The mcp market config information updated successfully")
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{mcp_market_config_id}", response_model=ApiResponse)
|
|
||||||
async def delete_mcp_market_config(
|
|
||||||
mcp_market_config_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
delete mcp market config
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Request to delete mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. Check whether the mcp market config exists
|
|
||||||
api_logger.debug(f"Check whether the mcp market config exists: {mcp_market_config_id}")
|
|
||||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
|
||||||
|
|
||||||
if not db_mcp_market_config:
|
|
||||||
api_logger.warning(
|
|
||||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
|
||||||
return success(msg='The mcp market config does not exist or access is denied')
|
|
||||||
|
|
||||||
# 2. Deleting mcp market config
|
|
||||||
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
|
||||||
api_logger.info(f"The mcp market config has been successfully deleted: (ID: {mcp_market_config_id})")
|
|
||||||
return success(msg="The mcp market config has been successfully deleted")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Failed to delete from the mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
@@ -1,262 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import json
|
|
||||||
from typing import Optional
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
from sqlalchemy import or_
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
|
||||||
from app.core.response_utils import success, fail
|
|
||||||
from app.db import get_db
|
|
||||||
from app.dependencies import get_current_user
|
|
||||||
from app.models import mcp_market_model
|
|
||||||
from app.models.user_model import User
|
|
||||||
from app.schemas import mcp_market_schema
|
|
||||||
from app.schemas.response_schema import ApiResponse
|
|
||||||
from app.services import mcp_market_service
|
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
|
||||||
api_logger = get_api_logger()
|
|
||||||
|
|
||||||
router = APIRouter(
|
|
||||||
prefix="/mcp_markets",
|
|
||||||
tags=["mcp_markets"],
|
|
||||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/mcp_markets", response_model=ApiResponse)
|
|
||||||
async def get_mcp_markets(
|
|
||||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
|
||||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
|
||||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: category, created_at"),
|
|
||||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
|
||||||
keywords: Optional[str] = Query(None, description="Search keywords (mcp_market base name)"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Query the mcp markets list in pages
|
|
||||||
- Support keyword search for name,description
|
|
||||||
- Support dynamic sorting
|
|
||||||
- Return paging metadata + mcp_market list
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Query mcp market list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
|
||||||
|
|
||||||
# 1. parameter validation
|
|
||||||
if page < 1 or pagesize < 1:
|
|
||||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="The paging parameter must be greater than 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Construct query conditions
|
|
||||||
filters = []
|
|
||||||
|
|
||||||
# Keyword search (fuzzy matching of mcp market name,description)
|
|
||||||
if keywords:
|
|
||||||
api_logger.debug(f"Add keyword search criteria: {keywords}")
|
|
||||||
filters.append(
|
|
||||||
or_(
|
|
||||||
mcp_market_model.McpMarket.name.ilike(f"%{keywords}%"),
|
|
||||||
mcp_market_model.McpMarket.description.ilike(f"%{keywords}%")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 3. Execute paged query
|
|
||||||
try:
|
|
||||||
api_logger.debug("Start executing mcp market paging query")
|
|
||||||
total, items = mcp_market_service.get_mcp_markets_paginated(
|
|
||||||
db=db,
|
|
||||||
filters=filters,
|
|
||||||
page=page,
|
|
||||||
pagesize=pagesize,
|
|
||||||
orderby=orderby,
|
|
||||||
desc=desc,
|
|
||||||
current_user=current_user
|
|
||||||
)
|
|
||||||
api_logger.info(f"mcp market query successful: total={total}, returned={len(items)} records")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"mcp market query failed: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Query failed: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Return structured response
|
|
||||||
result = {
|
|
||||||
"items": items,
|
|
||||||
"page": {
|
|
||||||
"page": page,
|
|
||||||
"pagesize": pagesize,
|
|
||||||
"total": total,
|
|
||||||
"has_next": True if page * pagesize < total else False
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return success(data=jsonable_encoder(result), msg="Query of mcp market list successful")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/mcp_market", response_model=ApiResponse)
|
|
||||||
async def create_mcp_market(
|
|
||||||
create_data: mcp_market_schema.McpMarketCreate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
create mcp market
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Request to create a mcp market: name={create_data.name}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
api_logger.debug(f"Start creating the mcp market: {create_data.name}")
|
|
||||||
# 1. Check if the mcp market name already exists
|
|
||||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=create_data.name, current_user=current_user)
|
|
||||||
if db_mcp_market_exist:
|
|
||||||
api_logger.warning(f"The mcp market name already exists: {create_data.name}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"The mcp market name already exists: {create_data.name}"
|
|
||||||
)
|
|
||||||
db_mcp_market = mcp_market_service.create_mcp_market(db=db, mcp_market=create_data, current_user=current_user)
|
|
||||||
api_logger.info(
|
|
||||||
f"The mcp market has been successfully created: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
|
||||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
|
||||||
msg="The mcp market has been successfully created")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"The creation of the mcp market failed: {create_data.name} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{mcp_market_id}", response_model=ApiResponse)
|
|
||||||
async def get_mcp_market(
|
|
||||||
mcp_market_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve mcp market information based on mcp_market_id
|
|
||||||
"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Obtain details of the mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. Query mcp market information from the database
|
|
||||||
api_logger.debug(f"Query mcp market: {mcp_market_id}")
|
|
||||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
|
||||||
if not db_mcp_market:
|
|
||||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The mcp market does not exist or access is denied"
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"mcp market query successful: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
|
||||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
|
||||||
msg="Successfully obtained mcp market information")
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"mcp market query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{mcp_market_id}", response_model=ApiResponse)
|
|
||||||
async def update_mcp_market(
|
|
||||||
mcp_market_id: uuid.UUID,
|
|
||||||
update_data: mcp_market_schema.McpMarketUpdate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
# 1. Check if the mcp market exists
|
|
||||||
api_logger.debug(f"Query the mcp market to be updated: {mcp_market_id}")
|
|
||||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
|
||||||
|
|
||||||
if not db_mcp_market:
|
|
||||||
api_logger.warning(
|
|
||||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The mcp market does not exist or you do not have permission to access it"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. not updating the name (name already exists)
|
|
||||||
update_dict = update_data.dict(exclude_unset=True)
|
|
||||||
if "name" in update_dict:
|
|
||||||
name = update_dict["name"]
|
|
||||||
if name != db_mcp_market.name:
|
|
||||||
# Check if the mcp market name already exists
|
|
||||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=name, current_user=current_user)
|
|
||||||
if db_mcp_market_exist:
|
|
||||||
api_logger.warning(f"The mcp market name already exists: {name}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"The mcp market name already exists: {name}"
|
|
||||||
)
|
|
||||||
# 3. Update fields (only update non-null fields)
|
|
||||||
api_logger.debug(f"Start updating the mcp market fields: {mcp_market_id}")
|
|
||||||
updated_fields = []
|
|
||||||
for field, value in update_dict.items():
|
|
||||||
if hasattr(db_mcp_market, field):
|
|
||||||
old_value = getattr(db_mcp_market, field)
|
|
||||||
if old_value != value:
|
|
||||||
# update value
|
|
||||||
setattr(db_mcp_market, field, value)
|
|
||||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
|
||||||
|
|
||||||
if updated_fields:
|
|
||||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
|
||||||
|
|
||||||
# 4. Save to database
|
|
||||||
try:
|
|
||||||
db.commit()
|
|
||||||
db.refresh(db_mcp_market)
|
|
||||||
api_logger.info(f"The mcp market has been successfully updated: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
api_logger.error(f"The mcp market update failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"The mcp market update failed: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Return the updated mcp market
|
|
||||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
|
||||||
msg="The mcp market information updated successfully")
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{mcp_market_id}", response_model=ApiResponse)
|
|
||||||
async def delete_mcp_market(
|
|
||||||
mcp_market_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
delete mcp market
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Request to delete mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. Check whether the mcp market exists
|
|
||||||
api_logger.debug(f"Check whether the mcp market exists: {mcp_market_id}")
|
|
||||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
|
||||||
|
|
||||||
if not db_mcp_market:
|
|
||||||
api_logger.warning(
|
|
||||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The mcp market does not exist or you do not have permission to access it"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Deleting mcp market
|
|
||||||
mcp_market_service.delete_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
|
||||||
api_logger.info(f"The mcp market has been successfully deleted: (ID: {mcp_market_id})")
|
|
||||||
return success(msg="The mcp market has been successfully deleted")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Failed to delete from the mcp market: mcp_market_id={mcp_market_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
@@ -1,32 +1,26 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from starlette.responses import StreamingResponse
|
|
||||||
|
|
||||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.language_utils import get_language_from_header
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.memory.agent.utils.redis_tool import store
|
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
|
||||||
from app.core.memory.enums import SearchStrategy, Neo4jNodeType
|
|
||||||
from app.core.memory.memory_service import MemoryService
|
|
||||||
from app.core.rag.llm.cv_model import QWenCV
|
from app.core.rag.llm.cv_model import QWenCV
|
||||||
from app.core.response_utils import fail, success
|
from app.core.response_utils import fail, success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||||
from app.models import ModelApiKey
|
from app.models import ModelApiKey
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.repositories import knowledge_repository
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.schemas.memory_agent_schema import StorageType, UserInput, Write_UserInput, WriteMemoryRequest
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
|
from app.repositories import knowledge_repository, WorkspaceRepository
|
||||||
|
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import task_service, workspace_service
|
from app.services import task_service, workspace_service
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.memory_agent_service import get_end_user_connected_config as get_config
|
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -59,8 +53,7 @@ async def get_health_status(
|
|||||||
|
|
||||||
@router.get("/download_log")
|
@router.get("/download_log")
|
||||||
async def download_log(
|
async def download_log(
|
||||||
log_type: str = Query("file", regex="^(file|transmission)$",
|
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||||
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -121,142 +114,128 @@ async def download_log(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
# @router.post("/writer_service", response_model=ApiResponse)
|
@router.post("/writer_service", response_model=ApiResponse)
|
||||||
# @cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
# async def write_server(
|
async def write_server(
|
||||||
# user_input: Write_UserInput,
|
user_input: Write_UserInput,
|
||||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
db: Session = Depends(get_db),
|
||||||
# db: Session = Depends(get_db),
|
current_user: User = Depends(get_current_user)
|
||||||
# current_user: User = Depends(get_current_user)
|
):
|
||||||
# ):
|
"""
|
||||||
# """
|
Write service endpoint - processes write operations synchronously
|
||||||
# Write service endpoint - processes write operations synchronously
|
|
||||||
#
|
Args:
|
||||||
# Args:
|
user_input: Write request containing message and end_user_id
|
||||||
# user_input: Write request containing message and end_user_id
|
|
||||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
Returns:
|
||||||
#
|
Response with write operation status
|
||||||
# Returns:
|
"""
|
||||||
# Response with write operation status
|
config_id = user_input.config_id
|
||||||
# """
|
workspace_id = current_user.current_workspace_id
|
||||||
# # 使用集中化的语言校验
|
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||||
# language = get_language_from_header(language_type)
|
|
||||||
#
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
# config_id = user_input.config_id
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
# workspace_id = current_user.current_workspace_id
|
db=db,
|
||||||
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
workspace_id=workspace_id,
|
||||||
#
|
user=current_user
|
||||||
# # 获取 storage_type,如果为 None 则使用默认值
|
)
|
||||||
# storage_type = workspace_service.get_workspace_storage_type(
|
if storage_type is None: storage_type = 'neo4j'
|
||||||
# db=db,
|
user_rag_memory_id = ''
|
||||||
# workspace_id=workspace_id,
|
|
||||||
# user=current_user
|
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||||
# )
|
if storage_type == 'rag':
|
||||||
# if storage_type is None: storage_type = 'neo4j'
|
if workspace_id:
|
||||||
# user_rag_memory_id = ''
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
#
|
db=db,
|
||||||
# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
name="USER_RAG_MERORY",
|
||||||
# if storage_type == 'rag':
|
workspace_id=workspace_id
|
||||||
# if workspace_id:
|
)
|
||||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
if knowledge:
|
||||||
# db=db,
|
user_rag_memory_id = str(knowledge.id)
|
||||||
# name="USER_RAG_MERORY",
|
else:
|
||||||
# workspace_id=workspace_id
|
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||||
# )
|
storage_type = 'neo4j'
|
||||||
# if knowledge:
|
else:
|
||||||
# user_rag_memory_id = str(knowledge.id)
|
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
# else:
|
storage_type = 'neo4j'
|
||||||
# api_logger.warning(
|
|
||||||
# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||||
# storage_type = 'neo4j'
|
try:
|
||||||
# else:
|
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
result = await memory_agent_service.write_memory(
|
||||||
# storage_type = 'neo4j'
|
user_input.end_user_id,
|
||||||
#
|
messages_list,
|
||||||
# api_logger.info(
|
config_id,
|
||||||
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
db,
|
||||||
# try:
|
storage_type,
|
||||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
user_rag_memory_id
|
||||||
# result = await memory_agent_service.write_memory(
|
)
|
||||||
# user_input.end_user_id,
|
|
||||||
# messages_list,
|
return success(data=result, msg="写入成功")
|
||||||
# config_id,
|
except BaseException as e:
|
||||||
# db,
|
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
# storage_type,
|
if hasattr(e, 'exceptions'):
|
||||||
# user_rag_memory_id,
|
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||||
# language
|
detailed_error = "; ".join(error_messages)
|
||||||
# )
|
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||||
#
|
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||||
# return success(data=result, msg="写入成功")
|
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||||
# except BaseException as e:
|
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
|
||||||
# if hasattr(e, 'exceptions'):
|
|
||||||
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||||
# detailed_error = "; ".join(error_messages)
|
@cur_workspace_access_guard()
|
||||||
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
async def write_server_async(
|
||||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
user_input: Write_UserInput,
|
||||||
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
db: Session = Depends(get_db),
|
||||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
current_user: User = Depends(get_current_user)
|
||||||
#
|
):
|
||||||
#
|
"""
|
||||||
# @router.post("/writer_service_async", response_model=ApiResponse)
|
Async write service endpoint - enqueues write processing to Celery
|
||||||
# @cur_workspace_access_guard()
|
|
||||||
# async def write_server_async(
|
Args:
|
||||||
# user_input: Write_UserInput,
|
user_input: Write request containing message and end_user_id
|
||||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
|
||||||
# db: Session = Depends(get_db),
|
Returns:
|
||||||
# current_user: User = Depends(get_current_user)
|
Task ID for tracking async operation
|
||||||
# ):
|
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||||
# """
|
"""
|
||||||
# Async write service endpoint - enqueues write processing to Celery
|
config_id = user_input.config_id
|
||||||
#
|
workspace_id = current_user.current_workspace_id
|
||||||
# Args:
|
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||||
# user_input: Write request containing message and end_user_id
|
|
||||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
#
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
# Returns:
|
db=db,
|
||||||
# Task ID for tracking async operation
|
workspace_id=workspace_id,
|
||||||
# Use GET /memory/write_result/{task_id} to check task status and get result
|
user=current_user
|
||||||
# """
|
)
|
||||||
# # 使用集中化的语言校验
|
if storage_type is None: storage_type = 'neo4j'
|
||||||
# language = get_language_from_header(language_type)
|
user_rag_memory_id = ''
|
||||||
#
|
if workspace_id:
|
||||||
# config_id = user_input.config_id
|
|
||||||
# workspace_id = current_user.current_workspace_id
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
# api_logger.info(
|
db=db,
|
||||||
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
name="USER_RAG_MERORY",
|
||||||
#
|
workspace_id=workspace_id
|
||||||
# # 获取 storage_type,如果为 None 则使用默认值
|
)
|
||||||
# storage_type = workspace_service.get_workspace_storage_type(
|
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
# db=db,
|
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
# workspace_id=workspace_id,
|
try:
|
||||||
# user=current_user
|
# 获取标准化的消息列表
|
||||||
# )
|
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
# if storage_type is None: storage_type = 'neo4j'
|
|
||||||
# user_rag_memory_id = ''
|
task = celery_app.send_task(
|
||||||
# if workspace_id:
|
"app.core.memory.agent.write_message",
|
||||||
#
|
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
)
|
||||||
# db=db,
|
api_logger.info(f"Write task queued: {task.id}")
|
||||||
# name="USER_RAG_MERORY",
|
|
||||||
# workspace_id=workspace_id
|
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||||
# )
|
except Exception as e:
|
||||||
# if knowledge: user_rag_memory_id = str(knowledge.id)
|
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||||
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
# try:
|
|
||||||
# # 获取标准化的消息列表
|
|
||||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
|
||||||
#
|
|
||||||
# task = celery_app.send_task(
|
|
||||||
# "app.core.memory.agent.write_message",
|
|
||||||
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
|
||||||
# )
|
|
||||||
# api_logger.info(f"Write task queued: {task.id}")
|
|
||||||
#
|
|
||||||
# return success(data={"task_id": task.id}, msg="写入任务已提交")
|
|
||||||
# except Exception as e:
|
|
||||||
# api_logger.error(f"Async write operation failed: {str(e)}")
|
|
||||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/read_service", response_model=ApiResponse)
|
@router.post("/read_service", response_model=ApiResponse)
|
||||||
@@ -300,93 +279,34 @@ async def read_server(
|
|||||||
if knowledge:
|
if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||||
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
|
||||||
try:
|
try:
|
||||||
# result = await memory_agent_service.read_memory(
|
result = await memory_agent_service.read_memory(
|
||||||
# user_input.end_user_id,
|
user_input.end_user_id,
|
||||||
# user_input.message,
|
|
||||||
# user_input.history,
|
|
||||||
# user_input.search_switch,
|
|
||||||
# config_id,
|
|
||||||
# db,
|
|
||||||
# storage_type,
|
|
||||||
# user_rag_memory_id
|
|
||||||
# )
|
|
||||||
# if str(user_input.search_switch) == "2":
|
|
||||||
# retrieve_info = result['answer']
|
|
||||||
# history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
|
||||||
# user_input.end_user_id)
|
|
||||||
# query = user_input.message
|
|
||||||
#
|
|
||||||
# # 调用 memory_agent_service 的方法生成最终答案
|
|
||||||
# result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
|
||||||
# end_user_id=user_input.end_user_id,
|
|
||||||
# retrieve_info=retrieve_info,
|
|
||||||
# history=history,
|
|
||||||
# query=query,
|
|
||||||
# config_id=config_id,
|
|
||||||
# db=db
|
|
||||||
# )
|
|
||||||
# if "信息不足,无法回答" in result['answer']:
|
|
||||||
# result['answer'] = retrieve_info
|
|
||||||
memory_config = get_config(user_input.end_user_id, db)
|
|
||||||
service = MemoryService(
|
|
||||||
db,
|
|
||||||
memory_config["memory_config_id"],
|
|
||||||
end_user_id=user_input.end_user_id
|
|
||||||
)
|
|
||||||
search_result = await service.read(
|
|
||||||
user_input.message,
|
user_input.message,
|
||||||
SearchStrategy(user_input.search_switch)
|
user_input.history,
|
||||||
|
user_input.search_switch,
|
||||||
|
config_id,
|
||||||
|
db,
|
||||||
|
storage_type,
|
||||||
|
user_rag_memory_id
|
||||||
)
|
)
|
||||||
intermediate_outputs = []
|
if str(user_input.search_switch) == "2":
|
||||||
sub_queries = set()
|
retrieve_info = result['answer']
|
||||||
for memory in search_result.memories:
|
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||||
sub_queries.add(str(memory.query))
|
query = user_input.message
|
||||||
if user_input.search_switch in [SearchStrategy.DEEP, SearchStrategy.NORMAL]:
|
|
||||||
intermediate_outputs.append({
|
|
||||||
"type": "problem_split",
|
|
||||||
"title": "问题拆分",
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"id": f"Q{idx+1}",
|
|
||||||
"question": question
|
|
||||||
}
|
|
||||||
for idx, question in enumerate(sub_queries)
|
|
||||||
]
|
|
||||||
})
|
|
||||||
perceptual_data = [
|
|
||||||
memory.data
|
|
||||||
for memory in search_result.memories
|
|
||||||
if memory.source == Neo4jNodeType.PERCEPTUAL
|
|
||||||
]
|
|
||||||
|
|
||||||
intermediate_outputs.append({
|
# 调用 memory_agent_service 的方法生成最终答案
|
||||||
"type": "perceptual_retrieve",
|
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||||
"title": "感知记忆检索",
|
|
||||||
"data": perceptual_data,
|
|
||||||
"total": len(perceptual_data),
|
|
||||||
})
|
|
||||||
intermediate_outputs.append({
|
|
||||||
"type": "search_result",
|
|
||||||
"title": f"合并检索结果 (共{len(sub_queries)}个查询,{len(search_result.memories)}条结果)",
|
|
||||||
"result": search_result.content,
|
|
||||||
"raw_result": search_result.memories,
|
|
||||||
"total": len(search_result.memories),
|
|
||||||
})
|
|
||||||
result = {
|
|
||||||
'answer': await memory_agent_service.generate_summary_from_retrieve(
|
|
||||||
end_user_id=user_input.end_user_id,
|
end_user_id=user_input.end_user_id,
|
||||||
retrieve_info=search_result.content,
|
retrieve_info=retrieve_info,
|
||||||
history=[],
|
history=history,
|
||||||
query=user_input.message,
|
query=query,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
db=db
|
db=db
|
||||||
),
|
)
|
||||||
"intermediate_outputs": intermediate_outputs
|
if "信息不足,无法回答" in result['answer']:
|
||||||
}
|
result['answer']=retrieve_info
|
||||||
|
|
||||||
return success(data=result, msg="回复对话消息成功")
|
return success(data=result, msg="回复对话消息成功")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
@@ -404,8 +324,7 @@ async def file_update(
|
|||||||
files: List[UploadFile] = File(..., description="要上传的文件"),
|
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||||
model_id:str = Form(..., description="模型ID"),
|
model_id:str = Form(..., description="模型ID"),
|
||||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
文件上传接口 - 支持图片识别
|
文件上传接口 - 支持图片识别
|
||||||
@@ -418,6 +337,9 @@ async def file_update(
|
|||||||
Returns:
|
Returns:
|
||||||
文件处理结果
|
文件处理结果
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
db_gen = get_db() # get_db 通常是一个生成器
|
||||||
|
db = next(db_gen)
|
||||||
api_logger.info(f"File upload requested, file count: {len(files)}")
|
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
apiConfig: ModelApiKey = config.api_keys[0]
|
apiConfig: ModelApiKey = config.api_keys[0]
|
||||||
@@ -696,19 +618,24 @@ async def status_type(
|
|||||||
async def get_knowledge_type_stats_api(
|
async def get_knowledge_type_stats_api(
|
||||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。
|
||||||
会对缺失类型补 0,返回字典形式。
|
会对缺失类型补 0,返回字典形式。
|
||||||
可选按状态过滤。
|
可选按状态过滤。
|
||||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
|
||||||
|
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||||
"""
|
"""
|
||||||
api_logger.info(
|
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||||
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
|
||||||
try:
|
try:
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
|
# 获取数据库会话
|
||||||
|
db_gen = get_db()
|
||||||
|
db = next(db_gen)
|
||||||
|
|
||||||
# 调用service层函数
|
# 调用service层函数
|
||||||
result = await memory_agent_service.get_knowledge_type_stats(
|
result = await memory_agent_service.get_knowledge_type_stats(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -723,56 +650,45 @@ async def get_knowledge_type_stats_api(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||||
async def get_interest_distribution_by_user_api(
|
async def get_hot_memory_tags_by_user_api(
|
||||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
limit: int = Query(20, description="返回标签数量限制"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session=Depends(get_db),
|
db: Session=Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取指定用户的兴趣分布标签
|
获取指定用户的热门记忆标签
|
||||||
|
|
||||||
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
|
|
||||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
|
||||||
|
|
||||||
返回格式:
|
返回格式:
|
||||||
[
|
[
|
||||||
{"name": "兴趣活动名", "frequency": 频次},
|
{"name": "标签名", "frequency": 频次},
|
||||||
...
|
...
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
|
workspace_id=current_user.current_workspace_id
|
||||||
|
workspace_repo = WorkspaceRepository(db)
|
||||||
|
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||||
|
|
||||||
|
if workspace_models:
|
||||||
|
model_id = workspace_models.get("llm", None)
|
||||||
|
else:
|
||||||
|
model_id = None
|
||||||
|
|
||||||
|
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||||
try:
|
try:
|
||||||
# 优先读取缓存
|
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||||
cached = await InterestMemoryCache.get_interest_distribution(
|
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
language=language,
|
language_type=language_type,
|
||||||
|
model_id=model_id,
|
||||||
|
limit=limit
|
||||||
)
|
)
|
||||||
if cached is not None:
|
return success(data=result, msg="获取热门记忆标签成功")
|
||||||
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
|
|
||||||
return success(data=cached, msg="获取兴趣分布标签成功")
|
|
||||||
|
|
||||||
# 缓存未命中,调用模型生成
|
|
||||||
result = await memory_agent_service.get_interest_distribution_by_user(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
language=language
|
|
||||||
)
|
|
||||||
|
|
||||||
# 写入缓存,24小时过期
|
|
||||||
await InterestMemoryCache.set_interest_distribution(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
language=language,
|
|
||||||
data=result,
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data=result, msg="获取兴趣分布标签成功")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Interest distribution by user failed: {str(e)}")
|
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||||
@@ -861,8 +777,11 @@ async def get_end_user_connected_config(
|
|||||||
Returns:
|
Returns:
|
||||||
包含 memory_config_id 和相关信息的响应
|
包含 memory_config_id 和相关信息的响应
|
||||||
"""
|
"""
|
||||||
|
from app.services.memory_agent_service import (
|
||||||
|
get_end_user_connected_config as get_config,
|
||||||
|
)
|
||||||
|
|
||||||
api_logger.info(f"Getting connected config for end_user_id: {end_user_id}")
|
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = get_config(end_user_id, db)
|
result = get_config(end_user_id, db)
|
||||||
|
|||||||
@@ -1,7 +1,4 @@
|
|||||||
import asyncio
|
|
||||||
import uuid
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
@@ -12,7 +9,6 @@ from app.schemas.response_schema import ApiResponse
|
|||||||
|
|
||||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||||
from app.services.app_statistics_service import AppStatisticsService
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
@@ -49,61 +45,61 @@ def get_workspace_total_end_users(
|
|||||||
|
|
||||||
@router.get("/end_users", response_model=ApiResponse)
|
@router.get("/end_users", response_model=ApiResponse)
|
||||||
async def get_workspace_end_users(
|
async def get_workspace_end_users(
|
||||||
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"),
|
|
||||||
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"),
|
|
||||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
|
||||||
pagesize: int = Query(10, ge=1, description="每页数量"),
|
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取工作空间的宿主列表(分页查询,支持模糊搜索)
|
获取工作空间的宿主列表(高性能优化版本 v2)
|
||||||
|
|
||||||
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
|
优化策略:
|
||||||
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
|
1. 批量查询 end_users(一次查询而非循环)
|
||||||
|
2. 并发查询所有用户的记忆数量(Neo4j)
|
||||||
|
3. RAG 模式使用批量查询(一次 SQL)
|
||||||
|
4. 只返回必要字段减少数据传输
|
||||||
|
5. 添加短期缓存减少重复查询
|
||||||
|
6. 并发执行配置查询和记忆数量查询
|
||||||
|
|
||||||
Args:
|
返回格式:
|
||||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
{
|
||||||
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
"end_user": {"id": "uuid", "other_name": "名称"},
|
||||||
page: 页码(从1开始,默认1)
|
"memory_num": {"total": 数量},
|
||||||
pagesize: 每页数量(默认10)
|
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
||||||
db: 数据库会话
|
}
|
||||||
current_user: 当前用户
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse: 包含宿主列表和分页信息
|
|
||||||
"""
|
"""
|
||||||
# 如果未提供 workspace_id,使用当前用户的工作空间
|
import asyncio
|
||||||
if workspace_id is None:
|
import json
|
||||||
|
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||||
|
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 尝试从缓存获取(30秒缓存)
|
||||||
|
cache_key = f"end_users:workspace:{workspace_id}"
|
||||||
|
try:
|
||||||
|
cached_data = await aio_redis_get(cache_key)
|
||||||
|
if cached_data:
|
||||||
|
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
||||||
|
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
||||||
|
|
||||||
# 获取当前空间类型
|
# 获取当前空间类型
|
||||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||||
|
|
||||||
# 获取分页的 end_users
|
# 获取 end_users(已优化为批量查询)
|
||||||
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
current_user=current_user,
|
current_user=current_user
|
||||||
page=page,
|
|
||||||
pagesize=pagesize,
|
|
||||||
keyword=keyword
|
|
||||||
)
|
)
|
||||||
|
|
||||||
end_users = end_users_result.get("items", [])
|
|
||||||
total = end_users_result.get("total", 0)
|
|
||||||
|
|
||||||
if not end_users:
|
if not end_users:
|
||||||
api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}")
|
api_logger.info("工作空间下没有宿主")
|
||||||
return success(data={
|
# 缓存空结果,避免重复查询
|
||||||
"items": [],
|
try:
|
||||||
"page": {
|
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
||||||
"page": page,
|
except Exception as e:
|
||||||
"pagesize": pagesize,
|
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||||
"total": total,
|
return success(data=[], msg="宿主列表获取成功")
|
||||||
"hasnext": (page * pagesize) < total
|
|
||||||
}
|
|
||||||
}, msg="宿主列表获取成功")
|
|
||||||
|
|
||||||
end_user_ids = [str(user.id) for user in end_users]
|
end_user_ids = [str(user.id) for user in end_users]
|
||||||
|
|
||||||
@@ -134,43 +130,36 @@ async def get_workspace_end_users(
|
|||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
|
|
||||||
elif current_workspace_type == "neo4j":
|
elif current_workspace_type == "neo4j":
|
||||||
# Neo4j 模式:批量查询(简化版本,只返回total)
|
# Neo4j 模式:并发查询(带并发限制)
|
||||||
|
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
||||||
|
MAX_CONCURRENT_QUERIES = 10
|
||||||
|
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
||||||
|
|
||||||
|
async def get_neo4j_memory_num(end_user_id: str):
|
||||||
|
async with semaphore:
|
||||||
try:
|
try:
|
||||||
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
return await memory_storage_service.search_all(end_user_id)
|
||||||
return {uid: {"total": count} for uid, count in batch_result.items()}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
|
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
return {"total": 0}
|
||||||
|
|
||||||
|
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
||||||
|
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
||||||
|
|
||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
|
|
||||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
|
||||||
try:
|
|
||||||
from app.celery_app import celery_app as _celery_app
|
|
||||||
_celery_app.send_task(
|
|
||||||
"app.tasks.init_implicit_emotions_for_users",
|
|
||||||
kwargs={"end_user_ids": end_user_ids},
|
|
||||||
)
|
|
||||||
_celery_app.send_task(
|
|
||||||
"app.tasks.init_interest_distribution_for_users",
|
|
||||||
kwargs={"end_user_ids": end_user_ids},
|
|
||||||
)
|
|
||||||
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
|
|
||||||
|
|
||||||
# 并发执行配置查询和记忆数量查询
|
# 并发执行配置查询和记忆数量查询
|
||||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||||
get_memory_configs(),
|
get_memory_configs(),
|
||||||
get_memory_nums()
|
get_memory_nums()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建结果列表
|
# 构建结果(优化:使用列表推导式)
|
||||||
items = []
|
result = []
|
||||||
for end_user in end_users:
|
for end_user in end_users:
|
||||||
user_id = str(end_user.id)
|
user_id = str(end_user.id)
|
||||||
config_info = memory_configs_map.get(user_id, {})
|
config_info = memory_configs_map.get(user_id, {})
|
||||||
items.append({
|
result.append({
|
||||||
'end_user': {
|
'end_user': {
|
||||||
'id': user_id,
|
'id': user_id,
|
||||||
'other_name': end_user.other_name
|
'other_name': end_user.other_name
|
||||||
@@ -182,26 +171,13 @@ async def get_workspace_end_users(
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
# 写入缓存(30秒过期)
|
||||||
try:
|
try:
|
||||||
from app.tasks import init_community_clustering_for_users
|
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
|
|
||||||
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||||
|
|
||||||
# 构建分页响应
|
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||||
result = {
|
|
||||||
"items": items,
|
|
||||||
"page": {
|
|
||||||
"page": page,
|
|
||||||
"pagesize": pagesize,
|
|
||||||
"total": total,
|
|
||||||
"hasnext": (page * pagesize) < total
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
|
||||||
return success(data=result, msg="宿主列表获取成功")
|
return success(data=result, msg="宿主列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
@@ -410,15 +386,14 @@ def get_current_user_rag_total_num(
|
|||||||
@router.get("/rag_content", response_model=ApiResponse)
|
@router.get("/rag_content", response_model=ApiResponse)
|
||||||
def get_rag_content(
|
def get_rag_content(
|
||||||
end_user_id: str = Query(..., description="宿主ID"),
|
end_user_id: str = Query(..., description="宿主ID"),
|
||||||
page: int = Query(1, gt=0, description="页码,从1开始"),
|
limit: int = Query(15, description="返回记录数"),
|
||||||
pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"),
|
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取当前宿主知识库中的chunk内容(分页)
|
获取当前宿主知识库中的chunk内容
|
||||||
"""
|
"""
|
||||||
data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user)
|
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
|
||||||
return success(data=data, msg="宿主RAGchunk数据获取成功")
|
return success(data=data, msg="宿主RAGchunk数据获取成功")
|
||||||
|
|
||||||
|
|
||||||
@@ -431,17 +406,25 @@ async def get_chunk_summary_tag(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
读取RAG摘要、标签和人物形象(纯读库,不触发生成)。
|
获取chunk总结、提取的标签和人物形象
|
||||||
|
|
||||||
返回格式:
|
返回格式:
|
||||||
{
|
{
|
||||||
"summary": "用户摘要",
|
"summary": "chunk内容的总结",
|
||||||
"tags": [{"tag": "标签1", "frequency": 5}, ...],
|
"tags": [
|
||||||
"personas": ["产品设计师", ...],
|
{"tag": "标签1", "frequency": 5},
|
||||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
{"tag": "标签2", "frequency": 3},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"personas": [
|
||||||
|
"产品设计师",
|
||||||
|
"旅行爱好者",
|
||||||
|
"摄影发烧友",
|
||||||
|
...
|
||||||
|
]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG摘要/标签/人物形象")
|
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象")
|
||||||
|
|
||||||
data = await memory_dashboard_service.get_chunk_summary_and_tags(
|
data = await memory_dashboard_service.get_chunk_summary_and_tags(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -451,7 +434,8 @@ async def get_chunk_summary_tag(
|
|||||||
current_user=current_user
|
current_user=current_user
|
||||||
)
|
)
|
||||||
|
|
||||||
return success(data=data, msg="获取成功")
|
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
|
||||||
|
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/chunk_insight", response_model=ApiResponse)
|
@router.get("/chunk_insight", response_model=ApiResponse)
|
||||||
@@ -462,18 +446,14 @@ async def get_chunk_insight(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
读取RAG洞察报告(纯读库,不触发生成)。
|
获取chunk的洞察内容
|
||||||
|
|
||||||
返回格式:
|
返回格式:
|
||||||
{
|
{
|
||||||
"insight": "总体概述",
|
"insight": "对chunk内容的深度洞察分析"
|
||||||
"behavior_pattern": "行为模式",
|
|
||||||
"key_findings": "关键发现",
|
|
||||||
"growth_trajectory": "成长轨迹",
|
|
||||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG洞察")
|
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察")
|
||||||
|
|
||||||
data = await memory_dashboard_service.get_chunk_insight(
|
data = await memory_dashboard_service.get_chunk_insight(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -482,44 +462,13 @@ async def get_chunk_insight(
|
|||||||
current_user=current_user
|
current_user=current_user
|
||||||
)
|
)
|
||||||
|
|
||||||
return success(data=data, msg="获取成功")
|
api_logger.info("成功获取chunk洞察")
|
||||||
|
return success(data=data, msg="chunk洞察获取成功")
|
||||||
|
|
||||||
class GenerateRagProfileRequest(BaseModel):
|
|
||||||
end_user_id: str = Field(..., description="宿主ID")
|
|
||||||
limit: int = Field(15, description="参与生成的chunk数量上限")
|
|
||||||
max_tags: int = Field(10, description="最大标签数量")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/generate_rag_profile", response_model=ApiResponse)
|
|
||||||
async def generate_rag_profile(
|
|
||||||
body: GenerateRagProfileRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
生产接口:为RAG存储模式的宿主全量重新生成完整画像并持久化到end_user表。
|
|
||||||
每次请求都会重新生成,覆盖已有数据。
|
|
||||||
"""
|
|
||||||
api_logger.info(f"用户 {current_user.username} 触发RAG画像生产: end_user_id={body.end_user_id}")
|
|
||||||
|
|
||||||
data = await memory_dashboard_service.generate_rag_profile(
|
|
||||||
end_user_id=body.end_user_id,
|
|
||||||
limit=body.limit,
|
|
||||||
max_tags=body.max_tags,
|
|
||||||
db=db,
|
|
||||||
current_user=current_user,
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"RAG画像生产完成: {data}")
|
|
||||||
return success(data=data, msg="RAG画像生产完成")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||||
async def dashboard_data(
|
async def dashboard_data(
|
||||||
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||||
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
|
|
||||||
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
|
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
@@ -554,15 +503,6 @@ async def dashboard_data(
|
|||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
|
||||||
|
|
||||||
# 如果没有提供时间范围,默认使用最近30天
|
|
||||||
if start_date is None or end_date is None:
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
end_dt = datetime.now()
|
|
||||||
start_dt = end_dt - timedelta(days=30)
|
|
||||||
end_date = int(end_dt.timestamp() * 1000)
|
|
||||||
start_date = int(start_dt.timestamp() * 1000)
|
|
||||||
api_logger.info(f"使用默认时间范围: {start_dt} 到 {end_dt}")
|
|
||||||
|
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
db=db,
|
db=db,
|
||||||
@@ -591,7 +531,7 @@ async def dashboard_data(
|
|||||||
"total_api_call": None
|
"total_api_call": None
|
||||||
}
|
}
|
||||||
|
|
||||||
# 1. 获取记忆总量(total_memory)—— neo4j 独有逻辑:查询 neo4j 存储节点
|
# 1. 获取记忆总量(total_memory)
|
||||||
try:
|
try:
|
||||||
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
|
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
|
||||||
db=db,
|
db=db,
|
||||||
@@ -600,32 +540,40 @@ async def dashboard_data(
|
|||||||
end_user_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}")
|
# total_app: 统计当前空间下的所有app数量
|
||||||
|
from app.repositories import app_repository
|
||||||
|
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||||
|
neo4j_data["total_app"] = len(apps_orm)
|
||||||
|
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||||
|
|
||||||
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
# 2. 获取知识库类型统计(total_knowledge)
|
||||||
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
|
||||||
neo4j_data.update(common_stats)
|
|
||||||
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
|
||||||
|
|
||||||
# 计算昨日对比
|
|
||||||
try:
|
try:
|
||||||
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
|
memory_agent_service = MemoryAgentService()
|
||||||
|
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
only_active=True,
|
||||||
|
current_workspace_id=workspace_id,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
|
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
|
||||||
|
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||||
|
|
||||||
|
# 3. 获取API调用增量(total_api_call,转换为整数)
|
||||||
|
try:
|
||||||
|
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
storage_type=storage_type,
|
current_user=current_user
|
||||||
today_data=neo4j_data
|
|
||||||
)
|
)
|
||||||
neo4j_data.update(changes)
|
neo4j_data["total_api_call"] = api_increment
|
||||||
|
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"计算neo4j昨日对比失败: {str(e)}")
|
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||||
neo4j_data.update({
|
|
||||||
"total_memory_change": None,
|
|
||||||
"total_app_change": None,
|
|
||||||
"total_knowledge_change": None,
|
|
||||||
"total_api_call_change": None,
|
|
||||||
})
|
|
||||||
|
|
||||||
result["neo4j_data"] = neo4j_data
|
result["neo4j_data"] = neo4j_data
|
||||||
api_logger.info("成功获取neo4j_data")
|
api_logger.info("成功获取neo4j_data")
|
||||||
@@ -639,36 +587,27 @@ async def dashboard_data(
|
|||||||
"total_api_call": None
|
"total_api_call": None
|
||||||
}
|
}
|
||||||
|
|
||||||
# 1. 获取记忆总量(total_memory)—— rag 独有逻辑:查询 document 表的 chunk_num
|
# 获取RAG相关数据
|
||||||
try:
|
try:
|
||||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
# total_memory: 使用 total_chunk(总chunk数)
|
||||||
|
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||||
rag_data["total_memory"] = total_chunk
|
rag_data["total_memory"] = total_chunk
|
||||||
api_logger.info(f"成功获取RAG记忆总量: {total_chunk}")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"获取RAG记忆总量失败: {str(e)}")
|
|
||||||
|
|
||||||
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
# total_app: 统计当前空间下的所有app数量
|
||||||
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
from app.repositories import app_repository
|
||||||
rag_data.update(common_stats)
|
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||||
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
rag_data["total_app"] = len(apps_orm)
|
||||||
|
|
||||||
# 计算昨日对比
|
# total_knowledge: 使用 total_kb(总知识库数)
|
||||||
try:
|
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||||
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
rag_data["total_knowledge"] = total_kb
|
||||||
db=db,
|
|
||||||
workspace_id=workspace_id,
|
# total_api_call: 固定值
|
||||||
storage_type=storage_type,
|
rag_data["total_api_call"] = 1024
|
||||||
today_data=rag_data
|
|
||||||
)
|
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
|
||||||
rag_data.update(changes)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"计算RAG昨日对比失败: {str(e)}")
|
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||||
rag_data.update({
|
|
||||||
"total_memory_change": None,
|
|
||||||
"total_app_change": None,
|
|
||||||
"total_knowledge_change": None,
|
|
||||||
"total_api_call_change": None,
|
|
||||||
})
|
|
||||||
|
|
||||||
result["rag_data"] = rag_data
|
result["rag_data"] = rag_data
|
||||||
api_logger.info("成功获取rag_data")
|
api_logger.info("成功获取rag_data")
|
||||||
|
|||||||
@@ -3,10 +3,9 @@
|
|||||||
包含情景记忆总览和详情查询接口
|
包含情景记忆总览和详情查询接口
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.language_utils import get_language_from_header
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import fail, success
|
from app.core.response_utils import fail, success
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
@@ -15,7 +14,6 @@ from app.schemas.response_schema import ApiResponse
|
|||||||
from app.schemas.memory_episodic_schema import (
|
from app.schemas.memory_episodic_schema import (
|
||||||
EpisodicMemoryOverviewRequest,
|
EpisodicMemoryOverviewRequest,
|
||||||
EpisodicMemoryDetailsRequest,
|
EpisodicMemoryDetailsRequest,
|
||||||
translate_episodic_type,
|
|
||||||
)
|
)
|
||||||
from app.services.memory_episodic_service import memory_episodic_service
|
from app.services.memory_episodic_service import memory_episodic_service
|
||||||
|
|
||||||
@@ -86,7 +84,6 @@ async def get_episodic_memory_overview_api(
|
|||||||
@router.post("/details", response_model=ApiResponse)
|
@router.post("/details", response_model=ApiResponse)
|
||||||
async def get_episodic_memory_details_api(
|
async def get_episodic_memory_details_api(
|
||||||
request: EpisodicMemoryDetailsRequest,
|
request: EpisodicMemoryDetailsRequest,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -114,11 +111,6 @@ async def get_episodic_memory_details_api(
|
|||||||
summary_id=request.summary_id
|
summary_id=request.summary_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 根据语言参数翻译 episodic_type
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
if "episodic_type" in result:
|
|
||||||
result["episodic_type"] = translate_episodic_type(result["episodic_type"], language)
|
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
|
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,9 +4,7 @@
|
|||||||
处理显性记忆相关的API接口,包括情景记忆和语义记忆的查询。
|
处理显性记忆相关的API接口,包括情景记忆和语义记忆的查询。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success, fail
|
from app.core.response_utils import success, fail
|
||||||
@@ -71,140 +69,6 @@ async def get_explicit_memory_overview_api(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/episodics", response_model=ApiResponse)
|
|
||||||
async def get_episodic_memory_list_api(
|
|
||||||
end_user_id: str = Query(..., description="end user ID"),
|
|
||||||
page: int = Query(1, gt=0, description="page number, starting from 1"),
|
|
||||||
pagesize: int = Query(10, gt=0, le=100, description="number of items per page, max 100"),
|
|
||||||
start_date: Optional[int] = Query(None, description="start timestamp (ms)"),
|
|
||||||
end_date: Optional[int] = Query(None, description="end timestamp (ms)"),
|
|
||||||
episodic_type: str = Query("all", description="episodic type :all/conversation/project_work/learning/decision/important_event"),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
获取情景记忆分页列表
|
|
||||||
|
|
||||||
返回指定用户的情景记忆列表,支持分页、时间范围筛选和情景类型筛选。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID(必填)
|
|
||||||
page: 页码(从1开始,默认1)
|
|
||||||
pagesize: 每页数量(默认10,最大100)
|
|
||||||
start_date: 开始时间戳(可选,毫秒),自动扩展到当天 00:00:00
|
|
||||||
end_date: 结束时间戳(可选,毫秒),自动扩展到当天 23:59:59
|
|
||||||
episodic_type: 情景类型筛选(可选,默认all)
|
|
||||||
current_user: 当前用户
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse: 包含情景记忆分页列表
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
- 基础分页查询:GET /episodics?end_user_id=xxx&page=1&pagesize=5
|
|
||||||
返回第1页,每页5条数据
|
|
||||||
- 按时间范围筛选:GET /episodics?end_user_id=xxx&page=1&pagesize=5&start_date=1738684800000&end_date=1738771199000
|
|
||||||
返回指定时间范围内的数据
|
|
||||||
- 按情景类型筛选:GET /episodics?end_user_id=xxx&page=1&pagesize=5&episodic_type=important_event
|
|
||||||
返回类型为"重要事件"的数据
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- start_date 和 end_date 必须同时提供或同时不提供
|
|
||||||
- start_date 不能大于 end_date
|
|
||||||
- episodic_type 可选值:all, conversation, project_work, learning, decision, important_event
|
|
||||||
- total 为该用户情景记忆总数(不受筛选条件影响)
|
|
||||||
- page.total 为筛选后的总条数
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
|
||||||
if workspace_id is None:
|
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆列表但未选择工作空间")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"情景记忆分页查询: end_user_id={end_user_id}, "
|
|
||||||
f"start_date={start_date}, end_date={end_date}, episodic_type={episodic_type}, "
|
|
||||||
f"page={page}, pagesize={pagesize}, username={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. 参数校验
|
|
||||||
if page < 1 or pagesize < 1:
|
|
||||||
api_logger.warning(f"分页参数错误: page={page}, pagesize={pagesize}")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "分页参数必须大于0")
|
|
||||||
|
|
||||||
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
|
|
||||||
if episodic_type not in valid_episodic_types:
|
|
||||||
api_logger.warning(f"无效的情景类型参数: {episodic_type}")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
|
|
||||||
|
|
||||||
# 时间戳参数校验
|
|
||||||
if (start_date is not None and end_date is None) or (end_date is not None and start_date is None):
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "start_date和end_date必须同时提供")
|
|
||||||
|
|
||||||
if start_date is not None and end_date is not None and start_date > end_date:
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "start_date不能大于end_date")
|
|
||||||
|
|
||||||
# 2. 执行查询
|
|
||||||
try:
|
|
||||||
result = await memory_explicit_service.get_episodic_memory_list(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
page=page,
|
|
||||||
pagesize=pagesize,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
episodic_type=episodic_type,
|
|
||||||
)
|
|
||||||
api_logger.info(
|
|
||||||
f"情景记忆分页查询成功: end_user_id={end_user_id}, "
|
|
||||||
f"total={result['total']}, 返回={len(result['items'])}条"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"情景记忆分页查询失败: end_user_id={end_user_id}, error={str(e)}")
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "情景记忆分页查询失败", str(e))
|
|
||||||
|
|
||||||
# 3. 返回结构化响应
|
|
||||||
return success(data=result, msg="查询成功")
|
|
||||||
|
|
||||||
@router.get("/semantics", response_model=ApiResponse)
|
|
||||||
async def get_semantic_memory_list_api(
|
|
||||||
end_user_id: str = Query(..., description="终端用户ID"),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
获取语义记忆列表
|
|
||||||
|
|
||||||
返回指定用户的全量语义记忆列表。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID(必填)
|
|
||||||
current_user: 当前用户
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse: 包含语义记忆全量列表
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
if workspace_id is None:
|
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询语义记忆列表但未选择工作空间")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"语义记忆列表查询: end_user_id={end_user_id}, username={current_user.username}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await memory_explicit_service.get_semantic_memory_list(
|
|
||||||
end_user_id=end_user_id
|
|
||||||
)
|
|
||||||
api_logger.info(
|
|
||||||
f"语义记忆列表查询成功: end_user_id={end_user_id}, total={len(result)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"语义记忆列表查询失败: end_user_id={end_user_id}, error={str(e)}")
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "语义记忆列表查询失败", str(e))
|
|
||||||
|
|
||||||
return success(data=result, msg="查询成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/details", response_model=ApiResponse)
|
@router.post("/details", response_model=ApiResponse)
|
||||||
async def get_explicit_memory_details_api(
|
async def get_explicit_memory_details_api(
|
||||||
request: ExplicitMemoryDetailsRequest,
|
request: ExplicitMemoryDetailsRequest,
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from app.schemas.memory_storage_schema import (
|
|||||||
ForgettingCurveRequest,
|
ForgettingCurveRequest,
|
||||||
ForgettingCurveResponse,
|
ForgettingCurveResponse,
|
||||||
ForgettingCurvePoint,
|
ForgettingCurvePoint,
|
||||||
PendingNodesResponse,
|
|
||||||
)
|
)
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.memory_forget_service import MemoryForgetService
|
from app.services.memory_forget_service import MemoryForgetService
|
||||||
@@ -309,100 +308,6 @@ async def get_forgetting_stats(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/pending-nodes", response_model=ApiResponse)
|
|
||||||
async def get_pending_nodes(
|
|
||||||
end_user_id: str,
|
|
||||||
page: int = 1,
|
|
||||||
pagesize: int = 10,
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
db: Session = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
获取待遗忘节点列表(独立分页接口)
|
|
||||||
|
|
||||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
|
||||||
此接口独立分页,与 /stats 接口分离。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 组ID(即 end_user_id,必填)
|
|
||||||
page: 页码(从1开始,默认1)
|
|
||||||
pagesize: 每页数量(默认10)
|
|
||||||
current_user: 当前用户
|
|
||||||
db: 数据库会话
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse: 包含待遗忘节点列表和分页信息的响应
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
- 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
|
|
||||||
- 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- page 从1开始,pagesize 必须大于0
|
|
||||||
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
# 检查用户是否已选择工作空间
|
|
||||||
if workspace_id is None:
|
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
|
||||||
|
|
||||||
# 验证 end_user_id 必填
|
|
||||||
if not end_user_id:
|
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
|
|
||||||
|
|
||||||
# 通过 end_user_id 获取关联的 config_id
|
|
||||||
try:
|
|
||||||
from app.services.memory_agent_service import get_end_user_connected_config
|
|
||||||
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
|
||||||
config_id = connected_config.get("memory_config_id")
|
|
||||||
config_id = resolve_config_id(config_id, db)
|
|
||||||
|
|
||||||
if config_id is None:
|
|
||||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
|
||||||
|
|
||||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
|
||||||
except ValueError as e:
|
|
||||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
|
||||||
|
|
||||||
# 验证分页参数
|
|
||||||
if page < 1:
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
|
|
||||||
if pagesize < 1:
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
|
|
||||||
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 调用服务层获取待遗忘节点列表
|
|
||||||
result = await forget_service.get_pending_nodes(
|
|
||||||
db=db,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
config_id=config_id,
|
|
||||||
page=page,
|
|
||||||
pagesize=pagesize
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建响应
|
|
||||||
response_data = PendingNodesResponse(**result)
|
|
||||||
|
|
||||||
return success(data=response_data.model_dump(), msg="查询成功")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||||
async def get_forgetting_curve(
|
async def get_forgetting_curve(
|
||||||
request: ForgettingCurveRequest,
|
request: ForgettingCurveRequest,
|
||||||
|
|||||||
@@ -1,25 +1,8 @@
|
|||||||
"""
|
|
||||||
Memory Reflection Controller
|
|
||||||
|
|
||||||
This module provides REST API endpoints for managing memory reflection configurations
|
|
||||||
and operations. It handles reflection engine setup, configuration management, and
|
|
||||||
execution of self-reflection processes across memory systems.
|
|
||||||
|
|
||||||
Key Features:
|
|
||||||
- Reflection configuration management (save, retrieve, update)
|
|
||||||
- Workspace-wide reflection execution across multiple applications
|
|
||||||
- Individual configuration-based reflection runs
|
|
||||||
- Multi-language support for reflection outputs
|
|
||||||
- Integration with Neo4j memory storage and LLM models
|
|
||||||
- Comprehensive error handling and logging
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from app.core.language_utils import get_language_from_header
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||||
ReflectionConfig,
|
ReflectionConfig,
|
||||||
@@ -44,13 +27,9 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
# Load environment variables for configuration
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Initialize API logger for request tracking and debugging
|
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
# Configure router with prefix and tags for API organization
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/memory",
|
prefix="/memory",
|
||||||
tags=["Memory"],
|
tags=["Memory"],
|
||||||
@@ -63,38 +42,7 @@ async def save_reflection_config(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""Save reflection configuration to data_comfig table"""
|
||||||
Save reflection configuration to memory config table
|
|
||||||
|
|
||||||
Persists reflection engine configuration settings to the data_config table,
|
|
||||||
including reflection parameters, model settings, and evaluation criteria.
|
|
||||||
Validates configuration parameters and ensures data consistency.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Memory reflection configuration data including:
|
|
||||||
- config_id: Configuration identifier to update
|
|
||||||
- reflection_enabled: Whether reflection is enabled
|
|
||||||
- reflection_period_in_hours: Reflection execution interval
|
|
||||||
- reflexion_range: Scope of reflection (partial/all)
|
|
||||||
- baseline: Reflection strategy (time/fact/hybrid)
|
|
||||||
- reflection_model_id: LLM model for reflection operations
|
|
||||||
- memory_verify: Enable memory verification checks
|
|
||||||
- quality_assessment: Enable quality assessment evaluation
|
|
||||||
current_user: Authenticated user saving the configuration
|
|
||||||
db: Database session for data operations
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Success response with saved reflection configuration data
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException 400: If config_id is missing or parameters are invalid
|
|
||||||
HTTPException 500: If configuration save operation fails
|
|
||||||
|
|
||||||
Database Operations:
|
|
||||||
- Updates memory_config table with reflection settings
|
|
||||||
- Commits transaction and refreshes entity
|
|
||||||
- Maintains configuration consistency
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
config_id = request.config_id
|
config_id = request.config_id
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
@@ -105,7 +53,6 @@ async def save_reflection_config(
|
|||||||
)
|
)
|
||||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||||
|
|
||||||
# Update reflection configuration in database
|
|
||||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||||
db,
|
db,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
@@ -118,7 +65,6 @@ async def save_reflection_config(
|
|||||||
quality_assessment=request.quality_assessment
|
quality_assessment=request.quality_assessment
|
||||||
)
|
)
|
||||||
|
|
||||||
# Commit transaction and refresh entity
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(memory_config)
|
db.refresh(memory_config)
|
||||||
|
|
||||||
@@ -155,65 +101,18 @@ async def start_workspace_reflection(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""启动工作空间中所有匹配应用的反思功能"""
|
||||||
Start reflection functionality for all matching applications in workspace
|
|
||||||
|
|
||||||
Initiates reflection processes across all applications within the user's current
|
|
||||||
workspace that have valid memory configurations. Processes each application's
|
|
||||||
configurations and associated end users, executing reflection operations
|
|
||||||
with proper error isolation and transaction management.
|
|
||||||
|
|
||||||
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
|
|
||||||
that reflection failures for individual users don't affect other operations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_user: Authenticated user initiating workspace reflection
|
|
||||||
db: Database session for configuration queries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Success response with reflection results for all processed applications:
|
|
||||||
- app_id: Application identifier
|
|
||||||
- config_id: Memory configuration identifier
|
|
||||||
- end_user_id: End user identifier
|
|
||||||
- reflection_result: Individual reflection operation result
|
|
||||||
|
|
||||||
Processing Logic:
|
|
||||||
1. Retrieve all applications in the current workspace
|
|
||||||
2. Filter applications with valid memory configurations
|
|
||||||
3. For each configuration, find matching releases
|
|
||||||
4. Execute reflection for each end user with isolated transactions
|
|
||||||
5. Aggregate results with error handling per user
|
|
||||||
|
|
||||||
Error Handling:
|
|
||||||
- Individual user reflection failures are isolated
|
|
||||||
- Failed operations are logged and included in results
|
|
||||||
- Database transactions are isolated per user to prevent cascading failures
|
|
||||||
- Comprehensive error reporting for debugging
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException 500: If workspace reflection initialization fails
|
|
||||||
|
|
||||||
Performance Notes:
|
|
||||||
- Uses independent database sessions for each user operation
|
|
||||||
- Prevents transaction failures from affecting other users
|
|
||||||
- Comprehensive logging for operation tracking
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
reflection_service = MemoryReflectionService(db)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||||
|
|
||||||
# Use independent database session to get workspace app details, avoiding transaction failures
|
service = WorkspaceAppService(db)
|
||||||
from app.db import get_db_context
|
|
||||||
with get_db_context() as query_db:
|
|
||||||
service = WorkspaceAppService(query_db)
|
|
||||||
result = service.get_workspace_apps_detailed(workspace_id)
|
result = service.get_workspace_apps_detailed(workspace_id)
|
||||||
|
|
||||||
reflection_results = []
|
reflection_results = []
|
||||||
|
|
||||||
# Process each application in the workspace
|
|
||||||
for data in result['apps_detailed_info']:
|
for data in result['apps_detailed_info']:
|
||||||
# Skip applications without configurations
|
# 跳过没有配置的应用
|
||||||
if not data['memory_configs']:
|
if not data['memory_configs']:
|
||||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||||
continue
|
continue
|
||||||
@@ -222,25 +121,22 @@ async def start_workspace_reflection(
|
|||||||
memory_configs = data['memory_configs']
|
memory_configs = data['memory_configs']
|
||||||
end_users = data['end_users']
|
end_users = data['end_users']
|
||||||
|
|
||||||
# Execute reflection for each configuration and user combination
|
# 为每个配置和用户组合执行反思
|
||||||
for config in memory_configs:
|
for config in memory_configs:
|
||||||
config_id_str = str(config['config_id'])
|
config_id_str = str(config['config_id'])
|
||||||
|
|
||||||
# Find all releases matching this configuration
|
# 找到匹配此配置的所有release
|
||||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||||
|
|
||||||
if not matching_releases:
|
if not matching_releases:
|
||||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Execute reflection for each user - using independent database sessions
|
# 为每个用户执行反思
|
||||||
for user in end_users:
|
for user in end_users:
|
||||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||||
|
|
||||||
# Create independent database session for each user to avoid transaction failure impact
|
|
||||||
with get_db_context() as user_db:
|
|
||||||
try:
|
try:
|
||||||
reflection_service = MemoryReflectionService(user_db)
|
|
||||||
reflection_result = await reflection_service.start_text_reflection(
|
reflection_result = await reflection_service.start_text_reflection(
|
||||||
config_data=config,
|
config_data=config,
|
||||||
end_user_id=user['id']
|
end_user_id=user['id']
|
||||||
@@ -280,51 +176,14 @@ async def start_reflection_configs(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||||
Query reflection configuration information by config_id
|
|
||||||
|
|
||||||
Retrieves detailed reflection configuration settings from the memory_config
|
|
||||||
table for a specific configuration ID. Provides comprehensive reflection
|
|
||||||
parameters including model settings, evaluation criteria, and operational flags.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_id: Configuration identifier (UUID or integer) to query
|
|
||||||
current_user: Authenticated user making the request
|
|
||||||
db: Database session for data operations
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Success response with detailed reflection configuration:
|
|
||||||
- config_id: Resolved configuration identifier
|
|
||||||
- reflection_enabled: Whether reflection is enabled for this config
|
|
||||||
- reflection_period_in_hours: Reflection execution interval
|
|
||||||
- reflexion_range: Scope of reflection operations (partial/all)
|
|
||||||
- baseline: Reflection strategy (time/fact/hybrid)
|
|
||||||
- reflection_model_id: LLM model identifier for reflection
|
|
||||||
- memory_verify: Memory verification flag
|
|
||||||
- quality_assessment: Quality assessment flag
|
|
||||||
|
|
||||||
Database Operations:
|
|
||||||
- Queries memory_config table by resolved config_id
|
|
||||||
- Retrieves all reflection-related configuration fields
|
|
||||||
- Resolves configuration ID for consistent formatting
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException 404: If configuration with specified ID is not found
|
|
||||||
HTTPException 500: If configuration query operation fails
|
|
||||||
|
|
||||||
ID Resolution:
|
|
||||||
- Supports both UUID and integer config_id formats
|
|
||||||
- Automatically resolves to appropriate internal format
|
|
||||||
- Maintains consistency across different ID representations
|
|
||||||
"""
|
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
try:
|
try:
|
||||||
config_id=resolve_config_id(config_id,db)
|
config_id=resolve_config_id(config_id,db)
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
memory_config_id = resolve_config_id(result.config_id, db)
|
memory_config_id = resolve_config_id(result.config_id, db)
|
||||||
|
# 构建返回数据
|
||||||
# Build response data with comprehensive configuration details
|
|
||||||
reflection_config = {
|
reflection_config = {
|
||||||
"config_id": memory_config_id,
|
"config_id": memory_config_id,
|
||||||
"reflection_enabled": result.enable_self_reflexion,
|
"reflection_enabled": result.enable_self_reflexion,
|
||||||
@@ -338,11 +197,9 @@ async def start_reflection_configs(
|
|||||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||||
return success(data=reflection_config, msg="反思配置查询成功")
|
return success(data=reflection_config, msg="反思配置查询成功")
|
||||||
|
|
||||||
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
|
|
||||||
return success(data=reflection_config, msg="Reflection configuration query successful")
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
# Re-raise HTTP exceptions without modification
|
# 重新抛出HTTP异常
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"查询反思配置失败: {str(e)}")
|
api_logger.error(f"查询反思配置失败: {str(e)}")
|
||||||
@@ -354,70 +211,15 @@ async def start_reflection_configs(
|
|||||||
@router.get("/reflection/run")
|
@router.get("/reflection/run")
|
||||||
async def reflection_run(
|
async def reflection_run(
|
||||||
config_id: UUID|int,
|
config_id: UUID|int,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""Activate the reflection function for all matching applications in the workspace"""
|
||||||
Execute reflection engine with specified configuration
|
|
||||||
|
|
||||||
Runs the reflection engine using configuration parameters from the database.
|
|
||||||
Validates model availability, sets up the reflection engine with proper
|
|
||||||
configuration, and executes the reflection process with multi-language support.
|
|
||||||
|
|
||||||
This endpoint provides a test run capability for reflection configurations,
|
|
||||||
allowing users to validate their reflection settings and see results before
|
|
||||||
deploying to production environments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_id: Configuration identifier (UUID or integer) for reflection settings
|
|
||||||
language_type: Language preference header for output localization (optional)
|
|
||||||
current_user: Authenticated user executing the reflection
|
|
||||||
db: Database session for configuration queries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Success response with reflection execution results including:
|
|
||||||
- baseline: Reflection strategy used
|
|
||||||
- source_data: Input data processed
|
|
||||||
- memory_verifies: Memory verification results (if enabled)
|
|
||||||
- quality_assessments: Quality assessment results (if enabled)
|
|
||||||
- reflexion_data: Generated reflection insights and solutions
|
|
||||||
|
|
||||||
Configuration Validation:
|
|
||||||
- Verifies configuration exists in database
|
|
||||||
- Validates LLM model availability
|
|
||||||
- Falls back to default model if specified model is unavailable
|
|
||||||
- Ensures all required parameters are properly set
|
|
||||||
|
|
||||||
Reflection Engine Setup:
|
|
||||||
- Creates ReflectionConfig with database parameters
|
|
||||||
- Initializes Neo4j connector for memory access
|
|
||||||
- Sets up ReflectionEngine with validated model
|
|
||||||
- Configures language preferences for output
|
|
||||||
|
|
||||||
Error Handling:
|
|
||||||
- Model validation with fallback to default
|
|
||||||
- Configuration validation and error reporting
|
|
||||||
- Comprehensive logging for debugging
|
|
||||||
- Graceful handling of missing configurations
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException 404: If configuration is not found
|
|
||||||
HTTPException 500: If reflection execution fails
|
|
||||||
|
|
||||||
Performance Notes:
|
|
||||||
- Direct database query for configuration retrieval
|
|
||||||
- Model validation to prevent runtime failures
|
|
||||||
- Efficient reflection engine initialization
|
|
||||||
- Language-aware output processing
|
|
||||||
"""
|
|
||||||
# Use centralized language validation for consistent localization
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
|
# 使用MemoryConfigRepository查询反思配置
|
||||||
# Query reflection configuration using MemoryConfigRepository
|
|
||||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
if not result:
|
if not result:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -427,7 +229,7 @@ async def reflection_run(
|
|||||||
|
|
||||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||||
|
|
||||||
# Validate model ID existence
|
# 验证模型ID是否存在
|
||||||
model_id = result.reflection_model_id
|
model_id = result.reflection_model_id
|
||||||
if model_id:
|
if model_id:
|
||||||
try:
|
try:
|
||||||
@@ -438,7 +240,6 @@ async def reflection_run(
|
|||||||
# 可以设置为None,让反思引擎使用默认模型
|
# 可以设置为None,让反思引擎使用默认模型
|
||||||
model_id = None
|
model_id = None
|
||||||
|
|
||||||
# Create reflection configuration with database parameters
|
|
||||||
config = ReflectionConfig(
|
config = ReflectionConfig(
|
||||||
enabled=result.enable_self_reflexion,
|
enabled=result.enable_self_reflexion,
|
||||||
iteration_period=result.iteration_period,
|
iteration_period=result.iteration_period,
|
||||||
@@ -451,13 +252,11 @@ async def reflection_run(
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
language_type=language_type
|
language_type=language_type
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Neo4j connector and reflection engine
|
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
engine = ReflectionEngine(
|
engine = ReflectionEngine(
|
||||||
config=config,
|
config=config,
|
||||||
neo4j_connector=connector,
|
neo4j_connector=connector,
|
||||||
llm_client=model_id # Pass validated model_id
|
llm_client=model_id # 传入验证后的 model_id
|
||||||
)
|
)
|
||||||
|
|
||||||
result=await (engine.reflection_run())
|
result=await (engine.reflection_run())
|
||||||
|
|||||||
@@ -1,40 +1,18 @@
|
|||||||
"""
|
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||||
Memory Short Term Controller
|
|
||||||
|
|
||||||
This module provides REST API endpoints for managing short-term and long-term memory
|
|
||||||
data retrieval and analysis. It handles memory system statistics, data aggregation,
|
|
||||||
and provides comprehensive memory insights for end users.
|
|
||||||
|
|
||||||
Key Features:
|
|
||||||
- Short-term memory data retrieval and statistics
|
|
||||||
- Long-term memory data aggregation
|
|
||||||
- Entity count integration
|
|
||||||
- Multi-language response support
|
|
||||||
- Memory system analytics and reporting
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.language_utils import get_language_from_header
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.services.memory_short_service import LongService, ShortService
|
|
||||||
from app.services.memory_storage_service import search_entity
|
from app.services.memory_storage_service import search_entity
|
||||||
|
from app.services.memory_short_service import ShortService,LongService
|
||||||
# Load environment variables for configuration
|
from dotenv import load_dotenv
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Initialize API logger for request tracking and debugging
|
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
# Configure router with prefix and tags for API organization
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/memory/short",
|
prefix="/memory/short",
|
||||||
tags=["Memory"],
|
tags=["Memory"],
|
||||||
@@ -42,77 +20,25 @@ router = APIRouter(
|
|||||||
@router.get("/short_term")
|
@router.get("/short_term")
|
||||||
async def short_term_configs(
|
async def short_term_configs(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
language_type:str = Header(default=None, alias="X-Language-Type"),
|
language_type:str = Header(default="zh", alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
# 获取短期记忆数据
|
||||||
Retrieve comprehensive short-term and long-term memory statistics
|
short_term=ShortService(end_user_id)
|
||||||
|
short_result=short_term.get_short_databasets()
|
||||||
|
short_count=short_term.get_short_count()
|
||||||
|
|
||||||
Provides a comprehensive overview of memory system data for a specific end user,
|
long_term=LongService(end_user_id)
|
||||||
including short-term memory entries, long-term memory aggregations, entity counts,
|
long_result=long_term.get_long_databasets()
|
||||||
and retrieval statistics. Supports multi-language responses based on request headers.
|
|
||||||
|
|
||||||
This endpoint serves as a central dashboard for memory system analytics, combining
|
|
||||||
data from multiple memory subsystems to provide a holistic view of user memory state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: Unique identifier for the end user whose memory data to retrieve
|
|
||||||
language_type: Language preference header for response localization (optional)
|
|
||||||
current_user: Authenticated user making the request (injected by dependency)
|
|
||||||
db: Database session for data operations (injected by dependency)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Success response containing comprehensive memory statistics:
|
|
||||||
- short_term: List of short-term memory entries with detailed data
|
|
||||||
- long_term: List of long-term memory aggregations and summaries
|
|
||||||
- entity: Count of entities associated with the end user
|
|
||||||
- retrieval_number: Total count of short-term memory retrievals
|
|
||||||
- long_term_number: Total count of long-term memory entries
|
|
||||||
|
|
||||||
Response Structure:
|
|
||||||
{
|
|
||||||
"code": 200,
|
|
||||||
"msg": "Short-term memory system data retrieved successfully",
|
|
||||||
"data": {
|
|
||||||
"short_term": [...], # Short-term memory entries
|
|
||||||
"long_term": [...], # Long-term memory data
|
|
||||||
"entity": 42, # Entity count
|
|
||||||
"retrieval_number": 156, # Short-term retrieval count
|
|
||||||
"long_term_number": 23 # Long-term memory count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If end_user_id is invalid or data retrieval fails
|
|
||||||
|
|
||||||
Performance Notes:
|
|
||||||
- Combines multiple service calls for comprehensive data
|
|
||||||
- Entity search is performed asynchronously for better performance
|
|
||||||
- Response time depends on memory data volume for the specified user
|
|
||||||
"""
|
|
||||||
# Use centralized language validation for consistent localization
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
# Retrieve short-term memory data and statistics
|
|
||||||
short_term = ShortService(end_user_id, db)
|
|
||||||
short_result = short_term.get_short_databasets() # Get short-term memory entries
|
|
||||||
short_count = short_term.get_short_count() # Get short-term retrieval count
|
|
||||||
|
|
||||||
# Retrieve long-term memory data and aggregations
|
|
||||||
long_term = LongService(end_user_id, db)
|
|
||||||
long_result = long_term.get_long_databasets() # Get long-term memory entries
|
|
||||||
|
|
||||||
# Get entity count for the specified end user
|
|
||||||
entity_result = await search_entity(end_user_id)
|
entity_result = await search_entity(end_user_id)
|
||||||
|
|
||||||
# Compile comprehensive memory statistics response
|
|
||||||
result = {
|
result = {
|
||||||
'short_term': short_result, # Short-term memory entries
|
'short_term': short_result,
|
||||||
'long_term': long_result, # Long-term memory data
|
'long_term': long_result,
|
||||||
'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
|
'entity': entity_result.get('num', 0),
|
||||||
"retrieval_number": short_count, # Short-term retrieval statistics
|
"retrieval_number":short_count,
|
||||||
"long_term_number": len(long_result) # Long-term memory entry count
|
"long_term_number":len(long_result)
|
||||||
}
|
}
|
||||||
|
|
||||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||||
@@ -1,12 +1,8 @@
|
|||||||
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.language_utils import get_language_from_header
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import fail, success
|
from app.core.response_utils import fail, success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
@@ -15,6 +11,7 @@ from app.models.user_model import User
|
|||||||
from app.schemas.memory_storage_schema import (
|
from app.schemas.memory_storage_schema import (
|
||||||
ConfigKey,
|
ConfigKey,
|
||||||
ConfigParamsCreate,
|
ConfigParamsCreate,
|
||||||
|
ConfigParamsDelete,
|
||||||
ConfigPilotRun,
|
ConfigPilotRun,
|
||||||
ConfigUpdate,
|
ConfigUpdate,
|
||||||
ConfigUpdateExtracted,
|
ConfigUpdateExtracted,
|
||||||
@@ -26,7 +23,7 @@ from app.services.memory_storage_service import (
|
|||||||
analytics_hot_memory_tags,
|
analytics_hot_memory_tags,
|
||||||
analytics_recent_activity_stats,
|
analytics_recent_activity_stats,
|
||||||
kb_type_distribution,
|
kb_type_distribution,
|
||||||
search_all_batch,
|
search_all,
|
||||||
search_chunk,
|
search_chunk,
|
||||||
search_detials,
|
search_detials,
|
||||||
search_dialogue,
|
search_dialogue,
|
||||||
@@ -34,8 +31,7 @@ from app.services.memory_storage_service import (
|
|||||||
search_entity,
|
search_entity,
|
||||||
search_statement,
|
search_statement,
|
||||||
)
|
)
|
||||||
from app.core.quota_stub import check_memory_engine_quota
|
from fastapi import APIRouter, Depends
|
||||||
from fastapi import APIRouter, Depends, Header
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -76,13 +72,75 @@ async def get_storage_info(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# --- DB connection dependency ---
|
||||||
|
_CONN: Optional[object] = None
|
||||||
|
|
||||||
|
|
||||||
|
"""PostgreSQL 连接生成与管理(使用 psycopg2)。"""
|
||||||
|
# 这个可以转移,可能是已经有的
|
||||||
|
# PostgreSQL 数据库连接
|
||||||
|
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
|
||||||
|
host = os.getenv("DB_HOST")
|
||||||
|
user = os.getenv("DB_USER")
|
||||||
|
password = os.getenv("DB_PASSWORD")
|
||||||
|
database = os.getenv("DB_NAME")
|
||||||
|
port_str = os.getenv("DB_PORT")
|
||||||
|
try:
|
||||||
|
import psycopg2 # type: ignore
|
||||||
|
port = int(port_str) if port_str else 5432
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
host=host or "localhost",
|
||||||
|
port=port,
|
||||||
|
user=user,
|
||||||
|
password=password,
|
||||||
|
dbname=database,
|
||||||
|
)
|
||||||
|
# 设置自动提交,避免显式事务管理
|
||||||
|
conn.autocommit = True
|
||||||
|
# 设置会话时区为中国标准时间(Asia/Shanghai),便于直接以本地时区展示
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
|
||||||
|
cur.close()
|
||||||
|
except Exception:
|
||||||
|
# 时区设置失败不影响连接,仅记录但不抛出
|
||||||
|
pass
|
||||||
|
return conn
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
print(f"[PostgreSQL] 连接失败: {e}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
|
||||||
|
global _CONN
|
||||||
|
if _CONN is None:
|
||||||
|
_CONN = _make_pgsql_conn()
|
||||||
|
return _CONN
|
||||||
|
|
||||||
|
|
||||||
|
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
|
||||||
|
"""Close and recreate the global DB connection."""
|
||||||
|
global _CONN
|
||||||
|
try:
|
||||||
|
if _CONN:
|
||||||
|
try:
|
||||||
|
_CONN.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_CONN = _make_pgsql_conn()
|
||||||
|
return _CONN is not None
|
||||||
|
except Exception:
|
||||||
|
_CONN = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||||
@check_memory_engine_quota
|
|
||||||
def create_config(
|
def create_config(
|
||||||
payload: ConfigParamsCreate,
|
payload: ConfigParamsCreate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
@@ -97,33 +155,7 @@ def create_config(
|
|||||||
svc = DataConfigService(db)
|
svc = DataConfigService(db)
|
||||||
result = svc.create(payload)
|
result = svc.create(payload)
|
||||||
return success(data=result, msg="创建成功")
|
return success(data=result, msg="创建成功")
|
||||||
except ValueError as e:
|
|
||||||
err_str = str(e)
|
|
||||||
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
|
|
||||||
config_name = err_str.split(":", 1)[1]
|
|
||||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
|
||||||
lang = get_language_from_header(x_language_type)
|
|
||||||
if lang == "en":
|
|
||||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
|
||||||
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
|
||||||
else:
|
|
||||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
|
||||||
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
|
||||||
return JSONResponse(status_code=400, content=msg)
|
|
||||||
api_logger.error(f"Create config failed: {err_str}")
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
|
|
||||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
|
||||||
lang = get_language_from_header(x_language_type)
|
|
||||||
if lang == "en":
|
|
||||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
|
||||||
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
|
||||||
else:
|
|
||||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
|
||||||
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
|
||||||
return JSONResponse(status_code=400, content=msg)
|
|
||||||
api_logger.error(f"Create config failed: {str(e)}")
|
api_logger.error(f"Create config failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||||
|
|
||||||
@@ -131,20 +163,9 @@ def create_config(
|
|||||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||||
def delete_config(
|
def delete_config(
|
||||||
config_id: UUID|int,
|
config_id: UUID|int,
|
||||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""删除记忆配置(带终端用户保护)
|
|
||||||
|
|
||||||
- 检查是否为默认配置,默认配置不允许删除
|
|
||||||
- 检查是否有终端用户连接到该配置
|
|
||||||
- 如果有连接且 force=False,返回警告
|
|
||||||
- 如果 force=True,清除终端用户引用后删除配置
|
|
||||||
|
|
||||||
Query Parameters:
|
|
||||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
config_id=resolve_config_id(config_id, db)
|
config_id=resolve_config_id(config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
@@ -152,56 +173,15 @@ def delete_config(
|
|||||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
|
||||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
|
|
||||||
f"config_id={config_id}, force={force}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用带保护的删除服务
|
svc = DataConfigService(db)
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
result = svc.delete(ConfigParamsDelete(config_id=config_id))
|
||||||
|
return success(data=result, msg="删除成功")
|
||||||
config_service = MemoryConfigService(db)
|
|
||||||
result = config_service.delete_config(config_id=config_id, force=force)
|
|
||||||
|
|
||||||
if result["status"] == "error":
|
|
||||||
api_logger.warning(
|
|
||||||
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
|
|
||||||
)
|
|
||||||
return fail(
|
|
||||||
code=BizCode.FORBIDDEN,
|
|
||||||
msg=result["message"],
|
|
||||||
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
|
|
||||||
)
|
|
||||||
|
|
||||||
if result["status"] == "warning":
|
|
||||||
api_logger.warning(
|
|
||||||
f"记忆配置正在使用,无法删除: config_id={config_id}, "
|
|
||||||
f"connected_count={result['connected_count']}"
|
|
||||||
)
|
|
||||||
return fail(
|
|
||||||
code=BizCode.RESOURCE_IN_USE,
|
|
||||||
msg=result["message"],
|
|
||||||
data={
|
|
||||||
"connected_count": result["connected_count"],
|
|
||||||
"force_required": result["force_required"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"记忆配置删除成功: config_id={config_id}, "
|
|
||||||
f"affected_users={result['affected_users']}"
|
|
||||||
)
|
|
||||||
return success(
|
|
||||||
msg=result["message"],
|
|
||||||
data={"affected_users": result["affected_users"]}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
|
api_logger.error(f"Delete config failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||||
def update_config(
|
def update_config(
|
||||||
payload: ConfigUpdate,
|
payload: ConfigUpdate,
|
||||||
@@ -218,8 +198,7 @@ def update_config(
|
|||||||
# 校验至少有一个字段需要更新
|
# 校验至少有一个字段需要更新
|
||||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
|
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||||
"config_name, config_desc, scene_id 均为空")
|
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||||
try:
|
try:
|
||||||
@@ -280,7 +259,6 @@ def read_config_extracted(
|
|||||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||||
def read_all_config(
|
def read_all_config(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -307,22 +285,17 @@ def read_all_config(
|
|||||||
@router.post("/pilot_run", response_model=None)
|
@router.post("/pilot_run", response_model=None)
|
||||||
async def pilot_run(
|
async def pilot_run(
|
||||||
payload: ConfigPilotRun,
|
payload: ConfigPilotRun,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
# 使用集中化的语言校验
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"Pilot run requested: config_id={payload.config_id}, "
|
f"Pilot run requested: config_id={payload.config_id}, "
|
||||||
f"dialogue_text_length={len(payload.dialogue_text)}, "
|
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||||
f"custom_text_length={len(payload.custom_text) if payload.custom_text else 0}"
|
|
||||||
)
|
)
|
||||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
svc = DataConfigService(db)
|
svc = DataConfigService(db)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
svc.pilot_run_stream(payload, language=language),
|
svc.pilot_run_stream(payload),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
@@ -331,8 +304,9 @@ async def pilot_run(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
# ==================== Search & Analytics ====================
|
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
||||||
|
"""
|
||||||
|
|
||||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||||
async def get_kb_type_distribution(
|
async def get_kb_type_distribution(
|
||||||
@@ -411,10 +385,7 @@ async def search_all_num(
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
if not end_user_id:
|
result = await search_all(end_user_id)
|
||||||
return success(data={"total": 0}, msg="查询成功")
|
|
||||||
batch_result = await search_all_batch([end_user_id])
|
|
||||||
result = {"total": batch_result.get(end_user_id, 0)}
|
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Search all failed: {str(e)}")
|
api_logger.error(f"Search all failed: {str(e)}")
|
||||||
@@ -449,6 +420,8 @@ async def search_entity_edges(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||||
async def get_hot_memory_tags_api(
|
async def get_hot_memory_tags_api(
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
@@ -473,9 +446,8 @@ async def get_hot_memory_tags_api(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试从Redis缓存获取
|
# 尝试从Redis缓存获取
|
||||||
import json
|
|
||||||
|
|
||||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||||
|
import json
|
||||||
|
|
||||||
cached_result = await aio_redis_get(cache_key)
|
cached_result = await aio_redis_get(cache_key)
|
||||||
if cached_result:
|
if cached_result:
|
||||||
@@ -549,11 +521,11 @@ async def clear_hot_memory_tags_cache(
|
|||||||
async def get_recent_activity_stats_api(
|
async def get_recent_activity_stats_api(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
api_logger.info("Recent activity stats requested")
|
||||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
|
||||||
try:
|
try:
|
||||||
result = await analytics_recent_activity_stats(workspace_id=workspace_id)
|
result = await analytics_recent_activity_stats()
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from app.core.response_utils import success
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models import User
|
from app.models import User
|
||||||
from app.schemas import conversation_schema
|
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
|
|
||||||
@@ -33,47 +32,35 @@ def get_memory_count(
|
|||||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||||
def get_conversations(
|
def get_conversations(
|
||||||
end_user_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
page: int = 1,
|
|
||||||
pagesize: int = 20,
|
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Retrieve conversations for the current user in a specific group with pagination.
|
Retrieve all conversations for the current user in a specific group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id (UUID): The group identifier.
|
end_user_id (UUID): The group identifier.
|
||||||
page (int): Page number (1-based). Defaults to 1.
|
|
||||||
pagesize (int): Number of items per page. Defaults to 20.
|
|
||||||
current_user (User, optional): The authenticated user.
|
current_user (User, optional): The authenticated user.
|
||||||
db (Session, optional): SQLAlchemy session.
|
db (Session, optional): SQLAlchemy session.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ApiResponse: Contains a paginated list of conversations.
|
ApiResponse: Contains a list of conversation IDs.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Initializes the ConversationService with the current DB session.
|
||||||
|
- Returns only conversation IDs for lightweight response.
|
||||||
|
- Logs can be added to trace requests in production.
|
||||||
"""
|
"""
|
||||||
page = max(1, page)
|
|
||||||
page_size = max(1, min(pagesize, 100)) # Limit page size between 1 and 100
|
|
||||||
conversation_service = ConversationService(db)
|
conversation_service = ConversationService(db)
|
||||||
conversations, total = conversation_service.get_user_conversations(
|
conversations = conversation_service.get_user_conversations(
|
||||||
end_user_id,
|
end_user_id
|
||||||
page=page,
|
|
||||||
page_size=page_size
|
|
||||||
)
|
)
|
||||||
return success(data={
|
return success(data=[
|
||||||
"items": [
|
|
||||||
{
|
{
|
||||||
"id": conversation.id,
|
"id": conversation.id,
|
||||||
"title": conversation.title
|
"title": conversation.title
|
||||||
} for conversation in conversations
|
} for conversation in conversations
|
||||||
],
|
], msg="get conversations success")
|
||||||
"total": total,
|
|
||||||
"page": {
|
|
||||||
"page": page,
|
|
||||||
"pagesize": page_size,
|
|
||||||
"total": total,
|
|
||||||
"hasnext": (page * page_size) < total
|
|
||||||
},
|
|
||||||
}, msg="get conversations success")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||||
@@ -103,7 +90,11 @@ def get_messages(
|
|||||||
conversation_id,
|
conversation_id,
|
||||||
)
|
)
|
||||||
messages = [
|
messages = [
|
||||||
conversation_schema.Message.model_validate(message)
|
{
|
||||||
|
"role": message.role,
|
||||||
|
"content": message.content,
|
||||||
|
"created_at": int(message.created_at.timestamp() * 1000),
|
||||||
|
}
|
||||||
for message in messages_obj
|
for message in messages_obj
|
||||||
]
|
]
|
||||||
return success(data=messages, msg="get conversation history success")
|
return success(data=messages, msg="get conversation history success")
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from app.core.response_utils import success
|
|||||||
from app.schemas.response_schema import ApiResponse, PageData
|
from app.schemas.response_schema import ApiResponse, PageData
|
||||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.quota_stub import check_model_quota, check_model_activation_quota
|
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -43,7 +42,6 @@ def get_model_strategies():
|
|||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||||
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
|
|
||||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||||
@@ -76,21 +74,10 @@ def get_model_list(
|
|||||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||||
|
|
||||||
capability_list = []
|
|
||||||
if capability is not None:
|
|
||||||
flat_capability = []
|
|
||||||
for item in capability:
|
|
||||||
split_items = [c.strip() for c in item.split(', ') if c.strip()]
|
|
||||||
flat_capability.extend(split_items)
|
|
||||||
|
|
||||||
unique_flat_capability = list(dict.fromkeys(flat_capability))
|
|
||||||
capability_list = unique_flat_capability
|
|
||||||
|
|
||||||
api_logger.error(f"获取模型type_list: {type_list}")
|
api_logger.error(f"获取模型type_list: {type_list}")
|
||||||
query = model_schema.ModelConfigQuery(
|
query = model_schema.ModelConfigQuery(
|
||||||
type=type_list,
|
type=type_list,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
capability=capability_list,
|
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
is_public=is_public,
|
is_public=is_public,
|
||||||
search=search,
|
search=search,
|
||||||
@@ -304,7 +291,6 @@ async def create_model(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/composite", response_model=ApiResponse)
|
@router.post("/composite", response_model=ApiResponse)
|
||||||
@check_model_quota
|
|
||||||
async def create_composite_model(
|
async def create_composite_model(
|
||||||
model_data: model_schema.CompositeModelCreate,
|
model_data: model_schema.CompositeModelCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -331,7 +317,6 @@ async def create_composite_model(
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||||
@check_model_activation_quota
|
|
||||||
async def update_composite_model(
|
async def update_composite_model(
|
||||||
model_id: uuid.UUID,
|
model_id: uuid.UUID,
|
||||||
model_data: model_schema.CompositeModelCreate,
|
model_data: model_schema.CompositeModelCreate,
|
||||||
@@ -343,7 +328,7 @@ async def update_composite_model(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if model_data.type is not None:
|
if model_data.type is not None:
|
||||||
raise BusinessException("不允许更改模型类型", BizCode.INVALID_PARAMETER)
|
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||||
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||||
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
||||||
|
|
||||||
@@ -384,14 +369,6 @@ def update_model(
|
|||||||
"""
|
"""
|
||||||
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||||
|
|
||||||
if model_data.type is not None or model_data.provider is not None:
|
|
||||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
|
||||||
|
|
||||||
if model_data.is_active:
|
|
||||||
active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active)
|
|
||||||
if not active_keys:
|
|
||||||
raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||||
result_orm = ModelConfigService.update_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
result_orm = ModelConfigService.update_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||||
@@ -489,9 +466,7 @@ async def create_model_api_key_by_provider(
|
|||||||
config=api_key_data.config,
|
config=api_key_data.config,
|
||||||
is_active=api_key_data.is_active,
|
is_active=api_key_data.is_active,
|
||||||
priority=api_key_data.priority,
|
priority=api_key_data.priority,
|
||||||
model_config_ids=model_config_ids,
|
model_config_ids=model_config_ids
|
||||||
capability=api_key_data.capability,
|
|
||||||
is_omni=api_key_data.is_omni
|
|
||||||
)
|
)
|
||||||
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||||
|
|
||||||
|
|||||||
@@ -4,14 +4,13 @@
|
|||||||
|
|
||||||
Endpoints:
|
Endpoints:
|
||||||
POST /api/memory/ontology/extract - 提取本体类
|
POST /api/memory/ontology/extract - 提取本体类
|
||||||
POST /api/memory/ontology/export - 按场景导出OWL文件
|
POST /api/memory/ontology/export - 导出OWL文件
|
||||||
POST /api/memory/ontology/import - 导入OWL文件到指定场景
|
|
||||||
POST /api/memory/ontology/scene - 创建本体场景
|
POST /api/memory/ontology/scene - 创建本体场景
|
||||||
PUT /api/memory/ontology/scene/{scene_id} - 更新本体场景
|
PUT /api/memory/ontology/scene/{scene_id} - 更新本体场景
|
||||||
DELETE /api/memory/ontology/scene/{scene_id} - 删除本体场景
|
DELETE /api/memory/ontology/scene/{scene_id} - 删除本体场景
|
||||||
GET /api/memory/ontology/scene/{scene_id} - 获取单个场景
|
GET /api/memory/ontology/scene/{scene_id} - 获取单个场景
|
||||||
GET /api/memory/ontology/scenes - 获取场景列表
|
GET /api/memory/ontology/scenes - 获取场景列表
|
||||||
POST /api/memory/ontology/class - 创建本体类型(支持批量)
|
POST /api/memory/ontology/class - 创建本体类型
|
||||||
PUT /api/memory/ontology/class/{class_id} - 更新本体类型
|
PUT /api/memory/ontology/class/{class_id} - 更新本体类型
|
||||||
DELETE /api/memory/ontology/class/{class_id} - 删除本体类型
|
DELETE /api/memory/ontology/class/{class_id} - 删除本体类型
|
||||||
GET /api/memory/ontology/class/{class_id} - 获取单个类型
|
GET /api/memory/ontology/class/{class_id} - 获取单个类型
|
||||||
@@ -20,28 +19,23 @@ Endpoints:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
import io
|
from typing import Dict, Optional
|
||||||
from typing import Dict, Optional, List
|
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
|
from fastapi import APIRouter, Depends, HTTPException, Header
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.quota_stub import check_ontology_project_quota
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.logging_config import get_api_logger, get_business_logger
|
|
||||||
from app.core.response_utils import fail, success
|
from app.core.response_utils import fail, success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.core.memory.models.ontology_scenario_models import OntologyClass
|
from app.services.memory_base_service import Translation_English
|
||||||
|
from app.core.memory.models.ontology_models import OntologyClass
|
||||||
|
from typing import List
|
||||||
from app.schemas.ontology_schemas import (
|
from app.schemas.ontology_schemas import (
|
||||||
ExportBySceneRequest,
|
ExportRequest,
|
||||||
ExportBySceneResponse,
|
ExportResponse,
|
||||||
ExtractionRequest,
|
ExtractionRequest,
|
||||||
ExtractionResponse,
|
ExtractionResponse,
|
||||||
SceneCreateRequest,
|
SceneCreateRequest,
|
||||||
@@ -52,7 +46,6 @@ from app.schemas.ontology_schemas import (
|
|||||||
ClassUpdateRequest,
|
ClassUpdateRequest,
|
||||||
ClassResponse,
|
ClassResponse,
|
||||||
ClassListResponse,
|
ClassListResponse,
|
||||||
ImportOwlResponse,
|
|
||||||
)
|
)
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.ontology_service import OntologyService
|
from app.services.ontology_service import OntologyService
|
||||||
@@ -63,7 +56,6 @@ from app.repositories.ontology_scene_repository import OntologySceneRepository
|
|||||||
|
|
||||||
|
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
business_logger = get_business_logger()
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
@@ -72,6 +64,72 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def translate_ontology_classes(
|
||||||
|
classes: List[OntologyClass],
|
||||||
|
model_id: str
|
||||||
|
) -> List[OntologyClass]:
|
||||||
|
"""翻译本体类列表
|
||||||
|
|
||||||
|
将本体类的中文字段翻译为英文,包括:
|
||||||
|
- name_chinese: 中文名称
|
||||||
|
- description: 描述
|
||||||
|
- examples: 示例列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
classes: 本体类列表
|
||||||
|
model_id: LLM模型ID,用于翻译
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OntologyClass]: 翻译后的本体类列表
|
||||||
|
"""
|
||||||
|
translated_classes = []
|
||||||
|
|
||||||
|
for ontology_class in classes:
|
||||||
|
# 创建类的副本,避免修改原对象
|
||||||
|
translated_class = ontology_class.model_copy(deep=True)
|
||||||
|
|
||||||
|
# 翻译 name_chinese 字段
|
||||||
|
if translated_class.name_chinese:
|
||||||
|
try:
|
||||||
|
translated_class.name_chinese = await Translation_English(
|
||||||
|
model_id,
|
||||||
|
translated_class.name_chinese
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to translate name_chinese: {e}")
|
||||||
|
# 保留原文
|
||||||
|
|
||||||
|
# 翻译 description 字段
|
||||||
|
if translated_class.description:
|
||||||
|
try:
|
||||||
|
translated_class.description = await Translation_English(
|
||||||
|
model_id,
|
||||||
|
translated_class.description
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to translate description: {e}")
|
||||||
|
# 保留原文
|
||||||
|
|
||||||
|
# 翻译 examples 列表
|
||||||
|
if translated_class.examples:
|
||||||
|
translated_examples = []
|
||||||
|
for example in translated_class.examples:
|
||||||
|
try:
|
||||||
|
translated_example = await Translation_English(
|
||||||
|
model_id,
|
||||||
|
example
|
||||||
|
)
|
||||||
|
translated_examples.append(translated_example)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to translate example: {e}")
|
||||||
|
translated_examples.append(example) # 保留原文
|
||||||
|
translated_class.examples = translated_examples
|
||||||
|
|
||||||
|
translated_classes.append(translated_class)
|
||||||
|
|
||||||
|
return translated_classes
|
||||||
|
|
||||||
|
|
||||||
def _get_ontology_service(
|
def _get_ontology_service(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -126,23 +184,15 @@ def _get_ontology_service(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理)
|
# 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理)
|
||||||
# from app.repositories.model_repository import ModelApiKeyRepository
|
from app.repositories.model_repository import ModelApiKeyRepository
|
||||||
from app.services.model_service import ModelApiKeyService
|
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
|
||||||
api_key_config = ModelApiKeyService.get_available_api_key(db, model_config.id)
|
if not api_keys:
|
||||||
if not api_key_config:
|
|
||||||
logger.error(f"Model {llm_id} has no active API key")
|
logger.error(f"Model {llm_id} has no active API key")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail="指定的LLM模型没有可用的API密钥"
|
detail="指定的LLM模型没有可用的API密钥"
|
||||||
)
|
)
|
||||||
# api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
|
api_key_config = api_keys[0]
|
||||||
# if not api_keys:
|
|
||||||
# logger.error(f"Model {llm_id} has no active API key")
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=400,
|
|
||||||
# detail="指定的LLM模型没有可用的API密钥"
|
|
||||||
# )
|
|
||||||
# api_key_config = api_keys[0]
|
|
||||||
|
|
||||||
is_composite = getattr(model_config, 'is_composite', False)
|
is_composite = getattr(model_config, 'is_composite', False)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -164,8 +214,6 @@ def _get_ontology_service(
|
|||||||
provider=actual_provider,
|
provider=actual_provider,
|
||||||
api_key=api_key_config.api_key,
|
api_key=api_key_config.api_key,
|
||||||
base_url=api_key_config.api_base,
|
base_url=api_key_config.api_base,
|
||||||
is_omni=api_key_config.is_omni,
|
|
||||||
capability=api_key_config.capability,
|
|
||||||
max_retries=3,
|
max_retries=3,
|
||||||
timeout=60.0
|
timeout=60.0
|
||||||
)
|
)
|
||||||
@@ -196,7 +244,7 @@ def _get_ontology_service(
|
|||||||
@router.post("/extract", response_model=ApiResponse)
|
@router.post("/extract", response_model=ApiResponse)
|
||||||
async def extract_ontology(
|
async def extract_ontology(
|
||||||
request: ExtractionRequest,
|
request: ExtractionRequest,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
@@ -205,25 +253,50 @@ async def extract_ontology(
|
|||||||
从场景描述中提取符合OWL规范的本体类。
|
从场景描述中提取符合OWL规范的本体类。
|
||||||
提取结果仅返回给前端,不会自动保存到数据库。
|
提取结果仅返回给前端,不会自动保存到数据库。
|
||||||
前端可以从返回结果中选择需要的类型,然后调用 /class 接口创建类型。
|
前端可以从返回结果中选择需要的类型,然后调用 /class 接口创建类型。
|
||||||
|
支持中英文切换,通过 X-Language-Type Header 指定语言。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: 提取请求,包含scenario、domain、llm_id和scene_id
|
request: 提取请求,包含scenario、domain、llm_id和scene_id
|
||||||
language_type: 语言类型 Header (zh/en)
|
language_type: 语言类型,'zh'(中文)或 'en'(英文),默认 'zh'
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
current_user: 当前用户
|
current_user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含提取结果的响应
|
||||||
|
|
||||||
|
Response format:
|
||||||
|
{
|
||||||
|
"code": 200,
|
||||||
|
"msg": "本体提取成功",
|
||||||
|
"data": {
|
||||||
|
"classes": [
|
||||||
|
{
|
||||||
|
"id": "147d9db50b524a9e909e01a753d3acdd",
|
||||||
|
"name": "Patient",
|
||||||
|
"name_chinese": "患者",
|
||||||
|
"description": "在医疗机构中接受诊疗、护理或健康管理的个体",
|
||||||
|
"examples": ["糖尿病患者", "术后康复患者", "门诊初诊患者"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Person",
|
||||||
|
"domain": "Healthcare"
|
||||||
|
},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"domain": "Healthcare",
|
||||||
|
"extracted_count": 7
|
||||||
|
}
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"Ontology extraction requested by user {current_user.id}, "
|
f"Ontology extraction requested by user {current_user.id}, "
|
||||||
f"scenario_length={len(request.scenario)}, "
|
f"scenario_length={len(request.scenario)}, "
|
||||||
f"domain={request.domain}, "
|
f"domain={request.domain}, "
|
||||||
f"llm_id={request.llm_id}, "
|
f"llm_id={request.llm_id}, "
|
||||||
f"scene_id={request.scene_id}"
|
f"scene_id={request.scene_id}, "
|
||||||
|
f"language_type={language_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用集中化的语言校验
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
# 获取当前工作空间ID
|
# 获取当前工作空间ID
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
if not workspace_id:
|
if not workspace_id:
|
||||||
@@ -237,22 +310,36 @@ async def extract_ontology(
|
|||||||
llm_id=request.llm_id
|
llm_id=request.llm_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调用服务层执行提取
|
# 调用服务层执行提取,传入scene_id和workspace_id
|
||||||
result = await service.extract_ontology(
|
result = await service.extract_ontology(
|
||||||
scenario=request.scenario,
|
scenario=request.scenario,
|
||||||
domain=request.domain,
|
domain=request.domain,
|
||||||
scene_id=request.scene_id,
|
scene_id=request.scene_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id
|
||||||
language=language
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 根据语言类型统一 name 字段
|
# ===== 新增:翻译逻辑 =====
|
||||||
# zh: name 使用 name_chinese(中文名)
|
# 如果需要英文,则翻译数据
|
||||||
# en: name 保持原值(英文 PascalCase)
|
if language_type != 'zh':
|
||||||
if language == "zh":
|
api_logger.info(f"Translating extraction result to English")
|
||||||
for cls in result.classes:
|
|
||||||
if cls.name_chinese:
|
# 翻译 classes 列表
|
||||||
cls.name = cls.name_chinese
|
result.classes = await translate_ontology_classes(
|
||||||
|
result.classes,
|
||||||
|
request.llm_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 翻译 domain 字段
|
||||||
|
if result.domain:
|
||||||
|
try:
|
||||||
|
result.domain = await Translation_English(
|
||||||
|
request.llm_id,
|
||||||
|
result.domain
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to translate domain: {e}")
|
||||||
|
# 保留原文
|
||||||
|
# ===== 翻译逻辑结束 =====
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
response = ExtractionResponse(
|
response = ExtractionResponse(
|
||||||
@@ -263,7 +350,7 @@ async def extract_ontology(
|
|||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"Ontology extraction completed, extracted {len(result.classes)} classes, "
|
f"Ontology extraction completed, extracted {len(result.classes)} classes, "
|
||||||
f"scene_id={request.scene_id}, language={language}"
|
f"saved to scene {request.scene_id}, language={language_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return success(data=response.model_dump(), msg="本体提取成功")
|
return success(data=response.model_dump(), msg="本体提取成功")
|
||||||
@@ -284,17 +371,155 @@ async def extract_ontology(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "本体提取失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "本体提取失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/export", response_model=ApiResponse)
|
||||||
|
async def export_owl(
|
||||||
|
request: ExportRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""导出OWL文件
|
||||||
|
|
||||||
|
将提取的本体类导出为OWL文件,支持多种格式。
|
||||||
|
导出操作不需要LLM,只使用OWL验证器和Owlready2库。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 导出请求,包含classes、format和include_metadata
|
||||||
|
db: 数据库会话
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含OWL文件内容的响应
|
||||||
|
|
||||||
|
Supported formats:
|
||||||
|
- rdfxml: 标准OWL RDF/XML格式(完整)
|
||||||
|
- turtle: Turtle格式(可读性好)
|
||||||
|
- ntriples: N-Triples格式(简单)
|
||||||
|
- json: JSON格式(简化,只包含类信息)
|
||||||
|
|
||||||
|
Response format:
|
||||||
|
{
|
||||||
|
"code": 200,
|
||||||
|
"msg": "OWL文件导出成功",
|
||||||
|
"data": {
|
||||||
|
"owl_content": "...",
|
||||||
|
"format": "rdfxml",
|
||||||
|
"classes_count": 7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"OWL export requested by user {current_user.id}, "
|
||||||
|
f"classes_count={len(request.classes)}, "
|
||||||
|
f"format={request.format}, "
|
||||||
|
f"include_metadata={request.include_metadata}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证格式
|
||||||
|
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
|
||||||
|
if request.format not in valid_formats:
|
||||||
|
api_logger.warning(f"Invalid export format: {request.format}")
|
||||||
|
return fail(
|
||||||
|
BizCode.BAD_REQUEST,
|
||||||
|
"不支持的导出格式",
|
||||||
|
f"format必须是以下之一: {', '.join(valid_formats)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# JSON格式直接导出,不需要OWL验证
|
||||||
|
if request.format == "json":
|
||||||
|
owl_validator = OWLValidator()
|
||||||
|
owl_content = owl_validator.export_to_owl(
|
||||||
|
world=None,
|
||||||
|
format="json",
|
||||||
|
classes=request.classes
|
||||||
|
)
|
||||||
|
|
||||||
|
response = ExportResponse(
|
||||||
|
owl_content=owl_content,
|
||||||
|
format=request.format,
|
||||||
|
classes_count=len(request.classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"JSON export completed, content_length={len(owl_content)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=response.model_dump(), msg="OWL文件导出成功")
|
||||||
|
|
||||||
|
# 创建临时文件路径
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode='w',
|
||||||
|
suffix='.owl',
|
||||||
|
delete=False
|
||||||
|
) as tmp_file:
|
||||||
|
output_path = tmp_file.name
|
||||||
|
|
||||||
|
# 导出操作不需要LLM,直接使用OWL验证器
|
||||||
|
owl_validator = OWLValidator()
|
||||||
|
|
||||||
|
# 验证本体类
|
||||||
|
logger.debug("Validating ontology classes")
|
||||||
|
is_valid, errors, world = owl_validator.validate_ontology_classes(
|
||||||
|
classes=request.classes,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_valid:
|
||||||
|
logger.warning(
|
||||||
|
f"OWL validation found {len(errors)} issues during export: {errors}"
|
||||||
|
)
|
||||||
|
# 继续导出,但记录警告
|
||||||
|
|
||||||
|
if not world:
|
||||||
|
error_msg = "Failed to create OWL world for export"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "创建OWL世界失败", error_msg)
|
||||||
|
|
||||||
|
# 导出OWL文件
|
||||||
|
logger.info(f"Exporting to {request.format} format")
|
||||||
|
owl_content = owl_validator.export_to_owl(
|
||||||
|
world=world,
|
||||||
|
output_path=output_path,
|
||||||
|
format=request.format,
|
||||||
|
classes=request.classes
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
response = ExportResponse(
|
||||||
|
owl_content=owl_content,
|
||||||
|
format=request.format,
|
||||||
|
classes_count=len(request.classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"OWL export completed, format={request.format}, "
|
||||||
|
f"content_length={len(owl_content)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=response.model_dump(), msg="OWL文件导出成功")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
# 验证错误 (400)
|
||||||
|
api_logger.warning(f"Validation error in export: {str(e)}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
# 运行时错误 (500)
|
||||||
|
api_logger.error(f"Runtime error in export: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 未知错误 (500)
|
||||||
|
api_logger.error(f"Unexpected error in export: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
# ==================== 本体场景管理接口 ====================
|
# ==================== 本体场景管理接口 ====================
|
||||||
|
|
||||||
@router.post("/scene", response_model=ApiResponse)
|
@router.post("/scene", response_model=ApiResponse)
|
||||||
@check_ontology_project_quota
|
|
||||||
async def create_scene(
|
async def create_scene(
|
||||||
request: SceneCreateRequest,
|
request: SceneCreateRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
|
||||||
):
|
):
|
||||||
"""创建本体场景
|
"""创建本体场景
|
||||||
|
|
||||||
@@ -365,18 +590,8 @@ async def create_scene(
|
|||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
err_str = str(e)
|
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True)
|
||||||
if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str:
|
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e))
|
||||||
api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}")
|
|
||||||
from app.core.language_utils import get_language_from_header
|
|
||||||
lang = get_language_from_header(x_language_type)
|
|
||||||
if lang == "en":
|
|
||||||
msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.")
|
|
||||||
else:
|
|
||||||
msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称")
|
|
||||||
return JSONResponse(status_code=400, content=msg)
|
|
||||||
api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
|
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
|
||||||
@@ -424,20 +639,6 @@ async def update_scene(
|
|||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||||
|
|
||||||
# 检查是否为系统默认场景
|
|
||||||
scene_repo = OntologySceneRepository(db)
|
|
||||||
scene = scene_repo.get_by_id(scene_uuid)
|
|
||||||
if scene and scene.is_system_default:
|
|
||||||
business_logger.warning(
|
|
||||||
f"尝试修改系统默认场景: user_id={current_user.id}, "
|
|
||||||
f"scene_id={scene_id}, scene_name={scene.scene_name}"
|
|
||||||
)
|
|
||||||
return fail(
|
|
||||||
BizCode.BAD_REQUEST,
|
|
||||||
"系统默认场景不可修改",
|
|
||||||
"该场景为系统预设场景,不允许修改"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建OntologyService实例
|
# 创建OntologyService实例
|
||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
@@ -530,19 +731,6 @@ async def delete_scene(
|
|||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||||
|
|
||||||
# 检查是否为系统默认场景
|
|
||||||
scene_repo = OntologySceneRepository(db)
|
|
||||||
scene = scene_repo.get_by_id(scene_uuid)
|
|
||||||
if scene and scene.is_system_default:
|
|
||||||
business_logger.warning(
|
|
||||||
f"尝试删除系统默认场景: user_id={current_user.id}, "
|
|
||||||
f"scene_id={scene_id}, scene_name={scene.scene_name}"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="SYSTEM_DEFAULT_SCENE_CANNOT_DELETE"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建OntologyService实例
|
# 创建OntologyService实例
|
||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
@@ -566,9 +754,6 @@ async def delete_scene(
|
|||||||
|
|
||||||
return success(data={"deleted": success_flag}, msg="场景删除成功")
|
return success(data={"deleted": success_flag}, msg="场景删除成功")
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
api_logger.warning(f"Validation error in scene deletion: {str(e)}")
|
api_logger.warning(f"Validation error in scene deletion: {str(e)}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||||
@@ -676,8 +861,7 @@ async def get_scenes(
|
|||||||
async def create_class(
|
async def create_class(
|
||||||
request: ClassCreateRequest,
|
request: ClassCreateRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
|
||||||
):
|
):
|
||||||
"""创建本体类型
|
"""创建本体类型
|
||||||
|
|
||||||
@@ -692,7 +876,7 @@ async def create_class(
|
|||||||
ApiResponse: 包含创建的类型信息
|
ApiResponse: 包含创建的类型信息
|
||||||
"""
|
"""
|
||||||
from app.controllers.ontology_secondary_routes import create_class_handler
|
from app.controllers.ontology_secondary_routes import create_class_handler
|
||||||
return await create_class_handler(request, db, current_user, x_language_type)
|
return await create_class_handler(request, db, current_user)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/class/{class_id}", response_model=ApiResponse)
|
@router.put("/class/{class_id}", response_model=ApiResponse)
|
||||||
@@ -819,370 +1003,3 @@ async def get_class(
|
|||||||
"""
|
"""
|
||||||
from app.controllers.ontology_secondary_routes import get_class_handler
|
from app.controllers.ontology_secondary_routes import get_class_handler
|
||||||
return await get_class_handler(class_id, db, current_user)
|
return await get_class_handler(class_id, db, current_user)
|
||||||
|
|
||||||
|
|
||||||
# ==================== OWL 导入接口 ====================
|
|
||||||
|
|
||||||
@router.post("/import", response_model=ApiResponse)
|
|
||||||
async def import_owl_file(
|
|
||||||
scene_name: str = Form(..., description="场景名称"),
|
|
||||||
scene_description: Optional[str] = Form(None, description="场景描述(可选)"),
|
|
||||||
file: UploadFile = File(..., description="OWL/TTL文件"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""导入 OWL/TTL 文件并创建新场景
|
|
||||||
|
|
||||||
上传 OWL 或 TTL 文件,解析其中定义的本体类型(owl:Class),
|
|
||||||
解析成功后创建新场景,并将类型保存到该场景的 ontology_class 表中。
|
|
||||||
|
|
||||||
文件格式根据文件扩展名自动识别:
|
|
||||||
- .owl, .rdf, .xml -> rdfxml 格式
|
|
||||||
- .ttl -> turtle 格式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_name: 场景名称(表单字段)
|
|
||||||
scene_description: 场景描述(表单字段,可选)
|
|
||||||
file: 上传的文件(支持 .owl, .ttl, .rdf, .xml)
|
|
||||||
db: 数据库会话
|
|
||||||
current_user: 当前用户
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse: 包含导入结果
|
|
||||||
"""
|
|
||||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
|
||||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
|
||||||
|
|
||||||
# 根据文件扩展名确定格式
|
|
||||||
filename = file.filename.lower() if file.filename else ""
|
|
||||||
if filename.endswith('.ttl'):
|
|
||||||
owl_format = "turtle"
|
|
||||||
file_type = "ttl"
|
|
||||||
elif filename.endswith(('.owl', '.rdf', '.xml')):
|
|
||||||
owl_format = "rdfxml"
|
|
||||||
file_type = "owl"
|
|
||||||
else:
|
|
||||||
return fail(
|
|
||||||
BizCode.BAD_REQUEST,
|
|
||||||
"文件格式不支持",
|
|
||||||
f"不支持的文件格式: {filename},支持的格式: .owl, .ttl, .rdf, .xml"
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"OWL import requested by user {current_user.id}, "
|
|
||||||
f"scene_name={scene_name}, "
|
|
||||||
f"filename={file.filename}, "
|
|
||||||
f"format={owl_format}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取当前工作空间ID
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
if not workspace_id:
|
|
||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
|
||||||
|
|
||||||
# 1. 验证场景名称不为空
|
|
||||||
if not scene_name or not scene_name.strip():
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "场景名称不能为空")
|
|
||||||
|
|
||||||
scene_name = scene_name.strip()
|
|
||||||
|
|
||||||
# 2. 检查场景名称是否已存在
|
|
||||||
scene_repo = OntologySceneRepository(db)
|
|
||||||
existing_scene = scene_repo.get_by_name(scene_name, workspace_id)
|
|
||||||
if existing_scene:
|
|
||||||
api_logger.warning(f"Scene name already exists: {scene_name}")
|
|
||||||
return fail(
|
|
||||||
BizCode.BAD_REQUEST,
|
|
||||||
"场景名称已存在",
|
|
||||||
f"工作空间下已存在名为 '{scene_name}' 的场景"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 读取文件内容
|
|
||||||
try:
|
|
||||||
content = await file.read()
|
|
||||||
owl_content = content.decode('utf-8')
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
return fail(
|
|
||||||
BizCode.BAD_REQUEST,
|
|
||||||
f"{file_type}文件导入失败",
|
|
||||||
"文件编码错误,请确保文件使用 UTF-8 编码"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. 解析 OWL 内容(先解析,成功后再创建场景)
|
|
||||||
owl_validator = OWLValidator()
|
|
||||||
parsed_classes = owl_validator.parse_owl_content(
|
|
||||||
owl_content=owl_content,
|
|
||||||
format=owl_format
|
|
||||||
)
|
|
||||||
|
|
||||||
if not parsed_classes:
|
|
||||||
api_logger.warning("No classes found in OWL content")
|
|
||||||
return fail(
|
|
||||||
BizCode.BAD_REQUEST,
|
|
||||||
"未找到本体类型",
|
|
||||||
"文件中没有定义任何本体类型(owl:Class)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. 文件解析成功,创建场景
|
|
||||||
scene = scene_repo.create(
|
|
||||||
scene_data={
|
|
||||||
"scene_name": scene_name,
|
|
||||||
"scene_description": scene_description
|
|
||||||
},
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
scene_uuid = scene.scene_id
|
|
||||||
|
|
||||||
api_logger.info(f"Scene created for import: {scene_uuid}")
|
|
||||||
|
|
||||||
# 6. 批量创建类型(去重同一批次内的重复类型)
|
|
||||||
class_repo = OntologyClassRepository(db)
|
|
||||||
created_items = []
|
|
||||||
existing_names = set()
|
|
||||||
skipped_count = 0
|
|
||||||
|
|
||||||
for cls in parsed_classes:
|
|
||||||
class_name = cls["name"]
|
|
||||||
class_description = cls.get("description")
|
|
||||||
|
|
||||||
# 检查同一批次内是否重复
|
|
||||||
if class_name in existing_names:
|
|
||||||
skipped_count += 1
|
|
||||||
api_logger.debug(f"Skipping duplicate class in batch: {class_name}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 创建类型
|
|
||||||
ontology_class = class_repo.create(
|
|
||||||
class_data={
|
|
||||||
"class_name": class_name,
|
|
||||||
"class_description": class_description
|
|
||||||
},
|
|
||||||
scene_id=scene_uuid
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加到已存在集合,防止同一批次内重复
|
|
||||||
existing_names.add(class_name)
|
|
||||||
|
|
||||||
created_items.append(ClassResponse(
|
|
||||||
class_id=ontology_class.class_id,
|
|
||||||
class_name=ontology_class.class_name,
|
|
||||||
class_description=ontology_class.class_description,
|
|
||||||
scene_id=ontology_class.scene_id,
|
|
||||||
created_at=ontology_class.created_at,
|
|
||||||
updated_at=ontology_class.updated_at
|
|
||||||
))
|
|
||||||
|
|
||||||
# 7. 提交事务
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
# 8. 构建响应
|
|
||||||
response = ImportOwlResponse(
|
|
||||||
scene_id=scene_uuid,
|
|
||||||
scene_name=scene.scene_name,
|
|
||||||
imported_count=len(created_items),
|
|
||||||
skipped_count=skipped_count,
|
|
||||||
items=created_items
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"{file_type} import completed, "
|
|
||||||
f"scene_id={scene_uuid}, "
|
|
||||||
f"scene_name={scene_name}, "
|
|
||||||
f"format={owl_format}, "
|
|
||||||
f"imported={len(created_items)}, "
|
|
||||||
f"skipped={skipped_count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data=response.model_dump(), msg=f"{file_type}文件导入成功")
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
db.rollback()
|
|
||||||
api_logger.warning(f"Validation error in import: {str(e)}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, f"{file_type}文件导入失败", str(e))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
api_logger.error(f"Unexpected error in import: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, f"{file_type}文件导入失败", str(e))
|
|
||||||
|
|
||||||
# ==================== OWL 导出接口 ====================
|
|
||||||
@router.post("/export")
|
|
||||||
async def export_owl_by_scene(
|
|
||||||
request: ExportBySceneRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""按场景导出OWL/TTL文件
|
|
||||||
|
|
||||||
根据scene_id从数据库查询该场景下的所有本体类型,并导出为文件下载。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 导出请求,包含 scene_id 和 format
|
|
||||||
db: 数据库会话
|
|
||||||
current_user: 当前用户
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StreamingResponse: 文件流响应,浏览器会直接下载文件
|
|
||||||
"""
|
|
||||||
from uuid import UUID
|
|
||||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
|
||||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"OWL export by scene requested by user {current_user.id}, "
|
|
||||||
f"scene_id={request.scene_id}, "
|
|
||||||
f"format={request.format}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 验证格式参数
|
|
||||||
valid_formats = ["rdfxml", "turtle"]
|
|
||||||
owl_format = request.format.lower() if request.format else "rdfxml"
|
|
||||||
if owl_format not in valid_formats:
|
|
||||||
api_logger.warning(f"Invalid format: {request.format}")
|
|
||||||
return fail(
|
|
||||||
BizCode.BAD_REQUEST,
|
|
||||||
"格式参数无效",
|
|
||||||
f"不支持的格式: {request.format},支持的格式: rdfxml, turtle"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取当前工作空间ID
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
if not workspace_id:
|
|
||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
|
||||||
|
|
||||||
# 1. 查询场景信息
|
|
||||||
scene_repo = OntologySceneRepository(db)
|
|
||||||
scene = scene_repo.get_by_id(request.scene_id)
|
|
||||||
|
|
||||||
if not scene:
|
|
||||||
api_logger.warning(f"Scene not found: {request.scene_id}")
|
|
||||||
return fail(BizCode.NOT_FOUND, "场景不存在", f"找不到场景: {request.scene_id}")
|
|
||||||
|
|
||||||
# 验证场景属于当前工作空间
|
|
||||||
if scene.workspace_id != workspace_id:
|
|
||||||
api_logger.warning(
|
|
||||||
f"Scene {request.scene_id} does not belong to workspace {workspace_id}"
|
|
||||||
)
|
|
||||||
return fail(BizCode.FORBIDDEN, "无权访问", "该场景不属于当前工作空间")
|
|
||||||
|
|
||||||
# 2. 查询场景下的所有本体类型
|
|
||||||
class_repo = OntologyClassRepository(db)
|
|
||||||
ontology_classes_db = class_repo.get_classes_by_scene(request.scene_id)
|
|
||||||
|
|
||||||
if not ontology_classes_db:
|
|
||||||
api_logger.warning(f"No classes found in scene: {request.scene_id}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "场景为空", "该场景下没有定义任何本体类型")
|
|
||||||
|
|
||||||
# 3. 将数据库模型转换为OWL导出所需的OntologyClass格式
|
|
||||||
ontology_classes: List[OntologyClass] = []
|
|
||||||
for db_class in ontology_classes_db:
|
|
||||||
owl_class = OntologyClass(
|
|
||||||
id=str(db_class.class_id),
|
|
||||||
name=db_class.class_name,
|
|
||||||
name_chinese=db_class.class_name if _is_chinese(db_class.class_name) else None,
|
|
||||||
description=db_class.class_description or "",
|
|
||||||
examples=[],
|
|
||||||
parent_class=None,
|
|
||||||
entity_type="Concept",
|
|
||||||
domain=scene.scene_name
|
|
||||||
)
|
|
||||||
ontology_classes.append(owl_class)
|
|
||||||
|
|
||||||
# 4. 确定文件名、扩展名和 MIME 类型
|
|
||||||
file_ext = ".ttl" if owl_format == "turtle" else ".owl"
|
|
||||||
filename = _sanitize_filename(scene.scene_name) + file_ext
|
|
||||||
media_type = "text/turtle" if owl_format == "turtle" else "application/rdf+xml"
|
|
||||||
file_type = "ttl" if owl_format == "turtle" else "owl"
|
|
||||||
|
|
||||||
# 5. 导出OWL文件
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode='w',
|
|
||||||
suffix='.owl',
|
|
||||||
delete=False
|
|
||||||
) as tmp_file:
|
|
||||||
output_path = tmp_file.name
|
|
||||||
|
|
||||||
owl_validator = OWLValidator()
|
|
||||||
|
|
||||||
# 验证本体类
|
|
||||||
is_valid, errors, world = owl_validator.validate_ontology_classes(
|
|
||||||
classes=ontology_classes,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_valid:
|
|
||||||
logger.warning(
|
|
||||||
f"OWL validation found {len(errors)} issues during export: {errors}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not world:
|
|
||||||
error_msg = "Failed to create OWL world for export"
|
|
||||||
logger.error(error_msg)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "创建OWL世界失败", error_msg)
|
|
||||||
|
|
||||||
# 导出OWL文件(使用请求指定的格式)
|
|
||||||
owl_content = owl_validator.export_to_owl(
|
|
||||||
world=world,
|
|
||||||
output_path=output_path,
|
|
||||||
format=owl_format,
|
|
||||||
classes=ontology_classes
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"{file_type} export by scene completed, "
|
|
||||||
f"scene={scene.scene_name}, "
|
|
||||||
f"filename={filename}, "
|
|
||||||
f"format={owl_format}, "
|
|
||||||
f"classes_count={len(ontology_classes)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 6. 返回文件流响应
|
|
||||||
# filename 使用 ASCII 安全的默认名,filename* 使用 UTF-8 编码的实际名称
|
|
||||||
ascii_filename = f"ontology{file_ext}"
|
|
||||||
encoded_filename = quote(filename)
|
|
||||||
return StreamingResponse(
|
|
||||||
io.BytesIO(owl_content.encode('utf-8')),
|
|
||||||
media_type=media_type,
|
|
||||||
headers={
|
|
||||||
"Content-Disposition": f"attachment; filename=\"{ascii_filename}\"; filename*=UTF-8''{encoded_filename}"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
api_logger.warning(f"Validation error in export by scene: {str(e)}")
|
|
||||||
file_type = "ttl" if (request.format and request.format.lower() == "turtle") else "owl"
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
api_logger.error(f"Runtime error in export by scene: {str(e)}", exc_info=True)
|
|
||||||
file_type = "ttl" if (request.format and request.format.lower() == "turtle") else "owl"
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, f"{file_type}文件导出失败", str(e))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Unexpected error in export by scene: {str(e)}", exc_info=True)
|
|
||||||
file_type = "ttl" if (request.format and request.format.lower() == "turtle") else "owl"
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, f"{file_type}文件导出失败", str(e))
|
|
||||||
|
|
||||||
|
|
||||||
def _is_chinese(text: str) -> bool:
|
|
||||||
"""检查文本是否包含中文字符"""
|
|
||||||
for char in text:
|
|
||||||
if '\u4e00' <= char <= '\u9fff':
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_filename(name: str) -> str:
|
|
||||||
"""清理文件名,移除不合法字符"""
|
|
||||||
import re
|
|
||||||
# 移除或替换不合法的文件名字符
|
|
||||||
sanitized = re.sub(r'[<>:"/\\|?*]', '_', name)
|
|
||||||
# 移除前后空格
|
|
||||||
sanitized = sanitized.strip()
|
|
||||||
# 如果为空,使用默认名称
|
|
||||||
if not sanitized:
|
|
||||||
sanitized = "ontology_export"
|
|
||||||
return sanitized
|
|
||||||
|
|||||||
@@ -7,11 +7,11 @@
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Depends, Header
|
from fastapi import Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.logging_config import get_api_logger, get_business_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import fail, success
|
from app.core.response_utils import fail, success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
@@ -30,11 +30,9 @@ from app.schemas.response_schema import ApiResponse
|
|||||||
from app.services.ontology_service import OntologyService
|
from app.services.ontology_service import OntologyService
|
||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
|
||||||
|
|
||||||
|
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
business_logger = get_business_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
||||||
@@ -58,7 +56,7 @@ async def scenes_handler(
|
|||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None,
|
||||||
scene_name: Optional[str] = None,
|
scene_name: Optional[str] = None,
|
||||||
page: Optional[int] = None,
|
page: Optional[int] = None,
|
||||||
pagesize: Optional[int] = None,
|
page_size: Optional[int] = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
@@ -71,14 +69,14 @@ async def scenes_handler(
|
|||||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||||
pagesize: 每页数量(可选,仅在全量查询时有效)
|
page_size: 每页数量(可选,仅在全量查询时有效)
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
current_user: 当前用户
|
current_user: 当前用户
|
||||||
"""
|
"""
|
||||||
operation = "search" if scene_name else "list"
|
operation = "search" if scene_name else "list"
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"Scene {operation} requested by user {current_user.id}, "
|
f"Scene {operation} requested by user {current_user.id}, "
|
||||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
|
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -105,13 +103,13 @@ async def scenes_handler(
|
|||||||
api_logger.warning(f"Invalid page number: {page}")
|
api_logger.warning(f"Invalid page number: {page}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||||
|
|
||||||
if pagesize is not None and pagesize < 1:
|
if page_size is not None and page_size < 1:
|
||||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||||
|
|
||||||
# 如果只提供了page或pagesize中的一个,返回错误
|
# 如果只提供了page或page_size中的一个,返回错误
|
||||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||||
|
|
||||||
# 模糊搜索场景(支持分页)
|
# 模糊搜索场景(支持分页)
|
||||||
@@ -119,15 +117,17 @@ async def scenes_handler(
|
|||||||
total = len(scenes)
|
total = len(scenes)
|
||||||
|
|
||||||
# 如果提供了分页参数,进行分页处理
|
# 如果提供了分页参数,进行分页处理
|
||||||
if page is not None and pagesize is not None:
|
if page is not None and page_size is not None:
|
||||||
start_idx = (page - 1) * pagesize
|
start_idx = (page - 1) * page_size
|
||||||
end_idx = start_idx + pagesize
|
end_idx = start_idx + page_size
|
||||||
scenes = scenes[start_idx:end_idx]
|
scenes = scenes[start_idx:end_idx]
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
items = []
|
items = []
|
||||||
for scene in scenes:
|
for scene in scenes:
|
||||||
|
# 获取前3个class_name作为entity_type
|
||||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||||
|
# 动态计算 type_num
|
||||||
type_num = len(scene.classes) if scene.classes else 0
|
type_num = len(scene.classes) if scene.classes else 0
|
||||||
|
|
||||||
items.append(SceneResponse(
|
items.append(SceneResponse(
|
||||||
@@ -139,16 +139,17 @@ async def scenes_handler(
|
|||||||
workspace_id=scene.workspace_id,
|
workspace_id=scene.workspace_id,
|
||||||
created_at=scene.created_at,
|
created_at=scene.created_at,
|
||||||
updated_at=scene.updated_at,
|
updated_at=scene.updated_at,
|
||||||
classes_count=type_num,
|
classes_count=type_num
|
||||||
is_system_default=scene.is_system_default
|
|
||||||
))
|
))
|
||||||
|
|
||||||
# 构建响应(包含分页信息)
|
# 构建响应(包含分页信息)
|
||||||
if page is not None and pagesize is not None:
|
if page is not None and page_size is not None:
|
||||||
hasnext = (page * pagesize) < total
|
# 计算是否有下一页
|
||||||
|
hasnext = (page * page_size) < total
|
||||||
|
|
||||||
pagination_info = PaginationInfo(
|
pagination_info = PaginationInfo(
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=pagesize,
|
pagesize=page_size,
|
||||||
total=total,
|
total=total,
|
||||||
hasnext=hasnext
|
hasnext=hasnext
|
||||||
)
|
)
|
||||||
@@ -162,25 +163,28 @@ async def scenes_handler(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 获取所有场景(支持分页)
|
# 获取所有场景(支持分页)
|
||||||
|
# 验证分页参数
|
||||||
if page is not None and page < 1:
|
if page is not None and page < 1:
|
||||||
api_logger.warning(f"Invalid page number: {page}")
|
api_logger.warning(f"Invalid page number: {page}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||||
|
|
||||||
if pagesize is not None and pagesize < 1:
|
if page_size is not None and page_size < 1:
|
||||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||||
|
|
||||||
# 如果只提供了page或pagesize中的一个,返回错误
|
# 如果只提供了page或page_size中的一个,返回错误
|
||||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||||
|
|
||||||
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
|
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
items = []
|
items = []
|
||||||
for scene in scenes:
|
for scene in scenes:
|
||||||
|
# 获取前3个class_name作为entity_type
|
||||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||||
|
# 动态计算 type_num
|
||||||
type_num = len(scene.classes) if scene.classes else 0
|
type_num = len(scene.classes) if scene.classes else 0
|
||||||
|
|
||||||
items.append(SceneResponse(
|
items.append(SceneResponse(
|
||||||
@@ -192,16 +196,17 @@ async def scenes_handler(
|
|||||||
workspace_id=scene.workspace_id,
|
workspace_id=scene.workspace_id,
|
||||||
created_at=scene.created_at,
|
created_at=scene.created_at,
|
||||||
updated_at=scene.updated_at,
|
updated_at=scene.updated_at,
|
||||||
classes_count=type_num,
|
classes_count=type_num
|
||||||
is_system_default=scene.is_system_default
|
|
||||||
))
|
))
|
||||||
|
|
||||||
# 构建响应(包含分页信息)
|
# 构建响应(包含分页信息)
|
||||||
if page is not None and pagesize is not None:
|
if page is not None and page_size is not None:
|
||||||
hasnext = (page * pagesize) < total
|
# 计算是否有下一页
|
||||||
|
hasnext = (page * page_size) < total
|
||||||
|
|
||||||
pagination_info = PaginationInfo(
|
pagination_info = PaginationInfo(
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=pagesize,
|
pagesize=page_size,
|
||||||
total=total,
|
total=total,
|
||||||
hasnext=hasnext
|
hasnext=hasnext
|
||||||
)
|
)
|
||||||
@@ -231,8 +236,7 @@ async def scenes_handler(
|
|||||||
async def create_class_handler(
|
async def create_class_handler(
|
||||||
request: ClassCreateRequest,
|
request: ClassCreateRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
x_language_type: Optional[str] = None
|
|
||||||
):
|
):
|
||||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||||
|
|
||||||
@@ -265,11 +269,8 @@ async def create_class_handler(
|
|||||||
]
|
]
|
||||||
|
|
||||||
if count == 1:
|
if count == 1:
|
||||||
# 单个创建 - 先检查重名
|
# 单个创建
|
||||||
class_data = classes_data[0]
|
class_data = classes_data[0]
|
||||||
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
|
|
||||||
if existing:
|
|
||||||
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
|
|
||||||
ontology_class = service.create_class(
|
ontology_class = service.create_class(
|
||||||
scene_id=request.scene_id,
|
scene_id=request.scene_id,
|
||||||
class_name=class_data["class_name"],
|
class_name=class_data["class_name"],
|
||||||
@@ -327,36 +328,12 @@ async def create_class_handler(
|
|||||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
err_str = str(e)
|
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
||||||
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||||
class_name = err_str.split(":", 1)[1]
|
|
||||||
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
|
|
||||||
from app.core.language_utils import get_language_from_header
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
lang = get_language_from_header(x_language_type)
|
|
||||||
if lang == "en":
|
|
||||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
|
||||||
else:
|
|
||||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
|
||||||
return JSONResponse(status_code=400, content=msg)
|
|
||||||
api_logger.warning(f"Validation error in class creation: {err_str}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
err_str = str(e)
|
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
||||||
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
|
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||||
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
|
|
||||||
from app.core.language_utils import get_language_from_header
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
lang = get_language_from_header(x_language_type)
|
|
||||||
class_name = request.classes[0].class_name if request.classes else ""
|
|
||||||
if lang == "en":
|
|
||||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
|
||||||
else:
|
|
||||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
|
||||||
return JSONResponse(status_code=400, content=msg)
|
|
||||||
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||||
@@ -389,20 +366,6 @@ async def update_class_handler(
|
|||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||||
|
|
||||||
# 检查是否为系统默认类型
|
|
||||||
class_repo = OntologyClassRepository(db)
|
|
||||||
ontology_class = class_repo.get_by_id(class_uuid)
|
|
||||||
if ontology_class and ontology_class.is_system_default:
|
|
||||||
business_logger.warning(
|
|
||||||
f"尝试修改系统默认类型: user_id={current_user.id}, "
|
|
||||||
f"class_id={class_id}, class_name={ontology_class.class_name}"
|
|
||||||
)
|
|
||||||
return fail(
|
|
||||||
BizCode.BAD_REQUEST,
|
|
||||||
"系统默认类型不可修改",
|
|
||||||
"该类型为系统预设类型,不允许修改"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建Service
|
# 创建Service
|
||||||
service = _get_dummy_ontology_service(db)
|
service = _get_dummy_ontology_service(db)
|
||||||
|
|
||||||
@@ -466,20 +429,6 @@ async def delete_class_handler(
|
|||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||||
|
|
||||||
# 检查是否为系统默认类型
|
|
||||||
class_repo = OntologyClassRepository(db)
|
|
||||||
ontology_class = class_repo.get_by_id(class_uuid)
|
|
||||||
if ontology_class and ontology_class.is_system_default:
|
|
||||||
business_logger.warning(
|
|
||||||
f"尝试删除系统默认类型: user_id={current_user.id}, "
|
|
||||||
f"class_id={class_id}, class_name={ontology_class.class_name}"
|
|
||||||
)
|
|
||||||
return fail(
|
|
||||||
BizCode.BAD_REQUEST,
|
|
||||||
"系统默认类型不可删除",
|
|
||||||
"该类型为系统预设类型,不允许删除"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建Service
|
# 创建Service
|
||||||
service = _get_dummy_ontology_service(db)
|
service = _get_dummy_ontology_service(db)
|
||||||
|
|
||||||
@@ -636,7 +585,6 @@ async def classes_handler(
|
|||||||
scene_id=scene_uuid,
|
scene_id=scene_uuid,
|
||||||
scene_name=scene.scene_name,
|
scene_name=scene.scene_name,
|
||||||
scene_description=scene.scene_description,
|
scene_description=scene.scene_description,
|
||||||
is_system_default=scene.is_system_default,
|
|
||||||
items=items
|
items=items
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -120,15 +120,13 @@ async def get_prompt_opt(
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
current_prompt=data.current_prompt,
|
current_prompt=data.current_prompt,
|
||||||
user_require=data.message,
|
user_require=data.message
|
||||||
skill=data.skill
|
|
||||||
):
|
):
|
||||||
# chunk 是 prompt 的增量内容
|
# chunk 是 prompt 的增量内容
|
||||||
yield f"event:message\ndata: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield f"event:error\ndata: {json.dumps(
|
yield f"event:error\ndata: {json.dumps(
|
||||||
{"error": str(e)},
|
{"error": str(e)}
|
||||||
ensure_ascii=False
|
|
||||||
)}\n\n"
|
)}\n\n"
|
||||||
yield "event:end\ndata: {}\n\n"
|
yield "event:end\ndata: {}\n\n"
|
||||||
|
|
||||||
|
|||||||
@@ -2,34 +2,25 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, Request
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.core.exceptions import BusinessException
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.quota_manager import check_end_user_quota
|
from app.core.response_utils import success
|
||||||
from app.core.response_utils import success, fail
|
|
||||||
from app.db import get_db, get_db_read
|
from app.db import get_db, get_db_read
|
||||||
from app.dependencies import get_share_user_id, ShareTokenData
|
from app.dependencies import get_share_user_id, ShareTokenData
|
||||||
from app.models.app_model import AppType
|
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
|
||||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||||
from app.schemas import release_share_schema, conversation_schema
|
from app.schemas import release_share_schema, conversation_schema
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.services import workspace_service
|
from app.services import workspace_service
|
||||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
|
||||||
from app.services.app_service import AppService
|
|
||||||
from app.services.auth_service import create_access_token
|
from app.services.auth_service import create_access_token
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.release_share_service import ReleaseShareService
|
from app.services.release_share_service import ReleaseShareService
|
||||||
from app.services.shared_chat_service import SharedChatService
|
from app.services.shared_chat_service import SharedChatService
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||||
from app.models.file_metadata_model import FileMetadata
|
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \
|
||||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
|
||||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||||
|
|
||||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||||
@@ -215,27 +206,15 @@ def list_conversations(
|
|||||||
logger.debug(f"share_data:{share_data.user_id}")
|
logger.debug(f"share_data:{share_data.user_id}")
|
||||||
other_id = share_data.user_id
|
other_id = share_data.user_id
|
||||||
service = SharedChatService(db)
|
service = SharedChatService(db)
|
||||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
share, release = service._get_release_by_share_token(share_data.share_token, password)
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
app_service = AppService(db)
|
|
||||||
app = app_service._get_app_or_404(share.app_id)
|
|
||||||
workspace_id = app.workspace_id
|
|
||||||
|
|
||||||
# 仅在新建终端用户时检查配额
|
|
||||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
|
||||||
if existing_end_user is None:
|
|
||||||
from app.core.quota_manager import _check_quota
|
|
||||||
from app.models.workspace_model import Workspace
|
|
||||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
|
||||||
if ws:
|
|
||||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
|
||||||
|
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=share.app_id,
|
app_id=share.app_id,
|
||||||
workspace_id=workspace_id,
|
|
||||||
other_id=other_id
|
other_id=other_id
|
||||||
)
|
)
|
||||||
logger.debug(new_end_user.id)
|
logger.debug(new_end_user.id)
|
||||||
|
service = SharedChatService(db)
|
||||||
conversations, total = service.list_conversations(
|
conversations, total = service.list_conversations(
|
||||||
share_token=share_data.share_token,
|
share_token=share_data.share_token,
|
||||||
user_id=str(new_end_user.id),
|
user_id=str(new_end_user.id),
|
||||||
@@ -272,41 +251,8 @@ def get_conversation(
|
|||||||
conv_service = ConversationService(db)
|
conv_service = ConversationService(db)
|
||||||
messages = conv_service.get_messages(conversation_id)
|
messages = conv_service.get_messages(conversation_id)
|
||||||
|
|
||||||
file_ids = []
|
# 构建响应
|
||||||
message_file_id_map = {}
|
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
|
||||||
|
|
||||||
# 第一次遍历:解析 audio_url,收集所有有效的 file_id
|
|
||||||
for idx, m in enumerate(messages):
|
|
||||||
if m.role == "assistant" and m.meta_data:
|
|
||||||
audio_url = m.meta_data.get("audio_url")
|
|
||||||
if not audio_url:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
# audio_url 无法解析为 UUID,标记为 unknown
|
|
||||||
m.meta_data["audio_status"] = "unknown"
|
|
||||||
continue
|
|
||||||
|
|
||||||
file_ids.append(file_id)
|
|
||||||
message_file_id_map[idx] = file_id
|
|
||||||
|
|
||||||
# 批量查询所有相关的 FileMetadata
|
|
||||||
file_status_map = {}
|
|
||||||
if file_ids:
|
|
||||||
file_metas = (
|
|
||||||
db.query(FileMetadata)
|
|
||||||
.filter(FileMetadata.id.in_(set(file_ids)))
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
file_status_map = {fm.id: fm.status for fm in file_metas}
|
|
||||||
|
|
||||||
# 第二次遍历:将查询结果映射回消息
|
|
||||||
for idx, file_id in message_file_id_map.items():
|
|
||||||
m = messages[idx]
|
|
||||||
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
|
|
||||||
|
|
||||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
|
|
||||||
conv_dict["messages"] = [
|
conv_dict["messages"] = [
|
||||||
conversation_schema.Message.model_validate(m) for m in messages
|
conversation_schema.Message.model_validate(m) for m in messages
|
||||||
]
|
]
|
||||||
@@ -347,61 +293,40 @@ async def chat(
|
|||||||
|
|
||||||
# 提前验证和准备(在流式响应开始前完成)
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
||||||
|
from app.models.app_model import AppType
|
||||||
try:
|
try:
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.services.app_service import AppService
|
||||||
# 验证分享链接和密码
|
# 验证分享链接和密码
|
||||||
share, release = service.get_release_by_share_token(share_token, password)
|
share, release = service._get_release_by_share_token(share_token, password)
|
||||||
|
|
||||||
# # Create end_user_id by concatenating app_id with user_id
|
# # Create end_user_id by concatenating app_id with user_id
|
||||||
# end_user_id = f"{share.app_id}_{user_id}"
|
# end_user_id = f"{share.app_id}_{user_id}"
|
||||||
|
|
||||||
# Store end_user_id in database with original user_id
|
# Store end_user_id in database with original user_id
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
app_service = AppService(db)
|
|
||||||
app = app_service._get_app_or_404(share.app_id)
|
|
||||||
workspace_id = app.workspace_id
|
|
||||||
|
|
||||||
# 仅在新建终端用户时检查配额,已有用户复用不受限制
|
|
||||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
|
||||||
logger.info(f"终端用户配额检查: workspace_id={workspace_id}, other_id={other_id}, existing={existing_end_user is not None}")
|
|
||||||
if existing_end_user is None:
|
|
||||||
from app.core.quota_manager import _check_quota
|
|
||||||
from app.models.workspace_model import Workspace
|
|
||||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
|
||||||
if ws:
|
|
||||||
logger.info(f"新终端用户,执行配额检查: tenant_id={ws.tenant_id}")
|
|
||||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
|
||||||
|
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=share.app_id,
|
app_id=share.app_id,
|
||||||
workspace_id=workspace_id,
|
|
||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
original_user_id=user_id
|
original_user_id=user_id # Save original user_id to other_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only extract and set memory_config_id when the end user doesn't have one yet
|
|
||||||
if not new_end_user.memory_config_id:
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
memory_config_service = MemoryConfigService(db)
|
|
||||||
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
|
|
||||||
if memory_config_id:
|
|
||||||
new_end_user.memory_config_id = memory_config_id
|
|
||||||
db.commit()
|
|
||||||
db.refresh(new_end_user)
|
|
||||||
end_user_id = str(new_end_user.id)
|
end_user_id = str(new_end_user.id)
|
||||||
|
|
||||||
# appid = share.app_id
|
appid = share.app_id
|
||||||
"""获取存储类型和工作空间的ID"""
|
"""获取存储类型和工作空间的ID"""
|
||||||
|
|
||||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||||
# app = db.query(App).filter(
|
from app.models.app_model import App
|
||||||
# App.id == appid,
|
app = db.query(App).filter(
|
||||||
# App.is_active.is_(True)
|
App.id == appid,
|
||||||
# ).first()
|
App.is_active.is_(True)
|
||||||
# if not app:
|
).first()
|
||||||
# raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
if not app:
|
||||||
|
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||||
|
|
||||||
# workspace_id = app.workspace_id
|
workspace_id = app.workspace_id
|
||||||
|
|
||||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||||
@@ -434,12 +359,12 @@ async def chat(
|
|||||||
app_type = release.app.type if release.app else None
|
app_type = release.app.type if release.app else None
|
||||||
|
|
||||||
# 根据应用类型验证配置
|
# 根据应用类型验证配置
|
||||||
if app_type == AppType.AGENT:
|
if app_type == "agent":
|
||||||
# Agent 类型:验证模型配置
|
# Agent 类型:验证模型配置
|
||||||
model_config_id = release.default_model_config_id
|
model_config_id = release.default_model_config_id
|
||||||
if not model_config_id:
|
if not model_config_id:
|
||||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == "multi_agent":
|
||||||
# Multi-Agent 类型:验证多 Agent 配置
|
# Multi-Agent 类型:验证多 Agent 配置
|
||||||
config = release.config or {}
|
config = release.config or {}
|
||||||
if not config.get("sub_agents"):
|
if not config.get("sub_agents"):
|
||||||
@@ -477,10 +402,31 @@ async def chat(
|
|||||||
# 流式返回
|
# 流式返回
|
||||||
agent_config = agent_config_4_app_release(release)
|
agent_config = agent_config_4_app_release(release)
|
||||||
|
|
||||||
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
|
||||||
agent_config.model_parameters["deep_thinking"] = False
|
|
||||||
|
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
|
# async def event_generator():
|
||||||
|
# async for event in service.chat_stream(
|
||||||
|
# share_token=share_token,
|
||||||
|
# message=payload.message,
|
||||||
|
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
# user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
# variables=payload.variables,
|
||||||
|
# password=password,
|
||||||
|
# web_search=payload.web_search,
|
||||||
|
# memory=payload.memory,
|
||||||
|
# storage_type=storage_type,
|
||||||
|
# user_rag_memory_id=user_rag_memory_id
|
||||||
|
# ):
|
||||||
|
# yield event
|
||||||
|
|
||||||
|
# return StreamingResponse(
|
||||||
|
# event_generator(),
|
||||||
|
# media_type="text/event-stream",
|
||||||
|
# headers={
|
||||||
|
# "Cache-Control": "no-cache",
|
||||||
|
# "Connection": "keep-alive",
|
||||||
|
# "X-Accel-Buffering": "no"
|
||||||
|
# }
|
||||||
|
# )
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for event in app_chat_service.agnet_chat_stream(
|
async for event in app_chat_service.agnet_chat_stream(
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -492,8 +438,7 @@ async def chat(
|
|||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id
|
||||||
files=payload.files # 传递多模态文件
|
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -506,6 +451,20 @@ async def chat(
|
|||||||
"X-Accel-Buffering": "no"
|
"X-Accel-Buffering": "no"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
# 非流式返回
|
||||||
|
# result = await service.chat(
|
||||||
|
# share_token=share_token,
|
||||||
|
# message=payload.message,
|
||||||
|
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
# user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
# variables=payload.variables,
|
||||||
|
# password=password,
|
||||||
|
# web_search=payload.web_search,
|
||||||
|
# memory=payload.memory,
|
||||||
|
# storage_type=storage_type,
|
||||||
|
# user_rag_memory_id=user_rag_memory_id
|
||||||
|
# )
|
||||||
|
# return success(data=conversation_schema.ChatResponse(**result))
|
||||||
result = await app_chat_service.agnet_chat(
|
result = await app_chat_service.agnet_chat(
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
@@ -516,8 +475,7 @@ async def chat(
|
|||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id
|
||||||
files=payload.files # 传递多模态文件
|
|
||||||
)
|
)
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
@@ -564,6 +522,48 @@ async def chat(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
|
# 多 Agent 流式返回
|
||||||
|
# if payload.stream:
|
||||||
|
# async def event_generator():
|
||||||
|
# async for event in service.multi_agent_chat_stream(
|
||||||
|
# share_token=share_token,
|
||||||
|
# message=payload.message,
|
||||||
|
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
# user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
# variables=payload.variables,
|
||||||
|
# password=password,
|
||||||
|
# web_search=payload.web_search,
|
||||||
|
# memory=payload.memory,
|
||||||
|
# storage_type=storage_type,
|
||||||
|
# user_rag_memory_id=user_rag_memory_id
|
||||||
|
# ):
|
||||||
|
# yield event
|
||||||
|
|
||||||
|
# return StreamingResponse(
|
||||||
|
# event_generator(),
|
||||||
|
# media_type="text/event-stream",
|
||||||
|
# headers={
|
||||||
|
# "Cache-Control": "no-cache",
|
||||||
|
# "Connection": "keep-alive",
|
||||||
|
# "X-Accel-Buffering": "no"
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # 多 Agent 非流式返回
|
||||||
|
# result = await service.multi_agent_chat(
|
||||||
|
# share_token=share_token,
|
||||||
|
# message=payload.message,
|
||||||
|
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
# user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
# variables=payload.variables,
|
||||||
|
# password=password,
|
||||||
|
# web_search=payload.web_search,
|
||||||
|
# memory=payload.memory,
|
||||||
|
# storage_type=storage_type,
|
||||||
|
# user_rag_memory_id=user_rag_memory_id
|
||||||
|
# )
|
||||||
|
|
||||||
|
# return success(data=conversation_schema.ChatResponse(**result))
|
||||||
elif app_type == AppType.WORKFLOW:
|
elif app_type == AppType.WORKFLOW:
|
||||||
config = workflow_config_4_app_release(release)
|
config = workflow_config_4_app_release(release)
|
||||||
if not config.id:
|
if not config.id:
|
||||||
@@ -578,7 +578,6 @@ async def chat(
|
|||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=end_user_id, # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
files=payload.files,
|
|
||||||
config=config,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
@@ -586,8 +585,7 @@ async def chat(
|
|||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
app_id=release.app_id,
|
app_id=release.app_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
release_id=release.id,
|
release_id=release.id
|
||||||
public=True
|
|
||||||
):
|
):
|
||||||
event_type = event.get("event", "message")
|
event_type = event.get("event", "message")
|
||||||
event_data = event.get("data", {})
|
event_data = event.get("data", {})
|
||||||
@@ -608,11 +606,11 @@ async def chat(
|
|||||||
|
|
||||||
# 多 Agent 非流式返回
|
# 多 Agent 非流式返回
|
||||||
result = await app_chat_service.workflow_chat(
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=end_user_id, # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
files=payload.files,
|
|
||||||
config=config,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
@@ -636,40 +634,6 @@ async def chat(
|
|||||||
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config", summary="获取应用启动配置")
|
|
||||||
async def config_query(
|
|
||||||
password: str = Query(None, description="访问密码"),
|
|
||||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
share_service = SharedChatService(db)
|
|
||||||
share_token = share_data.share_token
|
|
||||||
share, release = share_service.get_release_by_share_token(share_token, password)
|
|
||||||
if release.app.type == AppType.WORKFLOW:
|
|
||||||
workflow_service = WorkflowService(db)
|
|
||||||
content = {
|
|
||||||
"app_type": release.app.type,
|
|
||||||
"variables": workflow_service.get_start_node_variables(release.config),
|
|
||||||
"memory": workflow_service.is_memory_enable(release.config),
|
|
||||||
"features": release.config.get("features")
|
|
||||||
}
|
|
||||||
elif release.app.type == AppType.AGENT:
|
|
||||||
content = {
|
|
||||||
"app_type": release.app.type,
|
|
||||||
"variables": release.config.get("variables"),
|
|
||||||
"memory": release.config.get("memory", {}).get("enabled"),
|
|
||||||
"features": release.config.get("features"),
|
|
||||||
"model_parameters": release.config.get("model_parameters")
|
|
||||||
}
|
|
||||||
elif release.app.type == AppType.MULTI_AGENT:
|
|
||||||
content = {
|
|
||||||
"app_type": release.app.type,
|
|
||||||
"variables": [],
|
|
||||||
"features": release.config.get("features")
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
|
||||||
return success(data=content)
|
|
||||||
|
|||||||
@@ -4,18 +4,7 @@
|
|||||||
认证方式: API Key
|
认证方式: API Key
|
||||||
"""
|
"""
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
|
||||||
from . import (
|
|
||||||
app_api_controller,
|
|
||||||
end_user_api_controller,
|
|
||||||
memory_api_controller,
|
|
||||||
memory_config_api_controller,
|
|
||||||
rag_api_chunk_controller,
|
|
||||||
rag_api_document_controller,
|
|
||||||
rag_api_file_controller,
|
|
||||||
rag_api_knowledge_controller,
|
|
||||||
user_memory_api_controller,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 V1 API 路由器
|
# 创建 V1 API 路由器
|
||||||
service_router = APIRouter()
|
service_router = APIRouter()
|
||||||
@@ -27,8 +16,5 @@ service_router.include_router(rag_api_document_controller.router)
|
|||||||
service_router.include_router(rag_api_file_controller.router)
|
service_router.include_router(rag_api_file_controller.router)
|
||||||
service_router.include_router(rag_api_chunk_controller.router)
|
service_router.include_router(rag_api_chunk_controller.router)
|
||||||
service_router.include_router(memory_api_controller.router)
|
service_router.include_router(memory_api_controller.router)
|
||||||
service_router.include_router(end_user_api_controller.router)
|
|
||||||
service_router.include_router(memory_config_api_controller.router)
|
|
||||||
service_router.include_router(user_memory_api_controller.router)
|
|
||||||
|
|
||||||
__all__ = ["service_router"]
|
__all__ = ["service_router"]
|
||||||
|
|||||||
@@ -12,19 +12,18 @@ from app.core.exceptions import BusinessException
|
|||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_app_or_workspace
|
||||||
from app.models.app_model import App
|
from app.models.app_model import App
|
||||||
from app.models.app_model import AppType
|
from app.models.app_model import AppType
|
||||||
from app.models.app_release_model import AppRelease
|
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
from app.schemas import AppChatRequest, conversation_schema
|
from app.schemas import AppChatRequest, conversation_schema
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
from app.services import workspace_service
|
from app.services import workspace_service
|
||||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||||
from app.services.app_service import get_app_service, AppService
|
|
||||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
from app.services.app_service import get_app_service, AppService
|
||||||
|
|
||||||
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -35,7 +34,6 @@ async def list_apps():
|
|||||||
"""列出可访问的应用(占位)"""
|
"""列出可访问的应用(占位)"""
|
||||||
return success(data=[], msg="App API - Coming Soon")
|
return success(data=[], msg="App API - Coming Soon")
|
||||||
|
|
||||||
|
|
||||||
# /v1/app/chat
|
# /v1/app/chat
|
||||||
|
|
||||||
# @router.post("/chat")
|
# @router.post("/chat")
|
||||||
@@ -62,19 +60,18 @@ async def list_apps():
|
|||||||
# return success(data={"received": True}, msg="消息已接收")
|
# return success(data={"received": True}, msg="消息已接收")
|
||||||
|
|
||||||
|
|
||||||
def _checkAppConfig(release: AppRelease):
|
def _checkAppConfig(app: App):
|
||||||
if release.type == AppType.AGENT:
|
if app.type == AppType.AGENT:
|
||||||
if not release.config:
|
if not app.current_release.config:
|
||||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
elif release.type == AppType.MULTI_AGENT:
|
elif app.type == AppType.MULTI_AGENT:
|
||||||
if not release.config:
|
if not app.current_release.config:
|
||||||
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
elif release.type == AppType.WORKFLOW:
|
elif app.type == AppType.WORKFLOW:
|
||||||
if not release.config:
|
if not app.current_release.config:
|
||||||
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
else:
|
else:
|
||||||
raise BusinessException("不支持的应用类型", BizCode.APP_TYPE_NOT_SUPPORTED)
|
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat")
|
@router.post("/chat")
|
||||||
@require_api_key(scopes=["app"])
|
@require_api_key(scopes=["app"])
|
||||||
@@ -87,39 +84,18 @@ async def chat(
|
|||||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||||
message: str = Body(..., description="聊天消息内容"),
|
message: str = Body(..., description="聊天消息内容"),
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Agent/Workflow 聊天接口
|
|
||||||
|
|
||||||
- 不传 version:使用当前生效版本(current_release,回滚后为回滚目标版本)
|
|
||||||
- 传 version=release_id:使用指定版本uuid的历史快照,例如 {"version": "{{release_id}}"}
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
payload = AppChatRequest(**body)
|
payload = AppChatRequest(**body)
|
||||||
|
|
||||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
# 版本切换:指定 release_id 时查找对应历史快照,否则使用当前激活版本
|
|
||||||
if payload.version is not None:
|
|
||||||
active_release = app_service.get_release_by_id(app.id, payload.version)
|
|
||||||
else:
|
|
||||||
active_release = app.current_release
|
|
||||||
other_id = payload.user_id
|
other_id = payload.user_id
|
||||||
workspace_id = api_key_auth.workspace_id
|
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||||
|
other_id = payload.user_id
|
||||||
|
workspace_id = app.workspace_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
|
|
||||||
# 仅在新建终端用户时检查配额,已有用户复用不受限制
|
|
||||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
|
||||||
if existing_end_user is None:
|
|
||||||
from app.core.quota_manager import _check_quota
|
|
||||||
from app.models.workspace_model import Workspace
|
|
||||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
|
||||||
if ws:
|
|
||||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
|
||||||
|
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
workspace_id=workspace_id,
|
|
||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
|
original_user_id=other_id # Save original user_id to other_id
|
||||||
)
|
)
|
||||||
end_user_id = str(new_end_user.id)
|
end_user_id = str(new_end_user.id)
|
||||||
web_search=True
|
web_search=True
|
||||||
@@ -150,28 +126,22 @@ async def chat(
|
|||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
app_type = app.type
|
app_type = app.type
|
||||||
# check app config
|
# check app config
|
||||||
_checkAppConfig(active_release)
|
_checkAppConfig(app)
|
||||||
|
|
||||||
# 获取或创建会话(提前验证)
|
# 获取或创建会话(提前验证)
|
||||||
conversation = conversation_service.create_or_get_conversation(
|
conversation = conversation_service.create_or_get_conversation(
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=end_user_id,
|
user_id=end_user_id,
|
||||||
is_draft=False,
|
is_draft=False
|
||||||
conversation_id=payload.conversation_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if app_type == AppType.AGENT:
|
if app_type == AppType.AGENT:
|
||||||
|
|
||||||
# print("="*50)
|
# print("="*50)
|
||||||
# print(app.current_release.default_model_config_id)
|
# print(app.current_release.default_model_config_id)
|
||||||
agent_config = agent_config_4_app_release(active_release)
|
agent_config = agent_config_4_app_release(app.current_release)
|
||||||
# print(agent_config.default_model_config_id)
|
# print(agent_config.default_model_config_id)
|
||||||
|
|
||||||
# thinking 开关:仅当 agent 配置了 deep_thinking 且请求 thinking=True 时才启用
|
|
||||||
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
|
||||||
agent_config.model_parameters["deep_thinking"] = False
|
|
||||||
|
|
||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
@@ -185,8 +155,7 @@ async def chat(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id
|
||||||
files=payload.files # 传递多模态文件
|
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -211,13 +180,12 @@ async def chat(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id
|
||||||
files=payload.files # 传递多模态文件
|
|
||||||
)
|
)
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
# 多 Agent 流式返回
|
# 多 Agent 流式返回
|
||||||
config = multi_agent_config_4_app_release(active_release)
|
config = multi_agent_config_4_app_release(app.current_release)
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for event in app_chat_service.multi_agent_chat_stream(
|
async for event in app_chat_service.multi_agent_chat_stream(
|
||||||
@@ -260,15 +228,15 @@ async def chat(
|
|||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.WORKFLOW:
|
elif app_type == AppType.WORKFLOW:
|
||||||
# 多 Agent 流式返回
|
# 多 Agent 流式返回
|
||||||
config = workflow_config_4_app_release(active_release)
|
config = workflow_config_4_app_release(app.current_release)
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for event in app_chat_service.workflow_chat_stream(
|
async for event in app_chat_service.workflow_chat_stream(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=end_user_id, # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
files=payload.files,
|
|
||||||
config=config,
|
config=config,
|
||||||
web_search=web_search,
|
web_search=web_search,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
@@ -276,8 +244,7 @@ async def chat(
|
|||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
release_id=active_release.id,
|
release_id=app.current_release.id,
|
||||||
public=True
|
|
||||||
):
|
):
|
||||||
event_type = event.get("event", "message")
|
event_type = event.get("event", "message")
|
||||||
event_data = event.get("data", {})
|
event_data = event.get("data", {})
|
||||||
@@ -296,7 +263,7 @@ async def chat(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# workflow 非流式返回
|
# 多 Agent 非流式返回
|
||||||
result = await app_chat_service.workflow_chat(
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -308,10 +275,9 @@ async def chat(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
files=payload.files,
|
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
release_id=active_release.id
|
release_id=app.current_release.id
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"工作流试运行返回结果",
|
"工作流试运行返回结果",
|
||||||
@@ -325,4 +291,7 @@ async def chat(
|
|||||||
msg="工作流任务执行成功"
|
msg="工作流任务执行成功"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|
||||||
|
|||||||
@@ -1,173 +0,0 @@
|
|||||||
"""End User 服务接口 - 基于 API Key 认证"""
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, Request
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.controllers import user_memory_controllers
|
|
||||||
from app.core.api_key_auth import require_api_key
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.core.exceptions import BusinessException
|
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
from app.core.quota_stub import check_end_user_quota
|
|
||||||
from app.core.response_utils import success
|
|
||||||
from app.db import get_db
|
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
|
||||||
from app.schemas.end_user_info_schema import EndUserInfoUpdate
|
|
||||||
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
|
|
||||||
from app.services import api_key_service
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
|
|
||||||
logger = get_business_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
|
||||||
"""Build a current_user object from API key auth
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key_auth: Validated API key auth info
|
|
||||||
db: Database session
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User object with current_workspace_id set
|
|
||||||
"""
|
|
||||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
|
||||||
current_user = api_key.creator
|
|
||||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
@check_end_user_quota
|
|
||||||
async def create_end_user(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(None, description="Request body"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create or retrieve an end user for the workspace.
|
|
||||||
|
|
||||||
Creates a new end user and connects it to a memory configuration.
|
|
||||||
If an end user with the same other_id already exists in the workspace,
|
|
||||||
returns the existing one.
|
|
||||||
|
|
||||||
Optionally accepts a memory_config_id to connect the end user to a specific
|
|
||||||
memory configuration. If not provided, falls back to the workspace default config.
|
|
||||||
Optionally accepts an app_id to bind the end user to a specific app.
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
payload = CreateEndUserRequest(**body)
|
|
||||||
workspace_id = api_key_auth.workspace_id
|
|
||||||
|
|
||||||
logger.info("Create end user request - other_id: %s, workspace_id: %s", payload.other_id, workspace_id)
|
|
||||||
|
|
||||||
# Resolve memory_config_id: explicit > workspace default
|
|
||||||
memory_config_id = None
|
|
||||||
config_service = MemoryConfigService(db)
|
|
||||||
|
|
||||||
if payload.memory_config_id:
|
|
||||||
try:
|
|
||||||
memory_config_id = uuid.UUID(payload.memory_config_id)
|
|
||||||
except ValueError:
|
|
||||||
raise BusinessException(
|
|
||||||
f"Invalid memory_config_id format: {payload.memory_config_id}",
|
|
||||||
BizCode.INVALID_PARAMETER
|
|
||||||
)
|
|
||||||
config = config_service.get_config_with_fallback(memory_config_id, workspace_id)
|
|
||||||
if not config:
|
|
||||||
raise BusinessException(
|
|
||||||
f"Memory config not found: {payload.memory_config_id}",
|
|
||||||
BizCode.MEMORY_CONFIG_NOT_FOUND
|
|
||||||
)
|
|
||||||
memory_config_id = config.config_id
|
|
||||||
else:
|
|
||||||
default_config = config_service.get_workspace_default_config(workspace_id)
|
|
||||||
if default_config:
|
|
||||||
memory_config_id = default_config.config_id
|
|
||||||
logger.info(f"Using workspace default memory config: {memory_config_id}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"No default memory config found for workspace: {workspace_id}")
|
|
||||||
|
|
||||||
# Resolve app_id: explicit from payload, otherwise None
|
|
||||||
app_id = None
|
|
||||||
if payload.app_id:
|
|
||||||
try:
|
|
||||||
app_id = uuid.UUID(payload.app_id)
|
|
||||||
except ValueError:
|
|
||||||
raise BusinessException(
|
|
||||||
f"Invalid app_id format: {payload.app_id}",
|
|
||||||
BizCode.INVALID_PARAMETER
|
|
||||||
)
|
|
||||||
|
|
||||||
end_user_repo = EndUserRepository(db)
|
|
||||||
end_user = end_user_repo.get_or_create_end_user_with_config(
|
|
||||||
app_id=app_id,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
other_id=payload.other_id,
|
|
||||||
memory_config_id=memory_config_id,
|
|
||||||
other_name=payload.other_name,
|
|
||||||
)
|
|
||||||
end_user.other_name = payload.other_name
|
|
||||||
logger.info(f"End user ready: {end_user.id}")
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"id": str(end_user.id),
|
|
||||||
"other_id": end_user.other_id or "",
|
|
||||||
"other_name": end_user.other_name or "",
|
|
||||||
"workspace_id": str(end_user.workspace_id),
|
|
||||||
"memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/info")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_end_user_info(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get end user info.
|
|
||||||
|
|
||||||
Retrieves the info record (aliases, meta_data, etc.) for the specified end user.
|
|
||||||
Delegates to the manager-side controller for shared logic.
|
|
||||||
"""
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
return await user_memory_controllers.get_end_user_info(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/info/update")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def update_end_user_info(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(None, description="Request body"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update end user info.
|
|
||||||
|
|
||||||
Updates the info record (other_name, aliases, meta_data) for the specified end user.
|
|
||||||
Delegates to the manager-side controller for shared logic.
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
payload = EndUserInfoUpdate(**body)
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
return await user_memory_controllers.update_end_user_info(
|
|
||||||
info_update=payload,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
@@ -1,84 +1,49 @@
|
|||||||
"""Memory 服务接口 - 基于 API Key 认证"""
|
"""Memory 服务接口 - 基于 API Key 认证"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, Query, Request
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.celery_task_scheduler import scheduler
|
|
||||||
from app.core.api_key_auth import require_api_key
|
from app.core.api_key_auth import require_api_key
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.quota_stub import check_end_user_quota
|
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
from app.schemas.memory_api_schema import (
|
from app.schemas.memory_api_schema import (
|
||||||
MemoryReadRequest,
|
MemoryReadRequest,
|
||||||
MemoryReadResponse,
|
MemoryReadResponse,
|
||||||
MemoryReadSyncResponse,
|
|
||||||
MemoryWriteRequest,
|
MemoryWriteRequest,
|
||||||
MemoryWriteResponse,
|
MemoryWriteResponse,
|
||||||
MemoryWriteSyncResponse,
|
|
||||||
)
|
)
|
||||||
from app.services.memory_api_service import MemoryAPIService
|
from app.services.memory_api_service import MemoryAPIService
|
||||||
|
from fastapi import APIRouter, Body, Depends, Request
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_task_result(result: dict) -> dict:
|
|
||||||
"""Make Celery task result JSON-serializable.
|
|
||||||
|
|
||||||
Converts UUID and other non-serializable values to strings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: Raw task result dict from task_service
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON-safe dict
|
|
||||||
"""
|
|
||||||
import uuid as _uuid
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
def _convert(obj):
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return {k: _convert(v) for k, v in obj.items()}
|
|
||||||
if isinstance(obj, list):
|
|
||||||
return [_convert(i) for i in obj]
|
|
||||||
if isinstance(obj, _uuid.UUID):
|
|
||||||
return str(obj)
|
|
||||||
if isinstance(obj, datetime):
|
|
||||||
return obj.isoformat()
|
|
||||||
return obj
|
|
||||||
|
|
||||||
return _convert(result)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
async def get_memory_info():
|
async def get_memory_info():
|
||||||
"""获取记忆服务信息(占位)"""
|
"""获取记忆服务信息(占位)"""
|
||||||
return success(data={}, msg="Memory API - Coming Soon")
|
return success(data={}, msg="Memory API - Coming Soon")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/write")
|
@router.post("/write_api_service")
|
||||||
@require_api_key(scopes=["memory"])
|
@require_api_key(scopes=["memory"])
|
||||||
async def write_memory(
|
async def write_memory_api_service(
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
message: str = Body(..., description="Message content"),
|
payload: MemoryWriteRequest = Body(..., embed=False),
|
||||||
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Submit a memory write task.
|
Write memory to storage.
|
||||||
|
|
||||||
Validates the end user, then dispatches the write to a Celery background task
|
Stores memory content for the specified end user using the Memory API Service.
|
||||||
with per-user fair locking. Returns a task_id for status polling.
|
|
||||||
"""
|
"""
|
||||||
body = await request.json()
|
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||||
payload = MemoryWriteRequest(**body)
|
|
||||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
result = memory_api_service.write_memory(
|
result = await memory_api_service.write_memory(
|
||||||
workspace_id=api_key_auth.workspace_id,
|
workspace_id=api_key_auth.workspace_id,
|
||||||
end_user_id=payload.end_user_id,
|
end_user_id=payload.end_user_id,
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -87,51 +52,28 @@ async def write_memory(
|
|||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}")
|
logger.info(f"Memory write successful for end_user: {payload.end_user_id}")
|
||||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory written successfully")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/write/status")
|
@router.post("/read_api_service")
|
||||||
@require_api_key(scopes=["memory"])
|
@require_api_key(scopes=["memory"])
|
||||||
async def get_write_task_status(
|
async def read_memory_api_service(
|
||||||
request: Request,
|
|
||||||
task_id: str = Query(..., description="Celery task ID"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Check the status of a memory write task.
|
|
||||||
|
|
||||||
Returns the current status and result (if completed) of a previously submitted write task.
|
|
||||||
"""
|
|
||||||
logger.info(f"Write task status check - task_id: {task_id}")
|
|
||||||
|
|
||||||
result = scheduler.get_task_status(task_id)
|
|
||||||
|
|
||||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/read")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def read_memory(
|
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
message: str = Body(..., description="Query message"),
|
payload: MemoryReadRequest = Body(..., embed=False),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Submit a memory read task.
|
Read memory from storage.
|
||||||
|
|
||||||
Validates the end user, then dispatches the read to a Celery background task.
|
Queries and retrieves memories for the specified end user with context-aware responses.
|
||||||
Returns a task_id for status polling.
|
|
||||||
"""
|
"""
|
||||||
body = await request.json()
|
|
||||||
payload = MemoryReadRequest(**body)
|
|
||||||
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
result = memory_api_service.read_memory(
|
result = await memory_api_service.read_memory(
|
||||||
workspace_id=api_key_auth.workspace_id,
|
workspace_id=api_key_auth.workspace_id,
|
||||||
end_user_id=payload.end_user_id,
|
end_user_id=payload.end_user_id,
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -141,94 +83,5 @@ async def read_memory(
|
|||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory read task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
|
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
||||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read task submitted")
|
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/read/status")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_read_task_status(
|
|
||||||
request: Request,
|
|
||||||
task_id: str = Query(..., description="Celery task ID"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Check the status of a memory read task.
|
|
||||||
|
|
||||||
Returns the current status and result (if completed) of a previously submitted read task.
|
|
||||||
"""
|
|
||||||
logger.info(f"Read task status check - task_id: {task_id}")
|
|
||||||
|
|
||||||
from app.services.task_service import get_task_memory_read_result
|
|
||||||
result = get_task_memory_read_result(task_id)
|
|
||||||
|
|
||||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/write/sync")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
@check_end_user_quota
|
|
||||||
async def write_memory_sync(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(..., description="Message content"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Write memory synchronously.
|
|
||||||
|
|
||||||
Blocks until the write completes and returns the result directly.
|
|
||||||
For async processing with task polling, use /write instead.
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
payload = MemoryWriteRequest(**body)
|
|
||||||
logger.info(f"Memory write (sync) request - end_user_id: {payload.end_user_id}")
|
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
|
||||||
|
|
||||||
result = await memory_api_service.write_memory_sync(
|
|
||||||
workspace_id=api_key_auth.workspace_id,
|
|
||||||
end_user_id=payload.end_user_id,
|
|
||||||
message=payload.message,
|
|
||||||
config_id=payload.config_id,
|
|
||||||
storage_type=payload.storage_type,
|
|
||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Memory write (sync) successful for end_user: {payload.end_user_id}")
|
|
||||||
return success(data=MemoryWriteSyncResponse(**result).model_dump(), msg="Memory written successfully")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/read/sync")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def read_memory_sync(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(..., description="Query message"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Read memory synchronously.
|
|
||||||
|
|
||||||
Blocks until the read completes and returns the answer directly.
|
|
||||||
For async processing with task polling, use /read instead.
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
payload = MemoryReadRequest(**body)
|
|
||||||
logger.info(f"Memory read (sync) request - end_user_id: {payload.end_user_id}")
|
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
|
||||||
|
|
||||||
result = await memory_api_service.read_memory_sync(
|
|
||||||
workspace_id=api_key_auth.workspace_id,
|
|
||||||
end_user_id=payload.end_user_id,
|
|
||||||
message=payload.message,
|
|
||||||
search_switch=payload.search_switch,
|
|
||||||
config_id=payload.config_id,
|
|
||||||
storage_type=payload.storage_type,
|
|
||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Memory read (sync) successful for end_user: {payload.end_user_id}")
|
|
||||||
return success(data=MemoryReadSyncResponse(**result).model_dump(), msg="Memory read successfully")
|
|
||||||
|
|||||||
@@ -1,491 +0,0 @@
|
|||||||
"""Memory Config 服务接口 - 基于 API Key 认证"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, Header, Query, Request
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.controllers import memory_storage_controller
|
|
||||||
from app.controllers import memory_forget_controller
|
|
||||||
from app.controllers import ontology_controller
|
|
||||||
from app.controllers import emotion_config_controller
|
|
||||||
from app.controllers import memory_reflection_controller
|
|
||||||
from app.schemas.memory_storage_schema import ForgettingConfigUpdateRequest
|
|
||||||
from app.controllers.emotion_config_controller import EmotionConfigUpdate
|
|
||||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
|
||||||
from app.core.api_key_auth import require_api_key
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.core.exceptions import BusinessException
|
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
from app.core.response_utils import success
|
|
||||||
from app.db import get_db
|
|
||||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
|
||||||
from app.schemas.memory_api_schema import (
|
|
||||||
ConfigUpdateExtractedRequest,
|
|
||||||
ConfigUpdateRequest,
|
|
||||||
ListConfigsResponse,
|
|
||||||
ConfigCreateRequest,
|
|
||||||
ConfigUpdateForgettingRequest,
|
|
||||||
EmotionConfigUpdateRequest,
|
|
||||||
ReflectionConfigUpdateRequest,
|
|
||||||
)
|
|
||||||
from app.schemas.memory_storage_schema import (
|
|
||||||
ConfigUpdate,
|
|
||||||
ConfigUpdateExtracted,
|
|
||||||
ConfigParamsCreate,
|
|
||||||
)
|
|
||||||
from app.services import api_key_service
|
|
||||||
from app.services.memory_api_service import MemoryAPIService
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/memory_config", tags=["V1 - Memory Config API"])
|
|
||||||
logger = get_business_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
|
||||||
"""Build a current_user object from API key auth
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key_auth: Validated API key auth info
|
|
||||||
db: Database session
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User object with current_workspace_id set
|
|
||||||
"""
|
|
||||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
|
||||||
current_user = api_key.creator
|
|
||||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_config_ownership(config_id:str, workspace_id:uuid.UUID, db:Session):
|
|
||||||
"""Verify that the config belongs to the workspace.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_id: The ID of the config to verify
|
|
||||||
workspace_id: The workspace ID tocheck against
|
|
||||||
db: Database session for querying
|
|
||||||
Raises:
|
|
||||||
BusinessException: If the config does not exist or does not belong to the workspace
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
resolved_id = resolve_config_id(config_id, db)
|
|
||||||
except ValueError as e:
|
|
||||||
raise BusinessException(
|
|
||||||
message=f"Invalid config_id: {e}",
|
|
||||||
code=BizCode.INVALID_PARAMETER,
|
|
||||||
)
|
|
||||||
config = MemoryConfigRepository.get_by_id(db, resolved_id)
|
|
||||||
if not config or config.workspace_id != workspace_id:
|
|
||||||
raise BusinessException(
|
|
||||||
message="Config not found or access denied",
|
|
||||||
code=BizCode.MEMORY_CONFIG_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
# @router.get("/configs")
|
|
||||||
# @require_api_key(scopes=["memory"])
|
|
||||||
# async def list_memory_configs(
|
|
||||||
# request: Request,
|
|
||||||
# api_key_auth: ApiKeyAuth = None,
|
|
||||||
# db: Session = Depends(get_db),
|
|
||||||
# ):
|
|
||||||
# """
|
|
||||||
# List all memory configs for the workspace.
|
|
||||||
|
|
||||||
# Returns all available memory configurations associated with the authorized workspace.
|
|
||||||
# """
|
|
||||||
# logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
# memory_api_service = MemoryAPIService(db)
|
|
||||||
|
|
||||||
# result = memory_api_service.list_memory_configs(
|
|
||||||
# workspace_id=api_key_auth.workspace_id,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
|
||||||
# return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
|
||||||
|
|
||||||
@router.get("/read_all_config")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def read_all_config(
|
|
||||||
request:Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
List all memory configs with full details (enhanced version).
|
|
||||||
|
|
||||||
Returns complete config fields for the authorized workspace.
|
|
||||||
No config_id ownership check needed — results are filtered by workspace.
|
|
||||||
"""
|
|
||||||
logger.info(f"V1 get all configs (full) - workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
|
|
||||||
return memory_storage_controller.read_all_config(
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/scenes/simple")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_ontology_scenes(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get available ontology scenes for the workspace.
|
|
||||||
|
|
||||||
Returns a simple list of scene_id and scene_name for dropdown selection.
|
|
||||||
Used before creating a memory config to choose which ontology scene to associate.
|
|
||||||
"""
|
|
||||||
logger.info(f"V1 get scenes - workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
|
|
||||||
return await ontology_controller.get_scenes_simple(
|
|
||||||
db=db,
|
|
||||||
current_user=current_user,
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/read_config_extracted")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def read_config_extracted(
|
|
||||||
request: Request,
|
|
||||||
config_id: str = Query(..., description="config_id"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get extraction engine config details for a specific config.
|
|
||||||
|
|
||||||
Only configs belonging to the authorized workspace can be queried.
|
|
||||||
"""
|
|
||||||
logger.info(f"V1 read extracted config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
|
|
||||||
return memory_storage_controller.read_config_extracted(
|
|
||||||
config_id = config_id,
|
|
||||||
current_user = current_user,
|
|
||||||
db = db,
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/read_config_forgetting")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def read_config_forgetting(
|
|
||||||
request: Request,
|
|
||||||
config_id: str = Query(..., description="config_id"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get forgetting settings for a specific memory config.
|
|
||||||
|
|
||||||
Only configs belonging to the authorized workspace can be queried.
|
|
||||||
"""
|
|
||||||
logger.info(f"V1 read forgetting config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
|
|
||||||
result = await memory_forget_controller.read_forgetting_config(
|
|
||||||
config_id = config_id,
|
|
||||||
current_user = current_user,
|
|
||||||
db = db,
|
|
||||||
)
|
|
||||||
return jsonable_encoder(result)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/read_config_emotion")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def read_config_emotion(
|
|
||||||
request: Request,
|
|
||||||
config_id: str = Query(..., description="config_id"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get emotion engine config details for a specific config.
|
|
||||||
|
|
||||||
Only configs belonging to the authorized workspace can be queried.
|
|
||||||
"""
|
|
||||||
logger.info(f"V1 read emotion config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
|
|
||||||
return jsonable_encoder(emotion_config_controller.get_emotion_config(
|
|
||||||
config_id=config_id,
|
|
||||||
db=db,
|
|
||||||
current_user=current_user,
|
|
||||||
))
|
|
||||||
|
|
||||||
@router.get("/read_config_reflection")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def read_config_reflection(
|
|
||||||
request: Request,
|
|
||||||
config_id: str = Query(..., description="config_id"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get reflection engine config details for a specific config.
|
|
||||||
|
|
||||||
Only configs belonging to the authorized workspace can be queried.
|
|
||||||
"""
|
|
||||||
logger.info(f"V1 read reflection config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
|
|
||||||
return jsonable_encoder(await memory_reflection_controller.start_reflection_configs(
|
|
||||||
config_id=config_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create_config")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def create_memory_config(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(None, description="Request body"),
|
|
||||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new memory config for the workspace.
|
|
||||||
|
|
||||||
The config will be associated with the workspace of the API Key.
|
|
||||||
config_name is required, other fields are optional.
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
payload = ConfigCreateRequest(**body)
|
|
||||||
|
|
||||||
logger.info(f"V1 create config - workspace: {api_key_auth.workspace_id}, config_name: {payload.config_name}")
|
|
||||||
|
|
||||||
# 构造管理端 Schema,workspace_id 从 API Key 注入
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
mgmt_payload = ConfigParamsCreate(
|
|
||||||
config_name=payload.config_name,
|
|
||||||
config_desc=payload.config_desc or "",
|
|
||||||
scene_id=payload.scene_id,
|
|
||||||
llm_id=payload.llm_id,
|
|
||||||
embedding_id=payload.embedding_id,
|
|
||||||
rerank_id=payload.rerank_id,
|
|
||||||
reflection_model_id=payload.reflection_model_id,
|
|
||||||
emotion_model_id=payload.emotion_model_id,
|
|
||||||
)
|
|
||||||
#将返回数据中UUID序列化处理
|
|
||||||
result =memory_storage_controller.create_config(
|
|
||||||
payload=mgmt_payload,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
x_language_type=x_language_type,
|
|
||||||
)
|
|
||||||
return jsonable_encoder(result)
|
|
||||||
|
|
||||||
@router.put("/update_config")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def update_memory_config(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(None, description="Request body"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update memory config basic info (name, description, scene).
|
|
||||||
|
|
||||||
Requires API Key with 'memory' scope
|
|
||||||
Only configs belonging to the authorized workspace can be updated.
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
payload = ConfigUpdateRequest(**body)
|
|
||||||
|
|
||||||
logger.info(f"V1 update config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
mgmt_payload = ConfigUpdate(
|
|
||||||
config_id = payload.config_id,
|
|
||||||
config_name = payload.config_name,
|
|
||||||
config_desc = payload.config_desc,
|
|
||||||
scene_id = payload.scene_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return memory_storage_controller.update_config(
|
|
||||||
payload = mgmt_payload,
|
|
||||||
current_user = current_user,
|
|
||||||
db = db,
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.put("/update_config_extracted")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def update_memory_config_extracted(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(None, description="Request body"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
update memory config extraction engine config (models, thresholds, chunking, pruning, etc.).
|
|
||||||
|
|
||||||
Requires API Key with 'memory' scope.
|
|
||||||
Only configs belonging to the authorized workspace can be updated.
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
payload = ConfigUpdateExtractedRequest(**body)
|
|
||||||
|
|
||||||
logger.info(f"V1 update extracted config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
#校验权限
|
|
||||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
update_fields = payload.model_dump(exclude_unset=True)
|
|
||||||
mgmt_payload = ConfigUpdateExtracted(**update_fields)
|
|
||||||
|
|
||||||
return memory_storage_controller.update_config_extracted(
|
|
||||||
payload = mgmt_payload,
|
|
||||||
current_user = current_user,
|
|
||||||
db = db,
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.put("/update_config_forgetting")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def update_memory_config_forgetting(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(None, description="Request body"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
update memory config forgetting settings (forgetting strategy, parameters, etc.).
|
|
||||||
|
|
||||||
Requires API Key with 'memory' scope.
|
|
||||||
Only configs belonging to the authorized workspace can be updated.
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
payload = ConfigUpdateForgettingRequest(**body)
|
|
||||||
|
|
||||||
logger.info(f"V1 update forgetting config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
#校验权限
|
|
||||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
update_fields = payload.model_dump(exclude_unset=True)
|
|
||||||
mgmt_payload = ForgettingConfigUpdateRequest(**update_fields)
|
|
||||||
|
|
||||||
#将返回数据中UUID序列化处理
|
|
||||||
result = await memory_forget_controller.update_forgetting_config(
|
|
||||||
payload = mgmt_payload,
|
|
||||||
current_user = current_user,
|
|
||||||
db = db,
|
|
||||||
)
|
|
||||||
return jsonable_encoder(result)
|
|
||||||
|
|
||||||
@router.put("/update_config_emotion")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def update_config_emotion(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(None, description="Request body"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update emotion engine config (full update).
|
|
||||||
|
|
||||||
All fields except emotion_model_id are required.
|
|
||||||
Only configs belonging to the authorized workspace can be updated.
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
payload = EmotionConfigUpdateRequest(**body)
|
|
||||||
|
|
||||||
logger.info(f"V1 update emotion config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
update_fields = payload.model_dump(exclude_unset=True)
|
|
||||||
mgmt_payload = EmotionConfigUpdate(**update_fields)
|
|
||||||
return jsonable_encoder(emotion_config_controller.update_emotion_config(
|
|
||||||
config=mgmt_payload,
|
|
||||||
db=db,
|
|
||||||
current_user=current_user,
|
|
||||||
))
|
|
||||||
|
|
||||||
@router.put("/update_config_reflection")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def update_config_reflection(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(None, description="Request body"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update reflection engine config (full update).
|
|
||||||
|
|
||||||
All fields are required.
|
|
||||||
Only configs belonging to the authorized workspace can be updated.
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
payload = ReflectionConfigUpdateRequest(**body)
|
|
||||||
|
|
||||||
logger.info(f"V1 update reflection config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
update_fields = payload.model_dump(exclude_unset=True)
|
|
||||||
mgmt_payload = Memory_Reflection(**update_fields)
|
|
||||||
|
|
||||||
return jsonable_encoder(await memory_reflection_controller.save_reflection_config(
|
|
||||||
request=mgmt_payload,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
))
|
|
||||||
|
|
||||||
@router.delete("/delete_config")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def delete_memory_config(
|
|
||||||
config_id: str,
|
|
||||||
request: Request,
|
|
||||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Delete a memory config.
|
|
||||||
|
|
||||||
- Default configs cannot be deleted.
|
|
||||||
- If end users are connected and force=False, returns a warning.
|
|
||||||
- If force=True, clears end user references and deletes the config.
|
|
||||||
|
|
||||||
Only configs belonging to the authorized workspace can be deleted.
|
|
||||||
"""
|
|
||||||
logger.info(f"V1 delete config - config_id: {config_id}, force: {force}, workspace: {api_key_auth.workspace_id}")
|
|
||||||
|
|
||||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
|
||||||
|
|
||||||
current_user = _get_current_user(api_key_auth, db)
|
|
||||||
|
|
||||||
return memory_storage_controller.delete_config(
|
|
||||||
config_id=config_id,
|
|
||||||
force=force,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
@@ -246,73 +246,3 @@ async def rebuild_knowledge_graph(
|
|||||||
db=db,
|
db=db,
|
||||||
current_user=current_user)
|
current_user=current_user)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/check/yuque/auth", response_model=ApiResponse)
|
|
||||||
@require_api_key(scopes=["rag"])
|
|
||||||
async def check_yuque_auth(
|
|
||||||
yuque_user_id: str,
|
|
||||||
yuque_token: str,
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
check yuque auth info
|
|
||||||
"""
|
|
||||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
|
||||||
current_user = api_key.creator
|
|
||||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
|
||||||
|
|
||||||
api_logger.info(f"check yuque auth info, username: {current_user.username}")
|
|
||||||
|
|
||||||
return await knowledge_controller.check_yuque_auth(yuque_user_id=yuque_user_id,
|
|
||||||
yuque_token=yuque_token,
|
|
||||||
db=db,
|
|
||||||
current_user=current_user)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/check/feishu/auth", response_model=ApiResponse)
|
|
||||||
@require_api_key(scopes=["rag"])
|
|
||||||
async def check_feishu_auth(
|
|
||||||
feishu_app_id: str,
|
|
||||||
feishu_app_secret: str,
|
|
||||||
feishu_folder_token: str,
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
check feishu auth info
|
|
||||||
"""
|
|
||||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
|
||||||
current_user = api_key.creator
|
|
||||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
|
||||||
|
|
||||||
api_logger.info(f"check feishu auth info, username: {current_user.username}")
|
|
||||||
|
|
||||||
return await knowledge_controller.check_feishu_auth(feishu_app_id=feishu_app_id,
|
|
||||||
feishu_app_secret=feishu_app_secret,
|
|
||||||
feishu_folder_token=feishu_folder_token,
|
|
||||||
db=db,
|
|
||||||
current_user=current_user)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
|
|
||||||
@require_api_key(scopes=["rag"])
|
|
||||||
async def sync_knowledge(
|
|
||||||
knowledge_id: uuid.UUID,
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
sync knowledge base information based on knowledge_id
|
|
||||||
"""
|
|
||||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
|
||||||
current_user = api_key.creator
|
|
||||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
|
||||||
|
|
||||||
return await knowledge_controller.sync_knowledge(knowledge_id=knowledge_id,
|
|
||||||
db=db,
|
|
||||||
current_user=current_user)
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,230 +0,0 @@
|
|||||||
"""User Memory 服务接口 — 基于 API Key 认证
|
|
||||||
|
|
||||||
包装 user_memory_controllers.py 和 memory_agent_controller.py 中的内部接口,
|
|
||||||
提供基于 API Key 认证的对外服务:
|
|
||||||
1./analytics/graph_data - 知识图谱数据接口
|
|
||||||
2./analytics/community_graph - 社区图谱接口
|
|
||||||
3./analytics/node_statistics - 记忆节点统计接口
|
|
||||||
4./analytics/user_summary - 用户摘要接口
|
|
||||||
5./analytics/memory_insight - 记忆洞察接口
|
|
||||||
6./analytics/interest_distribution - 兴趣分布接口
|
|
||||||
7./analytics/end_user_info - 终端用户信息接口
|
|
||||||
8./analytics/generate_cache - 缓存生成接口
|
|
||||||
|
|
||||||
|
|
||||||
路由前缀: /memory
|
|
||||||
子路径: /analytics/...
|
|
||||||
最终路径: /v1/memory/analytics/...
|
|
||||||
认证方式: API Key (@require_api_key)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, Query, Request, Body
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.api_key_auth import require_api_key
|
|
||||||
from app.core.api_key_utils import get_current_user_from_api_key, validate_end_user_in_workspace
|
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
from app.db import get_db
|
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
|
||||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
|
||||||
|
|
||||||
# 包装内部服务 controller
|
|
||||||
from app.controllers import user_memory_controllers, memory_agent_controller
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/memory", tags=["V1 - User Memory API"])
|
|
||||||
logger = get_business_logger()
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 知识图谱 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/graph_data")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_graph_data(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
node_types: Optional[str] = Query(None, description="Comma-separated node types filter"),
|
|
||||||
limit: int = Query(100, description="Max nodes to return (auto-capped at 1000 in service layer)"),
|
|
||||||
depth: int = Query(1, description="Graph traversal depth (auto-capped at 3 in service layer)"),
|
|
||||||
center_node_id: Optional[str] = Query(None, description="Center node for subgraph"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get knowledge graph data (nodes + edges) for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_graph_data_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
node_types=node_types,
|
|
||||||
limit=limit,
|
|
||||||
depth=depth,
|
|
||||||
center_node_id=center_node_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/community_graph")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_community_graph(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get community clustering graph for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_community_graph_data_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 节点统计 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/node_statistics")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_node_statistics(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get memory node type statistics for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_node_statistics_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 用户摘要 & 洞察 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/user_summary")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_user_summary(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get cached user summary for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_user_summary_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
language_type=language_type,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/memory_insight")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_memory_insight(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get cached memory insight report for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_memory_insight_report_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 兴趣分布 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/interest_distribution")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_interest_distribution(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
limit: int = Query(5, le=5, description="Max interest tags to return"),
|
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get interest distribution tags for an end user."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await memory_agent_controller.get_interest_distribution_by_user_api(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
language_type=language_type,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 终端用户信息 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/end_user_info")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def get_end_user_info(
|
|
||||||
request: Request,
|
|
||||||
end_user_id: str = Query(..., description="End user ID"),
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get end user basic information (name, aliases, metadata)."""
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.get_end_user_info(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 缓存生成 ====================
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/analytics/generate_cache")
|
|
||||||
@require_api_key(scopes=["memory"])
|
|
||||||
async def generate_cache(
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
message: str = Body(None, description="Request body"),
|
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
|
||||||
):
|
|
||||||
"""Trigger cache generation (user summary + memory insight) for an end user or all workspace users."""
|
|
||||||
body = await request.json()
|
|
||||||
cache_request = GenerateCacheRequest(**body)
|
|
||||||
|
|
||||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
|
||||||
|
|
||||||
if cache_request.end_user_id:
|
|
||||||
validate_end_user_in_workspace(db, cache_request.end_user_id, api_key_auth.workspace_id)
|
|
||||||
|
|
||||||
return await user_memory_controllers.generate_cache_api(
|
|
||||||
request=cache_request,
|
|
||||||
language_type=language_type,
|
|
||||||
current_user=current_user,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
"""Skill Controller - 技能市场管理"""
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from typing import Optional
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from app.db import get_db
|
|
||||||
from app.dependencies import get_current_user
|
|
||||||
from app.models import User
|
|
||||||
from app.schemas import skill_schema
|
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
|
||||||
from app.services.skill_service import SkillService
|
|
||||||
from app.core.response_utils import success
|
|
||||||
from app.core.quota_stub import check_skill_quota
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/skills", tags=["Skills"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("", summary="创建技能")
|
|
||||||
@check_skill_quota
|
|
||||||
def create_skill(
|
|
||||||
data: skill_schema.SkillCreate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""创建技能 - 可以关联现有工具(内置、MCP、自定义)"""
|
|
||||||
tenant_id = current_user.tenant_id
|
|
||||||
skill = SkillService.create_skill(db, data, tenant_id)
|
|
||||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("", summary="技能列表")
|
|
||||||
def list_skills(
|
|
||||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
|
||||||
is_active: Optional[bool] = Query(None, description="是否激活"),
|
|
||||||
is_public: Optional[bool] = Query(None, description="是否公开"),
|
|
||||||
page: int = Query(1, ge=1, description="页码"),
|
|
||||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""技能市场列表 - 包含本工作空间和公开的技能"""
|
|
||||||
tenant_id = current_user.tenant_id
|
|
||||||
skills, total = SkillService.list_skills(
|
|
||||||
db, tenant_id, search, is_active, is_public, page, pagesize
|
|
||||||
)
|
|
||||||
|
|
||||||
items = [skill_schema.Skill.model_validate(s) for s in skills]
|
|
||||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
|
||||||
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{skill_id}", summary="获取技能详情")
|
|
||||||
def get_skill(
|
|
||||||
skill_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""获取技能详情"""
|
|
||||||
tenant_id = current_user.tenant_id
|
|
||||||
skill = SkillService.get_skill(db, skill_id, tenant_id)
|
|
||||||
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{skill_id}", summary="更新技能")
|
|
||||||
def update_skill(
|
|
||||||
skill_id: uuid.UUID,
|
|
||||||
data: skill_schema.SkillUpdate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""更新技能"""
|
|
||||||
tenant_id = current_user.tenant_id
|
|
||||||
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
|
|
||||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{skill_id}", summary="删除技能")
|
|
||||||
def delete_skill(
|
|
||||||
skill_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""删除技能"""
|
|
||||||
tenant_id = current_user.tenant_id
|
|
||||||
SkillService.delete_skill(db, skill_id, tenant_id)
|
|
||||||
return success(msg="技能删除成功")
|
|
||||||
@@ -1,173 +0,0 @@
|
|||||||
"""
|
|
||||||
租户套餐查询接口(普通用户可访问)
|
|
||||||
"""
|
|
||||||
import datetime
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
|
||||||
from app.core.response_utils import success, fail
|
|
||||||
from app.db import get_db
|
|
||||||
from app.dependencies import get_current_user
|
|
||||||
from app.i18n.dependencies import get_translator
|
|
||||||
from app.models.user_model import User
|
|
||||||
from app.schemas.response_schema import ApiResponse
|
|
||||||
|
|
||||||
logger = get_api_logger()
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/tenant", tags=["Tenant"])
|
|
||||||
public_router = APIRouter(tags=["Tenant"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/subscription", response_model=ApiResponse, summary="获取当前用户所属租户的套餐信息")
|
|
||||||
async def get_my_tenant_subscription(
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
t: Callable = Depends(get_translator),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
获取当前登录用户所属租户的有效套餐订阅信息。
|
|
||||||
包含套餐名称、版本、配额、到期时间等。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
|
||||||
|
|
||||||
if not current_user.tenant:
|
|
||||||
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
|
||||||
|
|
||||||
tenant_id = current_user.tenant.id
|
|
||||||
svc = TenantSubscriptionService(db)
|
|
||||||
sub = svc.get_subscription(tenant_id)
|
|
||||||
|
|
||||||
if not sub:
|
|
||||||
# 无订阅记录时,兜底返回免费套餐信息
|
|
||||||
free_plan = svc.plan_repo.get_free_plan()
|
|
||||||
if not free_plan:
|
|
||||||
return success(data=None, msg="暂无有效套餐")
|
|
||||||
return success(data={
|
|
||||||
"subscription_id": None,
|
|
||||||
"tenant_id": str(tenant_id),
|
|
||||||
"package_plan_id": str(free_plan.id),
|
|
||||||
"package_version": free_plan.version,
|
|
||||||
"package_plan": {
|
|
||||||
"id": str(free_plan.id),
|
|
||||||
"name": free_plan.name,
|
|
||||||
"name_en": free_plan.name_en,
|
|
||||||
"version": free_plan.version,
|
|
||||||
"category": free_plan.category,
|
|
||||||
"tier_level": free_plan.tier_level,
|
|
||||||
"price": float(free_plan.price) if free_plan.price is not None else 0.0,
|
|
||||||
"billing_cycle": free_plan.billing_cycle,
|
|
||||||
"core_value": free_plan.core_value,
|
|
||||||
"core_value_en": free_plan.core_value_en,
|
|
||||||
"tech_support": free_plan.tech_support,
|
|
||||||
"tech_support_en": free_plan.tech_support_en,
|
|
||||||
"sla_compliance": free_plan.sla_compliance,
|
|
||||||
"sla_compliance_en": free_plan.sla_compliance_en,
|
|
||||||
"page_customization": free_plan.page_customization,
|
|
||||||
"page_customization_en": free_plan.page_customization_en,
|
|
||||||
"theme_color": free_plan.theme_color,
|
|
||||||
},
|
|
||||||
"started_at": None,
|
|
||||||
"expired_at": None,
|
|
||||||
"status": "active",
|
|
||||||
"quotas": free_plan.quotas or {},
|
|
||||||
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
|
||||||
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
|
||||||
}, msg="免费套餐")
|
|
||||||
|
|
||||||
return success(data=svc.build_response(sub))
|
|
||||||
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
# 社区版无 premium 模块,从配置文件读取免费套餐
|
|
||||||
if not current_user.tenant:
|
|
||||||
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
|
||||||
|
|
||||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
|
||||||
|
|
||||||
plan = DEFAULT_FREE_PLAN
|
|
||||||
response_data = {
|
|
||||||
"subscription_id": None,
|
|
||||||
"tenant_id": str(current_user.tenant.id),
|
|
||||||
"package_plan_id": None,
|
|
||||||
"package_version": plan["version"],
|
|
||||||
"package_plan": {
|
|
||||||
"id": None,
|
|
||||||
"name": plan["name"],
|
|
||||||
"name_en": plan.get("name_en"),
|
|
||||||
"version": plan["version"],
|
|
||||||
"category": plan["category"],
|
|
||||||
"tier_level": plan["tier_level"],
|
|
||||||
"price": float(plan["price"]),
|
|
||||||
"billing_cycle": plan["billing_cycle"],
|
|
||||||
"core_value": plan.get("core_value"),
|
|
||||||
"core_value_en": plan.get("core_value_en"),
|
|
||||||
"tech_support": plan.get("tech_support"),
|
|
||||||
"tech_support_en": plan.get("tech_support_en"),
|
|
||||||
"sla_compliance": plan.get("sla_compliance"),
|
|
||||||
"sla_compliance_en": plan.get("sla_compliance_en"),
|
|
||||||
"page_customization": plan.get("page_customization"),
|
|
||||||
"page_customization_en": plan.get("page_customization_en"),
|
|
||||||
"theme_color": plan.get("theme_color"),
|
|
||||||
},
|
|
||||||
"started_at": None,
|
|
||||||
"expired_at": None,
|
|
||||||
"status": "active",
|
|
||||||
"quotas": plan["quotas"],
|
|
||||||
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
|
||||||
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
|
||||||
}
|
|
||||||
return success(data=response_data, msg="社区版免费套餐")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取租户套餐信息失败: {e}", exc_info=True)
|
|
||||||
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐信息失败"))
|
|
||||||
|
|
||||||
|
|
||||||
@public_router.get("/package-plans", response_model=ApiResponse, summary="获取套餐列表(公开)")
|
|
||||||
async def list_package_plans_public(
|
|
||||||
category: Optional[str] = None,
|
|
||||||
status: Optional[bool] = None,
|
|
||||||
search: Optional[str] = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
公开接口,无需鉴权。
|
|
||||||
SaaS 版从数据库读取套餐列表;社区版降级返回 default_free_plan.py 中的免费套餐。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from premium.platform_admin.package_plan_service import PackagePlanService
|
|
||||||
from premium.platform_admin.package_plan_schema import PackagePlanResponse
|
|
||||||
svc = PackagePlanService(db)
|
|
||||||
result = svc.get_list(page=1, size=9999, category=category, status=status, search=search)
|
|
||||||
return success(data=[PackagePlanResponse.model_validate(p).model_dump(mode="json") for p in result["items"]])
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
|
||||||
plan = DEFAULT_FREE_PLAN
|
|
||||||
return success(data=[{
|
|
||||||
"id": None,
|
|
||||||
"name": plan["name"],
|
|
||||||
"name_en": plan.get("name_en"),
|
|
||||||
"version": plan["version"],
|
|
||||||
"category": plan["category"],
|
|
||||||
"tier_level": plan["tier_level"],
|
|
||||||
"price": float(plan["price"]),
|
|
||||||
"billing_cycle": plan["billing_cycle"],
|
|
||||||
"core_value": plan.get("core_value"),
|
|
||||||
"core_value_en": plan.get("core_value_en"),
|
|
||||||
"tech_support": plan.get("tech_support"),
|
|
||||||
"tech_support_en": plan.get("tech_support_en"),
|
|
||||||
"sla_compliance": plan.get("sla_compliance"),
|
|
||||||
"sla_compliance_en": plan.get("sla_compliance_en"),
|
|
||||||
"page_customization": plan.get("page_customization"),
|
|
||||||
"page_customization_en": plan.get("page_customization_en"),
|
|
||||||
"theme_color": plan.get("theme_color"),
|
|
||||||
"status": plan.get("status", True),
|
|
||||||
"quotas": plan["quotas"],
|
|
||||||
}])
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取套餐列表失败: {e}", exc_info=True)
|
|
||||||
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐列表失败"))
|
|
||||||
@@ -3,11 +3,8 @@ from typing import Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.schemas.tool_schema import (
|
from app.schemas.tool_schema import (
|
||||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
|
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
|
||||||
CustomToolTestRequest, ToolActiveUpdate
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
@@ -17,7 +14,6 @@ from app.models import User
|
|||||||
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.core.exceptions import BusinessException
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
||||||
|
|
||||||
@@ -76,8 +72,6 @@ async def get_tool_methods(
|
|||||||
if methods is None:
|
if methods is None:
|
||||||
raise HTTPException(status_code=404, detail="工具不存在")
|
raise HTTPException(status_code=404, detail="工具不存在")
|
||||||
return success(data=methods, msg="获取工具方法成功")
|
return success(data=methods, msg="获取工具方法成功")
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -103,13 +97,7 @@ async def create_tool(
|
|||||||
):
|
):
|
||||||
"""创建工具"""
|
"""创建工具"""
|
||||||
try:
|
try:
|
||||||
# 将 MCP 来源字段合并进 config
|
tool_id = service.create_tool(
|
||||||
if request.tool_type == ToolType.MCP:
|
|
||||||
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
|
|
||||||
val = getattr(request, key, None)
|
|
||||||
if val is not None:
|
|
||||||
request.config[key] = val
|
|
||||||
tool_id = await service.create_tool(
|
|
||||||
name=request.name,
|
name=request.name,
|
||||||
tool_type=request.tool_type,
|
tool_type=request.tool_type,
|
||||||
tenant_id=current_user.tenant_id,
|
tenant_id=current_user.tenant_id,
|
||||||
@@ -119,12 +107,8 @@ async def create_tool(
|
|||||||
tags=request.tags
|
tags=request.tags
|
||||||
)
|
)
|
||||||
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
||||||
except BusinessException as e:
|
|
||||||
raise HTTPException(status_code=400, detail=e.message)
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -153,8 +137,6 @@ async def update_tool(
|
|||||||
return success(msg="工具更新成功")
|
return success(msg="工具更新成功")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -165,7 +147,7 @@ async def delete_tool(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
service: ToolService = Depends(get_tool_service)
|
service: ToolService = Depends(get_tool_service)
|
||||||
):
|
):
|
||||||
"""删除工具(逻辑删除,is_active=False)"""
|
"""删除工具"""
|
||||||
try:
|
try:
|
||||||
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
||||||
if not success_flag:
|
if not success_flag:
|
||||||
@@ -173,34 +155,6 @@ async def delete_tool(
|
|||||||
return success(msg="工具删除成功")
|
return success(msg="工具删除成功")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{tool_id}/active", response_model=ApiResponse)
|
|
||||||
async def set_tool_active(
|
|
||||||
tool_id: str,
|
|
||||||
request: ToolActiveUpdate,
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
service: ToolService = Depends(get_tool_service)
|
|
||||||
):
|
|
||||||
"""设置工具可用状态(启用/禁用)
|
|
||||||
|
|
||||||
- is_active=true: 启用工具
|
|
||||||
- is_active=false: 禁用工具(等同于删除,但可恢复)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
|
|
||||||
if not success_flag:
|
|
||||||
raise HTTPException(status_code=404, detail="工具不存在")
|
|
||||||
action = "启用" if request.is_active else "禁用"
|
|
||||||
return success(msg=f"工具已{action}")
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -233,8 +187,6 @@ async def execute_tool(
|
|||||||
},
|
},
|
||||||
msg="工具执行完成"
|
msg="工具执行完成"
|
||||||
)
|
)
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -251,8 +203,6 @@ async def parse_openapi_schema(
|
|||||||
if result["success"] is False:
|
if result["success"] is False:
|
||||||
raise HTTPException(status_code=400, detail=result["message"])
|
raise HTTPException(status_code=400, detail=result["message"])
|
||||||
return success(data=result, msg="Schema解析完成")
|
return success(data=result, msg="Schema解析完成")
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -266,10 +216,8 @@ async def sync_mcp_tools(
|
|||||||
try:
|
try:
|
||||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||||
if not result.get("success", False):
|
if not result.get("success", False):
|
||||||
raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
|
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
||||||
return success(data=result, msg="MCP工具列表同步完成")
|
return success(data=result, msg="MCP工具列表同步完成")
|
||||||
except BusinessException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -292,10 +240,8 @@ async def test_tool_connection(
|
|||||||
# 普通连接测试
|
# 普通连接测试
|
||||||
result = await service.test_connection(tool_id, current_user.tenant_id)
|
result = await service.test_connection(tool_id, current_user.tenant_id)
|
||||||
if result["success"] is False:
|
if result["success"] is False:
|
||||||
raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
|
raise HTTPException(status_code=400, detail=result["message"])
|
||||||
return success(data=result, msg="连接测试完成")
|
return success(data=result, msg="连接测试完成")
|
||||||
except BusinessException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -1,26 +1,16 @@
|
|||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.core.exceptions import BusinessException
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user, get_current_superuser
|
from app.dependencies import get_current_user, get_current_superuser
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.schemas import user_schema
|
from app.schemas import user_schema
|
||||||
from app.schemas.user_schema import (
|
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
|
||||||
ChangePasswordRequest,
|
|
||||||
AdminChangePasswordRequest,
|
|
||||||
SendEmailCodeRequest,
|
|
||||||
VerifyEmailCodeRequest,
|
|
||||||
VerifyPasswordRequest)
|
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import user_service
|
from app.services import user_service
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.core.security import verify_password
|
|
||||||
from app.i18n.dependencies import get_translator
|
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -35,8 +25,7 @@ router = APIRouter(
|
|||||||
def create_superuser(
|
def create_superuser(
|
||||||
user: user_schema.UserCreate,
|
user: user_schema.UserCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_superuser: User = Depends(get_current_superuser),
|
current_superuser: User = Depends(get_current_superuser)
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""创建超级管理员(仅超级管理员可访问)"""
|
"""创建超级管理员(仅超级管理员可访问)"""
|
||||||
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
||||||
@@ -45,7 +34,7 @@ def create_superuser(
|
|||||||
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
||||||
|
|
||||||
result_schema = user_schema.User.model_validate(result)
|
result_schema = user_schema.User.model_validate(result)
|
||||||
return success(data=result_schema, msg=t("users.create.superuser_success"))
|
return success(data=result_schema, msg="超级管理员创建成功")
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{user_id}", response_model=ApiResponse)
|
@router.delete("/{user_id}", response_model=ApiResponse)
|
||||||
@@ -53,7 +42,6 @@ def delete_user(
|
|||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""停用用户(软删除)"""
|
"""停用用户(软删除)"""
|
||||||
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
@@ -61,14 +49,13 @@ def delete_user(
|
|||||||
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
||||||
return success(msg=t("users.delete.deactivate_success"))
|
return success(msg="用户停用成功")
|
||||||
|
|
||||||
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
||||||
def activate_user(
|
def activate_user(
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""激活用户"""
|
"""激活用户"""
|
||||||
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
@@ -79,14 +66,13 @@ def activate_user(
|
|||||||
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
||||||
|
|
||||||
result_schema = user_schema.User.model_validate(result)
|
result_schema = user_schema.User.model_validate(result)
|
||||||
return success(data=result_schema, msg=t("users.activate.success"))
|
return success(data=result_schema, msg="用户激活成功")
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
def get_current_user_info(
|
def get_current_user_info(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""获取当前用户信息"""
|
"""获取当前用户信息"""
|
||||||
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
||||||
@@ -106,27 +92,12 @@ def get_current_user_info(
|
|||||||
result_schema.current_workspace_name = current_workspace.name
|
result_schema.current_workspace_name = current_workspace.name
|
||||||
|
|
||||||
for ws in result.workspaces:
|
for ws in result.workspaces:
|
||||||
if ws.workspace_id == current_user.current_workspace_id and ws.is_active:
|
if ws.workspace_id == current_user.current_workspace_id:
|
||||||
result_schema.role = ws.role
|
result_schema.role = ws.role
|
||||||
break
|
break
|
||||||
|
|
||||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||||
|
return success(data=result_schema, msg="用户信息获取成功")
|
||||||
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
|
||||||
if current_user.external_source:
|
|
||||||
try:
|
|
||||||
from premium.sso.models import SSOSource
|
|
||||||
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
|
||||||
if source and source.permissions:
|
|
||||||
result_schema.permissions = source.permissions
|
|
||||||
else:
|
|
||||||
result_schema.permissions = []
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
result_schema.permissions = []
|
|
||||||
else:
|
|
||||||
result_schema.permissions = ["all"]
|
|
||||||
|
|
||||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/superusers", response_model=ApiResponse)
|
@router.get("/superusers", response_model=ApiResponse)
|
||||||
@@ -134,7 +105,6 @@ def get_tenant_superusers(
|
|||||||
include_inactive: bool = False,
|
include_inactive: bool = False,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
||||||
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
||||||
@@ -147,7 +117,7 @@ def get_tenant_superusers(
|
|||||||
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
||||||
|
|
||||||
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
||||||
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=ApiResponse)
|
@router.get("/{user_id}", response_model=ApiResponse)
|
||||||
@@ -155,7 +125,6 @@ def get_user_info_by_id(
|
|||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""根据用户ID获取用户信息"""
|
"""根据用户ID获取用户信息"""
|
||||||
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
@@ -166,7 +135,7 @@ def get_user_info_by_id(
|
|||||||
api_logger.info(f"用户信息获取成功: {result.username}")
|
api_logger.info(f"用户信息获取成功: {result.username}")
|
||||||
|
|
||||||
result_schema = user_schema.User.model_validate(result)
|
result_schema = user_schema.User.model_validate(result)
|
||||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
return success(data=result_schema, msg="用户信息获取成功")
|
||||||
|
|
||||||
|
|
||||||
@router.put("/change-password", response_model=ApiResponse)
|
@router.put("/change-password", response_model=ApiResponse)
|
||||||
@@ -174,7 +143,6 @@ async def change_password(
|
|||||||
request: ChangePasswordRequest,
|
request: ChangePasswordRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""修改当前用户密码"""
|
"""修改当前用户密码"""
|
||||||
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
||||||
@@ -187,7 +155,7 @@ async def change_password(
|
|||||||
current_user=current_user
|
current_user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
||||||
return success(msg=t("auth.password.change_success"))
|
return success(msg="密码修改成功")
|
||||||
|
|
||||||
|
|
||||||
@router.put("/admin/change-password", response_model=ApiResponse)
|
@router.put("/admin/change-password", response_model=ApiResponse)
|
||||||
@@ -195,7 +163,6 @@ async def admin_change_password(
|
|||||||
request: AdminChangePasswordRequest,
|
request: AdminChangePasswordRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""超级管理员修改指定用户的密码"""
|
"""超级管理员修改指定用户的密码"""
|
||||||
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
||||||
@@ -210,107 +177,7 @@ async def admin_change_password(
|
|||||||
# 根据是否生成了随机密码来构造响应
|
# 根据是否生成了随机密码来构造响应
|
||||||
if request.new_password:
|
if request.new_password:
|
||||||
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
||||||
return success(msg=t("auth.password.change_success"))
|
return success(msg="密码修改成功")
|
||||||
else:
|
else:
|
||||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||||
return success(data=generated_password, msg=t("auth.password.reset_success"))
|
return success(data=generated_password, msg="密码重置成功")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
|
||||||
def verify_pwd(
|
|
||||||
request: VerifyPasswordRequest,
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
|
||||||
"""验证当前用户密码"""
|
|
||||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
|
||||||
|
|
||||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
|
||||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
|
||||||
if not is_valid:
|
|
||||||
raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
|
|
||||||
return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/send-email-code", response_model=ApiResponse)
|
|
||||||
async def send_email_code(
|
|
||||||
request: SendEmailCodeRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
|
||||||
"""发送邮箱验证码"""
|
|
||||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
|
||||||
|
|
||||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
|
||||||
|
|
||||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
|
||||||
return success(msg=t("users.email.code_sent"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/change-email", response_model=ApiResponse)
|
|
||||||
async def change_email(
|
|
||||||
request: VerifyEmailCodeRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
|
||||||
"""验证验证码并修改邮箱"""
|
|
||||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
|
||||||
|
|
||||||
await user_service.verify_and_change_email(
|
|
||||||
db=db,
|
|
||||||
user_id=current_user.id,
|
|
||||||
new_email=request.new_email,
|
|
||||||
code=request.code
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
|
||||||
return success(msg=t("users.email.change_success"))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me/language", response_model=ApiResponse)
|
|
||||||
def get_current_user_language(
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
|
||||||
"""获取当前用户的语言偏好"""
|
|
||||||
api_logger.info(f"获取用户语言偏好: {current_user.username}")
|
|
||||||
|
|
||||||
language = user_service.get_user_language_preference(
|
|
||||||
db=db,
|
|
||||||
user_id=current_user.id,
|
|
||||||
current_user=current_user
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
|
|
||||||
return success(
|
|
||||||
data=user_schema.LanguagePreferenceResponse(language=language),
|
|
||||||
msg=t("users.language.get_success")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/me/language", response_model=ApiResponse)
|
|
||||||
def update_current_user_language(
|
|
||||||
request: user_schema.LanguagePreferenceRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
t: Callable = Depends(get_translator)
|
|
||||||
):
|
|
||||||
"""设置当前用户的语言偏好"""
|
|
||||||
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
|
|
||||||
|
|
||||||
updated_user = user_service.update_user_language_preference(
|
|
||||||
db=db,
|
|
||||||
user_id=current_user.id,
|
|
||||||
language=request.language,
|
|
||||||
current_user=current_user
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
|
|
||||||
return success(
|
|
||||||
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
|
|
||||||
msg=t("users.language.update_success")
|
|
||||||
)
|
|
||||||
@@ -8,26 +8,23 @@ from sqlalchemy.orm import Session
|
|||||||
from fastapi import APIRouter, Depends,Header
|
from fastapi import APIRouter, Depends,Header
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.language_utils import get_language_from_header
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success, fail
|
from app.core.response_utils import success, fail
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.api_key_utils import timestamp_to_datetime
|
from app.core.api_key_utils import timestamp_to_datetime
|
||||||
|
from app.services.memory_base_service import Translation_English
|
||||||
from app.services.user_memory_service import (
|
from app.services.user_memory_service import (
|
||||||
UserMemoryService,
|
UserMemoryService,
|
||||||
analytics_memory_types,
|
analytics_memory_types,
|
||||||
analytics_graph_data,
|
analytics_graph_data,
|
||||||
analytics_community_graph_data,
|
|
||||||
)
|
)
|
||||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||||
from app.repositories.workspace_repository import WorkspaceRepository
|
from app.repositories.workspace_repository import WorkspaceRepository
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.schemas.end_user_schema import (
|
||||||
from app.schemas.end_user_info_schema import (
|
EndUserProfileResponse,
|
||||||
EndUserInfoResponse,
|
EndUserProfileUpdate,
|
||||||
EndUserInfoCreate,
|
|
||||||
EndUserInfoUpdate,
|
|
||||||
)
|
)
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
@@ -48,6 +45,7 @@ router = APIRouter(
|
|||||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||||
async def get_memory_insight_report_api(
|
async def get_memory_insight_report_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@@ -57,10 +55,18 @@ async def get_memory_insight_report_api(
|
|||||||
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
||||||
如需生成新的洞察报告,请使用专门的生成接口。
|
如需生成新的洞察报告,请使用专门的生成接口。
|
||||||
"""
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
workspace_repo = WorkspaceRepository(db)
|
||||||
|
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||||
|
|
||||||
|
if workspace_models:
|
||||||
|
model_id = workspace_models.get("llm", None)
|
||||||
|
else:
|
||||||
|
model_id = None
|
||||||
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||||
try:
|
try:
|
||||||
# 调用服务层获取缓存数据
|
# 调用服务层获取缓存数据
|
||||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
result = await user_memory_service.get_cached_memory_insight(db, end_user_id,model_id,language_type)
|
||||||
|
|
||||||
if result["is_cached"]:
|
if result["is_cached"]:
|
||||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||||
@@ -76,7 +82,7 @@ async def get_memory_insight_report_api(
|
|||||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||||
async def get_user_summary_api(
|
async def get_user_summary_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@@ -85,14 +91,7 @@ async def get_user_summary_api(
|
|||||||
|
|
||||||
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
||||||
如需生成新的用户摘要,请使用专门的生成接口。
|
如需生成新的用户摘要,请使用专门的生成接口。
|
||||||
|
|
||||||
语言控制:
|
|
||||||
- 使用 X-Language-Type Header 指定语言
|
|
||||||
- 如果未传 Header,默认使用中文 (zh)
|
|
||||||
"""
|
"""
|
||||||
# 使用集中化的语言校验
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
workspace_repo = WorkspaceRepository(db)
|
workspace_repo = WorkspaceRepository(db)
|
||||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||||
@@ -104,7 +103,7 @@ async def get_user_summary_api(
|
|||||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||||
try:
|
try:
|
||||||
# 调用服务层获取缓存数据
|
# 调用服务层获取缓存数据
|
||||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
|
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language_type)
|
||||||
|
|
||||||
if result["is_cached"]:
|
if result["is_cached"]:
|
||||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||||
@@ -120,7 +119,6 @@ async def get_user_summary_api(
|
|||||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||||
async def generate_cache_api(
|
async def generate_cache_api(
|
||||||
request: GenerateCacheRequest,
|
request: GenerateCacheRequest,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@@ -129,14 +127,7 @@ async def generate_cache_api(
|
|||||||
|
|
||||||
- 如果提供 end_user_id,只为该用户生成
|
- 如果提供 end_user_id,只为该用户生成
|
||||||
- 如果不提供,为当前工作空间的所有用户生成
|
- 如果不提供,为当前工作空间的所有用户生成
|
||||||
|
|
||||||
语言控制:
|
|
||||||
- 使用 X-Language-Type Header 指定语言 ("zh" 中文, "en" 英文)
|
|
||||||
- 如果未传 Header,默认使用中文 (zh)
|
|
||||||
"""
|
"""
|
||||||
# 使用集中化的语言校验
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
@@ -148,7 +139,7 @@ async def generate_cache_api(
|
|||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||||
f"end_user_id={end_user_id if end_user_id else '全部用户'}, language={language}"
|
f"end_user_id={end_user_id if end_user_id else '全部用户'}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -157,12 +148,10 @@ async def generate_cache_api(
|
|||||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||||
|
|
||||||
# 生成记忆洞察
|
# 生成记忆洞察
|
||||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
|
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id)
|
||||||
language=language)
|
|
||||||
|
|
||||||
# 生成用户摘要
|
# 生成用户摘要
|
||||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
|
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id)
|
||||||
language=language)
|
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
result = {
|
result = {
|
||||||
@@ -196,7 +185,7 @@ async def generate_cache_api(
|
|||||||
# 为整个工作空间生成
|
# 为整个工作空间生成
|
||||||
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
|
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
|
||||||
|
|
||||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id, language=language)
|
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
|
||||||
|
|
||||||
# 记录统计信息
|
# 记录统计信息
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
@@ -224,8 +213,7 @@ async def get_node_statistics_api(
|
|||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||||
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用新的记忆类型统计函数
|
# 调用新的记忆类型统计函数
|
||||||
@@ -233,14 +221,12 @@ async def get_node_statistics_api(
|
|||||||
|
|
||||||
# 计算总数用于日志
|
# 计算总数用于日志
|
||||||
total_count = sum(item["count"] for item in result)
|
total_count = sum(item["count"] for item in result)
|
||||||
api_logger.info(
|
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||||
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||||
async def get_graph_data_api(
|
async def get_graph_data_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
@@ -303,164 +289,106 @@ async def get_graph_data_api(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||||
async def get_community_graph_data_api(
|
async def get_end_user_profile(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
workspace_repo = WorkspaceRepository(db)
|
||||||
|
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||||
|
|
||||||
|
if workspace_models:
|
||||||
|
model_id = workspace_models.get("llm", None)
|
||||||
|
else:
|
||||||
|
model_id = None
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||||
f"workspace={workspace_id}"
|
f"workspace={workspace_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
|
# 查询终端用户
|
||||||
|
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||||
|
|
||||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
if not end_user:
|
||||||
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}")
|
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||||
return success(data=result, msg=result.get("message", "查询成功"))
|
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||||
|
# 构建响应数据
|
||||||
api_logger.info(
|
profile_data = EndUserProfileResponse(
|
||||||
f"成功获取社区图谱: end_user_id={end_user_id}, "
|
id=end_user.id,
|
||||||
f"nodes={result['statistics']['total_nodes']}, "
|
other_name=end_user.other_name,
|
||||||
f"edges={result['statistics']['total_edges']}"
|
position=end_user.position,
|
||||||
|
department=end_user.department,
|
||||||
|
contact=end_user.contact,
|
||||||
|
phone=end_user.phone,
|
||||||
|
hire_date=end_user.hire_date,
|
||||||
|
updatetime_profile=end_user.updatetime_profile
|
||||||
)
|
)
|
||||||
return success(data=result, msg="查询成功")
|
|
||||||
|
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
|
||||||
|
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
|
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
|
||||||
|
|
||||||
#=======================终端用户信息接口=======================
|
|
||||||
|
|
||||||
@router.get("/end_user_info", response_model=ApiResponse)
|
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
||||||
async def get_end_user_info(
|
async def update_end_user_profile(
|
||||||
end_user_id: str,
|
profile_update: EndUserProfileUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
查询终端用户信息记录
|
更新终端用户的基本信息
|
||||||
|
|
||||||
根据 end_user_id 查询单条终端用户信息记录。
|
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
||||||
|
所有字段都是可选的,只更新提供的字段。
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
end_user_id = profile_update.end_user_id
|
||||||
|
|
||||||
|
# 验证工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询终端用户信息但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"查询终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||||
f"workspace={workspace_id}"
|
f"workspace={workspace_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 校验 end_user 是否属于当前工作空间
|
# 调用 Service 层处理业务逻辑
|
||||||
end_user_repo = EndUserRepository(db)
|
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
|
||||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
|
||||||
if end_user is None:
|
|
||||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
|
||||||
if str(end_user.workspace_id) != str(workspace_id):
|
|
||||||
api_logger.warning(
|
|
||||||
f"用户 {current_user.username} 尝试查询不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
|
||||||
)
|
|
||||||
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
|
||||||
|
|
||||||
result = user_memory_service.get_end_user_info(db, end_user_id)
|
|
||||||
|
|
||||||
if result["success"]:
|
if result["success"]:
|
||||||
api_logger.info(f"成功查询终端用户信息: end_user_id={end_user_id}")
|
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
|
||||||
return success(data=result["data"], msg="查询成功")
|
|
||||||
else:
|
|
||||||
error_msg = result["error"]
|
|
||||||
api_logger.error(f"查询终端用户信息失败: end_user_id={end_user_id}, error={error_msg}")
|
|
||||||
|
|
||||||
if error_msg == "终端用户信息记录不存在":
|
|
||||||
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
|
||||||
elif error_msg == "无效的终端用户ID格式":
|
|
||||||
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
|
||||||
else:
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询终端用户信息失败", error_msg)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/end_user_info/updated", response_model=ApiResponse)
|
|
||||||
async def update_end_user_info(
|
|
||||||
info_update: EndUserInfoUpdate,
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
更新终端用户信息记录
|
|
||||||
|
|
||||||
根据 end_user_id 更新终端用户信息记录,支持批量更新多个别名。
|
|
||||||
|
|
||||||
示例请求体:
|
|
||||||
{
|
|
||||||
"end_user_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
|
||||||
"other_name": "张三1",
|
|
||||||
"aliases": ["小张", "张工"],
|
|
||||||
"meta_data": {"position": "工程师", "department": "技术部"}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
end_user_id = info_update.end_user_id
|
|
||||||
|
|
||||||
if workspace_id is None:
|
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新终端用户信息但未选择工作空间")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"更新终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
|
||||||
f"workspace={workspace_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 校验 end_user 是否属于当前工作空间
|
|
||||||
end_user_repo = EndUserRepository(db)
|
|
||||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
|
||||||
if end_user is None:
|
|
||||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
|
||||||
if str(end_user.workspace_id) != str(workspace_id):
|
|
||||||
api_logger.warning(
|
|
||||||
f"用户 {current_user.username} 尝试更新不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
|
||||||
)
|
|
||||||
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
|
||||||
|
|
||||||
# 获取更新数据(排除 end_user_id)
|
|
||||||
update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
|
||||||
|
|
||||||
result = user_memory_service.update_end_user_info(db, end_user_id, update_data)
|
|
||||||
|
|
||||||
if result["success"]:
|
|
||||||
api_logger.info(f"成功更新终端用户信息: end_user_id={end_user_id}")
|
|
||||||
return success(data=result["data"], msg="更新成功")
|
return success(data=result["data"], msg="更新成功")
|
||||||
else:
|
else:
|
||||||
error_msg = result["error"]
|
error_msg = result["error"]
|
||||||
api_logger.error(f"终端用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||||
|
|
||||||
if error_msg == "终端用户信息记录不存在":
|
# 根据错误类型映射到合适的业务错误码
|
||||||
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
if error_msg == "终端用户不存在":
|
||||||
elif error_msg == "无效的终端用户ID格式":
|
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
||||||
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
elif error_msg == "无效的用户ID格式":
|
||||||
|
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
|
||||||
else:
|
else:
|
||||||
return fail(BizCode.INTERNAL_ERROR, "终端用户信息更新失败", error_msg)
|
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||||
|
|
||||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||||
async def memory_space_timeline_of_shared_memories(
|
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
id: str, label: str,
|
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
# 使用集中化的语言校验
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
workspace_id=current_user.current_workspace_id
|
workspace_id=current_user.current_workspace_id
|
||||||
workspace_repo = WorkspaceRepository(db)
|
workspace_repo = WorkspaceRepository(db)
|
||||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||||
@@ -470,11 +398,9 @@ async def memory_space_timeline_of_shared_memories(
|
|||||||
else:
|
else:
|
||||||
model_id = None
|
model_id = None
|
||||||
MemoryEntity = MemoryEntityService(id, label)
|
MemoryEntity = MemoryEntityService(id, label)
|
||||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language_type)
|
||||||
|
|
||||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||||
async def memory_space_relationship_evolution(id: str, label: str,
|
async def memory_space_relationship_evolution(id: str, label: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
|||||||
610
api/app/controllers/workflow_controller.py
Normal file
610
api/app/controllers/workflow_controller.py
Normal file
@@ -0,0 +1,610 @@
|
|||||||
|
"""
|
||||||
|
工作流 API 控制器
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Path, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||||
|
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.models.app_model import App
|
||||||
|
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||||
|
from app.schemas.workflow_schema import (
|
||||||
|
WorkflowConfigCreate,
|
||||||
|
WorkflowConfigUpdate,
|
||||||
|
WorkflowConfig,
|
||||||
|
WorkflowValidationResponse,
|
||||||
|
WorkflowExecution,
|
||||||
|
WorkflowNodeExecution,
|
||||||
|
WorkflowExecutionRequest,
|
||||||
|
WorkflowExecutionResponse
|
||||||
|
)
|
||||||
|
from app.core.response_utils import success, fail
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/apps", tags=["workflow"])
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 工作流配置管理 ====================
|
||||||
|
|
||||||
|
@router.post("/{app_id}/workflow")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def create_workflow_config(
|
||||||
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
config: WorkflowConfigCreate,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
):
|
||||||
|
"""创建工作流配置
|
||||||
|
|
||||||
|
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证应用是否存在且属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="应用不存在或无权访问"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证应用类型
|
||||||
|
if app.type != "workflow":
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INVALID_PARAMETER,
|
||||||
|
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建工作流配置
|
||||||
|
workflow_config = service.create_workflow_config(
|
||||||
|
app_id=app_id,
|
||||||
|
nodes=[node.model_dump() for node in config.nodes],
|
||||||
|
edges=[edge.model_dump() for edge in config.edges],
|
||||||
|
variables=[var.model_dump() for var in config.variables],
|
||||||
|
execution_config=config.execution_config.model_dump(),
|
||||||
|
triggers=[trigger.model_dump() for trigger in config.triggers],
|
||||||
|
validate=True # 进行基础验证
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=WorkflowConfig.model_validate(workflow_config),
|
||||||
|
msg="工作流配置创建成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
except BusinessException as e:
|
||||||
|
logger.warning(f"创建工作流配置失败: {e.message}")
|
||||||
|
return fail(code=e.error_code, msg=e.message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"创建工作流配置失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# @router.get("/{app_id}/workflow")
|
||||||
|
# async def get_workflow_config(
|
||||||
|
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
# db: Annotated[Session, Depends(get_db)],
|
||||||
|
# current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
#
|
||||||
|
# ):
|
||||||
|
# """获取工作流配置
|
||||||
|
#
|
||||||
|
# 获取应用的工作流配置详情。
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# # 验证应用是否存在且属于当前工作空间
|
||||||
|
# app = db.query(App).filter(
|
||||||
|
# App.id == app_id,
|
||||||
|
# App.workspace_id == current_user.current_workspace_id,
|
||||||
|
# App.is_active == True
|
||||||
|
# ).first()
|
||||||
|
#
|
||||||
|
# if not app:
|
||||||
|
# return fail(
|
||||||
|
# code=BizCode.NOT_FOUND,
|
||||||
|
# msg="应用不存在或无权访问"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# # 获取工作流配置
|
||||||
|
# service = WorkflowService(db)
|
||||||
|
# workflow_config = service.get_workflow_config(app_id)
|
||||||
|
#
|
||||||
|
# if not workflow_config:
|
||||||
|
# return fail(
|
||||||
|
# code=BizCode.NOT_FOUND,
|
||||||
|
# msg="工作流配置不存在"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# return success(
|
||||||
|
# data=WorkflowConfig.model_validate(workflow_config)
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
|
||||||
|
# return fail(
|
||||||
|
# code=BizCode.INTERNAL_ERROR,
|
||||||
|
# msg=f"获取工作流配置失败: {str(e)}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# @router.put("/{app_id}/workflow")
|
||||||
|
# async def update_workflow_config(
|
||||||
|
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
# config: WorkflowConfigUpdate,
|
||||||
|
# db: Annotated[Session, Depends(get_db)],
|
||||||
|
# current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
# ):
|
||||||
|
# """更新工作流配置
|
||||||
|
|
||||||
|
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# # 验证应用是否存在且属于当前工作空间
|
||||||
|
# app = db.query(App).filter(
|
||||||
|
# App.id == app_id,
|
||||||
|
# App.workspace_id == current_user.current_workspace_id,
|
||||||
|
# App.is_active == True
|
||||||
|
# ).first()
|
||||||
|
|
||||||
|
# if not app:
|
||||||
|
# return fail(
|
||||||
|
# code=BizCode.NOT_FOUND,
|
||||||
|
# msg="应用不存在或无权访问"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # 更新工作流配置
|
||||||
|
# workflow_config = service.update_workflow_config(
|
||||||
|
# app_id=app_id,
|
||||||
|
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
|
||||||
|
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
|
||||||
|
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
|
||||||
|
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
|
||||||
|
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
|
||||||
|
# validate=True
|
||||||
|
# )
|
||||||
|
|
||||||
|
# return success(
|
||||||
|
# data=WorkflowConfig.model_validate(workflow_config),
|
||||||
|
# msg="工作流配置更新成功"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# except BusinessException as e:
|
||||||
|
# logger.warning(f"更新工作流配置失败: {e.message}")
|
||||||
|
# return fail(code=e.error_code, msg=e.message)
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
|
||||||
|
# return fail(
|
||||||
|
# code=BizCode.INTERNAL_ERROR,
|
||||||
|
# msg=f"更新工作流配置失败: {str(e)}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{app_id}/workflow")
|
||||||
|
async def delete_workflow_config(
|
||||||
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
):
|
||||||
|
"""删除工作流配置
|
||||||
|
|
||||||
|
删除应用的工作流配置。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证应用是否存在且属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="应用不存在或无权访问"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除工作流配置
|
||||||
|
deleted = service.delete_workflow_config(app_id)
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="工作流配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(msg="工作流配置删除成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"删除工作流配置失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{app_id}/workflow/validate")
|
||||||
|
async def validate_workflow_config(
|
||||||
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||||
|
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||||
|
):
|
||||||
|
"""验证工作流配置
|
||||||
|
|
||||||
|
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证应用是否存在且属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="应用不存在或无权访问"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证工作流配置
|
||||||
|
|
||||||
|
if for_publish:
|
||||||
|
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
|
||||||
|
else:
|
||||||
|
workflow_config = service.get_workflow_config(app_id)
|
||||||
|
if not workflow_config:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="工作流配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.workflow.validator import validate_workflow_config as validate_config
|
||||||
|
config_dict = {
|
||||||
|
"nodes": workflow_config.nodes,
|
||||||
|
"edges": workflow_config.edges,
|
||||||
|
"variables": workflow_config.variables,
|
||||||
|
"execution_config": workflow_config.execution_config,
|
||||||
|
"triggers": workflow_config.triggers
|
||||||
|
}
|
||||||
|
is_valid, errors = validate_config(config_dict, for_publish=False)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=WorkflowValidationResponse(
|
||||||
|
is_valid=is_valid,
|
||||||
|
errors=errors,
|
||||||
|
warnings=[]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except BusinessException as e:
|
||||||
|
logger.warning(f"验证工作流配置失败: {e.message}")
|
||||||
|
return fail(code=e.error_code, msg=e.message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"验证工作流配置失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 工作流执行管理 ====================
|
||||||
|
|
||||||
|
@router.get("/{app_id}/workflow/executions")
|
||||||
|
async def get_workflow_executions(
|
||||||
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||||
|
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||||
|
offset: Annotated[int, Query(ge=0)] = 0
|
||||||
|
):
|
||||||
|
"""获取工作流执行记录列表
|
||||||
|
|
||||||
|
获取应用的工作流执行历史记录。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证应用是否存在且属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="应用不存在或无权访问"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取执行记录
|
||||||
|
executions = service.get_executions_by_app(app_id, limit, offset)
|
||||||
|
|
||||||
|
# 获取统计信息
|
||||||
|
statistics = service.get_execution_statistics(app_id)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"executions": [WorkflowExecution.model_validate(e) for e in executions],
|
||||||
|
"statistics": statistics,
|
||||||
|
"pagination": {
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
"total": statistics["total"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"获取工作流执行记录失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/workflow/executions/{execution_id}")
|
||||||
|
async def get_workflow_execution(
|
||||||
|
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
):
|
||||||
|
"""获取工作流执行详情
|
||||||
|
|
||||||
|
获取单个工作流执行的详细信息,包括所有节点的执行记录。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取执行记录
|
||||||
|
execution = service.get_execution(execution_id)
|
||||||
|
|
||||||
|
if not execution:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="执行记录不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证应用是否属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == execution.app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="无权访问该执行记录"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取节点执行记录
|
||||||
|
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"execution": WorkflowExecution.model_validate(execution),
|
||||||
|
"node_executions": [
|
||||||
|
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"获取工作流执行详情失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 工作流执行 ====================
|
||||||
|
@router.post("/{app_id}/workflow/run")
|
||||||
|
async def run_workflow(
|
||||||
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
request: WorkflowExecutionRequest,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
):
|
||||||
|
"""执行工作流
|
||||||
|
|
||||||
|
执行工作流并返回结果。支持流式和非流式两种模式。
|
||||||
|
|
||||||
|
**非流式模式**:等待工作流执行完成后返回完整结果。
|
||||||
|
|
||||||
|
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证应用是否存在且属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="应用不存在或无权访问"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证应用类型
|
||||||
|
if app.type != "workflow":
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INVALID_PARAMETER,
|
||||||
|
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 准备输入数据
|
||||||
|
input_data = {
|
||||||
|
"message": request.message or "",
|
||||||
|
"variables": request.variables
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行工作流
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
# 流式执行
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
"""生成 SSE 事件
|
||||||
|
|
||||||
|
SSE 格式:
|
||||||
|
event: <event_type>
|
||||||
|
data: <json_data>
|
||||||
|
|
||||||
|
支持的事件类型:
|
||||||
|
- workflow_start: 工作流开始
|
||||||
|
- workflow_end: 工作流结束
|
||||||
|
- node_start: 节点开始执行
|
||||||
|
- node_end: 节点执行完成
|
||||||
|
- node_chunk: 中间节点的流式输出
|
||||||
|
- message: 最终消息的流式输出(End 节点及其相邻节点)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async for event in await service.run_workflow(
|
||||||
|
app_id=app_id,
|
||||||
|
input_data=input_data,
|
||||||
|
triggered_by=current_user.id,
|
||||||
|
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||||
|
stream=True
|
||||||
|
):
|
||||||
|
# 提取事件类型和数据
|
||||||
|
event_type = event.get("event", "message")
|
||||||
|
event_data = event.get("data", {})
|
||||||
|
|
||||||
|
# 转换为标准 SSE 格式(字符串)
|
||||||
|
# event: <type>
|
||||||
|
# data: <json>
|
||||||
|
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||||
|
yield sse_message
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||||
|
# 发送错误事件
|
||||||
|
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||||
|
yield sse_error
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 非流式执行
|
||||||
|
result = await service.run_workflow(
|
||||||
|
app_id=app_id,
|
||||||
|
input_data=input_data,
|
||||||
|
triggered_by=current_user.id,
|
||||||
|
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=WorkflowExecutionResponse(
|
||||||
|
execution_id=result["execution_id"],
|
||||||
|
status=result["status"],
|
||||||
|
output=result.get("output"),
|
||||||
|
output_data=result.get("output_data"),
|
||||||
|
error_message=result.get("error_message"),
|
||||||
|
elapsed_time=result.get("elapsed_time"),
|
||||||
|
token_usage=result.get("token_usage")
|
||||||
|
),
|
||||||
|
msg="工作流执行完成"
|
||||||
|
)
|
||||||
|
|
||||||
|
except BusinessException as e:
|
||||||
|
logger.warning(f"执行工作流失败: {e.message}")
|
||||||
|
return fail(code=e.error_code, msg=e.message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"执行工作流异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"执行工作流失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||||
|
async def cancel_workflow_execution(
|
||||||
|
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
):
|
||||||
|
"""取消工作流执行
|
||||||
|
|
||||||
|
取消正在运行的工作流执行。
|
||||||
|
|
||||||
|
**注意**:当前版本仅更新状态为 cancelled,实际的执行取消功能待实现。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取执行记录
|
||||||
|
execution = service.get_execution(execution_id)
|
||||||
|
|
||||||
|
if not execution:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="执行记录不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证应用是否属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == execution.app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="无权访问该执行记录"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查执行状态
|
||||||
|
if execution.status not in ["pending", "running"]:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INVALID_PARAMETER,
|
||||||
|
msg=f"无法取消状态为 {execution.status} 的执行"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新状态为 cancelled
|
||||||
|
service.update_execution_status(execution_id, "cancelled")
|
||||||
|
|
||||||
|
return success(msg="工作流执行已取消")
|
||||||
|
|
||||||
|
except BusinessException as e:
|
||||||
|
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||||
|
return fail(code=e.code, msg=e.message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"取消工作流执行失败: {str(e)}"
|
||||||
|
)
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
@@ -14,12 +14,6 @@ from app.dependencies import (
|
|||||||
get_current_user,
|
get_current_user,
|
||||||
workspace_access_guard,
|
workspace_access_guard,
|
||||||
)
|
)
|
||||||
from app.i18n.dependencies import get_current_language, get_translator
|
|
||||||
from app.i18n.serializers import (
|
|
||||||
WorkspaceSerializer,
|
|
||||||
WorkspaceMemberSerializer,
|
|
||||||
WorkspaceInviteSerializer
|
|
||||||
)
|
|
||||||
from app.models.tenant_model import Tenants
|
from app.models.tenant_model import Tenants
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.models.workspace_model import InviteStatus
|
from app.models.workspace_model import InviteStatus
|
||||||
@@ -35,7 +29,6 @@ from app.schemas.workspace_schema import (
|
|||||||
WorkspaceUpdate,
|
WorkspaceUpdate,
|
||||||
)
|
)
|
||||||
from app.services import workspace_service
|
from app.services import workspace_service
|
||||||
from app.core.quota_stub import check_workspace_quota
|
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -72,9 +65,7 @@ def get_workspaces(
|
|||||||
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
current_tenant: Tenants = Depends(get_current_tenant),
|
current_tenant: Tenants = Depends(get_current_tenant)
|
||||||
language: str = Depends(get_current_language),
|
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""获取当前租户下用户参与的所有工作空间
|
"""获取当前租户下用户参与的所有工作空间
|
||||||
|
|
||||||
@@ -97,51 +88,25 @@ def get_workspaces(
|
|||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
||||||
|
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
|
||||||
# 使用序列化器添加国际化字段
|
return success(data=workspaces_schema, msg="工作空间列表获取成功")
|
||||||
serializer = WorkspaceSerializer()
|
|
||||||
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
|
|
||||||
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
|
|
||||||
|
|
||||||
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ApiResponse)
|
@router.post("", response_model=ApiResponse)
|
||||||
@check_workspace_quota
|
|
||||||
def create_workspace(
|
def create_workspace(
|
||||||
workspace: WorkspaceCreate,
|
workspace: WorkspaceCreate,
|
||||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
language: str = Depends(get_current_language),
|
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""创建新的工作空间"""
|
"""创建新的工作空间"""
|
||||||
from app.core.language_utils import get_language_from_header
|
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
|
||||||
|
|
||||||
# 验证并获取语言参数
|
|
||||||
language = get_language_from_header(language_type)
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"用户 {current_user.username} 请求创建工作空间: {workspace.name}, "
|
|
||||||
f"language={language}"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = workspace_service.create_workspace(
|
result = workspace_service.create_workspace(
|
||||||
db=db, workspace=workspace, user=current_user, language=language
|
db=db, workspace=workspace, user=current_user)
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
|
||||||
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
result_schema = WorkspaceResponse.model_validate(result)
|
||||||
f"创建者: {current_user.username}, language={language}"
|
return success(data=result_schema, msg="工作空间创建成功")
|
||||||
)
|
|
||||||
|
|
||||||
# 使用序列化器添加国际化字段
|
|
||||||
serializer = WorkspaceSerializer()
|
|
||||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
|
||||||
result_i18n = serializer.serialize(result_data, language)
|
|
||||||
|
|
||||||
return success(data=result_i18n, msg=t("workspace.created"))
|
|
||||||
|
|
||||||
@router.put("", response_model=ApiResponse)
|
@router.put("", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
@@ -149,8 +114,6 @@ def update_workspace(
|
|||||||
workspace: WorkspaceUpdate,
|
workspace: WorkspaceUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
language: str = Depends(get_current_language),
|
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""更新工作空间"""
|
"""更新工作空间"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -163,21 +126,14 @@ def update_workspace(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
||||||
|
result_schema = WorkspaceResponse.model_validate(result)
|
||||||
# 使用序列化器添加国际化字段
|
return success(data=result_schema, msg="工作空间更新成功")
|
||||||
serializer = WorkspaceSerializer()
|
|
||||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
|
||||||
result_i18n = serializer.serialize(result_data, language)
|
|
||||||
|
|
||||||
return success(data=result_i18n, msg=t("workspace.updated"))
|
|
||||||
|
|
||||||
@router.get("/members", response_model=ApiResponse)
|
@router.get("/members", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_cur_workspace_members(
|
def get_cur_workspace_members(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
language: str = Depends(get_current_language),
|
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""获取工作空间成员列表(关系序列化)"""
|
"""获取工作空间成员列表(关系序列化)"""
|
||||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
||||||
@@ -188,14 +144,8 @@ def get_cur_workspace_members(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
||||||
|
|
||||||
# 转换为表格项并使用序列化器添加国际化字段
|
|
||||||
table_items = _convert_members_to_table_items(members)
|
table_items = _convert_members_to_table_items(members)
|
||||||
serializer = WorkspaceMemberSerializer()
|
return success(data=table_items, msg="工作空间成员列表获取成功")
|
||||||
members_data = [item.model_dump() for item in table_items]
|
|
||||||
members_i18n = serializer.serialize_list(members_data, language)
|
|
||||||
|
|
||||||
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/members", response_model=ApiResponse)
|
@router.put("/members", response_model=ApiResponse)
|
||||||
@@ -205,7 +155,6 @@ def update_workspace_members(
|
|||||||
updates: List[WorkspaceMemberUpdate],
|
updates: List[WorkspaceMemberUpdate],
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
||||||
@@ -216,28 +165,27 @@ def update_workspace_members(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
||||||
return success(msg=t("workspace.members.role_updated"))
|
return success(msg="成员角色更新成功")
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def delete_workspace_member(
|
def delete_workspace_member(
|
||||||
member_id: uuid.UUID,
|
member_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
|
|
||||||
await workspace_service.delete_workspace_member(
|
workspace_service.delete_workspace_member(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
member_id=member_id,
|
member_id=member_id,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
||||||
return success(msg=t("workspace.members.deleted"))
|
return success(msg="成员删除成功")
|
||||||
|
|
||||||
|
|
||||||
# 创建空间协作邀请
|
# 创建空间协作邀请
|
||||||
@@ -247,8 +195,6 @@ def create_workspace_invite(
|
|||||||
invite_data: WorkspaceInviteCreate,
|
invite_data: WorkspaceInviteCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
language: str = Depends(get_current_language),
|
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""创建工作空间邀请"""
|
"""创建工作空间邀请"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -261,12 +207,7 @@ def create_workspace_invite(
|
|||||||
user=current_user
|
user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
||||||
|
return success(data=result, msg="邀请创建成功")
|
||||||
# 使用序列化器添加国际化字段
|
|
||||||
serializer = WorkspaceInviteSerializer()
|
|
||||||
result_i18n = serializer.serialize(result, language)
|
|
||||||
|
|
||||||
return success(data=result_i18n, msg=t("workspace.invites.created"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/invites", response_model=ApiResponse)
|
@router.get("/invites", response_model=ApiResponse)
|
||||||
@@ -278,8 +219,6 @@ def get_workspace_invites(
|
|||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
language: str = Depends(get_current_language),
|
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""获取工作空间邀请列表"""
|
"""获取工作空间邀请列表"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -294,30 +233,18 @@ def get_workspace_invites(
|
|||||||
offset=offset
|
offset=offset
|
||||||
)
|
)
|
||||||
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
||||||
|
return success(data=invites, msg="邀请列表获取成功")
|
||||||
# 使用序列化器添加国际化字段
|
|
||||||
serializer = WorkspaceInviteSerializer()
|
|
||||||
invites_i18n = serializer.serialize_list(invites, language)
|
|
||||||
|
|
||||||
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
|
|
||||||
|
|
||||||
|
|
||||||
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
||||||
def get_workspace_invite_info(
|
def get_workspace_invite_info(
|
||||||
token: str,
|
token: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
language: str = Depends(get_current_language),
|
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""获取工作空间邀请用户信息(无需认证)"""
|
"""获取工作空间邀请用户信息(无需认证)"""
|
||||||
result = workspace_service.validate_invite_token(db=db, token=token)
|
result = workspace_service.validate_invite_token(db=db, token=token)
|
||||||
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
||||||
|
return success(data=result, msg="邀请验证成功")
|
||||||
# 使用序列化器添加国际化字段
|
|
||||||
serializer = WorkspaceInviteSerializer()
|
|
||||||
result_i18n = serializer.serialize(result, language)
|
|
||||||
|
|
||||||
return success(data=result_i18n, msg=t("workspace.invites.validated"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
||||||
@@ -327,8 +254,6 @@ def revoke_workspace_invite(
|
|||||||
invite_id: uuid.UUID,
|
invite_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
language: str = Depends(get_current_language),
|
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""撤销工作空间邀请"""
|
"""撤销工作空间邀请"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -341,12 +266,7 @@ def revoke_workspace_invite(
|
|||||||
user=current_user
|
user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
||||||
|
return success(data=result, msg="邀请撤销成功")
|
||||||
# 使用序列化器添加国际化字段
|
|
||||||
serializer = WorkspaceInviteSerializer()
|
|
||||||
result_i18n = serializer.serialize(result, language)
|
|
||||||
|
|
||||||
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
|
|
||||||
|
|
||||||
# ==================== 公开邀请接口(无需认证) ====================
|
# ==================== 公开邀请接口(无需认证) ====================
|
||||||
|
|
||||||
@@ -369,7 +289,6 @@ def switch_workspace(
|
|||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""切换工作空间"""
|
"""切换工作空间"""
|
||||||
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
||||||
@@ -380,7 +299,7 @@ def switch_workspace(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
||||||
return success(msg=t("workspace.switched"))
|
return success(msg="工作空间切换成功")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/storage", response_model=ApiResponse)
|
@router.get("/storage", response_model=ApiResponse)
|
||||||
@@ -388,7 +307,6 @@ def switch_workspace(
|
|||||||
def get_workspace_storage_type(
|
def get_workspace_storage_type(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""获取当前工作空间的存储类型"""
|
"""获取当前工作空间的存储类型"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -400,7 +318,7 @@ def get_workspace_storage_type(
|
|||||||
user=current_user
|
user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
||||||
return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
|
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/workspace_models", response_model=ApiResponse)
|
@router.get("/workspace_models", response_model=ApiResponse)
|
||||||
@@ -408,8 +326,6 @@ def get_workspace_storage_type(
|
|||||||
def workspace_models_configs(
|
def workspace_models_configs(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
language: str = Depends(get_current_language),
|
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -425,14 +341,14 @@ def workspace_models_configs(
|
|||||||
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=t("workspace.not_found")
|
detail="工作空间不存在或无权访问"
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||||
)
|
)
|
||||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
|
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
|
||||||
|
|
||||||
|
|
||||||
@router.put("/workspace_models", response_model=ApiResponse)
|
@router.put("/workspace_models", response_model=ApiResponse)
|
||||||
@@ -441,7 +357,6 @@ def update_workspace_models_configs(
|
|||||||
models_update: WorkspaceModelsUpdate,
|
models_update: WorkspaceModelsUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
t: callable = Depends(get_translator)
|
|
||||||
):
|
):
|
||||||
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -458,5 +373,5 @@ def update_workspace_models_configs(
|
|||||||
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
||||||
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
||||||
)
|
)
|
||||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))
|
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
# -*- coding: UTF-8 -*-
|
|
||||||
# Author: Eternity
|
|
||||||
# @Email: 1533512157@qq.com
|
|
||||||
# @Time : 2026/2/9 16:24
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
"""Agent Middleware - 动态技能过滤"""
|
|
||||||
import uuid
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
from langchain_core.runnables import RunnablePassthrough
|
|
||||||
|
|
||||||
from app.services.skill_service import SkillService
|
|
||||||
from app.repositories.skill_repository import SkillRepository
|
|
||||||
|
|
||||||
|
|
||||||
class AgentMiddleware:
|
|
||||||
"""Agent 中间件 - 用于动态过滤和加载技能"""
|
|
||||||
|
|
||||||
def __init__(self, skills: Optional[dict] = None):
|
|
||||||
"""
|
|
||||||
初始化中间件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]}
|
|
||||||
"""
|
|
||||||
self.skills = skills or {}
|
|
||||||
self.enabled = self.skills.get('enabled', False)
|
|
||||||
self.all_skills = self.skills.get('all_skills', False)
|
|
||||||
self.skill_ids = self.skills.get('skill_ids', [])
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def filter_tools(
|
|
||||||
tools: List,
|
|
||||||
message: str = "",
|
|
||||||
skill_configs: Dict[str, Any] = None,
|
|
||||||
tool_to_skill_map: Dict[str, str] = None
|
|
||||||
) -> tuple[List, List[str]]:
|
|
||||||
"""
|
|
||||||
根据消息内容和技能配置动态过滤工具
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tools: 所有可用工具列表
|
|
||||||
message: 用户消息(可用于智能过滤)
|
|
||||||
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
|
|
||||||
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(过滤后的工具列表, 激活的技能ID列表)
|
|
||||||
"""
|
|
||||||
if not tools:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
# 如果没有技能配置,返回所有工具
|
|
||||||
if not skill_configs:
|
|
||||||
return tools, []
|
|
||||||
|
|
||||||
# 基于关键词匹配激活技能
|
|
||||||
activated_skill_ids = []
|
|
||||||
message_lower = message.lower()
|
|
||||||
|
|
||||||
for skill_id, config in skill_configs.items():
|
|
||||||
if not config.get('enabled', True):
|
|
||||||
continue
|
|
||||||
|
|
||||||
keywords = config.get('keywords', [])
|
|
||||||
# 如果没有关键词限制,或消息包含关键词,则激活该技能
|
|
||||||
if not keywords or any(kw.lower() in message_lower for kw in keywords):
|
|
||||||
activated_skill_ids.append(skill_id)
|
|
||||||
|
|
||||||
# 如果没有工具映射关系,返回所有工具
|
|
||||||
if not tool_to_skill_map:
|
|
||||||
return tools, activated_skill_ids
|
|
||||||
|
|
||||||
# 根据激活的技能过滤工具
|
|
||||||
filtered_tools = []
|
|
||||||
for tool in tools:
|
|
||||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
|
||||||
# 如果工具不属于任何skill(base_tools),或者工具所属的skill被激活,则保留
|
|
||||||
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
|
|
||||||
filtered_tools.append(tool)
|
|
||||||
|
|
||||||
return filtered_tools, activated_skill_ids
|
|
||||||
|
|
||||||
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
|
|
||||||
"""
|
|
||||||
加载技能关联的工具
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
tenant_id: 租户id
|
|
||||||
base_tools: 基础工具列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
|
|
||||||
"""
|
|
||||||
|
|
||||||
tools_dict = {}
|
|
||||||
tool_to_skill_map = {} # 工具名称到技能ID的映射
|
|
||||||
|
|
||||||
if base_tools:
|
|
||||||
for tool in base_tools:
|
|
||||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
|
||||||
tools_dict[tool_name] = tool
|
|
||||||
# base_tools 不属于任何 skill,不加入映射
|
|
||||||
|
|
||||||
skill_configs = {}
|
|
||||||
skill_ids_to_load = []
|
|
||||||
|
|
||||||
# 如果启用技能且 all_skills 为 True,加载租户下所有激活的技能
|
|
||||||
if self.enabled and self.all_skills:
|
|
||||||
skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000)
|
|
||||||
skill_ids_to_load = [str(skill.id) for skill in skills]
|
|
||||||
elif self.enabled and self.skill_ids:
|
|
||||||
skill_ids_to_load = self.skill_ids
|
|
||||||
|
|
||||||
if skill_ids_to_load:
|
|
||||||
for skill_id in skill_ids_to_load:
|
|
||||||
try:
|
|
||||||
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
|
||||||
if skill and skill.is_active:
|
|
||||||
# 保存技能配置(包含prompt)
|
|
||||||
config = skill.config or {}
|
|
||||||
config['prompt'] = skill.prompt
|
|
||||||
config['name'] = skill.name
|
|
||||||
skill_configs[skill_id] = config
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 加载技能工具并获取映射关系
|
|
||||||
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id)
|
|
||||||
|
|
||||||
# 只添加不冲突的 skill_tools
|
|
||||||
for tool in skill_tools:
|
|
||||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
|
||||||
if tool_name not in tools_dict:
|
|
||||||
tools_dict[tool_name] = tool
|
|
||||||
# 复制映射关系
|
|
||||||
if tool_name in skill_tool_map:
|
|
||||||
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
|
|
||||||
|
|
||||||
return list(tools_dict.values()), skill_configs, tool_to_skill_map
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
|
|
||||||
"""
|
|
||||||
根据激活的技能ID获取对应的提示词
|
|
||||||
|
|
||||||
Args:
|
|
||||||
activated_skill_ids: 被激活的技能ID列表
|
|
||||||
skill_configs: 技能配置字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
合并后的提示词
|
|
||||||
"""
|
|
||||||
prompts = []
|
|
||||||
for skill_id in activated_skill_ids:
|
|
||||||
config = skill_configs.get(skill_id, {})
|
|
||||||
prompt = config.get('prompt')
|
|
||||||
name = config.get('name', 'Skill')
|
|
||||||
if prompt:
|
|
||||||
prompts.append(f"# {name}\n{prompt}")
|
|
||||||
|
|
||||||
return "\n\n".join(prompts) if prompts else ""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_runnable():
|
|
||||||
"""创建可运行的中间件"""
|
|
||||||
return RunnablePassthrough()
|
|
||||||
@@ -11,15 +11,17 @@ LangChain Agent 封装
|
|||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
from app.db import get_db
|
||||||
from langchain_core.tools import BaseTool
|
|
||||||
from langgraph.errors import GraphRecursionError
|
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
from app.models.models_model import ModelType
|
from app.models.models_model import ModelType
|
||||||
|
from app.services.memory_agent_service import (
|
||||||
|
get_end_user_connected_config,
|
||||||
|
)
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -31,18 +33,11 @@ class LangChainAgent:
|
|||||||
api_key: str,
|
api_key: str,
|
||||||
provider: str = "openai",
|
provider: str = "openai",
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
is_omni: bool = False,
|
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 2000,
|
max_tokens: int = 2000,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
tools: Optional[Sequence[BaseTool]] = None,
|
tools: Optional[Sequence[BaseTool]] = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False
|
||||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
|
||||||
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
|
|
||||||
deep_thinking: bool = False, # 是否启用深度思考模式
|
|
||||||
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
|
|
||||||
json_output: bool = False, # 是否强制 JSON 输出
|
|
||||||
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
|
|
||||||
):
|
):
|
||||||
"""初始化 LangChain Agent
|
"""初始化 LangChain Agent
|
||||||
|
|
||||||
@@ -55,71 +50,28 @@ class LangChainAgent:
|
|||||||
max_tokens: 最大 token 数
|
max_tokens: 最大 token 数
|
||||||
system_prompt: 系统提示词
|
system_prompt: 系统提示词
|
||||||
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
||||||
streaming: 是否启用流式输出
|
streaming: 是否启用流式输出(默认 True)
|
||||||
max_iterations: 最大迭代次数(None 表示自动计算:基础 5 次 + 每个工具 2 次)
|
|
||||||
max_tool_consecutive_calls: 单个工具最大连续调用次数(默认 3 次)
|
|
||||||
"""
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||||
self.tools = tools or []
|
self.tools = tools or []
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
self.is_omni = is_omni
|
|
||||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
|
||||||
|
|
||||||
# 工具调用计数器:记录每个工具的连续调用次数
|
# 创建 RedBearLLM(支持多提供商)
|
||||||
self.tool_call_counter: Dict[str, int] = {}
|
|
||||||
self.last_tool_called: Optional[str] = None
|
|
||||||
|
|
||||||
# 根据工具数量动态调整最大迭代次数
|
|
||||||
# 基础值 + 每个工具额外的调用机会
|
|
||||||
if max_iterations is None:
|
|
||||||
# 自动计算:基础 5 次 + 每个工具 2 次额外机会
|
|
||||||
self.max_iterations = 5 + len(self.tools) * 2
|
|
||||||
else:
|
|
||||||
self.max_iterations = max_iterations
|
|
||||||
|
|
||||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
|
||||||
|
|
||||||
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format
|
|
||||||
# 在 system prompt 中注入 JSON 要求
|
|
||||||
from app.models.models_model import ModelProvider
|
|
||||||
if json_output and (
|
|
||||||
(provider.lower() == ModelProvider.DASHSCOPE and not is_omni)
|
|
||||||
or provider.lower() == ModelProvider.VOLCANO
|
|
||||||
# 有工具时 response_format 会被移除,所有 provider 都需要 system prompt 注入保证 JSON 输出
|
|
||||||
or bool(tools)
|
|
||||||
):
|
|
||||||
self.system_prompt += "\n请以JSON格式输出。"
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
|
||||||
f"tool_count={len(self.tools)}, "
|
|
||||||
f"max_tool_consecutive_calls={self.max_tool_consecutive_calls}, "
|
|
||||||
f"auto_calculated={max_iterations is None}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 RedBearLLM,capability 校验由 RedBearModelConfig 统一处理
|
|
||||||
model_config = RedBearModelConfig(
|
model_config = RedBearModelConfig(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
is_omni=is_omni,
|
|
||||||
capability=capability,
|
|
||||||
deep_thinking=deep_thinking,
|
|
||||||
thinking_budget_tokens=thinking_budget_tokens,
|
|
||||||
json_output=json_output,
|
|
||||||
extra_params={
|
extra_params={
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"streaming": streaming
|
"streaming": streaming # 使用参数控制流式
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
||||||
# 从经过校验的 config 读取实际生效的能力开关
|
|
||||||
self.deep_thinking = model_config.deep_thinking
|
|
||||||
self.json_output = model_config.json_output
|
|
||||||
|
|
||||||
# 获取底层模型用于真正的流式调用
|
# 获取底层模型用于真正的流式调用
|
||||||
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
|
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
|
||||||
@@ -128,14 +80,11 @@ class LangChainAgent:
|
|||||||
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
||||||
self._underlying_llm.streaming = True
|
self._underlying_llm.streaming = True
|
||||||
|
|
||||||
# 包装工具以跟踪连续调用次数
|
|
||||||
wrapped_tools = self._wrap_tools_with_tracking(self.tools) if self.tools else None
|
|
||||||
|
|
||||||
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
||||||
# 无论是否有工具,都使用 agent 统一处理
|
# 无论是否有工具,都使用 agent 统一处理
|
||||||
self.agent = create_agent(
|
self.agent = create_agent(
|
||||||
model=self.llm,
|
model=self.llm,
|
||||||
tools=wrapped_tools,
|
tools=self.tools if self.tools else None,
|
||||||
system_prompt=self.system_prompt
|
system_prompt=self.system_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -147,92 +96,17 @@ class LangChainAgent:
|
|||||||
"has_api_base": bool(api_base),
|
"has_api_base": bool(api_base),
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
"max_iterations": self.max_iterations,
|
|
||||||
"max_tool_consecutive_calls": self.max_tool_consecutive_calls,
|
|
||||||
"tool_count": len(self.tools),
|
"tool_count": len(self.tools),
|
||||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||||
# "tool_count": len(self.tools)
|
# "tool_count": len(self.tools)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _wrap_tools_with_tracking(self, tools: Sequence[BaseTool]) -> List[BaseTool]:
|
|
||||||
"""包装工具以跟踪连续调用次数
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tools: 原始工具列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[BaseTool]: 包装后的工具列表
|
|
||||||
"""
|
|
||||||
from langchain_core.tools import StructuredTool
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
wrapped_tools = []
|
|
||||||
|
|
||||||
for original_tool in tools:
|
|
||||||
tool_name = original_tool.name
|
|
||||||
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
|
||||||
|
|
||||||
if not original_func:
|
|
||||||
# 如果无法获取原始函数,直接使用原工具
|
|
||||||
wrapped_tools.append(original_tool)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 创建包装函数
|
|
||||||
def make_wrapped_func(tool_name, original_func):
|
|
||||||
"""创建包装函数的工厂函数,避免闭包问题"""
|
|
||||||
|
|
||||||
@wraps(original_func)
|
|
||||||
def wrapped_func(*args, **kwargs):
|
|
||||||
"""包装后的工具函数,跟踪连续调用次数"""
|
|
||||||
# 检查是否是连续调用同一个工具
|
|
||||||
if self.last_tool_called == tool_name:
|
|
||||||
self.tool_call_counter[tool_name] = self.tool_call_counter.get(tool_name, 0) + 1
|
|
||||||
else:
|
|
||||||
# 切换到新工具,重置计数器
|
|
||||||
self.tool_call_counter[tool_name] = 1
|
|
||||||
self.last_tool_called = tool_name
|
|
||||||
|
|
||||||
current_count = self.tool_call_counter[tool_name]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查是否超过最大连续调用次数
|
|
||||||
if current_count > self.max_tool_consecutive_calls:
|
|
||||||
logger.warning(
|
|
||||||
f"工具 '{tool_name}' 连续调用次数已达上限 ({self.max_tool_consecutive_calls}),"
|
|
||||||
f"返回提示信息"
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
|
||||||
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 调用原始工具函数
|
|
||||||
return original_func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapped_func
|
|
||||||
|
|
||||||
# 使用 StructuredTool 创建新工具
|
|
||||||
wrapped_tool = StructuredTool(
|
|
||||||
name=original_tool.name,
|
|
||||||
description=original_tool.description,
|
|
||||||
func=make_wrapped_func(tool_name, original_func),
|
|
||||||
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
|
||||||
)
|
|
||||||
|
|
||||||
wrapped_tools.append(wrapped_tool)
|
|
||||||
|
|
||||||
return wrapped_tools
|
|
||||||
|
|
||||||
def _prepare_messages(
|
def _prepare_messages(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
history: Optional[List[Dict[str, str]]] = None,
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None
|
||||||
files: Optional[List[Dict[str, Any]]] = None
|
|
||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
"""准备消息列表
|
"""准备消息列表
|
||||||
|
|
||||||
@@ -240,12 +114,14 @@ class LangChainAgent:
|
|||||||
message: 用户消息
|
message: 用户消息
|
||||||
history: 历史消息列表
|
history: 历史消息列表
|
||||||
context: 上下文信息
|
context: 上下文信息
|
||||||
files: 多模态文件内容列表(已处理)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[BaseMessage]: 消息列表
|
List[BaseMessage]: 消息列表
|
||||||
"""
|
"""
|
||||||
messages: list = []
|
messages = []
|
||||||
|
|
||||||
|
# 添加系统提示词
|
||||||
|
messages.append(SystemMessage(content=self.system_prompt))
|
||||||
|
|
||||||
# 添加历史消息
|
# 添加历史消息
|
||||||
if history:
|
if history:
|
||||||
@@ -260,94 +136,19 @@ class LangChainAgent:
|
|||||||
if context:
|
if context:
|
||||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||||
|
|
||||||
# 构建用户消息(支持多模态)
|
|
||||||
if files and len(files) > 0:
|
|
||||||
content_parts = self._build_multimodal_content(user_content, files)
|
|
||||||
messages.append(HumanMessage(content=content_parts))
|
|
||||||
else:
|
|
||||||
# 纯文本消息
|
|
||||||
messages.append(HumanMessage(content=user_content))
|
messages.append(HumanMessage(content=user_content))
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_tokens_from_message(msg) -> int:
|
|
||||||
"""从 AIMessage 或类似对象中提取 total_tokens,兼容多种 provider 格式
|
|
||||||
|
|
||||||
支持的格式:
|
|
||||||
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
|
|
||||||
- response_metadata.usage.total_tokens (部分 provider)
|
|
||||||
- usage_metadata.total_tokens (LangChain 新版)
|
|
||||||
"""
|
|
||||||
total = 0
|
|
||||||
# 1. response_metadata
|
|
||||||
response_meta = getattr(msg, "response_metadata", None)
|
|
||||||
if response_meta and isinstance(response_meta, dict):
|
|
||||||
# 尝试 token_usage 路径
|
|
||||||
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
|
|
||||||
if isinstance(token_usage, dict):
|
|
||||||
total = token_usage.get("total_tokens", 0)
|
|
||||||
# 2. usage_metadata(LangChain 新版 AIMessage 属性)
|
|
||||||
if not total:
|
|
||||||
usage_meta = getattr(msg, "usage_metadata", None)
|
|
||||||
if usage_meta:
|
|
||||||
if isinstance(usage_meta, dict):
|
|
||||||
total = usage_meta.get("total_tokens", 0)
|
|
||||||
else:
|
|
||||||
total = getattr(usage_meta, "total_tokens", 0)
|
|
||||||
return total or 0
|
|
||||||
|
|
||||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
构建多模态消息内容
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: 文本内容
|
|
||||||
files: 文件列表(已由 MultimodalService 处理为对应 provider 的格式)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict]: 消息内容列表
|
|
||||||
"""
|
|
||||||
# 根据 provider 使用不同的文本格式
|
|
||||||
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
|
|
||||||
# ModelProvider.GPUSTACK] or (
|
|
||||||
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
|
|
||||||
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
|
|
||||||
# content_parts = [{"type": "text", "text": text}]
|
|
||||||
# else:
|
|
||||||
# # 通义千问等: {"text": "..."}
|
|
||||||
# content_parts = [{"type": "text", "text": text}]
|
|
||||||
content_parts = [{"type": "text", "text": text}]
|
|
||||||
|
|
||||||
# 添加文件内容
|
|
||||||
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
|
||||||
content_parts.extend(files)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"构建多模态消息: provider={self.provider}, "
|
|
||||||
f"parts={len(content_parts)}, "
|
|
||||||
f"files={len(files)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return content_parts
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_reasoning_content(msg) -> str:
|
|
||||||
"""从 AIMessage 中提取深度思考内容(reasoning_content)
|
|
||||||
|
|
||||||
所有 provider 统一通过 additional_kwargs.reasoning_content 传递:
|
|
||||||
- DeepSeek-R1 / QwQ: 原生字段
|
|
||||||
- Volcano (Doubao-thinking): 由 VolcanoChatOpenAI 从 delta.reasoning_content 注入
|
|
||||||
"""
|
|
||||||
additional = getattr(msg, "additional_kwargs", None) or {}
|
|
||||||
return additional.get("reasoning_content") or additional.get("reasoning", "")
|
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
history: Optional[List[Dict[str, str]]] = None,
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
files: Optional[List[Dict[str, Any]]] = None
|
end_user_id: Optional[str] = None,
|
||||||
|
config_id: Optional[str] = None, # 添加这个参数
|
||||||
|
storage_type: Optional[str] = None,
|
||||||
|
user_rag_memory_id: Optional[str] = None,
|
||||||
|
memory_flag: Optional[bool] = True
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""执行对话
|
"""执行对话
|
||||||
|
|
||||||
@@ -355,15 +156,35 @@ class LangChainAgent:
|
|||||||
message: 用户消息
|
message: 用户消息
|
||||||
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||||
context: 上下文信息(如知识库检索结果)
|
context: 上下文信息(如知识库检索结果)
|
||||||
files: 多模态文件
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含 content 和元数据的字典
|
Dict: 包含 content 和元数据的字典
|
||||||
"""
|
"""
|
||||||
|
message_chat= message
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
actual_config_id = config_id
|
||||||
|
# If config_id is None, try to get from end_user's connected config
|
||||||
|
if actual_config_id is None and end_user_id:
|
||||||
try:
|
try:
|
||||||
# 准备消息列表(支持多模态)
|
from app.services.memory_agent_service import (
|
||||||
messages = self._prepare_messages(message, history, context, files)
|
get_end_user_connected_config,
|
||||||
|
)
|
||||||
|
db = next(get_db())
|
||||||
|
try:
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
actual_config_id = connected_config.get("memory_config_id")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get db session: {e}")
|
||||||
|
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||||
|
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
|
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
|
try:
|
||||||
|
# 准备消息列表
|
||||||
|
messages = self._prepare_messages(message, history, context)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"准备调用 LangChain Agent",
|
"准备调用 LangChain Agent",
|
||||||
@@ -371,84 +192,27 @@ class LangChainAgent:
|
|||||||
"has_context": bool(context),
|
"has_context": bool(context),
|
||||||
"has_history": bool(history),
|
"has_history": bool(history),
|
||||||
"has_tools": bool(self.tools),
|
"has_tools": bool(self.tools),
|
||||||
"has_files": bool(files),
|
"message_count": len(messages)
|
||||||
"message_count": len(messages),
|
|
||||||
"max_iterations": self.max_iterations
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 统一使用 agent.invoke 调用
|
# 统一使用 agent.invoke 调用
|
||||||
# 通过 recursion_limit 限制最大迭代次数,防止工具调用死循环
|
result = await self.agent.ainvoke({"messages": messages})
|
||||||
try:
|
|
||||||
result = await self.agent.ainvoke(
|
|
||||||
{"messages": messages},
|
|
||||||
config={"recursion_limit": self.max_iterations}
|
|
||||||
)
|
|
||||||
except (RecursionError, GraphRecursionError) as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
|
||||||
extra={"error": str(e)}
|
|
||||||
)
|
|
||||||
# 返回一个友好的错误提示
|
|
||||||
return {
|
|
||||||
"content": f"抱歉,我在处理您的请求时遇到了问题。已达到最大处理步骤限制({self.max_iterations}次)。请尝试简化您的问题或稍后再试。",
|
|
||||||
"model": self.model_name,
|
|
||||||
"elapsed_time": time.time() - start_time,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 获取最后的 AI 消息
|
# 获取最后的 AI 消息
|
||||||
output_messages = result.get("messages", [])
|
output_messages = result.get("messages", [])
|
||||||
content = ""
|
content = ""
|
||||||
|
|
||||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
reasoning_content = ""
|
|
||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
|
||||||
logger.debug(f"AI 消息内容: {msg.content}")
|
|
||||||
|
|
||||||
# 处理多模态响应:content 可能是字符串或列表
|
|
||||||
if isinstance(msg.content, str):
|
|
||||||
content = msg.content
|
content = msg.content
|
||||||
logger.debug(f"提取字符串内容,长度: {len(content)}")
|
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||||
elif isinstance(msg.content, list):
|
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||||
# 多模态响应:提取文本部分
|
|
||||||
logger.debug(f"多模态响应,列表长度: {len(msg.content)}")
|
|
||||||
text_parts = []
|
|
||||||
for item in msg.content:
|
|
||||||
logger.debug(f"处理项: {item}")
|
|
||||||
if isinstance(item, dict):
|
|
||||||
# 通义千问格式: {"text": "..."}
|
|
||||||
if "text" in item:
|
|
||||||
text = item.get("text", "")
|
|
||||||
text_parts.append(text)
|
|
||||||
logger.debug(f"提取文本: {text[:100]}...")
|
|
||||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
|
||||||
elif item.get("type") == "text":
|
|
||||||
text = item.get("text", "")
|
|
||||||
text_parts.append(text)
|
|
||||||
logger.debug(f"提取文本: {text[:100]}...")
|
|
||||||
elif isinstance(item, str):
|
|
||||||
text_parts.append(item)
|
|
||||||
logger.debug(f"提取字符串: {item[:100]}...")
|
|
||||||
content = "".join(text_parts)
|
|
||||||
logger.debug(f"合并后内容长度: {len(content)}")
|
|
||||||
else:
|
|
||||||
content = str(msg.content)
|
|
||||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
|
||||||
total_tokens = self._extract_tokens_from_message(msg)
|
|
||||||
reasoning_content = self._extract_reasoning_content(msg) if self.deep_thinking else ""
|
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
if memory_flag:
|
||||||
|
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||||
response = {
|
response = {
|
||||||
"content": content,
|
"content": content,
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
@@ -459,8 +223,6 @@ class LangChainAgent:
|
|||||||
"total_tokens": total_tokens
|
"total_tokens": total_tokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if reasoning_content:
|
|
||||||
response["reasoning_content"] = reasoning_content
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Agent 调用完成",
|
"Agent 调用完成",
|
||||||
@@ -481,20 +243,21 @@ class LangChainAgent:
|
|||||||
message: str,
|
message: str,
|
||||||
history: Optional[List[Dict[str, str]]] = None,
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
files: Optional[List[Dict[str, Any]]] = None
|
end_user_id:Optional[str] = None,
|
||||||
) -> AsyncGenerator[str | int | dict[str, str], None]:
|
config_id: Optional[str] = None,
|
||||||
|
storage_type:Optional[str] = None,
|
||||||
|
user_rag_memory_id:Optional[str] = None,
|
||||||
|
memory_flag: Optional[bool] = True
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
"""执行流式对话
|
"""执行流式对话
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 用户消息
|
message: 用户消息
|
||||||
history: 历史消息列表
|
history: 历史消息列表
|
||||||
context: 上下文信息
|
context: 上下文信息
|
||||||
files: 多模态文件
|
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
str: 消息内容块
|
str: 消息内容块
|
||||||
int: token 统计
|
|
||||||
Dict: 深度思考内容 {"type": "reasoning", "content": "..."}
|
|
||||||
"""
|
"""
|
||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
logger.info(" chat_stream 方法开始执行")
|
logger.info(" chat_stream 方法开始执行")
|
||||||
@@ -502,28 +265,43 @@ class LangChainAgent:
|
|||||||
logger.info(f" Has tools: {bool(self.tools)}")
|
logger.info(f" Has tools: {bool(self.tools)}")
|
||||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
|
message_chat = message
|
||||||
|
actual_config_id = config_id
|
||||||
|
# If config_id is None, try to get from end_user's connected config
|
||||||
|
if actual_config_id is None and end_user_id:
|
||||||
try:
|
try:
|
||||||
# 准备消息列表(支持多模态)
|
db = next(get_db())
|
||||||
messages = self._prepare_messages(message, history, context, files)
|
try:
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
actual_config_id = connected_config.get("memory_config_id")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get db session: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||||
|
try:
|
||||||
|
# 准备消息列表
|
||||||
|
messages = self._prepare_messages(message, history, context)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"准备流式调用,has_tools={bool(self.tools)}, has_files={bool(files)}, message_count={len(messages)}"
|
f"准备流式调用,has_tools={bool(self.tools)}, message_count={len(messages)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
yielded_content = False
|
||||||
|
|
||||||
# 统一使用 agent 的 astream_events 实现流式输出
|
# 统一使用 agent 的 astream_events 实现流式输出
|
||||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||||
full_content = ''
|
full_content = ''
|
||||||
full_reasoning = ''
|
|
||||||
try:
|
try:
|
||||||
last_event = {}
|
|
||||||
async for event in self.agent.astream_events(
|
async for event in self.agent.astream_events(
|
||||||
{"messages": messages},
|
{"messages": messages},
|
||||||
version="v2",
|
version="v2"
|
||||||
config={"recursion_limit": self.max_iterations}
|
|
||||||
):
|
):
|
||||||
last_event = event
|
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
kind = event.get("event")
|
kind = event.get("event")
|
||||||
|
|
||||||
@@ -531,77 +309,22 @@ class LangChainAgent:
|
|||||||
if kind == "on_chat_model_stream":
|
if kind == "on_chat_model_stream":
|
||||||
# LLM 流式输出
|
# LLM 流式输出
|
||||||
chunk = event.get("data", {}).get("chunk")
|
chunk = event.get("data", {}).get("chunk")
|
||||||
if chunk and hasattr(chunk, "content"):
|
full_content+=chunk.content
|
||||||
# 提取深度思考内容(仅在启用深度思考时)
|
if chunk and hasattr(chunk, "content") and chunk.content:
|
||||||
if self.deep_thinking:
|
yield chunk.content
|
||||||
reasoning_chunk = self._extract_reasoning_content(chunk)
|
yielded_content = True
|
||||||
if reasoning_chunk:
|
|
||||||
full_reasoning += reasoning_chunk
|
|
||||||
yield {"type": "reasoning", "content": reasoning_chunk}
|
|
||||||
|
|
||||||
# 处理多模态响应:content 可能是字符串或列表
|
|
||||||
chunk_content = chunk.content
|
|
||||||
if isinstance(chunk_content, str) and chunk_content:
|
|
||||||
full_content += chunk_content
|
|
||||||
yield chunk_content
|
|
||||||
elif isinstance(chunk_content, list):
|
|
||||||
# 多模态响应:提取文本部分
|
|
||||||
for item in chunk_content:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
# 通义千问格式: {"text": "..."}
|
|
||||||
if "text" in item:
|
|
||||||
text = item.get("text", "")
|
|
||||||
if text:
|
|
||||||
full_content += text
|
|
||||||
yield text
|
|
||||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
|
||||||
elif item.get("type") == "text":
|
|
||||||
text = item.get("text", "")
|
|
||||||
if text:
|
|
||||||
full_content += text
|
|
||||||
yield text
|
|
||||||
elif isinstance(item, str):
|
|
||||||
full_content += item
|
|
||||||
yield item
|
|
||||||
|
|
||||||
elif kind == "on_llm_stream":
|
elif kind == "on_llm_stream":
|
||||||
# 另一种 LLM 流式事件
|
# 另一种 LLM 流式事件
|
||||||
chunk = event.get("data", {}).get("chunk")
|
chunk = event.get("data", {}).get("chunk")
|
||||||
if chunk:
|
if chunk:
|
||||||
if hasattr(chunk, "content"):
|
if hasattr(chunk, "content") and chunk.content:
|
||||||
# 提取深度思考内容(仅在启用深度思考时)
|
full_content+=chunk.content
|
||||||
if self.deep_thinking:
|
yield chunk.content
|
||||||
reasoning_chunk = self._extract_reasoning_content(chunk)
|
yielded_content = True
|
||||||
if reasoning_chunk:
|
|
||||||
full_reasoning += reasoning_chunk
|
|
||||||
yield {"type": "reasoning", "content": reasoning_chunk}
|
|
||||||
|
|
||||||
chunk_content = chunk.content
|
|
||||||
if isinstance(chunk_content, str) and chunk_content:
|
|
||||||
full_content += chunk_content
|
|
||||||
yield chunk_content
|
|
||||||
elif isinstance(chunk_content, list):
|
|
||||||
# 多模态响应:提取文本部分
|
|
||||||
for item in chunk_content:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
# 通义千问格式: {"text": "..."}
|
|
||||||
if "text" in item:
|
|
||||||
text = item.get("text", "")
|
|
||||||
if text:
|
|
||||||
full_content += text
|
|
||||||
yield text
|
|
||||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
|
||||||
elif item.get("type") == "text":
|
|
||||||
text = item.get("text", "")
|
|
||||||
if text:
|
|
||||||
full_content += text
|
|
||||||
yield text
|
|
||||||
elif isinstance(item, str):
|
|
||||||
full_content += item
|
|
||||||
yield item
|
|
||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
full_content += chunk
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
yielded_content = True
|
||||||
|
|
||||||
# 记录工具调用(可选)
|
# 记录工具调用(可选)
|
||||||
elif kind == "on_tool_start":
|
elif kind == "on_tool_start":
|
||||||
@@ -611,20 +334,16 @@ class LangChainAgent:
|
|||||||
|
|
||||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||||
# 统计token消耗
|
# 统计token消耗
|
||||||
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
|
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
stream_total_tokens = self._extract_tokens_from_message(msg)
|
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||||
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
|
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||||
yield stream_total_tokens
|
0) if response_meta else 0
|
||||||
|
yield total_tokens
|
||||||
break
|
break
|
||||||
|
if memory_flag:
|
||||||
except GraphRecursionError:
|
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||||
logger.warning(
|
|
||||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),模型可能不支持正确的工具调用停止判断"
|
|
||||||
)
|
|
||||||
if not full_content:
|
|
||||||
yield "抱歉,我在处理您的请求时遇到了问题(已达最大处理步骤限制)。请尝试简化问题或更换模型后重试。"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -638,3 +357,5 @@ class LangChainAgent:
|
|||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
logger.info("chat_stream 方法执行结束")
|
logger.info("chat_stream 方法执行结束")
|
||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -70,8 +70,6 @@ def require_api_key(
|
|||||||
})
|
})
|
||||||
raise BusinessException("API Key 无效或已过期", BizCode.API_KEY_INVALID)
|
raise BusinessException("API Key 无效或已过期", BizCode.API_KEY_INVALID)
|
||||||
|
|
||||||
ApiKeyAuthService.check_app_published(db, api_key_obj)
|
|
||||||
|
|
||||||
if scopes:
|
if scopes:
|
||||||
missing_scopes = []
|
missing_scopes = []
|
||||||
for scope in scopes:
|
for scope in scopes:
|
||||||
@@ -99,7 +97,7 @@ def require_api_key(
|
|||||||
)
|
)
|
||||||
|
|
||||||
rate_limiter = RateLimiterService()
|
rate_limiter = RateLimiterService()
|
||||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj, db=db)
|
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
||||||
if not is_allowed:
|
if not is_allowed:
|
||||||
logger.warning("API Key 限流触发", extra={
|
logger.warning("API Key 限流触发", extra={
|
||||||
"api_key_id": str(api_key_obj.id),
|
"api_key_id": str(api_key_obj.id),
|
||||||
@@ -108,12 +106,10 @@ def require_api_key(
|
|||||||
"error_msg": error_msg
|
"error_msg": error_msg
|
||||||
})
|
})
|
||||||
# 根据错误消息判断限流类型
|
# 根据错误消息判断限流类型
|
||||||
if "Daily" in error_msg:
|
if "QPS" in error_msg:
|
||||||
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
|
||||||
elif "Tenant" in error_msg:
|
|
||||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED # 租户套餐速率超限,同属 QPS 类
|
|
||||||
elif "QPS" in error_msg:
|
|
||||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||||
|
elif "Daily" in error_msg:
|
||||||
|
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
||||||
else:
|
else:
|
||||||
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,8 @@
|
|||||||
"""API Key 工具函数"""
|
"""API Key 工具函数"""
|
||||||
import secrets
|
import secrets
|
||||||
import uuid as _uuid
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy.orm import Session as _Session
|
|
||||||
from app.core.error_codes import BizCode as _BizCode
|
|
||||||
from app.core.exceptions import BusinessException as _BusinessException
|
|
||||||
from app.models.end_user_model import EndUser as _EndUser
|
|
||||||
from app.repositories.end_user_repository import EndUserRepository as _EndUserRepository
|
|
||||||
|
|
||||||
from app.models.api_key_model import ApiKeyType
|
from app.models.api_key_model import ApiKeyType
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
@@ -72,72 +65,3 @@ def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return int(dt.timestamp() * 1000)
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
def get_current_user_from_api_key(db: _Session, api_key_auth):
|
|
||||||
"""通过 API Key 构造 current_user 对象。
|
|
||||||
|
|
||||||
从 API Key 反查创建者(管理员用户),并设置其 workspace 上下文。
|
|
||||||
与内部接口的 Depends(get_current_user) (JWT) 等价。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
api_key_auth: API Key 认证信息(ApiKeyAuth)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User ORM 对象,已设置 current_workspace_id
|
|
||||||
"""
|
|
||||||
from app.services import api_key_service
|
|
||||||
|
|
||||||
api_key = api_key_service.ApiKeyService.get_api_key(
|
|
||||||
db, api_key_auth.api_key_id, api_key_auth.workspace_id
|
|
||||||
)
|
|
||||||
current_user = api_key.creator
|
|
||||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
def validate_end_user_in_workspace(
|
|
||||||
db: _Session,
|
|
||||||
end_user_id: str,
|
|
||||||
workspace_id,
|
|
||||||
) -> _EndUser:
|
|
||||||
"""校验 end_user 是否存在且属于指定 workspace。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
end_user_id: 终端用户 ID
|
|
||||||
workspace_id: 工作空间 ID(UUID 或字符串均可)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
EndUser ORM 对象(校验通过时)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
BusinessException(INVALID_PARAMETER): end_user_id 格式无效
|
|
||||||
BusinessException(USER_NOT_FOUND): end_user 不存在
|
|
||||||
BusinessException(PERMISSION_DENIED): end_user 不属于该 workspace
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
_uuid.UUID(end_user_id)
|
|
||||||
except (ValueError, AttributeError):
|
|
||||||
raise _BusinessException(
|
|
||||||
f"Invalid end_user_id format: {end_user_id}",
|
|
||||||
_BizCode.INVALID_PARAMETER,
|
|
||||||
)
|
|
||||||
|
|
||||||
end_user_repo = _EndUserRepository(db)
|
|
||||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
|
||||||
|
|
||||||
if end_user is None:
|
|
||||||
raise _BusinessException(
|
|
||||||
"End user not found",
|
|
||||||
_BizCode.USER_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
if str(end_user.workspace_id) != str(workspace_id):
|
|
||||||
raise _BusinessException(
|
|
||||||
"End user does not belong to this workspace",
|
|
||||||
_BizCode.PERMISSION_DENIED,
|
|
||||||
)
|
|
||||||
|
|
||||||
return end_user
|
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import Field, TypeAdapter
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
@@ -58,6 +58,7 @@ class Settings:
|
|||||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||||
|
|
||||||
|
|
||||||
# ElasticSearch configuration
|
# ElasticSearch configuration
|
||||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||||
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
|
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
|
||||||
@@ -97,7 +98,6 @@ class Settings:
|
|||||||
|
|
||||||
# File Upload
|
# File Upload
|
||||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||||
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
|
||||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||||
|
|
||||||
@@ -115,7 +115,6 @@ class Settings:
|
|||||||
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
|
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
|
||||||
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
|
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
|
||||||
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
|
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
|
||||||
S3_ENDPOINT_URL: str = os.getenv("S3_ENDPOINT_URL", "")
|
|
||||||
|
|
||||||
# VOLC ASR settings
|
# VOLC ASR settings
|
||||||
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
||||||
@@ -163,44 +162,6 @@ class Settings:
|
|||||||
# This controls the language used for memory summary titles and other generated content
|
# This controls the language used for memory summary titles and other generated content
|
||||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||||
|
|
||||||
# ========================================================================
|
|
||||||
# Internationalization (i18n) Configuration
|
|
||||||
# ========================================================================
|
|
||||||
# Default language for API responses
|
|
||||||
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
|
|
||||||
|
|
||||||
# Supported languages (comma-separated)
|
|
||||||
I18N_SUPPORTED_LANGUAGES: list[str] = [
|
|
||||||
lang.strip()
|
|
||||||
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
|
|
||||||
if lang.strip()
|
|
||||||
]
|
|
||||||
|
|
||||||
# Core locales directory (community edition)
|
|
||||||
# Use absolute path to work from any working directory
|
|
||||||
I18N_CORE_LOCALES_DIR: str = os.getenv(
|
|
||||||
"I18N_CORE_LOCALES_DIR",
|
|
||||||
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Premium locales directory (enterprise edition, optional)
|
|
||||||
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
|
|
||||||
|
|
||||||
# Enable translation cache
|
|
||||||
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
|
|
||||||
|
|
||||||
# LRU cache size for hot translations
|
|
||||||
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
|
|
||||||
|
|
||||||
# Enable hot reload of translation files
|
|
||||||
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
|
|
||||||
|
|
||||||
# Fallback language when translation is missing
|
|
||||||
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
|
|
||||||
|
|
||||||
# Log missing translations
|
|
||||||
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
|
|
||||||
|
|
||||||
# Logging settings
|
# Logging settings
|
||||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
@@ -229,47 +190,19 @@ class Settings:
|
|||||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||||
|
|
||||||
# Celery configuration (internal)
|
# Celery configuration (internal)
|
||||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||||
# 详见 docs/celery-env-bug-report.md
|
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||||
# 默认使用 Redis 作为 broker 和 backend,与业务缓存隔离
|
|
||||||
# 如需使用 RabbitMQ,在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost
|
|
||||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
|
||||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
|
||||||
|
|
||||||
# SMTP Email Configuration
|
|
||||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
|
||||||
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
|
||||||
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
|
||||||
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
|
||||||
|
|
||||||
SANDBOX_URL: str = os.getenv("SANDBOX_URL", "")
|
|
||||||
|
|
||||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||||
|
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||||
|
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||||
|
|
||||||
# Memory Cache Regeneration Configuration
|
# Memory Cache Regeneration Configuration
|
||||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||||
|
|
||||||
# Celery Beat Schedule Configuration (定时任务执行频率)
|
|
||||||
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
|
|
||||||
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
|
|
||||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
|
|
||||||
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
|
|
||||||
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
|
|
||||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
|
|
||||||
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
|
|
||||||
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
|
|
||||||
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
|
|
||||||
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
|
|
||||||
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
|
|
||||||
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
|
|
||||||
|
|
||||||
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
|
|
||||||
# implicit_emotions_update: 每天几分执行(分钟,0-59)
|
|
||||||
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
|
|
||||||
# Memory Module Configuration (internal)
|
# Memory Module Configuration (internal)
|
||||||
|
|
||||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||||
|
|
||||||
@@ -282,35 +215,9 @@ class Settings:
|
|||||||
# official environment system version
|
# official environment system version
|
||||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
||||||
|
|
||||||
# model square loading
|
|
||||||
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
|
||||||
|
|
||||||
# workflow config
|
# workflow config
|
||||||
WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800))
|
|
||||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||||
|
|
||||||
# ========================================================================
|
|
||||||
# General Ontology Type Configuration
|
|
||||||
# ========================================================================
|
|
||||||
# 通用本体文件路径列表(逗号分隔)
|
|
||||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
|
|
||||||
|
|
||||||
# 是否启用通用本体类型功能
|
|
||||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
|
||||||
|
|
||||||
# Prompt 中最大类型数量
|
|
||||||
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
|
|
||||||
|
|
||||||
# 核心通用类型列表(逗号分隔)—— 与 ontology.md Entity Ontology 保持一致的 13 类
|
|
||||||
CORE_GENERAL_TYPES: str = os.getenv(
|
|
||||||
"CORE_GENERAL_TYPES",
|
|
||||||
"人物,组织,群体,角色职业,地点设施,物品设备,软件平台,识别联系信息,"
|
|
||||||
"文档媒体,知识能力,偏好习惯,具体目标,称呼别名"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 实验模式开关(允许通过 API 动态切换本体配置)
|
|
||||||
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
|
|
||||||
|
|
||||||
def get_memory_output_path(self, filename: str = "") -> str:
|
def get_memory_output_path(self, filename: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
Get the full path for memory module output files.
|
Get the full path for memory module output files.
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ class BizCode(IntEnum):
|
|||||||
TENANT_NOT_FOUND = 3002
|
TENANT_NOT_FOUND = 3002
|
||||||
WORKSPACE_NO_ACCESS = 3003
|
WORKSPACE_NO_ACCESS = 3003
|
||||||
WORKSPACE_INVITE_NOT_FOUND = 3004
|
WORKSPACE_INVITE_NOT_FOUND = 3004
|
||||||
WORKSPACE_ACCESS_DENIED = 3005
|
|
||||||
# API Key 管理(3xxx)
|
# API Key 管理(3xxx)
|
||||||
API_KEY_NOT_FOUND = 3007
|
API_KEY_NOT_FOUND = 3007
|
||||||
API_KEY_DUPLICATE_NAME = 3008
|
API_KEY_DUPLICATE_NAME = 3008
|
||||||
@@ -31,9 +30,6 @@ class BizCode(IntEnum):
|
|||||||
API_KEY_QPS_LIMIT_EXCEEDED = 3014
|
API_KEY_QPS_LIMIT_EXCEEDED = 3014
|
||||||
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
|
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
|
||||||
API_KEY_QUOTA_EXCEEDED = 3016
|
API_KEY_QUOTA_EXCEEDED = 3016
|
||||||
API_KEY_RATE_LIMIT_EXCEEDED = 3017
|
|
||||||
QUOTA_EXCEEDED = 3018
|
|
||||||
RATE_LIMIT_EXCEEDED = 3019
|
|
||||||
# 资源(4xxx)
|
# 资源(4xxx)
|
||||||
NOT_FOUND = 4000
|
NOT_FOUND = 4000
|
||||||
USER_NOT_FOUND = 4001
|
USER_NOT_FOUND = 4001
|
||||||
@@ -44,14 +40,12 @@ class BizCode(IntEnum):
|
|||||||
FILE_NOT_FOUND = 4006
|
FILE_NOT_FOUND = 4006
|
||||||
APP_NOT_FOUND = 4007
|
APP_NOT_FOUND = 4007
|
||||||
RELEASE_NOT_FOUND = 4008
|
RELEASE_NOT_FOUND = 4008
|
||||||
USER_NO_ACCESS = 4009
|
|
||||||
|
|
||||||
# 冲突/状态(5xxx)
|
# 冲突/状态(5xxx)
|
||||||
DUPLICATE_NAME = 5001
|
DUPLICATE_NAME = 5001
|
||||||
RESOURCE_ALREADY_EXISTS = 5002
|
RESOURCE_ALREADY_EXISTS = 5002
|
||||||
VERSION_ALREADY_EXISTS = 5003
|
VERSION_ALREADY_EXISTS = 5003
|
||||||
STATE_CONFLICT = 5004
|
STATE_CONFLICT = 5004
|
||||||
RESOURCE_IN_USE = 5005
|
|
||||||
|
|
||||||
# 应用发布(6xxx)
|
# 应用发布(6xxx)
|
||||||
PUBLISH_FAILED = 6001
|
PUBLISH_FAILED = 6001
|
||||||
@@ -66,7 +60,6 @@ class BizCode(IntEnum):
|
|||||||
PERMISSION_DENIED = 6010
|
PERMISSION_DENIED = 6010
|
||||||
INVALID_CONVERSATION = 6011
|
INVALID_CONVERSATION = 6011
|
||||||
CONFIG_MISSING = 6012
|
CONFIG_MISSING = 6012
|
||||||
APP_NOT_PUBLISHED = 6013
|
|
||||||
|
|
||||||
# 模型(7xxx)
|
# 模型(7xxx)
|
||||||
MODEL_CONFIG_INVALID = 7001
|
MODEL_CONFIG_INVALID = 7001
|
||||||
@@ -119,11 +112,8 @@ HTTP_MAPPING = {
|
|||||||
BizCode.FORBIDDEN: 403,
|
BizCode.FORBIDDEN: 403,
|
||||||
BizCode.TENANT_NOT_FOUND: 400,
|
BizCode.TENANT_NOT_FOUND: 400,
|
||||||
BizCode.WORKSPACE_NO_ACCESS: 403,
|
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||||
BizCode.WORKSPACE_INVITE_NOT_FOUND: 400,
|
|
||||||
BizCode.WORKSPACE_ACCESS_DENIED: 403,
|
|
||||||
BizCode.NOT_FOUND: 400,
|
BizCode.NOT_FOUND: 400,
|
||||||
BizCode.USER_NOT_FOUND: 200,
|
BizCode.USER_NOT_FOUND: 200,
|
||||||
BizCode.USER_NO_ACCESS: 401,
|
|
||||||
BizCode.WORKSPACE_NOT_FOUND: 400,
|
BizCode.WORKSPACE_NOT_FOUND: 400,
|
||||||
BizCode.MODEL_NOT_FOUND: 400,
|
BizCode.MODEL_NOT_FOUND: 400,
|
||||||
BizCode.KNOWLEDGE_NOT_FOUND: 400,
|
BizCode.KNOWLEDGE_NOT_FOUND: 400,
|
||||||
@@ -135,7 +125,6 @@ HTTP_MAPPING = {
|
|||||||
BizCode.RESOURCE_ALREADY_EXISTS: 409,
|
BizCode.RESOURCE_ALREADY_EXISTS: 409,
|
||||||
BizCode.VERSION_ALREADY_EXISTS: 409,
|
BizCode.VERSION_ALREADY_EXISTS: 409,
|
||||||
BizCode.STATE_CONFLICT: 409,
|
BizCode.STATE_CONFLICT: 409,
|
||||||
BizCode.RESOURCE_IN_USE: 409,
|
|
||||||
BizCode.PUBLISH_FAILED: 500,
|
BizCode.PUBLISH_FAILED: 500,
|
||||||
BizCode.NO_DRAFT_TO_PUBLISH: 400,
|
BizCode.NO_DRAFT_TO_PUBLISH: 400,
|
||||||
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,
|
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,
|
||||||
@@ -159,7 +148,6 @@ HTTP_MAPPING = {
|
|||||||
BizCode.API_KEY_QPS_LIMIT_EXCEEDED: 429,
|
BizCode.API_KEY_QPS_LIMIT_EXCEEDED: 429,
|
||||||
BizCode.API_KEY_DAILY_LIMIT_EXCEEDED: 429,
|
BizCode.API_KEY_DAILY_LIMIT_EXCEEDED: 429,
|
||||||
BizCode.API_KEY_QUOTA_EXCEEDED: 429,
|
BizCode.API_KEY_QUOTA_EXCEEDED: 429,
|
||||||
BizCode.QUOTA_EXCEEDED: 402,
|
|
||||||
|
|
||||||
BizCode.MODEL_CONFIG_INVALID: 400,
|
BizCode.MODEL_CONFIG_INVALID: 400,
|
||||||
BizCode.API_KEY_MISSING: 400,
|
BizCode.API_KEY_MISSING: 400,
|
||||||
@@ -189,21 +177,4 @@ HTTP_MAPPING = {
|
|||||||
BizCode.DB_ERROR: 500,
|
BizCode.DB_ERROR: 500,
|
||||||
BizCode.SERVICE_UNAVAILABLE: 503,
|
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||||
BizCode.RATE_LIMITED: 429,
|
BizCode.RATE_LIMITED: 429,
|
||||||
BizCode.RATE_LIMIT_EXCEEDED: 429,
|
|
||||||
}
|
|
||||||
|
|
||||||
ERROR_CODE_TO_BIZ_CODE = {
|
|
||||||
"QUOTA_EXCEEDED": BizCode.QUOTA_EXCEEDED,
|
|
||||||
"RATE_LIMIT_EXCEEDED": BizCode.RATE_LIMIT_EXCEEDED,
|
|
||||||
"API_KEY_NOT_FOUND": BizCode.API_KEY_NOT_FOUND,
|
|
||||||
"API_KEY_INVALID": BizCode.API_KEY_INVALID,
|
|
||||||
"API_KEY_EXPIRED": BizCode.API_KEY_EXPIRED,
|
|
||||||
"WORKSPACE_NOT_FOUND": BizCode.WORKSPACE_NOT_FOUND,
|
|
||||||
"WORKSPACE_NO_ACCESS": BizCode.WORKSPACE_NO_ACCESS,
|
|
||||||
"PERMISSION_DENIED": BizCode.PERMISSION_DENIED,
|
|
||||||
"TOKEN_EXPIRED": BizCode.TOKEN_EXPIRED,
|
|
||||||
"TOKEN_INVALID": BizCode.TOKEN_INVALID,
|
|
||||||
"VALIDATION_FAILED": BizCode.VALIDATION_FAILED,
|
|
||||||
"INVALID_PARAMETER": BizCode.INVALID_PARAMETER,
|
|
||||||
"MISSING_PARAMETER": BizCode.MISSING_PARAMETER,
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,86 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""语言处理工具模块
|
|
||||||
|
|
||||||
本模块提供集中化的语言校验和处理功能,确保整个应用中语言参数的一致性。
|
|
||||||
|
|
||||||
Functions:
|
|
||||||
validate_language: 校验语言参数,确保其为有效值
|
|
||||||
get_language_from_header: 从请求头获取并校验语言参数
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from app.core.logging_config import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
# 支持的语言列表
|
|
||||||
SUPPORTED_LANGUAGES = {"zh", "en"}
|
|
||||||
|
|
||||||
# 默认回退语言
|
|
||||||
DEFAULT_LANGUAGE = "zh"
|
|
||||||
|
|
||||||
|
|
||||||
def validate_language(language: Optional[str]) -> str:
|
|
||||||
"""
|
|
||||||
校验语言参数,确保其为有效值。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
language: 待校验的语言代码,可以是 None、"zh"、"en" 或其他值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
有效的语言代码("zh" 或 "en")
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> validate_language("zh")
|
|
||||||
'zh'
|
|
||||||
>>> validate_language("en")
|
|
||||||
'en'
|
|
||||||
>>> validate_language("EN") # 大小写不敏感
|
|
||||||
'en'
|
|
||||||
>>> validate_language(None) # None 回退到默认值
|
|
||||||
'zh'
|
|
||||||
>>> validate_language("fr") # 不支持的语言回退到默认值
|
|
||||||
'zh'
|
|
||||||
"""
|
|
||||||
if language is None:
|
|
||||||
return DEFAULT_LANGUAGE
|
|
||||||
|
|
||||||
# 处理枚举类型:优先取 .value,避免 str(Language.ZH) → "Language.ZH"
|
|
||||||
if hasattr(language, "value"):
|
|
||||||
language = language.value
|
|
||||||
|
|
||||||
# 标准化:转小写并去除空白
|
|
||||||
lang = str(language).lower().strip()
|
|
||||||
|
|
||||||
if lang in SUPPORTED_LANGUAGES:
|
|
||||||
return lang
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"无效的语言参数 '{language}',已回退到默认值 '{DEFAULT_LANGUAGE}'。"
|
|
||||||
f"支持的语言: {SUPPORTED_LANGUAGES}"
|
|
||||||
)
|
|
||||||
return DEFAULT_LANGUAGE
|
|
||||||
|
|
||||||
|
|
||||||
def get_language_from_header(language_type: Optional[str]) -> str:
|
|
||||||
"""
|
|
||||||
从请求头获取并校验语言参数。
|
|
||||||
|
|
||||||
这是一个便捷函数,用于在 controller 层统一处理 X-Language-Type Header。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
language_type: 从 X-Language-Type Header 获取的语言值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
有效的语言代码("zh" 或 "en")
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> get_language_from_header(None) # Header 未传递
|
|
||||||
'zh'
|
|
||||||
>>> get_language_from_header("en")
|
|
||||||
'en'
|
|
||||||
>>> get_language_from_header("invalid") # 无效值回退
|
|
||||||
'zh'
|
|
||||||
"""
|
|
||||||
return validate_language(language_type)
|
|
||||||
@@ -38,56 +38,6 @@ class SensitiveDataLoggingFilter(logging.Filter):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class Neo4jSuccessNotificationFilter(logging.Filter):
|
|
||||||
"""Neo4j 日志过滤器:过滤成功/信息性状态的通知,保留真正的警告和错误
|
|
||||||
|
|
||||||
Neo4j 驱动会以 WARNING 级别记录所有数据库通知,包括成功的操作。
|
|
||||||
这个过滤器会过滤掉以下 GQL 状态码的通知,只保留真正的警告和错误:
|
|
||||||
- 00000: 成功完成 (successful completion)
|
|
||||||
- 00N00: 无数据 (no data)
|
|
||||||
- 00NA0: 无数据,信息性通知 (no data, informational notification)
|
|
||||||
|
|
||||||
使用正则表达式进行更严格的匹配,避免误过滤无关的警告。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
# 编译正则表达式以提高性能
|
|
||||||
# 匹配所有"成功/信息性"的 GQL 状态码:
|
|
||||||
# 00000 = 成功完成, 00N00 = 无数据, 00NA0 = 无数据信息性通知
|
|
||||||
GQL_STATUS_PATTERN = re.compile(r"gql_status=['\"](00000|00N00|00NA0)['\"]")
|
|
||||||
|
|
||||||
# 匹配 status_description 中的成功完成或信息性通知消息
|
|
||||||
SUCCESS_DESC_PATTERN = re.compile(r"status_description=['\"]note:\s*(successful\s+completion|no\s+data)['\"]", re.IGNORECASE)
|
|
||||||
|
|
||||||
def filter(self, record: logging.LogRecord) -> bool:
|
|
||||||
"""
|
|
||||||
过滤 Neo4j 成功通知
|
|
||||||
|
|
||||||
Args:
|
|
||||||
record: 日志记录
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True表示允许记录,False表示拒绝(过滤掉)
|
|
||||||
"""
|
|
||||||
# 只处理 INFO 和 WARNING 级别的日志
|
|
||||||
# Neo4j 驱动对 severity='INFORMATION' 的通知使用 INFO 级别,
|
|
||||||
# 对 severity='WARNING' 的通知使用 WARNING 级别
|
|
||||||
if record.levelno not in (logging.INFO, logging.WARNING):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查是否是 Neo4j 的成功通知
|
|
||||||
message = str(record.msg)
|
|
||||||
|
|
||||||
# 使用正则表达式进行更严格的匹配
|
|
||||||
# 这样可以避免误过滤包含这些子字符串但不是 Neo4j 通知的日志
|
|
||||||
if self.GQL_STATUS_PATTERN.search(message) or self.SUCCESS_DESC_PATTERN.search(message):
|
|
||||||
return False # 过滤掉这条日志
|
|
||||||
|
|
||||||
# 保留其他所有日志(包括真正的警告和错误)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class LoggingConfig:
|
class LoggingConfig:
|
||||||
"""全局日志配置类"""
|
"""全局日志配置类"""
|
||||||
|
|
||||||
@@ -115,26 +65,6 @@ class LoggingConfig:
|
|||||||
# 清除现有处理器
|
# 清除现有处理器
|
||||||
root_logger.handlers.clear()
|
root_logger.handlers.clear()
|
||||||
|
|
||||||
# Neo4j 通知过滤器 - 挂在 handler 上确保所有传播上来的日志都能被过滤
|
|
||||||
neo4j_filter = Neo4jSuccessNotificationFilter()
|
|
||||||
|
|
||||||
# 抑制 Neo4j 通知日志
|
|
||||||
# Neo4j 驱动内部会给 neo4j.notifications logger 配置自己的 handler,
|
|
||||||
# 导致日志绕过根 logger 的 filter 直接输出。
|
|
||||||
# 多管齐下确保过滤生效:
|
|
||||||
# 1. 设置 neo4j.notifications 级别为 WARNING(过滤 INFO 级别的 00NA0 通知)
|
|
||||||
# 2. 在所有 neo4j logger 上添加 filter(过滤 WARNING 级别的成功通知)
|
|
||||||
# 3. 在根 handler 上也添加 filter(兜底)
|
|
||||||
neo4j_notifications_logger = logging.getLogger("neo4j.notifications")
|
|
||||||
neo4j_notifications_logger.setLevel(logging.WARNING)
|
|
||||||
for neo4j_logger_name in ["neo4j", "neo4j.io", "neo4j.pool", "neo4j.notifications"]:
|
|
||||||
neo4j_logger = logging.getLogger(neo4j_logger_name)
|
|
||||||
neo4j_logger.addFilter(neo4j_filter)
|
|
||||||
|
|
||||||
# 压制 httpx / httpcore 的请求级日志(大量 HTTP Request: POST ... 噪音)
|
|
||||||
for noisy_logger in ["httpx", "httpcore", "httpcore.http11", "httpcore.connection"]:
|
|
||||||
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
# 创建格式化器
|
# 创建格式化器
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt=settings.LOG_FORMAT,
|
fmt=settings.LOG_FORMAT,
|
||||||
@@ -150,7 +80,6 @@ class LoggingConfig:
|
|||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||||
console_handler.addFilter(sensitive_filter)
|
console_handler.addFilter(sensitive_filter)
|
||||||
console_handler.addFilter(neo4j_filter)
|
|
||||||
root_logger.addHandler(console_handler)
|
root_logger.addHandler(console_handler)
|
||||||
|
|
||||||
# 文件处理器(带轮转)
|
# 文件处理器(带轮转)
|
||||||
@@ -164,7 +93,6 @@ class LoggingConfig:
|
|||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||||
file_handler.addFilter(sensitive_filter)
|
file_handler.addFilter(sensitive_filter)
|
||||||
file_handler.addFilter(neo4j_filter)
|
|
||||||
root_logger.addHandler(file_handler)
|
root_logger.addHandler(file_handler)
|
||||||
|
|
||||||
cls._initialized = True
|
cls._initialized = True
|
||||||
@@ -533,9 +461,8 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
|
|||||||
# Fallback to console only if file write fails
|
# Fallback to console only if file write fails
|
||||||
print(f"Warning: Could not write to timing log: {e}")
|
print(f"Warning: Could not write to timing log: {e}")
|
||||||
|
|
||||||
# Always log at INFO level (avoids Celery treating stdout as WARNING)
|
# Always print to console (backward compatible behavior)
|
||||||
_timing_logger = logging.getLogger(__name__)
|
print(f"✓ {step_name}: {duration:.2f}s")
|
||||||
_timing_logger.info(f"✓ {step_name}: {duration:.2f}s")
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_logger(name: str = "agent_service",
|
def get_agent_logger(name: str = "agent_service",
|
||||||
|
|||||||
@@ -1,45 +1,16 @@
|
|||||||
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||||
from app.schemas.memory_agent_schema import AgentMemoryDataset
|
|
||||||
|
|
||||||
|
|
||||||
def content_input_node(state: ReadState) -> ReadState:
|
def content_input_node(state: ReadState) -> ReadState:
|
||||||
"""
|
"""开始节点 - 提取内容并保持状态信息"""
|
||||||
Start node - Extract content and maintain state information
|
|
||||||
|
|
||||||
Extracts the content from the first message in the state and returns it
|
|
||||||
as the data field while preserving all other state information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing messages and other state data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Updated state with extracted content in data field
|
|
||||||
"""
|
|
||||||
|
|
||||||
content = state['messages'][0].content if state.get('messages') else ''
|
content = state['messages'][0].content if state.get('messages') else ''
|
||||||
# Return content and maintain all state information
|
# 返回内容并保持所有状态信息
|
||||||
for pronoun in AgentMemoryDataset.PRONOUN:
|
|
||||||
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
|
||||||
|
|
||||||
return {"data": content}
|
return {"data": content}
|
||||||
|
|
||||||
|
|
||||||
def content_input_write(state: WriteState) -> WriteState:
|
def content_input_write(state: WriteState) -> WriteState:
|
||||||
"""
|
"""开始节点 - 提取内容并保持状态信息"""
|
||||||
Start node - Extract content and maintain state information for write operations
|
|
||||||
|
|
||||||
Extracts the content from the first message in the state for write operations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: WriteState containing messages and other state data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
WriteState: Updated state with extracted content in data field
|
|
||||||
"""
|
|
||||||
|
|
||||||
content = state['messages'][0].content if state.get('messages') else ''
|
content = state['messages'][0].content if state.get('messages') else ''
|
||||||
# Return content and maintain all state information
|
# 返回内容并保持所有状态信息
|
||||||
for pronoun in AgentMemoryDataset.PRONOUN:
|
|
||||||
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
|
||||||
|
|
||||||
return {"data": content}
|
return {"data": content}
|
||||||
@@ -1,408 +0,0 @@
|
|||||||
"""
|
|
||||||
Perceptual Memory Retrieval Node & Service
|
|
||||||
|
|
||||||
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
|
|
||||||
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
|
|
||||||
with BM25+embedding fusion reranking.
|
|
||||||
|
|
||||||
Also provides the perceptual_retrieve_node for use as a LangGraph node.
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
import math
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
|
||||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
|
||||||
from app.repositories.neo4j.graph_search import (
|
|
||||||
search_perceptual_by_fulltext,
|
|
||||||
search_perceptual_by_embedding,
|
|
||||||
)
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class PerceptualSearchService:
|
|
||||||
"""
|
|
||||||
感知记忆检索服务。
|
|
||||||
|
|
||||||
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
|
|
||||||
调用方只需提供 query / keywords、end_user_id、memory_config,即可获得
|
|
||||||
格式化并排序后的感知记忆列表和拼接文本。
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
service = PerceptualSearchService(end_user_id=..., memory_config=...)
|
|
||||||
results = await service.search(query="...", keywords=[...], limit=10)
|
|
||||||
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
|
|
||||||
"""
|
|
||||||
|
|
||||||
DEFAULT_ALPHA = 0.6
|
|
||||||
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
end_user_id: str,
|
|
||||||
memory_config: Any,
|
|
||||||
alpha: float = DEFAULT_ALPHA,
|
|
||||||
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
|
|
||||||
):
|
|
||||||
self.end_user_id = end_user_id
|
|
||||||
self.memory_config = memory_config
|
|
||||||
self.alpha = alpha
|
|
||||||
self.content_score_threshold = content_score_threshold
|
|
||||||
|
|
||||||
async def search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
keywords: Optional[List[str]] = None,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
|
|
||||||
|
|
||||||
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
|
|
||||||
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 原始用户查询(用于向量检索和 BM25 补查)
|
|
||||||
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
|
|
||||||
limit: 最大返回数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
{
|
|
||||||
"memories": [格式化后的记忆 dict, ...],
|
|
||||||
"content": "拼接的纯文本摘要",
|
|
||||||
"keyword_raw": int,
|
|
||||||
"embedding_raw": int,
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
if keywords is None:
|
|
||||||
keywords = [query] if query else []
|
|
||||||
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
try:
|
|
||||||
kw_task = self._keyword_search(connector, keywords, limit)
|
|
||||||
emb_task = self._embedding_search(connector, query, limit)
|
|
||||||
|
|
||||||
kw_results, emb_results = await asyncio.gather(
|
|
||||||
kw_task, emb_task, return_exceptions=True
|
|
||||||
)
|
|
||||||
if isinstance(kw_results, Exception):
|
|
||||||
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
|
|
||||||
kw_results = []
|
|
||||||
if isinstance(emb_results, Exception):
|
|
||||||
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
|
|
||||||
emb_results = []
|
|
||||||
|
|
||||||
# 补查 BM25:找出 embedding 命中但 keyword 未命中的 id,
|
|
||||||
# 用原始 query 对这些节点补查全文索引拿 BM25 score
|
|
||||||
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
|
|
||||||
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
|
|
||||||
|
|
||||||
if emb_only_ids and query:
|
|
||||||
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
|
|
||||||
# 把补查到的 BM25 score 注入到 embedding 结果中
|
|
||||||
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
|
|
||||||
for r in emb_results:
|
|
||||||
rid = r.get("id", "")
|
|
||||||
if rid in backfill_map:
|
|
||||||
r["bm25_backfill_score"] = backfill_map[rid]
|
|
||||||
logger.info(
|
|
||||||
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
|
|
||||||
f"{len(backfill_map)} got BM25 scores"
|
|
||||||
)
|
|
||||||
|
|
||||||
reranked = self._rerank(kw_results, emb_results, limit)
|
|
||||||
|
|
||||||
memories = []
|
|
||||||
content_parts = []
|
|
||||||
for record in reranked:
|
|
||||||
fmt = self._format_result(record)
|
|
||||||
fmt["score"] = round(record.get("content_score", 0), 4)
|
|
||||||
memories.append(fmt)
|
|
||||||
content_parts.append(self._build_content_text(fmt))
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[PerceptualSearch] {len(memories)} results after rerank "
|
|
||||||
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"memories": memories,
|
|
||||||
"content": "\n\n".join(content_parts),
|
|
||||||
"keyword_raw": len(kw_results),
|
|
||||||
"embedding_raw": len(emb_results),
|
|
||||||
}
|
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
async def _bm25_backfill(
|
|
||||||
self,
|
|
||||||
connector: Neo4jConnector,
|
|
||||||
query: str,
|
|
||||||
target_ids: set,
|
|
||||||
limit: int,
|
|
||||||
) -> List[dict]:
|
|
||||||
"""
|
|
||||||
对指定 id 集合补查全文索引 BM25 score。
|
|
||||||
|
|
||||||
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
|
|
||||||
"""
|
|
||||||
escaped = escape_lucene_query(query)
|
|
||||||
if not escaped.strip():
|
|
||||||
return []
|
|
||||||
try:
|
|
||||||
r = await search_perceptual_by_fulltext(
|
|
||||||
connector=connector, query=escaped,
|
|
||||||
end_user_id=self.end_user_id,
|
|
||||||
limit=limit * 5, # 多查一些以提高命中率
|
|
||||||
)
|
|
||||||
all_hits = r.get("perceptuals", [])
|
|
||||||
return [h for h in all_hits if h.get("id") in target_ids]
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def _keyword_search(
|
|
||||||
self,
|
|
||||||
connector: Neo4jConnector,
|
|
||||||
keywords: List[str],
|
|
||||||
limit: int,
|
|
||||||
) -> List[dict]:
|
|
||||||
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
|
|
||||||
seen_ids: set = set()
|
|
||||||
all_results: List[dict] = []
|
|
||||||
|
|
||||||
async def _one(kw: str):
|
|
||||||
escaped = escape_lucene_query(kw)
|
|
||||||
if not escaped.strip():
|
|
||||||
return []
|
|
||||||
r = await search_perceptual_by_fulltext(
|
|
||||||
connector=connector, query=escaped,
|
|
||||||
end_user_id=self.end_user_id, limit=limit,
|
|
||||||
)
|
|
||||||
return r.get("perceptuals", [])
|
|
||||||
|
|
||||||
tasks = [_one(kw) for kw in keywords[:10]]
|
|
||||||
batch = await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
for result in batch:
|
|
||||||
if isinstance(result, Exception):
|
|
||||||
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
|
|
||||||
continue
|
|
||||||
for rec in result:
|
|
||||||
rid = rec.get("id", "")
|
|
||||||
if rid and rid not in seen_ids:
|
|
||||||
seen_ids.add(rid)
|
|
||||||
all_results.append(rec)
|
|
||||||
|
|
||||||
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
|
|
||||||
return all_results[:limit]
|
|
||||||
|
|
||||||
async def _embedding_search(
|
|
||||||
self,
|
|
||||||
connector: Neo4jConnector,
|
|
||||||
query_text: str,
|
|
||||||
limit: int,
|
|
||||||
) -> List[dict]:
|
|
||||||
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
|
|
||||||
try:
|
|
||||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
|
||||||
from app.core.models.base import RedBearModelConfig
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
|
|
||||||
with get_db_context() as db:
|
|
||||||
cfg = MemoryConfigService(db).get_embedder_config(
|
|
||||||
str(self.memory_config.embedding_model_id)
|
|
||||||
)
|
|
||||||
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
|
|
||||||
|
|
||||||
r = await search_perceptual_by_embedding(
|
|
||||||
connector=connector, embedder_client=client,
|
|
||||||
query_text=query_text, end_user_id=self.end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
return r.get("perceptuals", [])
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _rerank(
|
|
||||||
self,
|
|
||||||
keyword_results: List[dict],
|
|
||||||
embedding_results: List[dict],
|
|
||||||
limit: int,
|
|
||||||
) -> List[dict]:
|
|
||||||
"""BM25 + embedding 融合排序。
|
|
||||||
|
|
||||||
对 embedding 结果中带有 bm25_backfill_score 的条目,
|
|
||||||
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
|
|
||||||
"""
|
|
||||||
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
|
|
||||||
emb_backfill_items = []
|
|
||||||
for item in embedding_results:
|
|
||||||
backfill_score = item.get("bm25_backfill_score")
|
|
||||||
if backfill_score is not None and item.get("id"):
|
|
||||||
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
|
|
||||||
|
|
||||||
# 合并后统一归一化 BM25 scores
|
|
||||||
all_bm25_items = keyword_results + emb_backfill_items
|
|
||||||
all_bm25_items = self._normalize_scores(all_bm25_items)
|
|
||||||
|
|
||||||
# 建立 id -> normalized BM25 score 的映射
|
|
||||||
bm25_norm_map: Dict[str, float] = {}
|
|
||||||
for item in all_bm25_items:
|
|
||||||
item_id = item.get("id", "")
|
|
||||||
if item_id:
|
|
||||||
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
|
|
||||||
|
|
||||||
# 归一化 embedding scores
|
|
||||||
embedding_results = self._normalize_scores(embedding_results)
|
|
||||||
|
|
||||||
# 合并
|
|
||||||
combined: Dict[str, dict] = {}
|
|
||||||
for item in keyword_results:
|
|
||||||
item_id = item.get("id", "")
|
|
||||||
if not item_id:
|
|
||||||
continue
|
|
||||||
combined[item_id] = item.copy()
|
|
||||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
|
||||||
combined[item_id]["embedding_score"] = 0.0
|
|
||||||
|
|
||||||
for item in embedding_results:
|
|
||||||
item_id = item.get("id", "")
|
|
||||||
if not item_id:
|
|
||||||
continue
|
|
||||||
if item_id in combined:
|
|
||||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
|
||||||
else:
|
|
||||||
combined[item_id] = item.copy()
|
|
||||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
|
||||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
|
||||||
|
|
||||||
for item in combined.values():
|
|
||||||
bm25 = float(item.get("bm25_score", 0) or 0)
|
|
||||||
emb = float(item.get("embedding_score", 0) or 0)
|
|
||||||
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
|
|
||||||
|
|
||||||
results = list(combined.values())
|
|
||||||
before = len(results)
|
|
||||||
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
|
|
||||||
results.sort(key=lambda x: x["content_score"], reverse=True)
|
|
||||||
results = results[:limit]
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
|
|
||||||
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
|
|
||||||
"""Z-score + sigmoid 归一化。"""
|
|
||||||
if not items:
|
|
||||||
return items
|
|
||||||
scores = [float(it.get(field, 0) or 0) for it in items]
|
|
||||||
if len(scores) <= 1:
|
|
||||||
for it in items:
|
|
||||||
it[f"normalized_{field}"] = 1.0
|
|
||||||
return items
|
|
||||||
mean = sum(scores) / len(scores)
|
|
||||||
var = sum((s - mean) ** 2 for s in scores) / len(scores)
|
|
||||||
std = math.sqrt(var)
|
|
||||||
if std == 0:
|
|
||||||
for it in items:
|
|
||||||
it[f"normalized_{field}"] = 1.0
|
|
||||||
else:
|
|
||||||
for it, s in zip(items, scores):
|
|
||||||
z = (s - mean) / std
|
|
||||||
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
|
|
||||||
return items
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_result(record: dict) -> dict:
|
|
||||||
return {
|
|
||||||
"id": record.get("id", ""),
|
|
||||||
"perceptual_type": record.get("perceptual_type", ""),
|
|
||||||
"file_name": record.get("file_name", ""),
|
|
||||||
"file_path": record.get("file_path", ""),
|
|
||||||
"summary": record.get("summary", ""),
|
|
||||||
"topic": record.get("topic", ""),
|
|
||||||
"domain": record.get("domain", ""),
|
|
||||||
"keywords": record.get("keywords", []),
|
|
||||||
"created_at": str(record.get("created_at", "")),
|
|
||||||
"file_type": record.get("file_type", ""),
|
|
||||||
"score": record.get("score", 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_content_text(formatted: dict) -> str:
|
|
||||||
parts = []
|
|
||||||
if formatted["summary"]:
|
|
||||||
parts.append(formatted["summary"])
|
|
||||||
if formatted["topic"]:
|
|
||||||
parts.append(f"[主题: {formatted['topic']}]")
|
|
||||||
if formatted["keywords"]:
|
|
||||||
kw_list = formatted["keywords"]
|
|
||||||
if isinstance(kw_list, list):
|
|
||||||
parts.append(f"[关键词: {', '.join(kw_list)}]")
|
|
||||||
if formatted["file_name"]:
|
|
||||||
parts.append(f"[文件: {formatted['file_name']}]")
|
|
||||||
return " ".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
|
|
||||||
"""Extract search keywords from problem extension results."""
|
|
||||||
keywords = []
|
|
||||||
context = problem_extension.get("context", {})
|
|
||||||
if isinstance(context, dict):
|
|
||||||
for original_q, extended_qs in context.items():
|
|
||||||
keywords.append(original_q)
|
|
||||||
if isinstance(extended_qs, list):
|
|
||||||
keywords.extend(extended_qs)
|
|
||||||
return keywords
|
|
||||||
|
|
||||||
|
|
||||||
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
|
|
||||||
"""
|
|
||||||
LangGraph node: perceptual memory retrieval.
|
|
||||||
|
|
||||||
Uses PerceptualSearchService to run keyword + embedding search with
|
|
||||||
BM25 fusion reranking, then writes results to state['perceptual_data'].
|
|
||||||
"""
|
|
||||||
end_user_id = state.get("end_user_id", "")
|
|
||||||
problem_extension = state.get("problem_extension", {})
|
|
||||||
original_query = state.get("data", "")
|
|
||||||
memory_config = state.get("memory_config", None)
|
|
||||||
|
|
||||||
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
|
|
||||||
|
|
||||||
keywords = _extract_keywords_from_problems(problem_extension)
|
|
||||||
if not keywords:
|
|
||||||
keywords = [original_query] if original_query else []
|
|
||||||
|
|
||||||
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
|
|
||||||
|
|
||||||
service = PerceptualSearchService(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
memory_config=memory_config,
|
|
||||||
)
|
|
||||||
search_result = await service.search(
|
|
||||||
query=original_query,
|
|
||||||
keywords=keywords,
|
|
||||||
limit=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"memories": search_result["memories"],
|
|
||||||
"content": search_result["content"],
|
|
||||||
"_intermediate": {
|
|
||||||
"type": "perceptual_retrieve",
|
|
||||||
"title": "感知记忆检索",
|
|
||||||
"data": search_result["memories"],
|
|
||||||
"query": original_query,
|
|
||||||
"result_count": len(search_result["memories"]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return {"perceptual_data": result}
|
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
PROJECT_ROOT_,
|
PROJECT_ROOT_,
|
||||||
ReadState,
|
ReadState,
|
||||||
@@ -12,46 +12,27 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.db import get_db_context
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
|
db_session = next(get_db())
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProblemNodeService(LLMServiceMixin):
|
class ProblemNodeService(LLMServiceMixin):
|
||||||
"""
|
"""问题处理节点服务类"""
|
||||||
Problem processing node service class
|
|
||||||
|
|
||||||
Handles problem decomposition and extension operations using LLM services.
|
|
||||||
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
template_service: Service for rendering Jinja2 templates
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# Create global service instance
|
# 创建全局服务实例
|
||||||
problem_service = ProblemNodeService()
|
problem_service = ProblemNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||||
"""
|
"""问题分解节点"""
|
||||||
Problem decomposition node
|
|
||||||
|
|
||||||
Breaks down complex user queries into smaller, more manageable sub-problems.
|
|
||||||
Uses LLM to analyze the input and generate structured problem decomposition
|
|
||||||
with question types and reasoning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing user input and configuration
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Updated state with problem decomposition results
|
|
||||||
"""
|
|
||||||
# 从状态中获取数据
|
# 从状态中获取数据
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
@@ -72,7 +53,6 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务
|
# 使用优化的LLM服务
|
||||||
with get_db_context() as db_session:
|
|
||||||
structured = await problem_service.call_llm_structured(
|
structured = await problem_service.call_llm_structured(
|
||||||
state=state,
|
state=state,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@@ -84,7 +64,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
# 添加更详细的日志记录
|
# 添加更详细的日志记录
|
||||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||||
|
|
||||||
# Validate structured response
|
# 验证结构化响应
|
||||||
if not structured or not hasattr(structured, 'root'):
|
if not structured or not hasattr(structured, 'root'):
|
||||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||||
split_result = json.dumps([], ensure_ascii=False)
|
split_result = json.dumps([], ensure_ascii=False)
|
||||||
@@ -126,17 +106,17 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Provide more detailed error information
|
# 提供更详细的错误信息
|
||||||
error_details = {
|
error_details = {
|
||||||
"error_type": type(e).__name__,
|
"error_type": type(e).__name__,
|
||||||
"error_message": str(e),
|
"error_message": str(e),
|
||||||
"content_length": len(content),
|
"content_length": len(content),
|
||||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||||
|
|
||||||
# Create default empty result
|
# 创建默认的空结果
|
||||||
result = {
|
result = {
|
||||||
"context": json.dumps([], ensure_ascii=False),
|
"context": json.dumps([], ensure_ascii=False),
|
||||||
"original": content,
|
"original": content,
|
||||||
@@ -150,25 +130,13 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Return updated state including spit_context field
|
# 返回更新后的状态,包含spit_context字段
|
||||||
return {"spit_data": result}
|
return {"spit_data": result}
|
||||||
|
|
||||||
|
|
||||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||||
"""
|
"""问题扩展节点"""
|
||||||
Problem extension node
|
# 获取原始数据和分解结果
|
||||||
|
|
||||||
Extends the decomposed problems from Split_The_Problem node by generating
|
|
||||||
additional related questions and organizing them by original question.
|
|
||||||
Uses LLM to create comprehensive question extensions for better memory retrieval.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing decomposed problems and configuration
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Updated state with extended problem results
|
|
||||||
"""
|
|
||||||
# Get original data and decomposition results
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
data = state.get('spit_data', '')['context']
|
data = state.get('spit_data', '')['context']
|
||||||
@@ -203,7 +171,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务
|
# 使用优化的LLM服务
|
||||||
with get_db_context() as db_session:
|
|
||||||
response_content = await problem_service.call_llm_structured(
|
response_content = await problem_service.call_llm_structured(
|
||||||
state=state,
|
state=state,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@@ -214,7 +181,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||||
|
|
||||||
# Validate structured response
|
# 验证结构化响应
|
||||||
if not response_content or not hasattr(response_content, 'root'):
|
if not response_content or not hasattr(response_content, 'root'):
|
||||||
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
||||||
aggregated_dict = {}
|
aggregated_dict = {}
|
||||||
@@ -248,12 +215,12 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Provide more detailed error information
|
# 提供更详细的错误信息
|
||||||
error_details = {
|
error_details = {
|
||||||
"error_type": type(e).__name__,
|
"error_type": type(e).__name__,
|
||||||
"error_message": str(e),
|
"error_message": str(e),
|
||||||
"questions_count": len(databasets),
|
"questions_count": len(databasets),
|
||||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error(f"Problem_Extension error details: {error_details}")
|
logger.error(f"Problem_Extension error details: {error_details}")
|
||||||
@@ -263,6 +230,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||||
|
|
||||||
# Emit intermediate output for frontend
|
# Emit intermediate output for frontend
|
||||||
|
print(time.time() - start)
|
||||||
result = {
|
result = {
|
||||||
"context": aggregated_dict,
|
"context": aggregated_dict,
|
||||||
"original": data,
|
"original": data,
|
||||||
|
|||||||
@@ -6,41 +6,34 @@ import os
|
|||||||
# ===== 第三方库 =====
|
# ===== 第三方库 =====
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.db import get_db, get_db_context
|
||||||
|
|
||||||
|
from app.schemas import model_schema
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
|
from app.core.memory.agent.services.search_service import SearchService
|
||||||
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
|
COUNTState,
|
||||||
|
ReadState,
|
||||||
|
deduplicate_entries,
|
||||||
|
merge_to_key_value_pairs,
|
||||||
|
)
|
||||||
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
||||||
create_hybrid_retrieval_tool_sync,
|
create_hybrid_retrieval_tool_sync,
|
||||||
create_time_retrieval_tool,
|
create_time_retrieval_tool,
|
||||||
extract_tool_message_content,
|
extract_tool_message_content,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.services.search_service import SearchService
|
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
|
||||||
ReadState,
|
|
||||||
deduplicate_entries,
|
|
||||||
merge_to_key_value_pairs,
|
|
||||||
)
|
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
from app.db import get_db_context
|
|
||||||
from app.schemas import model_schema
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
from app.services.model_service import ModelConfigService
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
db = next(get_db())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
async def rag_config(state):
|
||||||
"""
|
|
||||||
Configure RAG (Retrieval-Augmented Generation) settings
|
|
||||||
|
|
||||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
|
||||||
weights, and reranker settings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: Current state containing user_rag_memory_id
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: RAG configuration dictionary
|
|
||||||
"""
|
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
kb_config = {
|
kb_config = {
|
||||||
"knowledge_bases": [
|
"knowledge_bases": [
|
||||||
@@ -57,22 +50,7 @@ async def rag_config(state):
|
|||||||
"reranker_top_k": 10
|
"reranker_top_k": 10
|
||||||
}
|
}
|
||||||
return kb_config
|
return kb_config
|
||||||
|
|
||||||
|
|
||||||
async def rag_knowledge(state,question):
|
async def rag_knowledge(state,question):
|
||||||
"""
|
|
||||||
Retrieve knowledge using RAG approach
|
|
||||||
|
|
||||||
Performs knowledge retrieval from configured knowledge bases using the
|
|
||||||
provided question and returns formatted results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: Current state containing configuration
|
|
||||||
question: Question to search for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
|
||||||
"""
|
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||||
@@ -93,24 +71,12 @@ async def rag_knowledge(state, question):
|
|||||||
|
|
||||||
|
|
||||||
async def llm_infomation(state: ReadState) -> ReadState:
|
async def llm_infomation(state: ReadState) -> ReadState:
|
||||||
"""
|
|
||||||
Get LLM configuration information from state
|
|
||||||
|
|
||||||
Retrieves model configuration details including model ID and tenant ID
|
|
||||||
from the memory configuration in the current state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing memory configuration
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Model configuration as Pydantic model
|
|
||||||
"""
|
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
model_id = memory_config.llm_model_id
|
model_id = memory_config.llm_model_id
|
||||||
tenant_id = memory_config.tenant_id
|
tenant_id = memory_config.tenant_id
|
||||||
|
|
||||||
# Use existing memory_config instead of re-querying database
|
# 使用现有的 memory_config 而不是重新查询数据库
|
||||||
# or use thread-safe database access
|
# 或者使用线程安全的数据库访问
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
||||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||||
@@ -119,20 +85,16 @@ async def llm_infomation(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
async def clean_databases(data) -> str:
|
async def clean_databases(data) -> str:
|
||||||
"""
|
"""
|
||||||
Simplified database search result cleaning function
|
简化的数据库搜索结果清理函数
|
||||||
|
|
||||||
Processes and cleans search results from various sources including
|
|
||||||
reranked results and time-based search results. Extracts text content
|
|
||||||
from structured data and returns as formatted string.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Search result data (can be string, dict, or other types)
|
data: 搜索结果数据
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Cleaned content string
|
清理后的内容字符串
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Parse JSON string
|
# 解析JSON字符串
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
try:
|
try:
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
@@ -142,24 +104,24 @@ async def clean_databases(data) -> str:
|
|||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
return str(data)
|
return str(data)
|
||||||
|
|
||||||
# Get result data
|
# 获取结果数据
|
||||||
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
||||||
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
||||||
results = data.get('results', data)
|
results = data.get('results', data)
|
||||||
if not isinstance(results, dict):
|
if not isinstance(results, dict):
|
||||||
return str(results)
|
return str(results)
|
||||||
|
|
||||||
# Collect all content
|
# 收集所有内容
|
||||||
content_list = []
|
content_list = []
|
||||||
|
|
||||||
# Process reranked results
|
# 处理重排序结果
|
||||||
reranked = results.get('reranked_results', {})
|
reranked = results.get('reranked_results', {})
|
||||||
if reranked:
|
if reranked:
|
||||||
for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']:
|
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
||||||
items = reranked.get(category, [])
|
items = reranked.get(category, [])
|
||||||
if isinstance(items, list):
|
if isinstance(items, list):
|
||||||
content_list.extend(items)
|
content_list.extend(items)
|
||||||
# Process time search results
|
# 处理时间搜索结果
|
||||||
time_search = results.get('time_search', {})
|
time_search = results.get('time_search', {})
|
||||||
if time_search:
|
if time_search:
|
||||||
if isinstance(time_search, dict):
|
if isinstance(time_search, dict):
|
||||||
@@ -169,23 +131,17 @@ async def clean_databases(data) -> str:
|
|||||||
elif isinstance(time_search, list):
|
elif isinstance(time_search, list):
|
||||||
content_list.extend(time_search)
|
content_list.extend(time_search)
|
||||||
|
|
||||||
# Extract text content,对 community 按 name 去重(多次 tool 调用会产生重复)
|
# 提取文本内容
|
||||||
text_parts = []
|
text_parts = []
|
||||||
seen_community_names = set()
|
|
||||||
for item in content_list:
|
for item in content_list:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
# community 节点用 name 去重
|
text = item.get('statement') or item.get('content', '')
|
||||||
if 'member_count' in item or 'core_entities' in item:
|
|
||||||
community_name = item.get('name') or item.get('id', '')
|
|
||||||
if community_name in seen_community_names:
|
|
||||||
continue
|
|
||||||
seen_community_names.add(community_name)
|
|
||||||
text = item.get('statement') or item.get('content') or item.get('summary', '')
|
|
||||||
if text:
|
if text:
|
||||||
text_parts.append(text)
|
text_parts.append(text)
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
text_parts.append(item)
|
text_parts.append(item)
|
||||||
|
|
||||||
|
|
||||||
return '\n'.join(text_parts).strip()
|
return '\n'.join(text_parts).strip()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -194,19 +150,11 @@ async def clean_databases(data) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||||
"""
|
|
||||||
Retrieve information using simplified search approach
|
|
||||||
|
|
||||||
Processes extended problems from previous nodes and performs retrieval
|
'''
|
||||||
using either RAG or hybrid search based on storage type. Handles concurrent
|
|
||||||
processing of multiple questions and deduplicates results.
|
|
||||||
|
|
||||||
Args:
|
模型信息
|
||||||
state: ReadState containing problem extensions and configuration
|
'''
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Updated state with retrieval results and intermediate outputs
|
|
||||||
"""
|
|
||||||
|
|
||||||
problem_extension=state.get('problem_extension', '')['context']
|
problem_extension=state.get('problem_extension', '')['context']
|
||||||
storage_type=state.get('storage_type', '')
|
storage_type=state.get('storage_type', '')
|
||||||
@@ -219,8 +167,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
for data in values:
|
for data in values:
|
||||||
problem_list.append(data)
|
problem_list.append(data)
|
||||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
|
# 创建异步任务处理单个问题
|
||||||
# Create async task to process individual questions
|
|
||||||
async def process_question_nodes(idx, question):
|
async def process_question_nodes(idx, question):
|
||||||
try:
|
try:
|
||||||
# Prepare search parameters based on storage type
|
# Prepare search parameters based on storage type
|
||||||
@@ -266,7 +213,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Process all questions concurrently
|
# 并发处理所有问题
|
||||||
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
||||||
databases_anser = await asyncio.gather(*tasks)
|
databases_anser = await asyncio.gather(*tasks)
|
||||||
databases_data = {
|
databases_data = {
|
||||||
@@ -313,21 +260,10 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
return {'retrieve':dup_databases}
|
return {'retrieve':dup_databases}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def retrieve(state: ReadState) -> ReadState:
|
async def retrieve(state: ReadState) -> ReadState:
|
||||||
"""
|
# 从state中获取end_user_id
|
||||||
Advanced retrieve function using LangChain agents and tools
|
|
||||||
|
|
||||||
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
|
|
||||||
to perform sophisticated information retrieval. Supports both RAG and traditional
|
|
||||||
memory storage approaches with concurrent processing and result deduplication.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing problem extensions and configuration
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Updated state with retrieval results and intermediate outputs
|
|
||||||
"""
|
|
||||||
# Get end_user_id from state
|
|
||||||
import time
|
import time
|
||||||
start=time.time()
|
start=time.time()
|
||||||
problem_extension = state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
@@ -347,7 +283,6 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
with get_db_context() as db: # 使用同步数据库上下文管理器
|
with get_db_context() as db: # 使用同步数据库上下文管理器
|
||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
return await llm_infomation(state)
|
return await llm_infomation(state)
|
||||||
|
|
||||||
llm_config = await get_llm_info()
|
llm_config = await get_llm_info()
|
||||||
api_key_obj = llm_config.api_keys[0]
|
api_key_obj = llm_config.api_keys[0]
|
||||||
api_key = api_key_obj.api_key
|
api_key = api_key_obj.api_key
|
||||||
@@ -361,11 +296,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
)
|
)
|
||||||
|
|
||||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||||
search_params = {
|
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
||||||
"end_user_id": end_user_id,
|
|
||||||
"return_raw_results": True,
|
|
||||||
"include": ["summaries", "statements", "chunks", "entities", "communities"],
|
|
||||||
}
|
|
||||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
llm,
|
llm,
|
||||||
@@ -373,21 +304,20 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create async task to process individual questions
|
# 创建异步任务处理单个问题
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# Define semaphore at module level to limit maximum concurrency
|
# 在模块级别定义信号量,限制最大并发数
|
||||||
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
|
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
|
||||||
|
|
||||||
async def process_question(idx, question):
|
async def process_question(idx, question):
|
||||||
async with SEMAPHORE: # Limit concurrency
|
async with SEMAPHORE: # 限制并发
|
||||||
try:
|
try:
|
||||||
if storage_type == "rag" and user_rag_memory_id:
|
if storage_type == "rag" and user_rag_memory_id:
|
||||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
||||||
question)
|
|
||||||
else:
|
else:
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
# Use asyncio to run synchronous agent.invoke in thread pool
|
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||||
import asyncio
|
import asyncio
|
||||||
response = await asyncio.get_event_loop().run_in_executor(
|
response = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
@@ -401,32 +331,8 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
raw_results = tool_results['content']
|
raw_results = tool_results['content']
|
||||||
clean_content = await clean_databases(raw_results)
|
clean_content = await clean_databases(raw_results)
|
||||||
|
|
||||||
# 社区展开:从 tool 返回结果中提取命中的 community,
|
|
||||||
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
|
|
||||||
_expanded_stmts_to_write = []
|
|
||||||
try:
|
|
||||||
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
|
|
||||||
reranked = results_dict.get('reranked_results', {})
|
|
||||||
community_hits = reranked.get('communities', [])
|
|
||||||
if not community_hits:
|
|
||||||
community_hits = results_dict.get('communities', [])
|
|
||||||
if community_hits:
|
|
||||||
from app.core.memory.agent.services.search_service import expand_communities_to_statements
|
|
||||||
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
|
|
||||||
community_results=community_hits,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
existing_content=clean_content,
|
|
||||||
)
|
|
||||||
if new_texts:
|
|
||||||
clean_content = clean_content + '\n' + '\n'.join(new_texts)
|
|
||||||
except Exception as parse_err:
|
|
||||||
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw_results = raw_results['results']
|
raw_results = raw_results['results']
|
||||||
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段)
|
|
||||||
if _expanded_stmts_to_write and isinstance(raw_results, dict):
|
|
||||||
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raw_results = []
|
raw_results = []
|
||||||
|
|
||||||
@@ -460,7 +366,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Process all questions concurrently
|
# 并发处理所有问题
|
||||||
import asyncio
|
import asyncio
|
||||||
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
||||||
databases_anser = await asyncio.gather(*tasks)
|
databases_anser = await asyncio.gather(*tasks)
|
||||||
@@ -507,3 +413,5 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
# json.dump(dup_databases, f, indent=4)
|
# json.dump(dup_databases, f, indent=4)
|
||||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||||
return {'retrieve': dup_databases}
|
return {'retrieve': dup_databases}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger, log_time
|
from app.core.logging_config import get_agent_logger, log_time
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
|
||||||
PerceptualSearchService,
|
|
||||||
)
|
|
||||||
from app.core.memory.agent.models.summary_models import (
|
from app.core.memory.agent.models.summary_models import (
|
||||||
RetrieveSummaryResponse,
|
RetrieveSummaryResponse,
|
||||||
SummaryResponse,
|
SummaryResponse,
|
||||||
@@ -19,144 +17,34 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.enums import Neo4jNodeType
|
from app.db import get_db
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
|
||||||
from app.db import get_db_context
|
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
db_session = next(get_db())
|
||||||
|
|
||||||
class SummaryNodeService(LLMServiceMixin):
|
class SummaryNodeService(LLMServiceMixin):
|
||||||
"""
|
"""总结节点服务类"""
|
||||||
Summary node service class
|
|
||||||
|
|
||||||
Handles summary generation operations using LLM services. Inherits from
|
|
||||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
|
||||||
generating summaries from retrieved information.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
template_service: Service for rendering Jinja2 templates
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
# 创建全局服务实例
|
||||||
# Create global service instance
|
|
||||||
summary_service = SummaryNodeService()
|
summary_service = SummaryNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
|
||||||
"""
|
|
||||||
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
|
|
||||||
|
|
||||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
|
||||||
weights, and reranker settings specifically for summary generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: Current state containing user_rag_memory_id
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: RAG configuration dictionary with knowledge base settings
|
|
||||||
"""
|
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
|
||||||
kb_config = {
|
|
||||||
"knowledge_bases": [
|
|
||||||
{
|
|
||||||
"kb_id": user_rag_memory_id,
|
|
||||||
"similarity_threshold": 0.7,
|
|
||||||
"vector_similarity_weight": 0.5,
|
|
||||||
"top_k": 10,
|
|
||||||
"retrieve_type": "participle"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"merge_strategy": "weight",
|
|
||||||
"reranker_id": os.getenv('reranker_id'),
|
|
||||||
"reranker_top_k": 10
|
|
||||||
}
|
|
||||||
return kb_config
|
|
||||||
|
|
||||||
|
|
||||||
async def rag_knowledge(state, question):
|
|
||||||
"""
|
|
||||||
Retrieve knowledge using RAG approach for summary generation
|
|
||||||
|
|
||||||
Performs knowledge retrieval from configured knowledge bases using the
|
|
||||||
provided question and returns formatted results for summary processing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: Current state containing configuration
|
|
||||||
question: Question to search for in knowledge base
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
|
||||||
- retrieval_knowledge: List of retrieved knowledge chunks
|
|
||||||
- clean_content: Formatted content string
|
|
||||||
- cleaned_query: Processed query string
|
|
||||||
- raw_results: Raw retrieval results
|
|
||||||
"""
|
|
||||||
kb_config = await rag_config(state)
|
|
||||||
end_user_id = state.get('end_user_id', '')
|
|
||||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
|
||||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
|
||||||
try:
|
|
||||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
|
||||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
|
||||||
cleaned_query = question
|
|
||||||
raw_results = clean_content
|
|
||||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
|
||||||
except Exception:
|
|
||||||
retrieval_knowledge = []
|
|
||||||
clean_content = ''
|
|
||||||
raw_results = ''
|
|
||||||
cleaned_query = question
|
|
||||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
|
||||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
|
||||||
|
|
||||||
|
|
||||||
async def summary_history(state: ReadState) -> ReadState:
|
async def summary_history(state: ReadState) -> ReadState:
|
||||||
"""
|
|
||||||
Retrieve conversation history for summary context
|
|
||||||
|
|
||||||
Gets the conversation history for the current user to provide context
|
|
||||||
for summary generation operations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing end_user_id
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Conversation history data
|
|
||||||
"""
|
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
|
||||||
search_mode) -> str:
|
|
||||||
"""
|
"""
|
||||||
Enhanced summary_llm function with better error handling and data validation
|
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||||
|
|
||||||
Generates summaries using LLM with structured output. Includes fallback mechanisms
|
|
||||||
for handling LLM failures and provides robust error recovery.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing current context
|
|
||||||
history: Conversation history for context
|
|
||||||
retrieve_info: Retrieved information to summarize
|
|
||||||
template_name: Jinja2 template name for prompt generation
|
|
||||||
operation_name: Type of operation (summary, input_summary, retrieve_summary)
|
|
||||||
response_model: Pydantic model for structured output
|
|
||||||
search_mode: Search mode flag ("0" for simple, "1" for complex)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Generated summary text or fallback message
|
|
||||||
"""
|
"""
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
|
|
||||||
# Build system prompt
|
# 构建系统提示词
|
||||||
if str(search_mode) == "0":
|
if str(search_mode) == "0":
|
||||||
system_prompt = await summary_service.template_service.render_template(
|
system_prompt = await summary_service.template_service.render_template(
|
||||||
template_name=template_name,
|
template_name=template_name,
|
||||||
@@ -173,8 +61,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
retrieve_info=retrieve_info
|
retrieve_info=retrieve_info
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# Use optimized LLM service for structured output
|
# 使用优化的LLM服务进行结构化输出
|
||||||
with get_db_context() as db_session:
|
|
||||||
structured = await summary_service.call_llm_structured(
|
structured = await summary_service.call_llm_structured(
|
||||||
state=state,
|
state=state,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@@ -182,23 +69,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
fallback_value=None
|
fallback_value=None
|
||||||
)
|
)
|
||||||
# Validate structured response
|
# 验证结构化响应
|
||||||
if structured is None:
|
if structured is None:
|
||||||
logger.warning("LLM返回None,使用默认回答")
|
logger.warning(f"LLM返回None,使用默认回答")
|
||||||
return "信息不足,无法回答"
|
return "信息不足,无法回答"
|
||||||
|
|
||||||
# Extract answer based on operation type
|
# 根据操作类型提取答案
|
||||||
if operation_name == "summary":
|
if operation_name == "summary":
|
||||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||||
else:
|
else:
|
||||||
# Handle RetrieveSummaryResponse
|
# 处理RetrieveSummaryResponse
|
||||||
if hasattr(structured, 'data') and structured.data:
|
if hasattr(structured, 'data') and structured.data:
|
||||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||||
else:
|
else:
|
||||||
logger.warning("结构化响应缺少data字段")
|
logger.warning(f"结构化响应缺少data字段")
|
||||||
aimessages = "信息不足,无法回答"
|
aimessages = "信息不足,无法回答"
|
||||||
|
|
||||||
# Validate answer is not empty
|
# 验证答案不为空
|
||||||
if not aimessages or aimessages.strip() == "":
|
if not aimessages or aimessages.strip() == "":
|
||||||
aimessages = "信息不足,无法回答"
|
aimessages = "信息不足,无法回答"
|
||||||
|
|
||||||
@@ -207,7 +94,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||||
|
|
||||||
# Try unstructured output as fallback
|
# 尝试非结构化输出作为fallback
|
||||||
try:
|
try:
|
||||||
logger.info("尝试非结构化输出作为fallback")
|
logger.info("尝试非结构化输出作为fallback")
|
||||||
response = await summary_service.call_llm_simple(
|
response = await summary_service.call_llm_simple(
|
||||||
@@ -218,9 +105,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response and response.strip():
|
if response and response.strip():
|
||||||
# Simple response cleaning
|
# 简单清理响应
|
||||||
cleaned_response = response.strip()
|
cleaned_response = response.strip()
|
||||||
# Remove possible JSON markers
|
# 移除可能的JSON标记
|
||||||
if cleaned_response.startswith('```'):
|
if cleaned_response.startswith('```'):
|
||||||
lines = cleaned_response.split('\n')
|
lines = cleaned_response.split('\n')
|
||||||
cleaned_response = '\n'.join(lines[1:-1])
|
cleaned_response = '\n'.join(lines[1:-1])
|
||||||
@@ -233,21 +120,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
logger.error(f"Fallback也失败: {fallback_error}")
|
logger.error(f"Fallback也失败: {fallback_error}")
|
||||||
return "信息不足,无法回答"
|
return "信息不足,无法回答"
|
||||||
|
|
||||||
|
|
||||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||||
"""
|
|
||||||
Save summary results to Redis session storage
|
|
||||||
|
|
||||||
Stores the generated summary and user query in Redis for session management
|
|
||||||
and conversation history tracking.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing user and query information
|
|
||||||
aimessages: Generated summary message to save
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Updated state after saving to Redis
|
|
||||||
"""
|
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
await SessionService(store).save_session(
|
await SessionService(store).save_session(
|
||||||
@@ -259,23 +132,7 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
|||||||
)
|
)
|
||||||
await SessionService(store).cleanup_duplicates()
|
await SessionService(store).cleanup_duplicates()
|
||||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||||
|
|
||||||
|
|
||||||
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||||
"""
|
|
||||||
Format summary results for different output types
|
|
||||||
|
|
||||||
Creates structured output formats for both input summary and retrieval summary
|
|
||||||
operations, including metadata and intermediate results for frontend display.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing storage and user information
|
|
||||||
aimessages: Generated summary message
|
|
||||||
raw_results: Raw search/retrieval results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (input_summary, retrieve_summary) formatted result dictionaries
|
|
||||||
"""
|
|
||||||
storage_type=state.get("storage_type",'')
|
storage_type=state.get("storage_type",'')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||||
data=state.get("data", '')
|
data=state.get("data", '')
|
||||||
@@ -312,21 +169,7 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState
|
|||||||
|
|
||||||
return input_summary,retrieve
|
return input_summary,retrieve
|
||||||
|
|
||||||
|
|
||||||
async def Input_Summary(state: ReadState) -> ReadState:
|
async def Input_Summary(state: ReadState) -> ReadState:
|
||||||
"""
|
|
||||||
Generate quick input summary from retrieved information
|
|
||||||
|
|
||||||
Performs fast retrieval and generates a quick summary response for user queries.
|
|
||||||
This function prioritizes speed by only searching summary nodes and provides
|
|
||||||
immediate feedback to users.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing user query, storage configuration, and context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Dictionary containing summary results with status and metadata
|
|
||||||
"""
|
|
||||||
start=time.time()
|
start=time.time()
|
||||||
storage_type=state.get("storage_type",'')
|
storage_type=state.get("storage_type",'')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
@@ -339,61 +182,16 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"question": data,
|
"question": data,
|
||||||
"return_raw_results": True,
|
"return_raw_results": True,
|
||||||
"include": [Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY] # MemorySummary 和 Community 同为高维度概括节点
|
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if storage_type != "rag":
|
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
||||||
|
|
||||||
async def _perceptual_search():
|
|
||||||
service = PerceptualSearchService(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
memory_config=memory_config,
|
|
||||||
)
|
|
||||||
return await service.search(query=data, limit=5)
|
|
||||||
|
|
||||||
hybrid_task = SearchService().execute_hybrid_search(
|
|
||||||
**search_params,
|
|
||||||
memory_config=memory_config,
|
|
||||||
expand_communities=False,
|
|
||||||
)
|
|
||||||
perceptual_task = _perceptual_search()
|
|
||||||
|
|
||||||
gather_results = await asyncio.gather(
|
|
||||||
hybrid_task, perceptual_task, return_exceptions=True
|
|
||||||
)
|
|
||||||
hybrid_result = gather_results[0]
|
|
||||||
perceptual_results = gather_results[1]
|
|
||||||
|
|
||||||
# 处理 hybrid search 异常
|
|
||||||
if isinstance(hybrid_result, Exception):
|
|
||||||
raise hybrid_result
|
|
||||||
retrieve_info, question, raw_results = hybrid_result
|
|
||||||
|
|
||||||
# 处理感知记忆结果
|
|
||||||
if isinstance(perceptual_results, Exception):
|
|
||||||
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
|
|
||||||
perceptual_results = []
|
|
||||||
|
|
||||||
# 拼接感知记忆内容到 retrieve_info
|
|
||||||
if perceptual_results and isinstance(perceptual_results, dict):
|
|
||||||
perceptual_content = perceptual_results.get("content", "")
|
|
||||||
if perceptual_content:
|
|
||||||
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
|
|
||||||
count = len(perceptual_results.get("memories", []))
|
|
||||||
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
|
|
||||||
|
|
||||||
# 调试:打印 community 检索结果数量
|
|
||||||
if raw_results and isinstance(raw_results, dict):
|
|
||||||
reranked = raw_results.get('reranked_results', {})
|
|
||||||
community_hits = reranked.get('communities', [])
|
|
||||||
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
|
|
||||||
f"summary 命中数: {len(reranked.get('summaries', []))}")
|
|
||||||
else:
|
|
||||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
||||||
retrieve_info, question, raw_results = "", data, []
|
retrieve_info, question, raw_results = "", data, []
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
||||||
# 'input_summary',RetrieveSummaryResponse)
|
# 'input_summary',RetrieveSummaryResponse)
|
||||||
@@ -410,25 +208,14 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
"error": str(e)
|
"error": str(e)
|
||||||
}
|
}
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
try:
|
||||||
duration = end - start
|
duration = end - start
|
||||||
|
except Exception:
|
||||||
|
duration = 0.0
|
||||||
log_time('检索', duration)
|
log_time('检索', duration)
|
||||||
return {"summary":summary}
|
return {"summary":summary}
|
||||||
|
|
||||||
|
|
||||||
async def Retrieve_Summary(state: ReadState)-> ReadState:
|
async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||||
"""
|
|
||||||
Generate comprehensive summary from retrieved expansion issues
|
|
||||||
|
|
||||||
Processes retrieved expansion issues and generates a detailed summary using LLM.
|
|
||||||
This function handles complex retrieval results and provides comprehensive answers
|
|
||||||
based on expanded query results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing retrieve data with expansion issues
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Dictionary containing comprehensive summary results
|
|
||||||
"""
|
|
||||||
retrieve=state.get("retrieve", '')
|
retrieve=state.get("retrieve", '')
|
||||||
history = await summary_history( state)
|
history = await summary_history( state)
|
||||||
import json
|
import json
|
||||||
@@ -448,20 +235,8 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
|||||||
retrieve_info_str=list(set(retrieve_info_str))
|
retrieve_info_str=list(set(retrieve_info_str))
|
||||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
retrieve_info_str='\n'.join(retrieve_info_str)
|
||||||
|
|
||||||
# Merge perceptual memory content
|
aimessages=await summary_llm(state,history,retrieve_info_str,
|
||||||
perceptual_data = state.get("perceptual_data", {})
|
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
|
||||||
if perceptual_content:
|
|
||||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
|
||||||
|
|
||||||
aimessages = await summary_llm(
|
|
||||||
state,
|
|
||||||
history,
|
|
||||||
retrieve_info_str,
|
|
||||||
'direct_summary_prompt.jinja2',
|
|
||||||
'retrieve_summary', RetrieveSummaryResponse,
|
|
||||||
"1"
|
|
||||||
)
|
|
||||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||||
await summary_redis_save(state, aimessages)
|
await summary_redis_save(state, aimessages)
|
||||||
if aimessages == '':
|
if aimessages == '':
|
||||||
@@ -474,26 +249,13 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
|||||||
duration = 0.0
|
duration = 0.0
|
||||||
log_time('Retrieval summary', duration)
|
log_time('Retrieval summary', duration)
|
||||||
|
|
||||||
# Fixed coroutine call - await first, then access return value
|
# 修复协程调用 - 先await,然后访问返回值
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary":summary}
|
return {"summary":summary}
|
||||||
|
|
||||||
|
|
||||||
async def Summary(state: ReadState)-> ReadState:
|
async def Summary(state: ReadState)-> ReadState:
|
||||||
"""
|
|
||||||
Generate final comprehensive summary from verified data
|
|
||||||
|
|
||||||
Creates the final summary using verified expansion issues and conversation history.
|
|
||||||
This function processes verified data to generate the most comprehensive and
|
|
||||||
accurate response to user queries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing verified data and query information
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Dictionary containing final summary results
|
|
||||||
"""
|
|
||||||
start=time.time()
|
start=time.time()
|
||||||
query = state.get("data", '')
|
query = state.get("data", '')
|
||||||
verify=state.get("verify", '')
|
verify=state.get("verify", '')
|
||||||
@@ -506,12 +268,6 @@ async def Summary(state: ReadState) -> ReadState:
|
|||||||
retrieve_info_str+=i+'\n'
|
retrieve_info_str+=i+'\n'
|
||||||
history=await summary_history(state)
|
history=await summary_history(state)
|
||||||
|
|
||||||
# Merge perceptual memory content
|
|
||||||
perceptual_data = state.get("perceptual_data", {})
|
|
||||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
|
||||||
if perceptual_content:
|
|
||||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
@@ -530,26 +286,12 @@ async def Summary(state: ReadState) -> ReadState:
|
|||||||
duration = 0.0
|
duration = 0.0
|
||||||
log_time('Retrieval summary', duration)
|
log_time('Retrieval summary', duration)
|
||||||
|
|
||||||
# Fixed coroutine call - await first, then access return value
|
# 修复协程调用 - 先await,然后访问返回值
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary":summary}
|
return {"summary":summary}
|
||||||
|
|
||||||
|
|
||||||
async def Summary_fails(state: ReadState)-> ReadState:
|
async def Summary_fails(state: ReadState)-> ReadState:
|
||||||
"""
|
|
||||||
Generate fallback summary when normal summary process fails
|
|
||||||
|
|
||||||
Provides a fallback summary generation mechanism when the standard summary
|
|
||||||
process encounters errors or fails to produce satisfactory results. Uses
|
|
||||||
a specialized failure template to handle edge cases.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing verified data and failure context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReadState: Dictionary containing fallback summary results
|
|
||||||
"""
|
|
||||||
storage_type=state.get("storage_type", '')
|
storage_type=state.get("storage_type", '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
||||||
history = await summary_history(state)
|
history = await summary_history(state)
|
||||||
@@ -562,13 +304,6 @@ async def Summary_fails(state: ReadState) -> ReadState:
|
|||||||
if key == 'answer_small':
|
if key == 'answer_small':
|
||||||
for i in value:
|
for i in value:
|
||||||
retrieve_info_str += i + '\n'
|
retrieve_info_str += i + '\n'
|
||||||
|
|
||||||
# Merge perceptual memory content
|
|
||||||
perceptual_data = state.get("perceptual_data", {})
|
|
||||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
|
||||||
if perceptual_content:
|
|
||||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
from app.core.memory.agent.models.verification_models import VerificationResult
|
from app.core.memory.agent.models.verification_models import VerificationResult
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
PROJECT_ROOT_,
|
PROJECT_ROOT_,
|
||||||
ReadState,
|
ReadState,
|
||||||
@@ -11,53 +10,29 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.db import get_db_context
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
|
db_session = next(get_db())
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VerificationNodeService(LLMServiceMixin):
|
class VerificationNodeService(LLMServiceMixin):
|
||||||
"""
|
"""验证节点服务类"""
|
||||||
Verification node service class
|
|
||||||
|
|
||||||
Handles data verification operations using LLM services. Inherits from
|
|
||||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
|
||||||
verifying and validating retrieved information.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
template_service: Service for rendering Jinja2 templates
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
# 创建全局服务实例
|
||||||
# Create global service instance
|
|
||||||
verification_service = VerificationNodeService()
|
verification_service = VerificationNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||||
"""
|
"""处理验证结果并生成输出格式"""
|
||||||
Process verification results and generate output format
|
|
||||||
|
|
||||||
Transforms VerificationResult objects into structured output format suitable
|
|
||||||
for frontend consumption. Handles conversion of VerificationItem objects to
|
|
||||||
dictionary format and adds metadata for tracking.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: ReadState containing storage and user configuration
|
|
||||||
messages_deal: VerificationResult containing verification outcomes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Formatted verification result with status and metadata
|
|
||||||
"""
|
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
data = state.get('data', '')
|
data = state.get('data', '')
|
||||||
|
|
||||||
# Convert VerificationItem objects to dictionary list
|
# 将 VerificationItem 对象转换为字典列表
|
||||||
verified_data = []
|
verified_data = []
|
||||||
if messages_deal.expansion_issue:
|
if messages_deal.expansion_issue:
|
||||||
for item in messages_deal.expansion_issue:
|
for item in messages_deal.expansion_issue:
|
||||||
@@ -83,8 +58,6 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Verify_result
|
return Verify_result
|
||||||
|
|
||||||
|
|
||||||
async def Verify(state: ReadState):
|
async def Verify(state: ReadState):
|
||||||
logger.info("=== Verify 节点开始执行 ===")
|
logger.info("=== Verify 节点开始执行 ===")
|
||||||
try:
|
try:
|
||||||
@@ -98,8 +71,7 @@ async def Verify(state: ReadState):
|
|||||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||||
|
|
||||||
retrieve = state.get("retrieve", {})
|
retrieve = state.get("retrieve", {})
|
||||||
logger.info(
|
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||||
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
|
||||||
|
|
||||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||||
@@ -111,7 +83,7 @@ async def Verify(state: ReadState):
|
|||||||
|
|
||||||
logger.info("Verify: 开始渲染模板")
|
logger.info("Verify: 开始渲染模板")
|
||||||
|
|
||||||
# Generate JSON schema to guide LLM output format
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
json_schema = VerificationResult.model_json_schema()
|
json_schema = VerificationResult.model_json_schema()
|
||||||
|
|
||||||
system_prompt = await verification_service.template_service.render_template(
|
system_prompt = await verification_service.template_service.render_template(
|
||||||
@@ -126,10 +98,9 @@ async def Verify(state: ReadState):
|
|||||||
# 使用优化的LLM服务,添加超时保护
|
# 使用优化的LLM服务,添加超时保护
|
||||||
logger.info("Verify: 开始调用 LLM")
|
logger.info("Verify: 开始调用 LLM")
|
||||||
try:
|
try:
|
||||||
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
|
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||||
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
|
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||||
|
import asyncio
|
||||||
with get_db_context() as db_session:
|
|
||||||
structured = await asyncio.wait_for(
|
structured = await asyncio.wait_for(
|
||||||
verification_service.call_llm_structured(
|
verification_service.call_llm_structured(
|
||||||
state=state,
|
state=state,
|
||||||
@@ -144,7 +115,7 @@ async def Verify(state: ReadState):
|
|||||||
"reason": "验证失败或超时"
|
"reason": "验证失败或超时"
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
timeout=150.0 # 150 second timeout
|
timeout=150.0 # 150秒超时
|
||||||
)
|
)
|
||||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
|
from app.core.memory.agent.utils.write_tools import write
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def write_node(state: WriteState) -> WriteState:
|
||||||
|
"""
|
||||||
|
Write data to the database/file system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: WriteState containing messages, end_user_id, and memory_config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'write_result' with status and data fields
|
||||||
|
"""
|
||||||
|
messages = state.get('messages', [])
|
||||||
|
end_user_id = state.get('end_user_id', '')
|
||||||
|
memory_config = state.get('memory_config', '')
|
||||||
|
|
||||||
|
# Convert LangChain messages to structured format expected by write()
|
||||||
|
structured_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
||||||
|
# Map LangChain message types to role names
|
||||||
|
role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type
|
||||||
|
structured_messages.append({
|
||||||
|
"role": role,
|
||||||
|
"content": msg.content # content is now guaranteed to be a string
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await write(
|
||||||
|
messages=structured_messages,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
memory_config=memory_config,
|
||||||
|
)
|
||||||
|
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||||
|
|
||||||
|
write_result = {
|
||||||
|
"status": "success",
|
||||||
|
"data": structured_messages,
|
||||||
|
"config_id": memory_config.config_id,
|
||||||
|
"config_name": memory_config.config_name,
|
||||||
|
}
|
||||||
|
return {"write_result": write_result}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Data_write failed: {e}", exc_info=True)
|
||||||
|
write_result = {
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
}
|
||||||
|
return {"write_result": write_result}
|
||||||
@@ -1,20 +1,22 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import logging
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.constants import START, END
|
from langgraph.constants import START, END
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
|
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
|
||||||
perceptual_retrieve_node,
|
|
||||||
)
|
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
||||||
Split_The_Problem,
|
Split_The_Problem,
|
||||||
Problem_Extension,
|
Problem_Extension,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
||||||
retrieve_nodes,
|
retrieve,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||||
Input_Summary,
|
Input_Summary,
|
||||||
@@ -28,26 +30,12 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
|||||||
Retrieve_continue,
|
Retrieve_continue,
|
||||||
Verify_continue,
|
Verify_continue,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def make_read_graph():
|
async def make_read_graph():
|
||||||
"""
|
"""创建并返回 LangGraph 工作流"""
|
||||||
Create and return a LangGraph workflow for memory reading operations
|
|
||||||
|
|
||||||
Builds a state graph workflow that handles memory retrieval, problem analysis,
|
|
||||||
verification, and summarization. The workflow includes nodes for content input,
|
|
||||||
problem splitting, retrieval, verification, and various summary operations.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
StateGraph: Compiled LangGraph workflow for memory reading
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If workflow creation fails
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Build workflow graph
|
# Build workflow graph
|
||||||
workflow = StateGraph(ReadState)
|
workflow = StateGraph(ReadState)
|
||||||
@@ -55,34 +43,135 @@ async def make_read_graph():
|
|||||||
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
||||||
workflow.add_node("Problem_Extension", Problem_Extension)
|
workflow.add_node("Problem_Extension", Problem_Extension)
|
||||||
workflow.add_node("Input_Summary", Input_Summary)
|
workflow.add_node("Input_Summary", Input_Summary)
|
||||||
workflow.add_node("Retrieve", retrieve_nodes)
|
# workflow.add_node("Retrieve", retrieve_nodes)
|
||||||
# workflow.add_node("Retrieve", retrieve)
|
workflow.add_node("Retrieve", retrieve)
|
||||||
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
|
|
||||||
workflow.add_node("Verify", Verify)
|
workflow.add_node("Verify", Verify)
|
||||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||||
workflow.add_node("Summary", Summary)
|
workflow.add_node("Summary", Summary)
|
||||||
workflow.add_node("Summary_fails", Summary_fails)
|
workflow.add_node("Summary_fails", Summary_fails)
|
||||||
|
|
||||||
# Add edges to define workflow flow
|
# 添加边
|
||||||
workflow.add_edge(START, "content_input")
|
workflow.add_edge(START, "content_input")
|
||||||
workflow.add_conditional_edges("content_input", Split_continue)
|
workflow.add_conditional_edges("content_input", Split_continue)
|
||||||
workflow.add_edge("Input_Summary", END)
|
workflow.add_edge("Input_Summary", END)
|
||||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||||
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
|
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||||
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
|
|
||||||
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
|
|
||||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||||
workflow.add_edge("Retrieve_Summary", END)
|
workflow.add_edge("Retrieve_Summary", END)
|
||||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||||
workflow.add_edge("Summary_fails", END)
|
workflow.add_edge("Summary_fails", END)
|
||||||
workflow.add_edge("Summary", END)
|
workflow.add_edge("Summary", END)
|
||||||
|
|
||||||
|
|
||||||
|
'''-----'''
|
||||||
# workflow.add_edge("Retrieve", END)
|
# workflow.add_edge("Retrieve", END)
|
||||||
|
|
||||||
# Compile workflow
|
# 编译工作流
|
||||||
graph = workflow.compile()
|
graph = workflow.compile()
|
||||||
yield graph
|
yield graph
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建工作流失败: {e}")
|
print(f"创建工作流失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
finally:
|
||||||
|
print("工作流创建完成")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""主函数 - 运行工作流"""
|
||||||
|
message = "昨天有什么好看的电影"
|
||||||
|
end_user_id = '88a459f5_text09' # 组ID
|
||||||
|
storage_type = 'neo4j' # 存储类型
|
||||||
|
search_switch = '1' # 搜索开关
|
||||||
|
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||||
|
|
||||||
|
# 获取数据库会话
|
||||||
|
db_session = next(get_db())
|
||||||
|
config_service = MemoryConfigService(db_session)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=17, # 改为整数
|
||||||
|
service_name="MemoryAgentService"
|
||||||
|
)
|
||||||
|
import time
|
||||||
|
start=time.time()
|
||||||
|
try:
|
||||||
|
async with make_read_graph() as graph:
|
||||||
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
|
# 初始状态 - 包含所有必要字段
|
||||||
|
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||||
|
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||||
|
# 获取节点更新信息
|
||||||
|
_intermediate_outputs = []
|
||||||
|
summary = ''
|
||||||
|
|
||||||
|
async for update_event in graph.astream(
|
||||||
|
initial_state,
|
||||||
|
stream_mode="updates",
|
||||||
|
config=config
|
||||||
|
):
|
||||||
|
for node_name, node_data in update_event.items():
|
||||||
|
print(f"处理节点: {node_name}")
|
||||||
|
|
||||||
|
# 处理不同Summary节点的返回结构
|
||||||
|
if 'Summary' in node_name:
|
||||||
|
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||||
|
summary = node_data['InputSummary']['summary_result']
|
||||||
|
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||||||
|
summary = node_data['RetrieveSummary']['summary_result']
|
||||||
|
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||||||
|
summary = node_data['summary']['summary_result']
|
||||||
|
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||||||
|
summary = node_data['SummaryFails']['summary_result']
|
||||||
|
|
||||||
|
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||||
|
if spit_data and spit_data != [] and spit_data != {}:
|
||||||
|
_intermediate_outputs.append(spit_data)
|
||||||
|
|
||||||
|
# Problem_Extension 节点
|
||||||
|
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||||
|
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||||
|
_intermediate_outputs.append(problem_extension)
|
||||||
|
|
||||||
|
# Retrieve 节点
|
||||||
|
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||||
|
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||||
|
_intermediate_outputs.extend(retrieve_node)
|
||||||
|
|
||||||
|
# Verify 节点
|
||||||
|
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||||
|
if verify_n and verify_n != [] and verify_n != {}:
|
||||||
|
_intermediate_outputs.append(verify_n)
|
||||||
|
|
||||||
|
|
||||||
|
# Summary 节点
|
||||||
|
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||||
|
if summary_n and summary_n != [] and summary_n != {}:
|
||||||
|
_intermediate_outputs.append(summary_n)
|
||||||
|
|
||||||
|
# # 过滤掉空值
|
||||||
|
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||||
|
#
|
||||||
|
# # 优化搜索结果
|
||||||
|
# print("=== 开始优化搜索结果 ===")
|
||||||
|
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||||
|
# result=reorder_output_results(optimized_outputs)
|
||||||
|
# # 保存优化后的结果到文件
|
||||||
|
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
|
||||||
|
# import json
|
||||||
|
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
|
||||||
|
#
|
||||||
|
print(f"=== 最终摘要 ===")
|
||||||
|
print(summary)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
end=time.time()
|
||||||
|
print(100*'y')
|
||||||
|
print(f"总耗时: {end-start}s")
|
||||||
|
print(100*'y')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||||
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
counter = COUNTState(limit=3)
|
counter = COUNTState(limit=3)
|
||||||
|
|
||||||
|
|
||||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||||
"""
|
"""
|
||||||
Determine routing based on search_switch value.
|
Determine routing based on search_switch value.
|
||||||
@@ -25,7 +25,6 @@ def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summ
|
|||||||
return 'Input_Summary'
|
return 'Input_Summary'
|
||||||
return 'Split_The_Problem' # 默认情况
|
return 'Split_The_Problem' # 默认情况
|
||||||
|
|
||||||
|
|
||||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||||
"""
|
"""
|
||||||
Determine routing based on search_switch value.
|
Determine routing based on search_switch value.
|
||||||
@@ -44,8 +43,6 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
|||||||
elif search_switch == '1':
|
elif search_switch == '1':
|
||||||
return 'Retrieve_Summary'
|
return 'Retrieve_Summary'
|
||||||
return 'Retrieve_Summary' # Default based on business logic
|
return 'Retrieve_Summary' # Default based on business logic
|
||||||
|
|
||||||
|
|
||||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||||
status=state.get('verify', '')['status']
|
status=state.get('verify', '')['status']
|
||||||
# loop_count = counter.get_total()
|
# loop_count = counter.get_total()
|
||||||
|
|||||||
@@ -1,244 +1,184 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from app.celery_task_scheduler import scheduler
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||||
|
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||||
|
|
||||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||||
from app.core.memory.agent.utils.redis_tool import count_store
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
from app.core.memory.agent.utils.redis_tool import count_store
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context, get_db
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
|
from app.services.memory_konwledges_server import write_rag
|
||||||
|
from app.services.task_service import get_task_memory_write_result
|
||||||
|
from app.tasks import write_message_task
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
|
|
||||||
|
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||||
async def write(
|
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||||
storage_type,
|
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||||
end_user_id,
|
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||||
user_message,
|
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||||
ai_message,
|
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||||
user_rag_memory_id,
|
actual_config_id, long_term_messages=[]):
|
||||||
actual_end_user_id,
|
|
||||||
actual_config_id,
|
|
||||||
long_term_messages=None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Write memory with structured message support
|
写入记忆(支持结构化消息)
|
||||||
|
|
||||||
Handles memory writing operations for different storage types (Neo4j/RAG).
|
|
||||||
Supports both individual message pairs and batch long-term message processing.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
storage_type: Storage type identifier ("neo4j" or "rag")
|
storage_type: 存储类型 (neo4j/rag)
|
||||||
end_user_id: Terminal user identifier
|
end_user_id: 终端用户ID
|
||||||
user_message: User message content
|
user_message: 用户消息内容
|
||||||
ai_message: AI response content
|
ai_message: AI 回复内容
|
||||||
user_rag_memory_id: RAG memory identifier
|
user_rag_memory_id: RAG 记忆ID
|
||||||
actual_end_user_id: Actual user identifier for storage
|
actual_end_user_id: 实际用户ID
|
||||||
actual_config_id: Configuration identifier
|
actual_config_id: 配置ID
|
||||||
long_term_messages: Optional list of structured messages for batch processing
|
|
||||||
|
|
||||||
Logic explanation:
|
逻辑说明:
|
||||||
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
|
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||||
- Neo4j mode: Uses structured message lists
|
- Neo4j 模式:使用结构化消息列表
|
||||||
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
|
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||||
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
|
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||||
3. Each message is converted to independent Chunk, preserving speaker field
|
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if long_term_messages is None:
|
db = next(get_db())
|
||||||
long_term_messages = []
|
try:
|
||||||
with get_db_context() as db:
|
|
||||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||||
# Neo4j mode: Use structured message lists
|
# Neo4j 模式:使用结构化消息列表
|
||||||
structured_messages = []
|
structured_messages = []
|
||||||
|
|
||||||
# Always add user message (if not empty)
|
# 始终添加用户消息(如果不为空)
|
||||||
if isinstance(user_message, str) and user_message.strip() != "":
|
if isinstance(user_message, str) and user_message.strip() != "":
|
||||||
structured_messages.append({"role": "user", "content": user_message})
|
structured_messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
# Only add assistant message when AI reply is not empty
|
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||||
|
|
||||||
# If long_term_messages provided, use it to replace structured_messages
|
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||||
if long_term_messages and isinstance(long_term_messages, list):
|
if long_term_messages and isinstance(long_term_messages, list):
|
||||||
structured_messages = long_term_messages
|
structured_messages = long_term_messages
|
||||||
elif long_term_messages and isinstance(long_term_messages, str):
|
elif long_term_messages and isinstance(long_term_messages, str):
|
||||||
# If it's a JSON string, parse it first
|
# 如果是 JSON 字符串,先解析
|
||||||
try:
|
try:
|
||||||
structured_messages = json.loads(long_term_messages)
|
structured_messages = json.loads(long_term_messages)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||||
|
|
||||||
# If no messages, return directly
|
# 如果没有消息,直接返回
|
||||||
if not structured_messages:
|
if not structured_messages:
|
||||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
# write_id = write_message_task.delay(
|
write_id = write_message_task.delay(
|
||||||
# actual_end_user_id, # end_user_id: User ID
|
actual_end_user_id, # end_user_id: 用户ID
|
||||||
# structured_messages, # message: JSON string format message list
|
structured_messages, # message: JSON 字符串格式的消息列表
|
||||||
# str(actual_config_id), # config_id: Configuration ID string
|
str(actual_config_id), # config_id: 配置ID字符串
|
||||||
# storage_type, # storage_type: "neo4j"
|
storage_type, # storage_type: "neo4j"
|
||||||
# user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||||
# )
|
|
||||||
scheduler.push_task(
|
|
||||||
"app.core.memory.agent.write_message",
|
|
||||||
str(actual_end_user_id),
|
|
||||||
{
|
|
||||||
"end_user_id": str(actual_end_user_id),
|
|
||||||
"message": structured_messages,
|
|
||||||
"config_id": str(actual_config_id),
|
|
||||||
"storage_type": storage_type,
|
|
||||||
"user_rag_memory_id": user_rag_memory_id or ""
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
# logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||||
# write_status = get_task_memory_write_result(str(write_id))
|
|
||||||
# logger.info(f'[WRITE] Task result - user={actual_end_user_id}')
|
|
||||||
|
|
||||||
|
|
||||||
async def term_memory_save(end_user_id, strategy_type, scope):
|
|
||||||
"""
|
|
||||||
Save long-term memory data to database
|
|
||||||
|
|
||||||
Handles the storage of long-term memory data based on different strategies
|
|
||||||
(chunk-based or aggregate-based) and manages the transition from short-term
|
|
||||||
to long-term memory storage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: User identifier for memory association
|
|
||||||
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
|
||||||
scope: Scope/window size for memory processing
|
|
||||||
"""
|
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
repo = LongTermMemoryRepository(db_session)
|
repo = LongTermMemoryRepository(db_session)
|
||||||
|
|
||||||
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
result = write_store.get_session_by_userid(end_user_id)
|
result = write_store.get_session_by_userid(end_user_id)
|
||||||
if not result:
|
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||||
logger.warning(f"No write data found for user {end_user_id}")
|
|
||||||
return
|
|
||||||
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
|
|
||||||
data = await format_parsing(result, "dict")
|
data = await format_parsing(result, "dict")
|
||||||
chunk_data = data[:scope]
|
chunk_data = data[:scope]
|
||||||
if len(chunk_data)==scope:
|
if len(chunk_data)==scope:
|
||||||
repo.upsert(end_user_id, chunk_data)
|
repo.upsert(end_user_id, chunk_data)
|
||||||
logger.info('---------写入短长期-----------')
|
logger.info(f'---------写入短长期-----------')
|
||||||
else:
|
else:
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||||
long_messages = await messages_parse(long_time_data)
|
long_messages = await messages_parse(long_time_data)
|
||||||
repo.upsert(end_user_id, long_messages)
|
repo.upsert(end_user_id, long_messages)
|
||||||
logger.info('写入短长期:')
|
logger.info(f'写入短长期:')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
'''根据窗口'''
|
||||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||||
"""
|
'''
|
||||||
TODO 考虑作为滑动窗口写入的函数
|
根据窗口获取redis数据,写入neo4j:
|
||||||
Process dialogue based on window size and write to Neo4j
|
|
||||||
|
|
||||||
Manages conversation data based on a sliding window approach. When the window
|
|
||||||
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: Terminal user identifier
|
end_user_id: 终端用户ID
|
||||||
memory_config: Memory configuration object containing settings
|
memory_config: 内存配置对象
|
||||||
langchain_messages: Original message data list
|
langchain_messages:原始数据LIST
|
||||||
scope: Window size determining when to trigger long-term storage
|
scope:窗口大小
|
||||||
"""
|
'''
|
||||||
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
|
scope=scope
|
||||||
if is_end_user_has_history:
|
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||||
end_user_visit_count, redis_messages = is_end_user_has_history
|
if is_end_user_id is not False:
|
||||||
else:
|
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||||
return
|
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||||
end_user_visit_count += 1
|
is_end_user_id += 1
|
||||||
if end_user_visit_count < scope:
|
langchain_messages += redis_messages
|
||||||
redis_messages.extend(langchain_messages)
|
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||||
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
|
elif int(is_end_user_id) == int(scope):
|
||||||
else:
|
|
||||||
logger.info('写入长期记忆NEO4J')
|
logger.info('写入长期记忆NEO4J')
|
||||||
redis_messages.extend(langchain_messages)
|
formatted_messages = (redis_messages)
|
||||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||||
if hasattr(memory_config, 'config_id'):
|
if hasattr(memory_config, 'config_id'):
|
||||||
config_id = memory_config.config_id
|
config_id = memory_config.config_id
|
||||||
else:
|
else:
|
||||||
config_id = memory_config
|
config_id = memory_config
|
||||||
|
|
||||||
scheduler.push_task(
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
"app.core.memory.agent.write_message",
|
config_id, formatted_messages)
|
||||||
str(end_user_id),
|
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
{
|
else:
|
||||||
"end_user_id": str(end_user_id),
|
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
"message": redis_messages,
|
|
||||||
"config_id": str(config_id),
|
|
||||||
"storage_type": AgentMemory_Long_Term.STORAGE_NEO4J,
|
|
||||||
"user_rag_memory_id": ""
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# write_message_task.delay(
|
|
||||||
# end_user_id, # end_user_id: User ID
|
|
||||||
# redis_messages, # message: JSON string format message list
|
|
||||||
# config_id, # config_id: Configuration ID string
|
|
||||||
# AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
|
||||||
# "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
|
||||||
# )
|
|
||||||
count_store.update_sessions_count(end_user_id, 0, [])
|
|
||||||
|
|
||||||
|
|
||||||
|
"""根据时间"""
|
||||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||||
"""
|
'''
|
||||||
Process memory storage based on time intervals and write to Neo4j
|
根据时间获取redis数据,写入neo4j:
|
||||||
|
|
||||||
Retrieves Redis data based on time intervals and writes it to Neo4j for
|
|
||||||
long-term storage. This function handles time-based memory consolidation.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: Terminal user identifier
|
end_user_id: 终端用户ID
|
||||||
memory_config: Memory configuration object containing settings
|
memory_config: 内存配置对象
|
||||||
time: Time interval for data retrieval
|
'''
|
||||||
"""
|
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||||
format_messages = long_time_data
|
format_messages = (long_time_data)
|
||||||
messages=[]
|
messages=[]
|
||||||
memory_config=memory_config.config_id
|
memory_config=memory_config.config_id
|
||||||
for i in format_messages:
|
for i in format_messages:
|
||||||
message=json.loads(i['Query'])
|
message=json.loads(i['Query'])
|
||||||
messages+= message
|
messages+= message
|
||||||
if format_messages:
|
if format_messages!=[]:
|
||||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
memory_config, messages)
|
memory_config, messages)
|
||||||
|
'''聚合判断'''
|
||||||
|
|
||||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||||
"""
|
"""
|
||||||
Aggregation judgment function: determine if input sentence and historical messages describe the same event
|
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||||
|
|
||||||
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
|
|
||||||
historical data or stored as separate events. This helps optimize memory storage and retrieval.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: Terminal user identifier
|
end_user_id: 终端用户ID
|
||||||
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||||
memory_config: Memory configuration object containing LLM settings
|
memory_config: 内存配置对象
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Aggregation judgment result containing is_same_event flag and processed output
|
|
||||||
"""
|
"""
|
||||||
history = None
|
|
||||||
try:
|
try:
|
||||||
# 1. Get historical session data (using new method)
|
# 1. 获取历史会话数据(使用新方法)
|
||||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||||
history = await format_parsing(result)
|
history = await format_parsing(result)
|
||||||
if not result:
|
if not result:
|
||||||
@@ -285,7 +225,9 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
|||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
|
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"is_same_event": False,
|
"is_same_event": False,
|
||||||
|
|||||||
@@ -2,53 +2,41 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
from app.core.memory.src.search import (
|
from app.core.memory.src.search import (
|
||||||
search_by_temporal,
|
search_by_temporal,
|
||||||
search_by_keyword_temporal,
|
search_by_keyword_temporal,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_tool_message_content(response):
|
def extract_tool_message_content(response):
|
||||||
"""
|
"""从agent响应中提取ToolMessage内容和工具名称"""
|
||||||
Extract ToolMessage content and tool names from agent response
|
|
||||||
|
|
||||||
Parses agent response messages to extract tool execution results and metadata.
|
|
||||||
Handles JSON parsing and provides structured access to tool output data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: Agent response dictionary containing messages
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
|
|
||||||
- tool_name: Name of the executed tool
|
|
||||||
- content: Parsed tool execution result (JSON or raw text)
|
|
||||||
"""
|
|
||||||
messages = response.get('messages', [])
|
messages = response.get('messages', [])
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
||||||
# This is a ToolMessage
|
# 这是一个ToolMessage
|
||||||
tool_content = message.content
|
tool_content = message.content
|
||||||
tool_name = None
|
tool_name = None
|
||||||
|
|
||||||
# Try to get tool name
|
# 尝试获取工具名称
|
||||||
if hasattr(message, 'name'):
|
if hasattr(message, 'name'):
|
||||||
tool_name = message.name
|
tool_name = message.name
|
||||||
elif hasattr(message, 'tool_name'):
|
elif hasattr(message, 'tool_name'):
|
||||||
tool_name = message.tool_name
|
tool_name = message.tool_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Parse JSON content
|
# 解析JSON内容
|
||||||
parsed_content = json.loads(tool_content)
|
parsed_content = json.loads(tool_content)
|
||||||
return {
|
return {
|
||||||
'tool_name': tool_name,
|
'tool_name': tool_name,
|
||||||
'content': parsed_content
|
'content': parsed_content
|
||||||
}
|
}
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# If not JSON format, return content directly
|
# 如果不是JSON格式,直接返回内容
|
||||||
return {
|
return {
|
||||||
'tool_name': tool_name,
|
'tool_name': tool_name,
|
||||||
'content': tool_content
|
'content': tool_content
|
||||||
@@ -58,49 +46,26 @@ def extract_tool_message_content(response):
|
|||||||
|
|
||||||
|
|
||||||
class TimeRetrievalInput(BaseModel):
|
class TimeRetrievalInput(BaseModel):
|
||||||
"""
|
"""时间检索工具的输入模式"""
|
||||||
Input schema for time retrieval tool
|
|
||||||
|
|
||||||
Defines the expected input parameters for time-based retrieval operations.
|
|
||||||
Used for validation and documentation of tool parameters.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
context: User input query content for search
|
|
||||||
end_user_id: Group ID for filtering search results, defaults to test user
|
|
||||||
"""
|
|
||||||
context: str = Field(description="用户输入的查询内容")
|
context: str = Field(description="用户输入的查询内容")
|
||||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||||
|
|
||||||
|
|
||||||
def create_time_retrieval_tool(end_user_id: str):
|
def create_time_retrieval_tool(end_user_id: str):
|
||||||
"""
|
"""
|
||||||
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
|
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||||
|
|
||||||
Creates a specialized time-based retrieval tool that searches for statements within
|
|
||||||
specified time ranges. Includes field cleaning functionality to remove unnecessary
|
|
||||||
metadata from search results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: User identifier for scoping search results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
function: Configured TimeRetrievalWithGroupId tool function
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def clean_temporal_result_fields(data):
|
def clean_temporal_result_fields(data):
|
||||||
"""
|
"""
|
||||||
Clean unnecessary fields from temporal search results and modify structure
|
清理时间搜索结果中不需要的字段,并修改结构
|
||||||
|
|
||||||
Removes metadata fields that are not needed for end-user consumption and
|
|
||||||
restructures the response format for better usability.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Data to be cleaned (dict, list, or other types)
|
data: 要清理的数据
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Cleaned data with unnecessary fields removed
|
清理后的数据
|
||||||
"""
|
"""
|
||||||
# List of fields to filter out
|
# 需要过滤的字段列表
|
||||||
fields_to_remove = {
|
fields_to_remove = {
|
||||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||||
'valid_at', 'invalid_at', 'statement_ids'
|
'valid_at', 'invalid_at', 'statement_ids'
|
||||||
@@ -110,9 +75,9 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
cleaned = {}
|
cleaned = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
||||||
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
|
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
|
||||||
cleaned_value = clean_temporal_result_fields(value)
|
cleaned_value = clean_temporal_result_fields(value)
|
||||||
# Further change internal statements to time_search
|
# 进一步将内部的 statements 改为 time_search
|
||||||
if 'statements' in cleaned_value:
|
if 'statements' in cleaned_value:
|
||||||
cleaned['results'] = {
|
cleaned['results'] = {
|
||||||
'time_search': cleaned_value['statements']
|
'time_search': cleaned_value['statements']
|
||||||
@@ -128,33 +93,24 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
|
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||||
end_user_id_param: str = None, clean_output: bool = True) -> str:
|
|
||||||
"""
|
"""
|
||||||
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
|
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||||
|
显式接收参数:
|
||||||
Performs time-based search operations with automatic metadata filtering. Supports
|
- context: 查询上下文内容
|
||||||
flexible date range specification and provides clean, user-friendly output.
|
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||||
|
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||||
Explicit parameters:
|
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||||
- context: Query context content
|
- clean_output: 是否清理输出中的元数据字段
|
||||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
|
||||||
- end_user_id_param: Group ID (optional, overrides default group ID)
|
|
||||||
- clean_output: Whether to clean metadata fields from output
|
|
||||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: JSON formatted search results with temporal data
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
# Use passed parameters or default values
|
# 使用传入的参数或默认值
|
||||||
actual_end_user_id = end_user_id_param or end_user_id
|
actual_end_user_id = end_user_id_param or end_user_id
|
||||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
# Basic time search
|
# 基本时间搜索
|
||||||
results = await search_by_temporal(
|
results = await search_by_temporal(
|
||||||
end_user_id=actual_end_user_id,
|
end_user_id=actual_end_user_id,
|
||||||
start_date=actual_start_date,
|
start_date=actual_start_date,
|
||||||
@@ -162,7 +118,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
limit=10
|
limit=10
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean unnecessary fields from results
|
# 清理结果中不需要的字段
|
||||||
if clean_output:
|
if clean_output:
|
||||||
cleaned_results = clean_temporal_result_fields(results)
|
cleaned_results = clean_temporal_result_fields(results)
|
||||||
else:
|
else:
|
||||||
@@ -173,32 +129,22 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
return asyncio.run(_async_search())
|
return asyncio.run(_async_search())
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
|
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
||||||
clean_output: bool = True) -> str:
|
|
||||||
"""
|
"""
|
||||||
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
|
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||||
|
显式接收参数:
|
||||||
Performs combined keyword and temporal search operations with automatic metadata
|
- context: 查询内容
|
||||||
filtering. Provides more targeted search results by combining content relevance
|
- days_back: 向前搜索的天数,默认7天
|
||||||
with time-based filtering.
|
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||||
|
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||||
Explicit parameters:
|
- clean_output: 是否清理输出中的元数据字段
|
||||||
- context: Query content for keyword matching
|
- end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||||
- days_back: Number of days to search backwards, default 7 days
|
|
||||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
|
||||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
|
||||||
- clean_output: Whether to clean metadata fields from output
|
|
||||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: JSON formatted search results combining keyword and temporal data
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
# Keyword time search
|
# 关键词时间搜索
|
||||||
results = await search_by_keyword_temporal(
|
results = await search_by_keyword_temporal(
|
||||||
query_text=context,
|
query_text=context,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -207,7 +153,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
limit=15
|
limit=15
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean unnecessary fields from results
|
# 清理结果中不需要的字段
|
||||||
if clean_output:
|
if clean_output:
|
||||||
cleaned_results = clean_temporal_result_fields(results)
|
cleaned_results = clean_temporal_result_fields(results)
|
||||||
else:
|
else:
|
||||||
@@ -222,53 +168,43 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
|
|
||||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||||
"""
|
"""
|
||||||
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
|
创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段
|
||||||
|
|
||||||
Creates an advanced hybrid search tool that combines multiple search strategies
|
|
||||||
(keyword, vector, hybrid) with automatic result cleaning and formatting.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_config: Memory configuration object containing LLM and search settings
|
memory_config: 内存配置对象
|
||||||
**search_params: Search parameters including end_user_id, limit, include, etc.
|
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||||
|
|
||||||
Returns:
|
|
||||||
function: Configured HybridSearch tool function with async capabilities
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def clean_result_fields(data):
|
def clean_result_fields(data):
|
||||||
"""
|
"""
|
||||||
Recursively clean unnecessary fields from results
|
递归清理结果中不需要的字段
|
||||||
|
|
||||||
Removes metadata fields that are not needed for end-user consumption,
|
|
||||||
improving readability and reducing response size.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Data to be cleaned (can be dict, list, or other types)
|
data: 要清理的数据(可能是字典、列表或其他类型)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Cleaned data with unnecessary fields removed
|
清理后的数据
|
||||||
"""
|
"""
|
||||||
# List of fields to filter out
|
# 需要过滤的字段列表
|
||||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
fields_to_remove = {
|
fields_to_remove = {
|
||||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
'created_at', 'chunk_id', 'apply_id',
|
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||||
}
|
}
|
||||||
# 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements
|
|
||||||
|
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
# Clean dictionary
|
# 对字典进行清理
|
||||||
cleaned = {}
|
cleaned = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key not in fields_to_remove:
|
if key not in fields_to_remove:
|
||||||
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
|
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
|
||||||
return cleaned
|
return cleaned
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
# Clean each element in list
|
# 对列表中的每个元素进行清理
|
||||||
return [clean_result_fields(item) for item in data]
|
return [clean_result_fields(item) for item in data]
|
||||||
else:
|
else:
|
||||||
# Return other types directly
|
# 其他类型直接返回
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -280,55 +216,49 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
rerank_alpha: float = 0.6,
|
rerank_alpha: float = 0.6,
|
||||||
use_forgetting_rerank: bool = False,
|
use_forgetting_rerank: bool = False,
|
||||||
use_llm_rerank: bool = False,
|
use_llm_rerank: bool = False,
|
||||||
clean_output: bool = True # New: whether to clean output fields
|
clean_output: bool = True # 新增:是否清理输出字段
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
|
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
|
||||||
|
|
||||||
Provides comprehensive search capabilities combining multiple search strategies
|
|
||||||
with intelligent result ranking and automatic metadata filtering for clean output.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: Query content for search
|
context: 查询内容
|
||||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||||
limit: Result quantity limit
|
limit: 结果数量限制
|
||||||
end_user_id: Group ID for filtering search results
|
end_user_id: 组ID,用于过滤搜索结果
|
||||||
rerank_alpha: Reranking weight parameter for result scoring
|
rerank_alpha: 重排序权重参数
|
||||||
use_forgetting_rerank: Whether to use forgetting-based reranking
|
use_forgetting_rerank: 是否使用遗忘重排序
|
||||||
use_llm_rerank: Whether to use LLM-based reranking
|
use_llm_rerank: 是否使用LLM重排序
|
||||||
clean_output: Whether to clean metadata fields from output
|
clean_output: 是否清理输出中的元数据字段
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: JSON formatted comprehensive search results
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Import run_hybrid_search function
|
# 导入run_hybrid_search函数
|
||||||
from app.core.memory.src.search import run_hybrid_search
|
from app.core.memory.src.search import run_hybrid_search
|
||||||
|
|
||||||
# Merge parameters, prioritize passed parameters
|
# 合并参数,优先使用传入的参数
|
||||||
final_params = {
|
final_params = {
|
||||||
"query_text": context,
|
"query_text": context,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||||
"limit": limit or search_params.get("limit", 10),
|
"limit": limit or search_params.get("limit", 10),
|
||||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]),
|
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||||
"output_path": None, # Don't save to file
|
"output_path": None, # 不保存到文件
|
||||||
"memory_config": memory_config,
|
"memory_config": memory_config,
|
||||||
"rerank_alpha": rerank_alpha,
|
"rerank_alpha": rerank_alpha,
|
||||||
"use_forgetting_rerank": use_forgetting_rerank,
|
"use_forgetting_rerank": use_forgetting_rerank,
|
||||||
"use_llm_rerank": use_llm_rerank
|
"use_llm_rerank": use_llm_rerank
|
||||||
}
|
}
|
||||||
|
|
||||||
# Execute hybrid retrieval
|
# 执行混合检索
|
||||||
raw_results = await run_hybrid_search(**final_params)
|
raw_results = await run_hybrid_search(**final_params)
|
||||||
|
|
||||||
# Clean unnecessary fields from results
|
# 清理结果中不需要的字段
|
||||||
if clean_output:
|
if clean_output:
|
||||||
cleaned_results = clean_result_fields(raw_results)
|
cleaned_results = clean_result_fields(raw_results)
|
||||||
else:
|
else:
|
||||||
cleaned_results = raw_results
|
cleaned_results = raw_results
|
||||||
|
|
||||||
# Format return results
|
# 格式化返回结果
|
||||||
formatted_results = {
|
formatted_results = {
|
||||||
"search_query": context,
|
"search_query": context,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
@@ -351,19 +281,12 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
|
|
||||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||||
"""
|
"""
|
||||||
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
|
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
|
||||||
|
|
||||||
Creates a synchronous wrapper around the async hybrid search functionality,
|
|
||||||
making it compatible with synchronous tool execution environments.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_config: Memory configuration object containing search settings
|
memory_config: 内存配置对象
|
||||||
**search_params: Search parameters for configuration
|
**search_params: 搜索参数
|
||||||
|
|
||||||
Returns:
|
|
||||||
function: Configured HybridSearchSync tool function
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def HybridSearchSync(
|
def HybridSearchSync(
|
||||||
context: str,
|
context: str,
|
||||||
@@ -373,24 +296,17 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
|||||||
clean_output: bool = True
|
clean_output: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
|
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
|
||||||
|
|
||||||
Provides the same hybrid search capabilities as the async version but in a
|
|
||||||
synchronous execution context. Automatically handles async-to-sync conversion.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: Query content for search
|
context: 查询内容
|
||||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||||
limit: Result quantity limit
|
limit: 结果数量限制
|
||||||
end_user_id: Group ID for filtering search results
|
end_user_id: 组ID,用于过滤搜索结果
|
||||||
clean_output: Whether to clean metadata fields from output
|
clean_output: 是否清理输出中的元数据字段
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: JSON formatted search results
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
# Create async tool and execute
|
# 创建异步工具并执行
|
||||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||||
return await async_tool.ainvoke({
|
return await async_tool.ainvoke({
|
||||||
"context": context,
|
"context": context,
|
||||||
|
|||||||
@@ -1,24 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
|
||||||
|
|
||||||
async def format_parsing(messages: list,type:str='string'):
|
async def format_parsing(messages: list,type:str='string'):
|
||||||
"""
|
"""
|
||||||
Format and parse message lists into different output types
|
格式化解析消息列表
|
||||||
|
|
||||||
Processes message lists from storage and converts them into either string format
|
|
||||||
or dictionary format based on the specified type parameter. Handles JSON parsing
|
|
||||||
and role-based message organization.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of message objects from storage containing message data
|
messages: 消息列表
|
||||||
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
|
type: 返回类型 ('string' 或 'dict')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: Formatted message list in the specified format
|
格式化后的消息列表
|
||||||
- 'string': List of formatted text messages with role prefixes
|
|
||||||
- 'dict': List of dictionaries mapping user messages to AI responses
|
|
||||||
"""
|
"""
|
||||||
result = []
|
result = []
|
||||||
user=[]
|
user=[]
|
||||||
@@ -47,20 +39,7 @@ async def format_parsing(messages: list, type: str = 'string'):
|
|||||||
result.append({key:values})
|
result.append({key:values})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def messages_parse(messages: list | dict):
|
async def messages_parse(messages: list | dict):
|
||||||
"""
|
|
||||||
Parse messages from storage format into user-AI conversation pairs
|
|
||||||
|
|
||||||
Extracts and organizes conversation data from stored message format,
|
|
||||||
separating user and AI messages and pairing them for database storage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List or dictionary containing stored message data with Query fields
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of dictionaries containing user-AI message pairs for database storage
|
|
||||||
"""
|
|
||||||
user=[]
|
user=[]
|
||||||
ai=[]
|
ai=[]
|
||||||
database=[]
|
database=[]
|
||||||
@@ -79,19 +58,6 @@ async def messages_parse(messages: list | dict):
|
|||||||
|
|
||||||
|
|
||||||
async def agent_chat_messages(user_content,ai_content):
|
async def agent_chat_messages(user_content,ai_content):
|
||||||
"""
|
|
||||||
Create structured chat message format for agent conversations
|
|
||||||
|
|
||||||
Formats user and AI content into a standardized message structure suitable
|
|
||||||
for agent processing and storage. Creates role-based message objects.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_content: User's message content string
|
|
||||||
ai_content: AI's response content string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of structured message dictionaries with role and content fields
|
|
||||||
"""
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|||||||
@@ -1,106 +1,103 @@
|
|||||||
import warnings
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from langgraph.constants import END, START
|
||||||
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
|
from app.db import get_db, get_db_context
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
aggregate_judgment
|
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_konwledges_server import write_rag
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
if sys.platform.startswith("win"):
|
||||||
async def long_term_storage(
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
long_term_type: str,
|
@asynccontextmanager
|
||||||
langchain_messages: list,
|
async def make_write_graph():
|
||||||
memory_config_id: str,
|
|
||||||
end_user_id: str,
|
|
||||||
scope: int = 6
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Handle long-term memory storage with different strategies
|
Create a write graph workflow for memory operations.
|
||||||
|
|
||||||
Supports multiple storage strategies including chunk-based, time-based,
|
|
||||||
and aggregate judgment approaches for long-term memory persistence.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
user_id: User identifier
|
||||||
langchain_messages: List of messages to store
|
tools: MCP tools loaded from session
|
||||||
memory_config_id: Memory configuration identifier
|
apply_id: Application identifier
|
||||||
end_user_id: User group identifier
|
end_user_id: Group identifier
|
||||||
scope: Scope parameter for chunk-based storage (default: 6)
|
memory_config: MemoryConfig object containing all configuration
|
||||||
"""
|
"""
|
||||||
if langchain_messages is None:
|
workflow = StateGraph(WriteState)
|
||||||
langchain_messages = []
|
workflow.add_node("save_neo4j", write_node)
|
||||||
|
workflow.add_edge(START, "save_neo4j")
|
||||||
|
workflow.add_edge("save_neo4j", END)
|
||||||
|
|
||||||
write_store.save_session_write(end_user_id, langchain_messages)
|
graph = workflow.compile()
|
||||||
|
|
||||||
|
yield graph
|
||||||
|
|
||||||
|
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||||
|
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||||
# 获取数据库会话
|
# 获取数据库会话
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
config_service = MemoryConfigService(db_session)
|
config_service = MemoryConfigService(db_session)
|
||||||
# 通过 end_user_id 获取 workspace_id,确保日志和 fallback 逻辑完整
|
|
||||||
from app.services.memory_agent_service import get_end_user_connected_config
|
|
||||||
import uuid as _uuid
|
|
||||||
workspace_id = None
|
|
||||||
try:
|
|
||||||
connected = get_end_user_connected_config(end_user_id, db_session)
|
|
||||||
raw = connected.get("workspace_id")
|
|
||||||
if raw and raw != "None":
|
|
||||||
workspace_id = _uuid.UUID(str(raw))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
memory_config = config_service.load_memory_config(
|
memory_config = config_service.load_memory_config(
|
||||||
config_id=memory_config_id,
|
config_id=memory_config, # 改为整数
|
||||||
workspace_id=workspace_id,
|
|
||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
if long_term_type=='chunk':
|
||||||
# Dialogue window with 6 rounds of conversation
|
'''方案一:对话窗口6轮对话'''
|
||||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
if long_term_type=='time':
|
||||||
# Time-based strategy
|
"""时间"""
|
||||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
if long_term_type=='aggregate':
|
||||||
# Aggregate judgment
|
"""方案三:聚合判断"""
|
||||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||||
|
|
||||||
|
|
||||||
async def write_long_term(
|
|
||||||
storage_type: str,
|
|
||||||
end_user_id: str,
|
|
||||||
messages: list[dict],
|
|
||||||
user_rag_memory_id: str,
|
|
||||||
actual_config_id: str
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Write long-term memory with different storage types
|
|
||||||
|
|
||||||
Handles both RAG-based storage and traditional memory storage approaches.
|
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||||
For traditional storage, uses chunk-based strategy with paired user-AI messages.
|
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_type: Type of storage (RAG or traditional)
|
|
||||||
end_user_id: User group identifier
|
|
||||||
messages: message list
|
|
||||||
user_rag_memory_id: RAG memory identifier
|
|
||||||
actual_config_id: Actual configuration ID
|
|
||||||
"""
|
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||||
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||||
message_content = []
|
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||||
for message in messages:
|
|
||||||
message_content.append(f'{message.get("role")}:{message.get("content")}')
|
|
||||||
messages_string = "\n".join(message_content)
|
|
||||||
await write_rag(end_user_id, messages_string, user_rag_memory_id)
|
|
||||||
else:
|
else:
|
||||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||||
await long_term_storage(long_term_type=CHUNK,
|
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||||
langchain_messages=messages,
|
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||||
memory_config_id=actual_config_id,
|
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||||
end_user_id=end_user_id,
|
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||||
scope=SCOPE)
|
|
||||||
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
|
# async def main():
|
||||||
|
# """主函数 - 运行工作流"""
|
||||||
|
# langchain_messages = [
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": "今天周五去爬山"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "role": "assistant",
|
||||||
|
# "content": "好耶"
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# ]
|
||||||
|
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||||
|
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||||
|
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# import asyncio
|
||||||
|
# asyncio.run(main())
|
||||||
@@ -15,7 +15,7 @@ class ParameterBuilder:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize the parameter builder."""
|
"""Initialize the parameter builder."""
|
||||||
logger.debug("ParameterBuilder initialized")
|
logger.info("ParameterBuilder initialized")
|
||||||
|
|
||||||
def build_tool_args(
|
def build_tool_args(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -7,88 +7,21 @@ and deduplication.
|
|||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.enums import Neo4jNodeType
|
|
||||||
from app.core.memory.src.search import run_hybrid_search
|
from app.core.memory.src.search import run_hybrid_search
|
||||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||||
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
|
||||||
_EXPAND_FIELDS_TO_REMOVE = {
|
|
||||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
|
||||||
'created_at', 'chunk_id', 'apply_id',
|
|
||||||
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _clean_expand_fields(obj):
|
|
||||||
"""递归过滤展开结果中不可序列化的字段(DateTime 等)。"""
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE}
|
|
||||||
if isinstance(obj, list):
|
|
||||||
return [_clean_expand_fields(i) for i in obj]
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
async def expand_communities_to_statements(
|
|
||||||
community_results: List[dict],
|
|
||||||
end_user_id: str,
|
|
||||||
existing_content: str = "",
|
|
||||||
limit: int = 10,
|
|
||||||
) -> Tuple[List[dict], List[str]]:
|
|
||||||
"""
|
|
||||||
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
|
||||||
|
|
||||||
- 对展开结果去重(过滤已在 existing_content 中出现的文本)
|
|
||||||
- 过滤不可序列化字段
|
|
||||||
- 返回 (cleaned_expanded_stmts, new_texts)
|
|
||||||
- cleaned_expanded_stmts: 可直接写回 raw_results 的列表
|
|
||||||
- new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content
|
|
||||||
"""
|
|
||||||
community_ids = [r.get("id") for r in community_results if r.get("id")]
|
|
||||||
if not community_ids or not end_user_id:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
try:
|
|
||||||
result = await search_graph_community_expand(
|
|
||||||
connector=connector,
|
|
||||||
community_ids=community_ids,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}")
|
|
||||||
return [], []
|
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
expanded_stmts = result.get("expanded_statements", [])
|
|
||||||
if not expanded_stmts:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
existing_lines = set(existing_content.splitlines())
|
|
||||||
new_texts = [
|
|
||||||
s["statement"] for s in expanded_stmts
|
|
||||||
if s.get("statement") and s["statement"] not in existing_lines
|
|
||||||
]
|
|
||||||
cleaned = _clean_expand_fields(expanded_stmts)
|
|
||||||
logger.info(
|
|
||||||
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
|
||||||
return cleaned, new_texts
|
|
||||||
|
|
||||||
|
|
||||||
class SearchService:
|
class SearchService:
|
||||||
"""Service for executing hybrid search and processing results."""
|
"""Service for executing hybrid search and processing results."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize the search service."""
|
"""Initialize the search service."""
|
||||||
logger.debug("SearchService initialized")
|
logger.info("SearchService initialized")
|
||||||
|
|
||||||
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
|
def extract_content_from_result(self, result: dict) -> str:
|
||||||
"""
|
"""
|
||||||
Extract only meaningful content from search results, dropping all metadata.
|
Extract only meaningful content from search results, dropping all metadata.
|
||||||
|
|
||||||
@@ -97,11 +30,9 @@ class SearchService:
|
|||||||
- Entities: extract 'name' and 'fact_summary' fields
|
- Entities: extract 'name' and 'fact_summary' fields
|
||||||
- Summaries: extract 'content' field
|
- Summaries: extract 'content' field
|
||||||
- Chunks: extract 'content' field
|
- Chunks: extract 'content' field
|
||||||
- Communities: extract 'content' field (c.summary), prefixed with community name
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
result: Search result dictionary
|
result: Search result dictionary
|
||||||
node_type: Hint for node type ("community", "summary", etc.)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Clean content string without metadata
|
Clean content string without metadata
|
||||||
@@ -112,24 +43,11 @@ class SearchService:
|
|||||||
content_parts = []
|
content_parts = []
|
||||||
|
|
||||||
# Statements: extract statement field
|
# Statements: extract statement field
|
||||||
if Neo4jNodeType.STATEMENT in result and result[Neo4jNodeType.STATEMENT]:
|
if 'statement' in result and result['statement']:
|
||||||
content_parts.append(result[Neo4jNodeType.STATEMENT])
|
content_parts.append(result['statement'])
|
||||||
|
|
||||||
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
# Summaries/Chunks: extract content field
|
||||||
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
if 'content' in result and result['content']:
|
||||||
is_community = (
|
|
||||||
node_type == Neo4jNodeType.COMMUNITY
|
|
||||||
or 'member_count' in result
|
|
||||||
or 'core_entities' in result
|
|
||||||
)
|
|
||||||
if is_community:
|
|
||||||
name = result.get('name', '')
|
|
||||||
content = result.get('content', '')
|
|
||||||
if content:
|
|
||||||
prefix = f"[主题:{name}] " if name else ""
|
|
||||||
content_parts.append(f"{prefix}{content}")
|
|
||||||
elif 'content' in result and result['content']:
|
|
||||||
# Summaries / Chunks
|
|
||||||
content_parts.append(result['content'])
|
content_parts.append(result['content'])
|
||||||
|
|
||||||
# Entities: extract name and fact_summary (commented out in original)
|
# Entities: extract name and fact_summary (commented out in original)
|
||||||
@@ -181,8 +99,7 @@ class SearchService:
|
|||||||
rerank_alpha: float = 0.4,
|
rerank_alpha: float = 0.4,
|
||||||
output_path: str = "search_results.json",
|
output_path: str = "search_results.json",
|
||||||
return_raw_results: bool = False,
|
return_raw_results: bool = False,
|
||||||
memory_config=None,
|
memory_config = None
|
||||||
expand_communities: bool = True,
|
|
||||||
) -> Tuple[str, str, Optional[dict]]:
|
) -> Tuple[str, str, Optional[dict]]:
|
||||||
"""
|
"""
|
||||||
Execute hybrid search and return clean content.
|
Execute hybrid search and return clean content.
|
||||||
@@ -197,15 +114,13 @@ class SearchService:
|
|||||||
output_path: Path to save search results (default: "search_results.json")
|
output_path: Path to save search results (default: "search_results.json")
|
||||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||||
memory_config: Memory configuration object (required)
|
memory_config: Memory configuration object (required)
|
||||||
expand_communities: If True, expand community hits to member statements (default: True).
|
|
||||||
Set to False for quick-summary paths that only need community-level text.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (clean_content, cleaned_query, raw_results)
|
Tuple of (clean_content, cleaned_query, raw_results)
|
||||||
raw_results is None if return_raw_results=False
|
raw_results is None if return_raw_results=False
|
||||||
"""
|
"""
|
||||||
if include is None:
|
if include is None:
|
||||||
include = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
include = ["statements", "chunks", "entities", "summaries"]
|
||||||
|
|
||||||
# Clean query
|
# Clean query
|
||||||
cleaned_query = self.clean_query(question)
|
cleaned_query = self.clean_query(question)
|
||||||
@@ -231,8 +146,8 @@ class SearchService:
|
|||||||
if search_type == "hybrid":
|
if search_type == "hybrid":
|
||||||
reranked_results = answer.get('reranked_results', {})
|
reranked_results = answer.get('reranked_results', {})
|
||||||
|
|
||||||
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
||||||
priority_order = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||||
|
|
||||||
for category in priority_order:
|
for category in priority_order:
|
||||||
if category in include and category in reranked_results:
|
if category in include and category in reranked_results:
|
||||||
@@ -242,7 +157,7 @@ class SearchService:
|
|||||||
else:
|
else:
|
||||||
# For keyword or embedding search, results are directly in answer dict
|
# For keyword or embedding search, results are directly in answer dict
|
||||||
# Apply same priority order
|
# Apply same priority order
|
||||||
priority_order = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||||
|
|
||||||
for category in priority_order:
|
for category in priority_order:
|
||||||
if category in include and category in answer:
|
if category in include and category in answer:
|
||||||
@@ -250,25 +165,12 @@ class SearchService:
|
|||||||
if isinstance(category_results, list):
|
if isinstance(category_results, list):
|
||||||
answer_list.extend(category_results)
|
answer_list.extend(category_results)
|
||||||
|
|
||||||
# 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要)
|
# Extract clean content from all results
|
||||||
if expand_communities and Neo4jNodeType.COMMUNITY in include:
|
content_list = [
|
||||||
community_results = (
|
self.extract_content_from_result(ans)
|
||||||
answer.get('reranked_results', {}).get(Neo4jNodeType.COMMUNITY.value, [])
|
for ans in answer_list
|
||||||
if search_type == "hybrid"
|
]
|
||||||
else answer.get(Neo4jNodeType.COMMUNITY.value, [])
|
|
||||||
)
|
|
||||||
cleaned_stmts, new_texts = await expand_communities_to_statements(
|
|
||||||
community_results=community_results,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
)
|
|
||||||
answer_list.extend(cleaned_stmts)
|
|
||||||
|
|
||||||
# Extract clean content from all results,按类型传入 node_type 区分 community
|
|
||||||
content_list = []
|
|
||||||
for ans in answer_list:
|
|
||||||
# community 节点有 member_count 或 core_entities 字段
|
|
||||||
ntype = Neo4jNodeType.COMMUNITY if ('member_count' in ans or 'core_entities' in ans) else ""
|
|
||||||
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
|
||||||
|
|
||||||
# Filter out empty strings and join with newlines
|
# Filter out empty strings and join with newlines
|
||||||
clean_content = '\n'.join([c for c in content_list if c])
|
clean_content = '\n'.join([c for c in content_list if c])
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class SessionService:
|
|||||||
store: Redis session store instance
|
store: Redis session store instance
|
||||||
"""
|
"""
|
||||||
self.store = store
|
self.store = store
|
||||||
logger.debug("SessionService initialized")
|
logger.info("SessionService initialized")
|
||||||
|
|
||||||
def resolve_user_id(self, session_string: str) -> str:
|
def resolve_user_id(self, session_string: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class TemplateService:
|
|||||||
loader=FileSystemLoader(template_root),
|
loader=FileSystemLoader(template_root),
|
||||||
autoescape=False # Disable autoescape for prompt templates
|
autoescape=False # Disable autoescape for prompt templates
|
||||||
)
|
)
|
||||||
logger.debug(f"TemplateService initialized with root: {template_root}")
|
logger.info(f"TemplateService initialized with root: {template_root}")
|
||||||
|
|
||||||
@lru_cache(maxsize=128)
|
@lru_cache(maxsize=128)
|
||||||
def _load_template(self, template_name: str) -> Template:
|
def _load_template(self, template_name: str) -> Template:
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||||
@@ -8,20 +11,17 @@ async def get_chunked_dialogs(
|
|||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
end_user_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
messages: list = None,
|
messages: list = None,
|
||||||
ref_id: str = "",
|
ref_id: str = "wyl_20251027",
|
||||||
config_id: str = None,
|
config_id: str = None
|
||||||
workspace_id=None,
|
|
||||||
snapshot=None,
|
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||||
end_user_id: Group identifier
|
end_user_id: Group identifier
|
||||||
messages: Structured message list [{"role": "user", "content": "...", "dialog_at": "..."}]
|
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||||
ref_id: Reference identifier
|
ref_id: Reference identifier
|
||||||
config_id: Configuration ID for processing (used to load pruning config)
|
config_id: Configuration ID for processing
|
||||||
snapshot: Optional PipelineSnapshot instance for saving pruning output
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of DialogData objects with generated chunks
|
List of DialogData objects with generated chunks
|
||||||
@@ -34,25 +34,18 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
conversation_messages = []
|
conversation_messages = []
|
||||||
|
|
||||||
# step1: 消息格式校验 role:user、assistant。content
|
|
||||||
for idx, msg in enumerate(messages):
|
for idx, msg in enumerate(messages):
|
||||||
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
||||||
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
||||||
|
|
||||||
role = msg['role']
|
role = msg['role']
|
||||||
content = msg['content']
|
content = msg['content']
|
||||||
files = msg.get("file_content", [])
|
|
||||||
|
|
||||||
if role not in ['user', 'assistant']:
|
if role not in ['user', 'assistant']:
|
||||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||||
|
|
||||||
if content.strip():
|
if content.strip():
|
||||||
conversation_messages.append(ConversationMessage(
|
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||||
role=role,
|
|
||||||
msg=content.strip(),
|
|
||||||
dialog_at=msg.get("dialog_at"),
|
|
||||||
files=files,
|
|
||||||
))
|
|
||||||
|
|
||||||
if not conversation_messages:
|
if not conversation_messages:
|
||||||
raise ValueError("Message list cannot be empty after filtering")
|
raise ValueError("Message list cannot be empty after filtering")
|
||||||
@@ -62,75 +55,9 @@ async def get_chunked_dialogs(
|
|||||||
context=conversation_context,
|
context=conversation_context,
|
||||||
ref_id=ref_id,
|
ref_id=ref_id,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
config_id=config_id,
|
config_id=config_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# step2: 语义剪枝步骤(在分块之前)
|
|
||||||
try:
|
|
||||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
|
||||||
from app.core.memory.models.config_models import PruningConfig
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|
||||||
|
|
||||||
# 加载剪枝配置
|
|
||||||
pruning_config = None
|
|
||||||
if config_id:
|
|
||||||
try:
|
|
||||||
with get_db_context() as db:
|
|
||||||
# 使用 MemoryConfigService 加载完整的 MemoryConfig 对象
|
|
||||||
config_service = MemoryConfigService(db)
|
|
||||||
memory_config = config_service.load_memory_config(
|
|
||||||
config_id=config_id,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
service_name="semantic_pruning"
|
|
||||||
)
|
|
||||||
|
|
||||||
if memory_config:
|
|
||||||
pruning_config = PruningConfig(
|
|
||||||
pruning_switch=memory_config.pruning_enabled,
|
|
||||||
pruning_scene=memory_config.pruning_scene or "education",
|
|
||||||
pruning_threshold=memory_config.pruning_threshold,
|
|
||||||
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
|
||||||
ontology_class_infos=memory_config.ontology_class_infos,
|
|
||||||
)
|
|
||||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
|
||||||
|
|
||||||
# 获取LLM客户端用于剪枝
|
|
||||||
if pruning_config.pruning_switch:
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
|
||||||
|
|
||||||
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
|
||||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client, snapshot=snapshot)
|
|
||||||
original_msg_count = len(dialog_data.context.msgs)
|
|
||||||
|
|
||||||
# 使用 prune_dataset 而不是 prune_dialog
|
|
||||||
# prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息
|
|
||||||
pruned_dialogs = await pruner.prune_dataset([dialog_data])
|
|
||||||
|
|
||||||
if pruned_dialogs:
|
|
||||||
dialog_data = pruned_dialogs[0]
|
|
||||||
remaining_msg_count = len(dialog_data.context.msgs)
|
|
||||||
deleted_count = original_msg_count - remaining_msg_count
|
|
||||||
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
|
|
||||||
|
|
||||||
# 将剪枝记录挂到 metadata,供 graph_build_step 构建节点
|
|
||||||
if pruner.pruning_records:
|
|
||||||
dialog_data.metadata["assistant_pruning_records"] = [
|
|
||||||
r.model_dump() for r in pruner.pruning_records
|
|
||||||
]
|
|
||||||
logger.info(f"[剪枝] 收集到 {len(pruner.pruning_records)} 条剪枝记录")
|
|
||||||
else:
|
|
||||||
logger.warning("[剪枝] prune_dataset 返回空列表")
|
|
||||||
else:
|
|
||||||
logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
|
|
||||||
|
|
||||||
# step3: 分块
|
|
||||||
chunker = DialogueChunker(chunker_strategy)
|
chunker = DialogueChunker(chunker_strategy)
|
||||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||||
dialog_data.chunks = extracted_chunks
|
dialog_data.chunks = extracted_chunks
|
||||||
|
|||||||
56
api/app/core/memory/agent/utils/llm_client_pool.py
Normal file
56
api/app/core/memory/agent/utils/llm_client_pool.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
|
||||||
|
from app.db import get_db
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
class LLMClientPool:
|
||||||
|
"""LLM客户端连接池"""
|
||||||
|
|
||||||
|
def __init__(self, max_size: int = 5):
|
||||||
|
self.max_size = max_size
|
||||||
|
self.pools: Dict[str, asyncio.Queue] = {}
|
||||||
|
self.active_clients: Dict[str, int] = {}
|
||||||
|
|
||||||
|
async def get_client(self, llm_model_id: str):
|
||||||
|
"""获取LLM客户端"""
|
||||||
|
if llm_model_id not in self.pools:
|
||||||
|
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
|
||||||
|
self.active_clients[llm_model_id] = 0
|
||||||
|
|
||||||
|
pool = self.pools[llm_model_id]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试从池中获取客户端
|
||||||
|
client = pool.get_nowait()
|
||||||
|
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
|
||||||
|
return client
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
# 池为空,创建新客户端
|
||||||
|
if self.active_clients[llm_model_id] < self.max_size:
|
||||||
|
db_session = next(get_db())
|
||||||
|
client = get_llm_client_fast(llm_model_id, db_session)
|
||||||
|
self.active_clients[llm_model_id] += 1
|
||||||
|
logger.debug(f"创建新LLM客户端: {llm_model_id}")
|
||||||
|
return client
|
||||||
|
else:
|
||||||
|
# 等待可用客户端
|
||||||
|
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
|
||||||
|
return await pool.get()
|
||||||
|
|
||||||
|
async def return_client(self, llm_model_id: str, client):
|
||||||
|
"""归还LLM客户端到池中"""
|
||||||
|
if llm_model_id in self.pools:
|
||||||
|
try:
|
||||||
|
self.pools[llm_model_id].put_nowait(client)
|
||||||
|
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
# 池已满,丢弃客户端
|
||||||
|
self.active_clients[llm_model_id] -= 1
|
||||||
|
logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}")
|
||||||
|
|
||||||
|
# 全局客户端池
|
||||||
|
llm_client_pool = LLMClientPool()
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, TypedDict
|
from typing import Annotated, TypedDict
|
||||||
@@ -7,19 +8,16 @@ from langgraph.graph import add_messages
|
|||||||
|
|
||||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||||
|
|
||||||
|
|
||||||
class WriteState(TypedDict):
|
class WriteState(TypedDict):
|
||||||
"""
|
'''
|
||||||
Langgrapg Writing TypedDict
|
Langgrapg Writing TypedDict
|
||||||
"""
|
'''
|
||||||
messages: Annotated[list[AnyMessage], add_messages]
|
messages: Annotated[list[AnyMessage], add_messages]
|
||||||
end_user_id: str
|
end_user_id: str
|
||||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||||
memory_config: object
|
memory_config: object
|
||||||
write_result: dict
|
write_result: dict
|
||||||
data: str
|
data: str
|
||||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
|
||||||
|
|
||||||
|
|
||||||
class ReadState(TypedDict):
|
class ReadState(TypedDict):
|
||||||
"""
|
"""
|
||||||
@@ -51,14 +49,11 @@ class ReadState(TypedDict):
|
|||||||
embedding_id: str
|
embedding_id: str
|
||||||
memory_config: object # 新增字段用于传递内存配置对象
|
memory_config: object # 新增字段用于传递内存配置对象
|
||||||
retrieve:dict
|
retrieve:dict
|
||||||
perceptual_data: dict
|
|
||||||
RetrieveSummary: dict
|
RetrieveSummary: dict
|
||||||
InputSummary: dict
|
InputSummary: dict
|
||||||
verify: dict
|
verify: dict
|
||||||
SummaryFails: dict
|
SummaryFails: dict
|
||||||
summary: dict
|
summary: dict
|
||||||
|
|
||||||
|
|
||||||
class COUNTState:
|
class COUNTState:
|
||||||
"""
|
"""
|
||||||
工作流对话检索内容计数器
|
工作流对话检索内容计数器
|
||||||
@@ -103,7 +98,6 @@ class COUNTState:
|
|||||||
self.total = 0
|
self.total = 0
|
||||||
print("[COUNTState] 已重置为 0")
|
print("[COUNTState] 已重置为 0")
|
||||||
|
|
||||||
|
|
||||||
def deduplicate_entries(entries):
|
def deduplicate_entries(entries):
|
||||||
seen = set()
|
seen = set()
|
||||||
deduped = []
|
deduped = []
|
||||||
@@ -114,7 +108,6 @@ def deduplicate_entries(entries):
|
|||||||
deduped.append(entry)
|
deduped.append(entry)
|
||||||
return deduped
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||||
grouped = defaultdict(list)
|
grouped = defaultdict(list)
|
||||||
for item in data:
|
for item in data:
|
||||||
|
|||||||
@@ -39,30 +39,6 @@
|
|||||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||||
|
|
||||||
## 指代消歧规则(Coreference Resolution):
|
|
||||||
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
|
||||||
|
|
||||||
1. **"用户"的消歧**:
|
|
||||||
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
|
||||||
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物,则"用户"指的就是这个人
|
|
||||||
- 示例:历史中有"老李的原名叫李建国",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
|
||||||
|
|
||||||
2. **"我"的消歧**:
|
|
||||||
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
|
||||||
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
|
||||||
|
|
||||||
3. **"他/她/它"的消歧**:
|
|
||||||
- 从上下文或历史中找出最近提到的同类实体
|
|
||||||
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
|
||||||
|
|
||||||
4. **"那个人/这个人"的消歧**:
|
|
||||||
- 从历史中找出最近提到的人物
|
|
||||||
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
|
||||||
|
|
||||||
5. **优先级**:
|
|
||||||
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
|
||||||
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
输出要求:
|
输出要求:
|
||||||
@@ -95,34 +71,6 @@
|
|||||||
"reason": "输出原问题的关键要素"
|
"reason": "输出原问题的关键要素"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
## 指代消歧示例(重要):
|
|
||||||
示例1 - "用户"的消歧:
|
|
||||||
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
|
||||||
输入问题:"用户是谁?"
|
|
||||||
输出:
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"original_question": "用户是谁?",
|
|
||||||
"extended_question": "李建国是谁?",
|
|
||||||
"type": "单跳",
|
|
||||||
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
示例2 - "我"的消歧:
|
|
||||||
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
|
||||||
输入问题:"我推荐的书是什么?"
|
|
||||||
输出:
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"original_question": "我推荐的书是什么?",
|
|
||||||
"extended_question": "张曼玉推荐的书是什么?",
|
|
||||||
"type": "单跳",
|
|
||||||
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
**Output format**
|
**Output format**
|
||||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||||
|
|||||||
@@ -27,30 +27,6 @@
|
|||||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||||
|
|
||||||
## 指代消歧规则(Coreference Resolution):
|
|
||||||
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
|
||||||
|
|
||||||
1. **"用户"的消歧**:
|
|
||||||
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
|
||||||
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物(如"老李"、"李建国"),则"用户"指的就是这个人
|
|
||||||
- 示例:历史中反复出现"老李/李建国/建国哥",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
|
||||||
|
|
||||||
2. **"我"的消歧**:
|
|
||||||
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
|
||||||
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
|
||||||
|
|
||||||
3. **"他/她/它"的消歧**:
|
|
||||||
- 从上下文或历史中找出最近提到的同类实体
|
|
||||||
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
|
||||||
|
|
||||||
4. **"那个人/这个人"的消歧**:
|
|
||||||
- 从历史中找出最近提到的人物
|
|
||||||
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
|
||||||
|
|
||||||
5. **优先级**:
|
|
||||||
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
|
||||||
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
|
||||||
|
|
||||||
## 指令:
|
## 指令:
|
||||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||||
单跳(Single-hop)
|
单跳(Single-hop)
|
||||||
@@ -175,34 +151,6 @@
|
|||||||
]
|
]
|
||||||
- 必须通过json.loads()的格式支持的形式输出
|
- 必须通过json.loads()的格式支持的形式输出
|
||||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||||
|
|
||||||
## 指代消歧示例(重要):
|
|
||||||
示例1 - "用户"的消歧:
|
|
||||||
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
|
||||||
输入问题:"用户是谁?"
|
|
||||||
输出:
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": "Q1",
|
|
||||||
"question": "李建国是谁?",
|
|
||||||
"type": "单跳",
|
|
||||||
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
示例2 - "我"的消歧:
|
|
||||||
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
|
||||||
输入问题:"我推荐的书是什么?"
|
|
||||||
输出:
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": "Q1",
|
|
||||||
"question": "张曼玉推荐的书是什么?",
|
|
||||||
"type": "单跳",
|
|
||||||
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
- 关键的JSON格式要求
|
- 关键的JSON格式要求
|
||||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import uuid
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from typing import List, Dict, Any, Optional, Union
|
from typing import List, Dict, Any, Optional, Union
|
||||||
|
|
||||||
from app.core.logging_config import get_logger
|
|
||||||
from app.core.memory.agent.utils.redis_base import (
|
from app.core.memory.agent.utils.redis_base import (
|
||||||
serialize_messages,
|
serialize_messages,
|
||||||
deserialize_messages,
|
deserialize_messages,
|
||||||
@@ -15,7 +14,7 @@ from app.core.memory.agent.utils.redis_base import (
|
|||||||
get_current_timestamp
|
get_current_timestamp
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RedisWriteStore:
|
class RedisWriteStore:
|
||||||
@@ -67,10 +66,10 @@ class RedisWriteStore:
|
|||||||
})
|
})
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[save_session_write] 保存会话失败: {e}")
|
print(f"[save_session_write] 保存会话失败: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||||
@@ -113,10 +112,10 @@ class RedisWriteStore:
|
|||||||
if not results:
|
if not results:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[get_session_by_userid] 查询失败: {e}")
|
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||||
@@ -145,7 +144,7 @@ class RedisWriteStore:
|
|||||||
# 只查询 write 类型的 key
|
# 只查询 write 类型的 key
|
||||||
keys = self.r.keys('session:write:*')
|
keys = self.r.keys('session:write:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -176,16 +175,18 @@ class RedisWriteStore:
|
|||||||
results.append(session_info)
|
results.append(session_info)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 按时间排序(最新的在前)
|
# 按时间排序(最新的在前)
|
||||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||||
|
|
||||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
|
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def find_user_recent_sessions(self, userid: str,
|
def find_user_recent_sessions(self, userid: str,
|
||||||
@@ -206,7 +207,7 @@ class RedisWriteStore:
|
|||||||
# 只查询 write 类型的 key
|
# 只查询 write 类型的 key
|
||||||
keys = self.r.keys('session:write:*')
|
keys = self.r.keys('session:write:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -233,10 +234,11 @@ class RedisWriteStore:
|
|||||||
# 根据时间范围过滤
|
# 根据时间范围过滤
|
||||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||||
# 排序并移除时间字段
|
# 排序并移除时间字段
|
||||||
result_items = sort_and_limit_results(filtered_items)
|
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||||
|
print(result_items)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||||
|
|
||||||
return result_items
|
return result_items
|
||||||
@@ -276,7 +278,7 @@ class RedisCountStore:
|
|||||||
decode_responses=True,
|
decode_responses=True,
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
self.uuid = session_id
|
self.uudi = session_id
|
||||||
|
|
||||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -296,7 +298,7 @@ class RedisCountStore:
|
|||||||
|
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
pipe.hset(key, mapping={
|
pipe.hset(key, mapping={
|
||||||
"id": self.uuid,
|
"id": self.uudi,
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"count": int(count),
|
"count": int(count),
|
||||||
"messages": serialize_messages(messages),
|
"messages": serialize_messages(messages),
|
||||||
@@ -309,10 +311,10 @@ class RedisCountStore:
|
|||||||
|
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
|
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||||
"""
|
"""
|
||||||
通过 end_user_id 查询访问次数统计
|
通过 end_user_id 查询访问次数统计
|
||||||
|
|
||||||
@@ -333,7 +335,7 @@ class RedisCountStore:
|
|||||||
self.r.delete(index_key)
|
self.r.delete(index_key)
|
||||||
return False
|
return False
|
||||||
except Exception as type_error:
|
except Exception as type_error:
|
||||||
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||||
|
|
||||||
session_id = self.r.get(index_key)
|
session_id = self.r.get(index_key)
|
||||||
|
|
||||||
@@ -353,20 +355,15 @@ class RedisCountStore:
|
|||||||
messages_str = data.get('messages')
|
messages_str = data.get('messages')
|
||||||
|
|
||||||
if count is not None:
|
if count is not None:
|
||||||
messages: list[dict] = deserialize_messages(messages_str)
|
messages = deserialize_messages(messages_str)
|
||||||
return int(count), messages
|
return [int(count), messages]
|
||||||
|
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[get_sessions_count] 查询失败: {e}")
|
print(f"[get_sessions_count] 查询失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||||
def update_sessions_count(
|
messages: Any) -> bool:
|
||||||
self,
|
|
||||||
end_user_id: str,
|
|
||||||
new_count: int,
|
|
||||||
messages: Any
|
|
||||||
) -> bool:
|
|
||||||
"""
|
"""
|
||||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||||
|
|
||||||
@@ -387,17 +384,17 @@ class RedisCountStore:
|
|||||||
key_type = self.r.type(index_key)
|
key_type = self.r.type(index_key)
|
||||||
if key_type != 'string' and key_type != 'none':
|
if key_type != 'string' and key_type != 'none':
|
||||||
# 索引键类型错误,删除并返回 False
|
# 索引键类型错误,删除并返回 False
|
||||||
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||||
self.r.delete(index_key)
|
self.r.delete(index_key)
|
||||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
return False
|
return False
|
||||||
except Exception as type_error:
|
except Exception as type_error:
|
||||||
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||||
|
|
||||||
session_id = self.r.get(index_key)
|
session_id = self.r.get(index_key)
|
||||||
|
|
||||||
if not session_id:
|
if not session_id:
|
||||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 直接更新数据
|
# 直接更新数据
|
||||||
@@ -405,15 +402,15 @@ class RedisCountStore:
|
|||||||
messages_str = serialize_messages(messages)
|
messages_str = serialize_messages(messages)
|
||||||
|
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
pipe.hset(key, 'count', str(new_count))
|
pipe.hset(key, 'count', int(new_count))
|
||||||
pipe.hset(key, 'messages', messages_str)
|
pipe.hset(key, 'messages', messages_str)
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[update_sessions_count] 更新失败: {e}")
|
print(f"[update_sessions_count] 更新失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_all_count_sessions(self) -> int:
|
def delete_all_count_sessions(self) -> int:
|
||||||
@@ -486,10 +483,10 @@ class RedisSessionStore:
|
|||||||
})
|
})
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[save_session] 保存会话失败: {e}")
|
print(f"[save_session] 保存会话失败: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# ==================== 读取操作 ====================
|
# ==================== 读取操作 ====================
|
||||||
@@ -541,7 +538,7 @@ class RedisSessionStore:
|
|||||||
|
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -568,7 +565,7 @@ class RedisSessionStore:
|
|||||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||||
|
|
||||||
return result_items
|
return result_items
|
||||||
|
|
||||||
@@ -635,7 +632,7 @@ class RedisSessionStore:
|
|||||||
|
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
logger.debug("[delete_duplicate_sessions] 没有会话数据")
|
print("[delete_duplicate_sessions] 没有会话数据")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 批量获取所有数据
|
# 批量获取所有数据
|
||||||
@@ -681,7 +678,7 @@ class RedisSessionStore:
|
|||||||
deleted_count += len(batch)
|
deleted_count += len(batch)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
202
api/app/core/memory/agent/utils/write_tools.py
Normal file
202
api/app/core/memory/agent/utils/write_tools.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""
|
||||||
|
Write Tools for Memory Knowledge Extraction Pipeline
|
||||||
|
|
||||||
|
This module provides the main write function for executing the knowledge extraction
|
||||||
|
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||||
|
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||||
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
from app.core.memory.utils.log.logging_utils import log_time
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||||
|
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||||
|
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def write(
|
||||||
|
end_user_id: str,
|
||||||
|
memory_config: MemoryConfig,
|
||||||
|
messages: list,
|
||||||
|
ref_id: str = "wyl20251027",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Execute the complete knowledge extraction pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User identifier
|
||||||
|
apply_id: Application identifier
|
||||||
|
end_user_id: Group identifier
|
||||||
|
memory_config: MemoryConfig object containing all configuration
|
||||||
|
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||||
|
ref_id: Reference ID, defaults to "wyl20251027"
|
||||||
|
"""
|
||||||
|
# Extract config values
|
||||||
|
embedding_model_id = str(memory_config.embedding_model_id)
|
||||||
|
chunker_strategy = memory_config.chunker_strategy
|
||||||
|
config_id = str(memory_config.config_id)
|
||||||
|
|
||||||
|
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
||||||
|
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
|
||||||
|
logger.info(f"Workspace: {memory_config.workspace_name}")
|
||||||
|
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
||||||
|
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
||||||
|
logger.info(f"Chunker strategy: {chunker_strategy}")
|
||||||
|
logger.info(f"end_user_id ID: {end_user_id}")
|
||||||
|
|
||||||
|
# Construct clients from memory_config using factory pattern with db session
|
||||||
|
with get_db_context() as db:
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||||
|
embedder_client = factory.get_embedder_client_from_config(memory_config)
|
||||||
|
logger.info("LLM and embedding clients constructed")
|
||||||
|
|
||||||
|
# Initialize timing log
|
||||||
|
log_file = "logs/time.log"
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
|
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
|
||||||
|
f.write(f"Config: {memory_config.config_name} (ID: {config_id})\n")
|
||||||
|
|
||||||
|
pipeline_start = time.time()
|
||||||
|
|
||||||
|
# Initialize Neo4j connector
|
||||||
|
neo4j_connector = Neo4jConnector()
|
||||||
|
|
||||||
|
# Step 1: Load and chunk data
|
||||||
|
step_start = time.time()
|
||||||
|
chunked_dialogs = await get_chunked_dialogs(
|
||||||
|
chunker_strategy=chunker_strategy,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
messages=messages,
|
||||||
|
ref_id=ref_id,
|
||||||
|
config_id=config_id,
|
||||||
|
)
|
||||||
|
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
# Step 2: Initialize and run ExtractionOrchestrator
|
||||||
|
step_start = time.time()
|
||||||
|
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||||
|
pipeline_config = get_pipeline_config(memory_config)
|
||||||
|
|
||||||
|
orchestrator = ExtractionOrchestrator(
|
||||||
|
llm_client=llm_client,
|
||||||
|
embedder_client=embedder_client,
|
||||||
|
connector=neo4j_connector,
|
||||||
|
config=pipeline_config,
|
||||||
|
embedding_id=embedding_model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the complete extraction pipeline
|
||||||
|
(
|
||||||
|
all_dialogue_nodes,
|
||||||
|
all_chunk_nodes,
|
||||||
|
all_statement_nodes,
|
||||||
|
all_entity_nodes,
|
||||||
|
all_statement_chunk_edges,
|
||||||
|
all_statement_entity_edges,
|
||||||
|
all_entity_entity_edges,
|
||||||
|
all_dedup_details,
|
||||||
|
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||||
|
|
||||||
|
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
# Step 3: Save all data to Neo4j database
|
||||||
|
step_start = time.time()
|
||||||
|
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||||
|
try:
|
||||||
|
await create_fulltext_indexes()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# 添加死锁重试机制
|
||||||
|
max_retries = 3
|
||||||
|
retry_delay = 1 # 秒
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
success = await save_dialog_and_statements_to_neo4j(
|
||||||
|
dialogue_nodes=all_dialogue_nodes,
|
||||||
|
chunk_nodes=all_chunk_nodes,
|
||||||
|
statement_nodes=all_statement_nodes,
|
||||||
|
entity_nodes=all_entity_nodes,
|
||||||
|
statement_chunk_edges=all_statement_chunk_edges,
|
||||||
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
|
entity_edges=all_entity_entity_edges,
|
||||||
|
connector=neo4j_connector
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to save some data to Neo4j")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
|
||||||
|
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
# 检查是否是死锁错误
|
||||||
|
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
|
||||||
|
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# 非死锁错误,直接抛出
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
await neo4j_connector.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing Neo4j connector: {e}")
|
||||||
|
|
||||||
|
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
# Step 4: Generate Memory summaries and save to Neo4j
|
||||||
|
step_start = time.time()
|
||||||
|
try:
|
||||||
|
summaries = await memory_summary_generation(
|
||||||
|
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ms_connector = Neo4jConnector()
|
||||||
|
await add_memory_summary_nodes(summaries, ms_connector)
|
||||||
|
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await ms_connector.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
# Log total pipeline time
|
||||||
|
total_time = time.time() - pipeline_start
|
||||||
|
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||||
|
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
|
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||||
|
|
||||||
|
logger.info("=== Pipeline Complete ===")
|
||||||
|
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||||
@@ -1,12 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
@@ -19,10 +16,6 @@ class FilteredTags(BaseModel):
|
|||||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||||
|
|
||||||
class InterestTags(BaseModel):
|
|
||||||
"""用于接收LLM筛选后的兴趣活动标签列表的模型。"""
|
|
||||||
interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。")
|
|
||||||
|
|
||||||
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||||
@@ -46,20 +39,16 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
|||||||
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
workspace_id = connected_config.get("workspace_id")
|
|
||||||
|
|
||||||
if not config_id and not workspace_id:
|
if not config_id:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No memory_config_id found for end_user_id: {end_user_id}. "
|
f"No memory_config_id found for end_user_id: {end_user_id}. "
|
||||||
"Please ensure the user has a valid memory configuration."
|
"Please ensure the user has a valid memory configuration."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the config_id to get the proper LLM client with workspace fallback
|
# Use the config_id to get the proper LLM client
|
||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
memory_config = config_service.load_memory_config(
|
memory_config = config_service.load_memory_config(config_id)
|
||||||
config_id=config_id,
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not memory_config.llm_model_id:
|
if not memory_config.llm_model_id:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -92,74 +81,10 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
|||||||
return structured_response.meaningful_tags
|
return structured_response.meaningful_tags
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM筛选过程中发生错误: {e}", exc_info=True)
|
print(f"LLM筛选过程中发生错误: {e}")
|
||||||
# 在LLM失败时返回原始标签,确保流程继续
|
# 在LLM失败时返回原始标签,确保流程继续
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]:
|
|
||||||
"""
|
|
||||||
使用LLM从标签列表中筛选出代表用户兴趣活动的标签。
|
|
||||||
|
|
||||||
与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣,
|
|
||||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tags: 原始标签列表
|
|
||||||
end_user_id: 用户ID,用于获取LLM配置
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
筛选后的兴趣活动标签列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with get_db_context() as db:
|
|
||||||
from app.services.memory_agent_service import (
|
|
||||||
get_end_user_connected_config,
|
|
||||||
)
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
|
||||||
config_id = connected_config.get("memory_config_id")
|
|
||||||
workspace_id = connected_config.get("workspace_id")
|
|
||||||
|
|
||||||
if not config_id and not workspace_id:
|
|
||||||
raise ValueError(
|
|
||||||
f"No memory_config_id found for end_user_id: {end_user_id}."
|
|
||||||
)
|
|
||||||
|
|
||||||
config_service = MemoryConfigService(db)
|
|
||||||
memory_config = config_service.load_memory_config(
|
|
||||||
config_id=config_id,
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not memory_config.llm_model_id:
|
|
||||||
raise ValueError(
|
|
||||||
f"No llm_model_id found in memory config {config_id}."
|
|
||||||
)
|
|
||||||
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
|
||||||
|
|
||||||
tag_list_str = ", ".join(tags)
|
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt
|
|
||||||
rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language)
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": rendered_prompt
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
structured_response = await llm_client.response_structured(
|
|
||||||
messages=messages,
|
|
||||||
response_model=InterestTags
|
|
||||||
)
|
|
||||||
|
|
||||||
return structured_response.interest_tags
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"兴趣标签LLM筛选过程中发生错误: {e}", exc_info=True)
|
|
||||||
return tags
|
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_tags_from_db(
|
async def get_raw_tags_from_db(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
@@ -210,14 +135,14 @@ async def get_raw_tags_from_db(
|
|||||||
|
|
||||||
return [(record["name"], record["frequency"]) for record in results]
|
return [(record["name"], record["frequency"]) for record in results]
|
||||||
|
|
||||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]:
|
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||||
"""
|
"""
|
||||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||||
查询更多的标签(40条)给LLM提供更丰富的上下文进行筛选,但最终返回数量由limit参数控制。
|
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||||
limit: 最终返回的标签数量限制(默认10)
|
limit: 返回的标签数量限制
|
||||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@@ -232,9 +157,8 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool =
|
|||||||
# 使用项目的Neo4jConnector
|
# 使用项目的Neo4jConnector
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
# 1. 从数据库获取原始排名靠前的标签(查询40条给LLM提供更丰富的上下文)
|
# 1. 从数据库获取原始排名靠前的标签
|
||||||
query_limit = 40
|
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
|
||||||
if not raw_tags_with_freq:
|
if not raw_tags_with_freq:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -249,61 +173,7 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool =
|
|||||||
if tag in meaningful_tag_names:
|
if tag in meaningful_tag_names:
|
||||||
final_tags.append((tag, freq))
|
final_tags.append((tag, freq))
|
||||||
|
|
||||||
# 4. 限制返回的标签数量
|
return final_tags
|
||||||
return final_tags[:limit]
|
|
||||||
finally:
|
finally:
|
||||||
# 确保关闭连接
|
# 确保关闭连接
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]:
|
|
||||||
"""
|
|
||||||
获取用户的兴趣分布标签。
|
|
||||||
|
|
||||||
与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt,
|
|
||||||
过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
|
||||||
limit: 最终返回的标签数量限制(默认10)
|
|
||||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果end_user_id未提供或为空
|
|
||||||
"""
|
|
||||||
if not end_user_id or not end_user_id.strip():
|
|
||||||
raise ValueError(
|
|
||||||
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
|
||||||
)
|
|
||||||
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
try:
|
|
||||||
# 查询更多原始标签,给LLM提供充足上下文
|
|
||||||
query_limit = 40
|
|
||||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
|
||||||
if not raw_tags_with_freq:
|
|
||||||
return []
|
|
||||||
|
|
||||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
|
||||||
raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq}
|
|
||||||
|
|
||||||
# 使用兴趣活动专用prompt进行筛选(支持语义推断出新标签)
|
|
||||||
interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language)
|
|
||||||
|
|
||||||
# 构建最终标签列表:
|
|
||||||
# - 原始标签中存在的,保留原始频率
|
|
||||||
# - LLM推断出的新标签(不在原始列表中),赋予默认频率1
|
|
||||||
final_tags = []
|
|
||||||
seen = set()
|
|
||||||
for tag in interest_tag_names:
|
|
||||||
if tag in seen:
|
|
||||||
continue
|
|
||||||
seen.add(tag)
|
|
||||||
freq = raw_freq_map.get(tag, 1)
|
|
||||||
final_tags.append((tag, freq))
|
|
||||||
|
|
||||||
# 按频率降序排列
|
|
||||||
final_tags.sort(key=lambda x: x[1], reverse=True)
|
|
||||||
|
|
||||||
return final_tags[:limit]
|
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user