Compare commits
383 Commits
v0.2.10
...
release/v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09393b2326 | ||
|
|
eaa66ba71a | ||
|
|
c59a97afba | ||
|
|
9480a61229 | ||
|
|
7ffd250b08 | ||
|
|
52bccfaede | ||
|
|
9233e74f36 | ||
|
|
46dfd92a9f | ||
|
|
5f33cec8ad | ||
|
|
334502f06b | ||
|
|
b0bb5e883c | ||
|
|
b9cfc47e1e | ||
|
|
4a4391a19c | ||
|
|
7193eed9e3 | ||
|
|
2a03f70287 | ||
|
|
124e8d0639 | ||
|
|
7dc35bb3fb | ||
|
|
b488590537 | ||
|
|
aa56ad15f9 | ||
|
|
d6af459ca8 | ||
|
|
2f7fd85ab1 | ||
|
|
398aebd0c5 | ||
|
|
eaa4058c56 | ||
|
|
21b25bfef7 | ||
|
|
a61acbef93 | ||
|
|
a90757745d | ||
|
|
b882863907 | ||
|
|
9159d5cbb0 | ||
|
|
537f6a1812 | ||
|
|
1ea0f308ba | ||
|
|
77c023102e | ||
|
|
ad24119b2d | ||
|
|
ea6fa154e0 | ||
|
|
158507cf8e | ||
|
|
5e0d30dde8 | ||
|
|
363d775270 | ||
|
|
ad4121b0d8 | ||
|
|
671df83bcd | ||
|
|
8bb5a66401 | ||
|
|
4c9f327833 | ||
|
|
6bd528eace | ||
|
|
2b5bece9b6 | ||
|
|
ea0e65f1ec | ||
|
|
cb2a7aa60a | ||
|
|
402c8aef5d | ||
|
|
eb98a69a84 | ||
|
|
152a84aff3 | ||
|
|
c5c8be89ed | ||
|
|
30aed72b74 | ||
|
|
35c2d9d0d3 | ||
|
|
27275eee43 | ||
|
|
7eb21f677f | ||
|
|
6de5d413c4 | ||
|
|
aecb0f6497 | ||
|
|
83b7c6870d | ||
|
|
74157adb12 | ||
|
|
8011610acc | ||
|
|
f1dc507b5c | ||
|
|
f3ac7e084d | ||
|
|
ba3743f9f1 | ||
|
|
20ddc76a4d | ||
|
|
84ca98555d | ||
|
|
7e6d17e4e3 | ||
|
|
7f3c48ce2a | ||
|
|
e5c16a2a24 | ||
|
|
8887600f7d | ||
|
|
df6eb74b28 | ||
|
|
b4b9974064 | ||
|
|
ff65dee754 | ||
|
|
2c2ed0ebf3 | ||
|
|
d60f838fb8 | ||
|
|
817aa78d03 | ||
|
|
4c73887a48 | ||
|
|
94d2d975ee | ||
|
|
d59990d326 | ||
|
|
3227c25b07 | ||
|
|
08b5c7bc8a | ||
|
|
475e573891 | ||
|
|
b03300c804 | ||
|
|
a5d07ee66d | ||
|
|
10a655772f | ||
|
|
aeeb18581d | ||
|
|
fb1160e833 | ||
|
|
c448cf0660 | ||
|
|
5289b3a2cb | ||
|
|
48f3d9b105 | ||
|
|
559b4bef6b | ||
|
|
4a39fd5f46 | ||
|
|
b22c15cccc | ||
|
|
a2f85b3d98 | ||
|
|
7f1cf13b23 | ||
|
|
d4129edcf5 | ||
|
|
ab2a58d68e | ||
|
|
a28b62763e | ||
|
|
86540a81d1 | ||
|
|
dcd874fecd | ||
|
|
bbd85733b8 | ||
|
|
22c5f12657 | ||
|
|
7b5d7696cb | ||
|
|
cb33724673 | ||
|
|
48b56a3d88 | ||
|
|
83d0fb9387 | ||
|
|
bb964c1ed8 | ||
|
|
81d58b001f | ||
|
|
99bc84a9f2 | ||
|
|
37dbe0f95b | ||
|
|
d4a1904b19 | ||
|
|
ecdad19f54 | ||
|
|
fb93c509f4 | ||
|
|
f597139913 | ||
|
|
113ae59f84 | ||
|
|
62c721bdf6 | ||
|
|
4cbb0cee2f | ||
|
|
8c586935a8 | ||
|
|
d5272af76f | ||
|
|
cf8912e929 | ||
|
|
327c1904b1 | ||
|
|
58c13aaeb4 | ||
|
|
377ddd2b9b | ||
|
|
52f7ea7456 | ||
|
|
b02baedd2c | ||
|
|
f3c3b6255e | ||
|
|
b659e2a6e1 | ||
|
|
e15e32cc7b | ||
|
|
04d20dc094 | ||
|
|
b8123fc84c | ||
|
|
5a17b7fd0d | ||
|
|
e3d0602850 | ||
|
|
696b2d2417 | ||
|
|
a5613314b8 | ||
|
|
e87529876c | ||
|
|
7bb3e65fb7 | ||
|
|
5ada7e77fc | ||
|
|
79b7da44e2 | ||
|
|
26a3d8a41b | ||
|
|
2380cd55ef | ||
|
|
a105df33ab | ||
|
|
0dd8cc5d43 | ||
|
|
fd90a4c2ad | ||
|
|
b302a94620 | ||
|
|
c96dc53534 | ||
|
|
f883c1469d | ||
|
|
ddfd81259a | ||
|
|
e015455fb8 | ||
|
|
915cb54f21 | ||
|
|
cada860a16 | ||
|
|
e1f8ad871b | ||
|
|
e205aaa6e6 | ||
|
|
62edafcebe | ||
|
|
ccdf7ae81d | ||
|
|
643f69bb90 | ||
|
|
73fbc19747 | ||
|
|
7ba0726473 | ||
|
|
8c6b65db12 | ||
|
|
5ce0bdb0f5 | ||
|
|
b59e2b5bcd | ||
|
|
5a2fe738dc | ||
|
|
f04412c455 | ||
|
|
db6fc5d2db | ||
|
|
b6aca0b1e7 | ||
|
|
4fd7395464 | ||
|
|
78ba313262 | ||
|
|
d35bc3a2cf | ||
|
|
d5c8d16e64 | ||
|
|
09496bd7b9 | ||
|
|
171f25a350 | ||
|
|
c7230659e3 | ||
|
|
502d87e88d | ||
|
|
1faa258e23 | ||
|
|
bef6a50deb | ||
|
|
cc12ec3fa8 | ||
|
|
466864afe3 | ||
|
|
e0d7a5a91f | ||
|
|
5ac2d5602e | ||
|
|
f4c3974956 | ||
|
|
71e5b6586a | ||
|
|
bfb723a468 | ||
|
|
61f2e44bd5 | ||
|
|
ed765b7c26 | ||
|
|
3018d186f7 | ||
|
|
2e1470cb52 | ||
|
|
737858731b | ||
|
|
d072eb1af7 | ||
|
|
daaee63bd5 | ||
|
|
e3c643b659 | ||
|
|
017efdc320 | ||
|
|
29aef4527c | ||
|
|
d9cb2b511b | ||
|
|
18be1a9f89 | ||
|
|
49e0801d15 | ||
|
|
dde7ea9039 | ||
|
|
5262aedab9 | ||
|
|
441b21774d | ||
|
|
d6dd038167 | ||
|
|
47c242e513 | ||
|
|
811193dd75 | ||
|
|
797780824c | ||
|
|
75e95bab01 | ||
|
|
e7a400bb96 | ||
|
|
28ca4d1734 | ||
|
|
5e6490213d | ||
|
|
3b359df02f | ||
|
|
fcf3071cb0 | ||
|
|
1294aabbcc | ||
|
|
3c2a78a449 | ||
|
|
4f0e5d0866 | ||
|
|
7a84ee33c6 | ||
|
|
e3265e4ba3 | ||
|
|
3e7a004599 | ||
|
|
fa1e5ee43c | ||
|
|
c72a6fd724 | ||
|
|
0965008210 | ||
|
|
bcadd2a6f1 | ||
|
|
e4f306dabb | ||
|
|
b5ec5c2cea | ||
|
|
e539b3eeb7 | ||
|
|
7f8765b815 | ||
|
|
72b39c6fa3 | ||
|
|
9032f50a19 | ||
|
|
aa683efaa0 | ||
|
|
2d9986f902 | ||
|
|
06075ffef5 | ||
|
|
a7336b0829 | ||
|
|
0d16e168e7 | ||
|
|
a882e5e5c4 | ||
|
|
c614bb5be7 | ||
|
|
1ff0f3ebfd | ||
|
|
bafcb5c545 | ||
|
|
f8d27fada6 | ||
|
|
90365cd026 | ||
|
|
d96c7b88f0 | ||
|
|
99559621c5 | ||
|
|
926f65a1ff | ||
|
|
b20971dc95 | ||
|
|
1ff0274027 | ||
|
|
8495aa5dde | ||
|
|
d8ef7a8e02 | ||
|
|
7a4a02b2bb | ||
|
|
8f623a66c8 | ||
|
|
77ed9faea1 | ||
|
|
1ff3748935 | ||
|
|
f023c43f80 | ||
|
|
60124e3232 | ||
|
|
70d4e79de1 | ||
|
|
59b5a1bcf2 | ||
|
|
62f345b3de | ||
|
|
a3f0415cd3 | ||
|
|
2450fe3afe | ||
|
|
52e726eabc | ||
|
|
7ca80b5d01 | ||
|
|
9470dd2f1e | ||
|
|
10f1089198 | ||
|
|
095f4e3001 | ||
|
|
ef8c7093b5 | ||
|
|
05ea372776 | ||
|
|
2b067ce08a | ||
|
|
b63cff2993 | ||
|
|
5bb9ce9018 | ||
|
|
aa581a9083 | ||
|
|
ac51ccaf1f | ||
|
|
5eaedaad77 | ||
|
|
bd955569b3 | ||
|
|
7a2a941ac4 | ||
|
|
19fa8314e4 | ||
|
|
cba24e58db | ||
|
|
62355186ef | ||
|
|
82faedc972 | ||
|
|
11ea486f82 | ||
|
|
efdee32f85 | ||
|
|
988d101e93 | ||
|
|
418f9f4dba | ||
|
|
520ee7c132 | ||
|
|
72be9f75f9 | ||
|
|
2b52b32b96 | ||
|
|
a96f20ee05 | ||
|
|
b8acc0a32f | ||
|
|
e1cf3bb3d2 | ||
|
|
6f66c9727f | ||
|
|
3beca641e1 | ||
|
|
b8507a1df6 | ||
|
|
0f28d54c43 | ||
|
|
0afc38e7ef | ||
|
|
07fd85c342 | ||
|
|
4c2a1e6d1d | ||
|
|
7cfb6ace22 | ||
|
|
91cc20d589 | ||
|
|
f01ca51896 | ||
|
|
f4a63f7d55 | ||
|
|
0019f3acfd | ||
|
|
3fe90a5e13 | ||
|
|
bc14c94407 | ||
|
|
a21dad70ed | ||
|
|
807a4e715d | ||
|
|
58d18b476c | ||
|
|
5e5927a0b9 | ||
|
|
7869121382 | ||
|
|
7c0fb624d9 | ||
|
|
af83980f99 | ||
|
|
cf0d11208c | ||
|
|
87d1630230 | ||
|
|
50392384e7 | ||
|
|
9a926a8398 | ||
|
|
e5e6699168 | ||
|
|
068e2bfb7e | ||
|
|
4ce6fede67 | ||
|
|
8497c955f9 | ||
|
|
72fe3962cf | ||
|
|
c253968aa8 | ||
|
|
d517bceda2 | ||
|
|
412183c359 | ||
|
|
90e8e90528 | ||
|
|
fd05c000f6 | ||
|
|
627d6a0381 | ||
|
|
807dee8460 | ||
|
|
ac7d39524e | ||
|
|
cd018814fe | ||
|
|
e0b7e95af6 | ||
|
|
3a62d50048 | ||
|
|
0e60da6d8a | ||
|
|
39e94eb3ea | ||
|
|
3e0f59adc6 | ||
|
|
660cd2fadb | ||
|
|
6f1bb43eab | ||
|
|
61b5627505 | ||
|
|
af6392fb09 | ||
|
|
84b1a95313 | ||
|
|
8b21dab255 | ||
|
|
fc5ce63e44 | ||
|
|
15a863b41a | ||
|
|
5226c5b79d | ||
|
|
27e9f9968d | ||
|
|
d38612a10d | ||
|
|
32c71dcd89 | ||
|
|
428e7ebaa5 | ||
|
|
57833689d9 | ||
|
|
384a67482c | ||
|
|
7842435321 | ||
|
|
33c4c5d31b | ||
|
|
ca4f7aa65d | ||
|
|
b875626f18 | ||
|
|
130684cac0 | ||
|
|
5adff38bda | ||
|
|
62e0b2730b | ||
|
|
55b2e05ba8 | ||
|
|
562ca6c1f1 | ||
|
|
e298b38de9 | ||
|
|
a7b8ba0c66 | ||
|
|
460c86cd94 | ||
|
|
33a1c178ff | ||
|
|
c81612e6d3 | ||
|
|
9f9ac69f97 | ||
|
|
0516822d42 | ||
|
|
b598171a3d | ||
|
|
a4ea7f0385 | ||
|
|
32ae60fc65 | ||
|
|
6b272c5b44 | ||
|
|
2782d0661f | ||
|
|
ea2f5e61c9 | ||
|
|
5975d70bf9 | ||
|
|
e0546e01ef | ||
|
|
70aab94fc3 | ||
|
|
0f50537d7d | ||
|
|
b7c1ce261b | ||
|
|
edac6a164e | ||
|
|
1503b242ea | ||
|
|
3ff44f0108 | ||
|
|
18fd48505d | ||
|
|
807ddce5cd | ||
|
|
62fb6c79a0 | ||
|
|
cc373b2864 | ||
|
|
f2d7479229 | ||
|
|
ae1909b7e9 | ||
|
|
e817cfd292 | ||
|
|
e48b146e60 | ||
|
|
07b66a9801 | ||
|
|
cd8229f370 | ||
|
|
cfbf83f71e | ||
|
|
99862db7a0 | ||
|
|
00a8099857 | ||
|
|
117e29fbe3 | ||
|
|
2f0bb793d8 | ||
|
|
010eff17cf | ||
|
|
7ce29019f7 |
164
.github/workflows/release-notify-wechat.yml
vendored
Normal file
164
.github/workflows/release-notify-wechat.yml
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
name: Release Notify Workflow
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
|
||||
jobs:
|
||||
notify:
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
startsWith(github.event.pull_request.base.ref, 'release')
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# 防止 GitHub HEAD 未同步
|
||||
- run: sleep 3
|
||||
|
||||
# 1️⃣ 获取分支 HEAD
|
||||
- name: Get HEAD
|
||||
id: head
|
||||
run: |
|
||||
HEAD_SHA=$(curl -s \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
https://api.github.com/repos/${{ github.repository }}/git/ref/heads/${{ github.event.pull_request.base.ref }} \
|
||||
| jq -r '.object.sha')
|
||||
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
|
||||
|
||||
# 2️⃣ 判断是否最终PR
|
||||
- name: Check Latest
|
||||
id: check
|
||||
run: |
|
||||
if [ "${{ github.event.pull_request.merge_commit_sha }}" = "${{ steps.head.outputs.head_sha }}" ]; then
|
||||
echo "ok=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "ok=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# 3️⃣ 尝试从 PR body 提取 Sourcery 摘要
|
||||
- name: Extract Sourcery Summary
|
||||
if: steps.check.outputs.ok == 'true'
|
||||
id: sourcery
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import os, re
|
||||
|
||||
body = os.environ.get("PR_BODY", "") or ""
|
||||
match = re.search(
|
||||
r"## Summary by Sourcery\s*\n(.*?)(?=\n## |\Z)",
|
||||
body,
|
||||
re.DOTALL
|
||||
)
|
||||
|
||||
if match:
|
||||
summary = match.group(1).strip()
|
||||
found = "true"
|
||||
else:
|
||||
summary = ""
|
||||
found = "false"
|
||||
|
||||
with open("sourcery_summary.txt", "w", encoding="utf-8") as f:
|
||||
f.write(summary)
|
||||
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||
gh.write(f"found={found}\n")
|
||||
gh.write("summary<<EOF\n")
|
||||
gh.write(summary + "\n")
|
||||
gh.write("EOF\n")
|
||||
PYEOF
|
||||
|
||||
# 4️⃣ Fallback: 获取 commits + 通义千问总结
|
||||
- name: Get Commits
|
||||
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||
run: |
|
||||
curl -s \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
${{ github.event.pull_request.commits_url }} \
|
||||
| jq -r '.[].commit.message' | head -n 20 > commits.txt
|
||||
|
||||
- name: AI Summary (Qwen Fallback)
|
||||
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||
id: qwen
|
||||
env:
|
||||
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import json, os, urllib.request
|
||||
|
||||
with open("commits.txt", "r") as f:
|
||||
commits = f.read().strip()
|
||||
|
||||
prompt = "请用中文总结以下代码提交,输出3-5条要点,面向测试人员。直接输出编号列表,不要输出标题或前言:\n" + commits
|
||||
payload = {"model": "qwen-plus", "input": {"prompt": prompt}}
|
||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
|
||||
data=data,
|
||||
headers={
|
||||
"Authorization": "Bearer " + os.environ["DASHSCOPE_API_KEY"],
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
resp = urllib.request.urlopen(req)
|
||||
result = json.loads(resp.read().decode())
|
||||
summary = result.get("output", {}).get("text", "AI 摘要生成失败")
|
||||
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||
gh.write("summary<<EOF\n")
|
||||
gh.write(summary + "\n")
|
||||
gh.write("EOF\n")
|
||||
PYEOF
|
||||
|
||||
# 5️⃣ 企业微信通知(Markdown)
|
||||
- name: Notify WeChat
|
||||
if: steps.check.outputs.ok == 'true'
|
||||
env:
|
||||
WECHAT_WEBHOOK: ${{ secrets.WECHAT_WEBHOOK }}
|
||||
BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||
AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SOURCERY_FOUND: ${{ steps.sourcery.outputs.found }}
|
||||
SOURCERY_SUMMARY: ${{ steps.sourcery.outputs.summary }}
|
||||
QWEN_SUMMARY: ${{ steps.qwen.outputs.summary }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import json, os, urllib.request
|
||||
|
||||
if os.environ.get("SOURCERY_FOUND") == "true":
|
||||
label = "Summary by Sourcery"
|
||||
summary = os.environ.get("SOURCERY_SUMMARY", "")
|
||||
else:
|
||||
label = "AI变更摘要"
|
||||
summary = os.environ.get("QWEN_SUMMARY", "AI 摘要生成失败")
|
||||
|
||||
pr_number = os.environ.get("PR_NUMBER", "")
|
||||
short_sha = os.environ.get("MERGE_SHA", "")[:7]
|
||||
|
||||
content = (
|
||||
"## 🚀 Release 发布通知\n"
|
||||
"> <20> **分支**: " + os.environ["BRANCH"] + "\n"
|
||||
"> 👤 **提交人**: " + os.environ["AUTHOR"] + "\n"
|
||||
"> 📝 **标题**: " + os.environ["PR_TITLE"] + "\n"
|
||||
"> 🔢 **PR编号**: #" + pr_number + "\n"
|
||||
"> 🔖 **Commit**: " + short_sha + "\n\n"
|
||||
"### 🧠 " + label + "\n" +
|
||||
summary + "\n\n"
|
||||
"---\n"
|
||||
"🔗 [查看PR详情](" + os.environ["PR_URL"] + ")"
|
||||
)
|
||||
payload = {"msgtype": "markdown", "markdown": {"content": content}}
|
||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
os.environ["WECHAT_WEBHOOK"],
|
||||
data=data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
resp = urllib.request.urlopen(req)
|
||||
print(resp.read().decode())
|
||||
PYEOF
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -27,6 +27,7 @@ time.log
|
||||
celerybeat-schedule.db
|
||||
search_results.json
|
||||
redbear-mem-metrics/
|
||||
redbear-mem-benchmark/
|
||||
pitch-deck/
|
||||
|
||||
api/migrations/versions
|
||||
|
||||
@@ -111,11 +111,17 @@ celery_app.conf.update(
|
||||
# Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题)
|
||||
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
|
||||
|
||||
# Metadata extraction → memory_tasks queue
|
||||
'app.tasks.extract_user_metadata': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# GraphRAG tasks → graphrag_tasks queue (独立队列,避免阻塞文档解析)
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'graphrag_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_document': {'queue': 'graphrag_tasks'},
|
||||
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
Celery Worker 入口点
|
||||
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
|
||||
"""
|
||||
from celery.signals import worker_process_init
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import LoggingConfig, get_logger
|
||||
|
||||
@@ -13,4 +15,39 @@ logger.info("Celery worker logging initialized")
|
||||
# 导入任务模块以注册任务
|
||||
import app.tasks
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def _reinit_db_pool(**kwargs):
|
||||
"""
|
||||
prefork 子进程启动时重建被 fork 污染的资源。
|
||||
|
||||
fork() 后子进程继承了父进程的:
|
||||
1. SQLAlchemy 连接池 — 多进程共享 TCP socket 导致 DB 连接损坏
|
||||
2. ThreadPoolExecutor — fork 后线程状态不确定,第二个任务会死锁
|
||||
"""
|
||||
# 重建 DB 连接池
|
||||
from app.db import engine
|
||||
engine.dispose()
|
||||
logger.info("DB connection pool disposed for forked worker process")
|
||||
|
||||
# 重建模块级 ThreadPoolExecutor(fork 后线程池不可用)
|
||||
try:
|
||||
from app.core.rag.deepdoc.parser import figure_parser
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
figure_parser.shared_executor = ThreadPoolExecutor(max_workers=10)
|
||||
logger.info("figure_parser.shared_executor recreated")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to recreate figure_parser.shared_executor: {e}")
|
||||
|
||||
try:
|
||||
from app.core.rag.utils import libre_office
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import os
|
||||
max_workers = os.cpu_count() * 2 if os.cpu_count() else 4
|
||||
libre_office.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
logger.info("libre_office.executor recreated")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to recreate libre_office.executor: {e}")
|
||||
|
||||
|
||||
__all__ = ['celery_app']
|
||||
|
||||
77
api/app/config/default_free_plan.py
Normal file
77
api/app/config/default_free_plan.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
社区版默认免费套餐配置
|
||||
当无法从 SaaS 版获取 premium 模块时,使用此配置作为兜底
|
||||
|
||||
可通过环境变量覆盖配额配置,格式:QUOTA_<QUOTA_NAME>
|
||||
例如:QUOTA_END_USER_QUOTA=100
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def _get_quota_from_env():
|
||||
"""从环境变量获取配额配置"""
|
||||
quota_keys = [
|
||||
"workspace_quota",
|
||||
"skill_quota",
|
||||
"app_quota",
|
||||
"knowledge_capacity_quota",
|
||||
"memory_engine_quota",
|
||||
"end_user_quota",
|
||||
"ontology_project_quota",
|
||||
"model_quota",
|
||||
"api_ops_rate_limit",
|
||||
]
|
||||
quotas = {}
|
||||
for key in quota_keys:
|
||||
env_key = f"QUOTA_{key.upper()}"
|
||||
env_value = os.getenv(env_key)
|
||||
if env_value is not None:
|
||||
try:
|
||||
quotas[key] = float(env_value) if '.' in env_value else int(env_value)
|
||||
except ValueError:
|
||||
pass
|
||||
return quotas
|
||||
|
||||
|
||||
def _build_default_free_plan():
|
||||
"""构建默认免费套餐配置"""
|
||||
base = {
|
||||
"name": "记忆体验版",
|
||||
"name_en": "Memory Experience",
|
||||
"category": "saas_personal",
|
||||
"tier_level": 0,
|
||||
"version": "1.0",
|
||||
"status": True,
|
||||
"price": 0,
|
||||
"billing_cycle": "permanent_free",
|
||||
"core_value": "感受永久记忆",
|
||||
"core_value_en": "Experience Permanent Memory",
|
||||
"tech_support": "社群交流",
|
||||
"tech_support_en": "Community Support",
|
||||
"sla_compliance": "无",
|
||||
"sla_compliance_en": "None",
|
||||
"page_customization": "无",
|
||||
"page_customization_en": "None",
|
||||
"theme_color": "#64748B",
|
||||
"quotas": {
|
||||
"workspace_quota": 1,
|
||||
"skill_quota": 5,
|
||||
"app_quota": 2,
|
||||
"knowledge_capacity_quota": 0.3,
|
||||
"memory_engine_quota": 1,
|
||||
"end_user_quota": 10,
|
||||
"ontology_project_quota": 3,
|
||||
"model_quota": 1,
|
||||
"api_ops_rate_limit": 50,
|
||||
},
|
||||
}
|
||||
|
||||
env_quotas = _get_quota_from_env()
|
||||
if env_quotas:
|
||||
base["quotas"].update(env_quotas)
|
||||
|
||||
return base
|
||||
|
||||
|
||||
DEFAULT_FREE_PLAN = _build_default_free_plan()
|
||||
@@ -47,7 +47,8 @@ from . import (
|
||||
user_memory_controllers,
|
||||
workspace_controller,
|
||||
ontology_controller,
|
||||
skill_controller
|
||||
skill_controller,
|
||||
tenant_subscription_controller,
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -98,5 +99,7 @@ manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
manager_router.include_router(i18n_controller.router)
|
||||
manager_router.include_router(tenant_subscription_controller.router)
|
||||
manager_router.include_router(tenant_subscription_controller.public_router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -167,6 +167,8 @@ def update_api_key(
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
||||
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
|
||||
@@ -28,6 +28,7 @@ from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.services.workflow_import_service import WorkflowImportService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_dsl_service import AppDslService
|
||||
from app.core.quota_stub import check_app_quota
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -35,6 +36,7 @@ logger = get_business_logger()
|
||||
|
||||
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
||||
@cur_workspace_access_guard()
|
||||
@check_app_quota
|
||||
def create_app(
|
||||
payload: app_schema.AppCreate,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -217,6 +219,7 @@ def delete_app(
|
||||
|
||||
@router.post("/{app_id}/copy", summary="复制应用")
|
||||
@cur_workspace_access_guard()
|
||||
@check_app_quota
|
||||
def copy_app(
|
||||
app_id: uuid.UUID,
|
||||
new_name: Optional[str] = None,
|
||||
@@ -269,6 +272,19 @@ def update_agent_config(
|
||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/model/parameters/default", summary="获取 Agent 模型参数默认配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_agent_model_parameters(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = AppService(db)
|
||||
model_parameters = service.get_default_model_parameters(app_id=app_id)
|
||||
return success(data=model_parameters, msg="获取 Agent 模型参数默认配置")
|
||||
|
||||
|
||||
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_agent_config(
|
||||
@@ -1129,6 +1145,7 @@ async def import_workflow_config(
|
||||
|
||||
@router.post("/workflow/import/save")
|
||||
@cur_workspace_access_guard()
|
||||
@check_app_quota
|
||||
async def save_workflow_import(
|
||||
data: WorkflowImportSave,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -1250,9 +1267,11 @@ async def export_app(
|
||||
async def import_app(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
app_id: Optional[str] = Form(None),
|
||||
):
|
||||
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
||||
传入 app_id 时覆盖该应用的配置(类型必须一致),否则创建新应用。
|
||||
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
||||
"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
@@ -1263,13 +1282,19 @@ async def import_app(
|
||||
if not dsl or "app" not in dsl:
|
||||
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
||||
|
||||
new_app, warnings = AppDslService(db).import_dsl(
|
||||
target_app_id = uuid.UUID(app_id) if app_id else None
|
||||
# 仅新建应用时检查配额,覆盖已有应用时跳过
|
||||
if target_app_id is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
_check_quota(db, current_user.tenant_id, "app_quota", "app", workspace_id=current_user.current_workspace_id)
|
||||
result_app, warnings = AppDslService(db).import_dsl(
|
||||
dsl=dsl,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
tenant_id=current_user.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=target_app_id,
|
||||
)
|
||||
return success(
|
||||
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
|
||||
data={"app": app_schema.App.model_validate(result_app), "warnings": warnings},
|
||||
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
||||
)
|
||||
|
||||
@@ -136,7 +136,7 @@ async def refresh_token(
|
||||
# 检查用户是否存在
|
||||
user = auth_service.get_user_by_id(db, userId)
|
||||
if not user:
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NO_ACCESS)
|
||||
|
||||
# 检查 refresh token 黑名单
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
|
||||
@@ -23,6 +23,7 @@ from app.models.user_model import User
|
||||
from app.schemas import chunk_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -442,10 +443,10 @@ async def retrieve_chunks(
|
||||
match retrieve_data.retrieve_type:
|
||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
case chunk_schema.RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
@@ -456,22 +457,24 @@ async def retrieve_chunks(
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
|
||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) if unique_rs else []
|
||||
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
||||
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
||||
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
||||
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
|
||||
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
|
||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
key=llm_key.api_key,
|
||||
model_name=llm_key.model_name,
|
||||
base_url=llm_key.api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
key=emb_key.api_key,
|
||||
model_name=emb_key.model_name,
|
||||
base_url=emb_key.api_base
|
||||
)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
rs.insert(0, doc)
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
@@ -19,6 +19,7 @@ from app.models.user_model import User
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import file_service, document_service
|
||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
@@ -131,6 +132,7 @@ async def create_folder(
|
||||
|
||||
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
@check_knowledge_capacity_quota
|
||||
async def upload_file(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.schemas import knowledge_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import knowledge_service, document_service
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -179,6 +180,7 @@ async def get_knowledges(
|
||||
|
||||
|
||||
@router.post("/knowledge", response_model=ApiResponse)
|
||||
@check_knowledge_capacity_quota
|
||||
async def create_knowledge(
|
||||
create_data: knowledge_schema.KnowledgeCreate,
|
||||
db: Session = Depends(get_db),
|
||||
|
||||
@@ -34,6 +34,7 @@ from app.services.memory_storage_service import (
|
||||
search_entity,
|
||||
search_statement,
|
||||
)
|
||||
from app.core.quota_stub import check_memory_engine_quota
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -76,6 +77,7 @@ async def get_storage_info(
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@check_memory_engine_quota
|
||||
def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.quota_stub import check_model_quota, check_model_activation_quota
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -303,6 +304,7 @@ async def create_model(
|
||||
|
||||
|
||||
@router.post("/composite", response_model=ApiResponse)
|
||||
@check_model_quota
|
||||
async def create_composite_model(
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -329,6 +331,7 @@ async def create_composite_model(
|
||||
|
||||
|
||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||
@check_model_activation_quota
|
||||
async def update_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
@@ -370,6 +373,7 @@ def delete_composite_model(
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=ApiResponse)
|
||||
@check_model_activation_quota
|
||||
def update_model(
|
||||
model_id: uuid.UUID,
|
||||
model_data: model_schema.ModelConfigUpdate,
|
||||
|
||||
@@ -28,6 +28,8 @@ from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, H
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.quota_stub import check_ontology_project_quota
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -163,7 +165,7 @@ def _get_ontology_service(
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
support_thinking="thinking" in (api_key_config.capability or []),
|
||||
capability=api_key_config.capability,
|
||||
max_retries=3,
|
||||
timeout=60.0
|
||||
)
|
||||
@@ -287,6 +289,7 @@ async def extract_ontology(
|
||||
# ==================== 本体场景管理接口 ====================
|
||||
|
||||
@router.post("/scene", response_model=ApiResponse)
|
||||
@check_ontology_project_quota
|
||||
async def create_scene(
|
||||
request: SceneCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
|
||||
@@ -124,10 +124,11 @@ async def get_prompt_opt(
|
||||
skill=data.skill
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
yield f"event:message\ndata: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"event:error\ndata: {json.dumps(
|
||||
{"error": str(e)}
|
||||
{"error": str(e)},
|
||||
ensure_ascii=False
|
||||
)}\n\n"
|
||||
yield "event:end\ndata: {}\n\n"
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.quota_manager import check_end_user_quota
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db, get_db_read
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
@@ -218,9 +219,20 @@ def list_conversations(
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
# 仅在新建终端用户时检查配额
|
||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||
if existing_end_user is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
from app.models.workspace_model import Workspace
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if ws:
|
||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
@@ -348,6 +360,18 @@ async def chat(
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
# 仅在新建终端用户时检查配额,已有用户复用不受限制
|
||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||
logger.info(f"终端用户配额检查: workspace_id={workspace_id}, other_id={other_id}, existing={existing_end_user is not None}")
|
||||
if existing_end_user is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
from app.models.workspace_model import Workspace
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if ws:
|
||||
logger.info(f"新终端用户,执行配额检查: tenant_id={ws.tenant_id}")
|
||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=workspace_id,
|
||||
|
||||
@@ -4,7 +4,17 @@
|
||||
认证方式: API Key
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller
|
||||
|
||||
from . import (
|
||||
app_api_controller,
|
||||
end_user_api_controller,
|
||||
memory_api_controller,
|
||||
memory_config_api_controller,
|
||||
rag_api_chunk_controller,
|
||||
rag_api_document_controller,
|
||||
rag_api_file_controller,
|
||||
rag_api_knowledge_controller,
|
||||
)
|
||||
|
||||
# 创建 V1 API 路由器
|
||||
service_router = APIRouter()
|
||||
@@ -17,5 +27,6 @@ service_router.include_router(rag_api_file_controller.router)
|
||||
service_router.include_router(rag_api_chunk_controller.router)
|
||||
service_router.include_router(memory_api_controller.router)
|
||||
service_router.include_router(end_user_api_controller.router)
|
||||
service_router.include_router(memory_config_api_controller.router)
|
||||
|
||||
__all__ = ["service_router"]
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas import AppChatRequest, conversation_schema
|
||||
@@ -61,18 +62,18 @@ async def list_apps():
|
||||
# return success(data={"received": True}, msg="消息已接收")
|
||||
|
||||
|
||||
def _checkAppConfig(app: App):
|
||||
if app.type == AppType.AGENT:
|
||||
if not app.current_release.config:
|
||||
def _checkAppConfig(release: AppRelease):
|
||||
if release.type == AppType.AGENT:
|
||||
if not release.config:
|
||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app.type == AppType.MULTI_AGENT:
|
||||
if not app.current_release.config:
|
||||
elif release.type == AppType.MULTI_AGENT:
|
||||
if not release.config:
|
||||
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app.type == AppType.WORKFLOW:
|
||||
if not app.current_release.config:
|
||||
elif release.type == AppType.WORKFLOW:
|
||||
if not release.config:
|
||||
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
else:
|
||||
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
|
||||
raise BusinessException("不支持的应用类型", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@@ -86,13 +87,35 @@ async def chat(
|
||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
):
|
||||
"""
|
||||
Agent/Workflow 聊天接口
|
||||
|
||||
- 不传 version:使用当前生效版本(current_release,回滚后为回滚目标版本)
|
||||
- 传 version=release_id:使用指定版本uuid的历史快照,例如 {"version": "{{release_id}}"}
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = AppChatRequest(**body)
|
||||
|
||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||
|
||||
# 版本切换:指定 release_id 时查找对应历史快照,否则使用当前激活版本
|
||||
if payload.version is not None:
|
||||
active_release = app_service.get_release_by_id(app.id, payload.version)
|
||||
else:
|
||||
active_release = app.current_release
|
||||
other_id = payload.user_id
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
|
||||
# 仅在新建终端用户时检查配额,已有用户复用不受限制
|
||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||
if existing_end_user is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
from app.models.workspace_model import Workspace
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if ws:
|
||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
@@ -127,7 +150,7 @@ async def chat(
|
||||
storage_type = 'neo4j'
|
||||
app_type = app.type
|
||||
# check app config
|
||||
_checkAppConfig(app)
|
||||
_checkAppConfig(active_release)
|
||||
|
||||
# 获取或创建会话(提前验证)
|
||||
conversation = conversation_service.create_or_get_conversation(
|
||||
@@ -142,7 +165,7 @@ async def chat(
|
||||
|
||||
# print("="*50)
|
||||
# print(app.current_release.default_model_config_id)
|
||||
agent_config = agent_config_4_app_release(app.current_release)
|
||||
agent_config = agent_config_4_app_release(active_release)
|
||||
# print(agent_config.default_model_config_id)
|
||||
|
||||
# thinking 开关:仅当 agent 配置了 deep_thinking 且请求 thinking=True 时才启用
|
||||
@@ -194,7 +217,7 @@ async def chat(
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# 多 Agent 流式返回
|
||||
config = multi_agent_config_4_app_release(app.current_release)
|
||||
config = multi_agent_config_4_app_release(active_release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
@@ -237,7 +260,7 @@ async def chat(
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
# 多 Agent 流式返回
|
||||
config = workflow_config_4_app_release(app.current_release)
|
||||
config = workflow_config_4_app_release(active_release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
@@ -253,7 +276,7 @@ async def chat(
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
release_id=active_release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
@@ -288,7 +311,7 @@ async def chat(
|
||||
files=payload.files,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id
|
||||
release_id=active_release.id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
@@ -302,6 +325,4 @@ async def chat(
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
@@ -5,28 +5,49 @@ import uuid
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import user_memory_controllers
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.quota_stub import check_end_user_quota
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.end_user_info_schema import EndUserInfoUpdate
|
||||
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
|
||||
from app.services import api_key_service
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
||||
"""Build a current_user object from API key auth
|
||||
|
||||
Args:
|
||||
api_key_auth: Validated API key auth info
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object with current_workspace_id set
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
return current_user
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
@require_api_key(scopes=["memory"])
|
||||
@check_end_user_quota
|
||||
async def create_end_user(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Request body"),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Create or retrieve an end user for the workspace.
|
||||
@@ -37,6 +58,7 @@ async def create_end_user(
|
||||
|
||||
Optionally accepts a memory_config_id to connect the end user to a specific
|
||||
memory configuration. If not provided, falls back to the workspace default config.
|
||||
Optionally accepts an app_id to bind the end user to a specific app.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = CreateEndUserRequest(**body)
|
||||
@@ -71,14 +93,26 @@ async def create_end_user(
|
||||
else:
|
||||
logger.warning(f"No default memory config found for workspace: {workspace_id}")
|
||||
|
||||
# Resolve app_id: explicit from payload, otherwise None
|
||||
app_id = None
|
||||
if payload.app_id:
|
||||
try:
|
||||
app_id = uuid.UUID(payload.app_id)
|
||||
except ValueError:
|
||||
raise BusinessException(
|
||||
f"Invalid app_id format: {payload.app_id}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_or_create_end_user_with_config(
|
||||
app_id=api_key_auth.resource_id,
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=payload.other_id,
|
||||
memory_config_id=memory_config_id,
|
||||
other_name=payload.other_name,
|
||||
)
|
||||
|
||||
end_user.other_name = payload.other_name
|
||||
logger.info(f"End user ready: {end_user.id}")
|
||||
|
||||
result = {
|
||||
@@ -90,3 +124,50 @@ async def create_end_user(
|
||||
}
|
||||
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_end_user_info(
|
||||
request: Request,
|
||||
end_user_id: str,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get end user info.
|
||||
|
||||
Retrieves the info record (aliases, meta_data, etc.) for the specified end user.
|
||||
Delegates to the manager-side controller for shared logic.
|
||||
"""
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
return await user_memory_controllers.get_end_user_info(
|
||||
end_user_id=end_user_id,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/info/update")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_end_user_info(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update end user info.
|
||||
|
||||
Updates the info record (other_name, aliases, meta_data) for the specified end user.
|
||||
Delegates to the manager-side controller for shared logic.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = EndUserInfoUpdate(**body)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
return await user_memory_controllers.update_end_user_info(
|
||||
info_update=payload,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
@@ -1,53 +1,83 @@
|
||||
"""Memory 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.quota_stub import check_end_user_quota
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import (
|
||||
CreateEndUserRequest,
|
||||
CreateEndUserResponse,
|
||||
ListConfigsResponse,
|
||||
MemoryReadRequest,
|
||||
MemoryReadResponse,
|
||||
MemoryReadSyncResponse,
|
||||
MemoryWriteRequest,
|
||||
MemoryWriteResponse,
|
||||
MemoryWriteSyncResponse,
|
||||
)
|
||||
from app.services.memory_api_service import MemoryAPIService
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def _sanitize_task_result(result: dict) -> dict:
|
||||
"""Make Celery task result JSON-serializable.
|
||||
|
||||
Converts UUID and other non-serializable values to strings.
|
||||
|
||||
Args:
|
||||
result: Raw task result dict from task_service
|
||||
|
||||
Returns:
|
||||
JSON-safe dict
|
||||
"""
|
||||
import uuid as _uuid
|
||||
from datetime import datetime
|
||||
|
||||
def _convert(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: _convert(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_convert(i) for i in obj]
|
||||
if isinstance(obj, _uuid.UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return obj
|
||||
|
||||
return _convert(result)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_memory_info():
|
||||
"""获取记忆服务信息(占位)"""
|
||||
return success(data={}, msg="Memory API - Coming Soon")
|
||||
|
||||
|
||||
@router.post("/write_api_service")
|
||||
@router.post("/write")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def write_memory_api_service(
|
||||
async def write_memory(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Message content"),
|
||||
):
|
||||
"""
|
||||
Write memory to storage.
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
Submit a memory write task.
|
||||
|
||||
Validates the end user, then dispatches the write to a Celery background task
|
||||
with per-user fair locking. Returns a task_id for status polling.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryWriteRequest(**body)
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = await memory_api_service.write_memory(
|
||||
|
||||
result = memory_api_service.write_memory(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
@@ -55,31 +85,53 @@ async def write_memory_api_service(
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Memory write successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory written successfully")
|
||||
|
||||
logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
|
||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
||||
|
||||
|
||||
@router.post("/read_api_service")
|
||||
@router.get("/write/status")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_memory_api_service(
|
||||
async def get_write_task_status(
|
||||
request: Request,
|
||||
task_id: str = Query(..., description="Celery task ID"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Check the status of a memory write task.
|
||||
|
||||
Returns the current status and result (if completed) of a previously submitted write task.
|
||||
"""
|
||||
logger.info(f"Write task status check - task_id: {task_id}")
|
||||
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
result = get_task_memory_write_result(task_id)
|
||||
|
||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||
|
||||
|
||||
@router.post("/read")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_memory(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Query message"),
|
||||
):
|
||||
"""
|
||||
Read memory from storage.
|
||||
|
||||
Queries and retrieves memories for the specified end user with context-aware responses.
|
||||
Submit a memory read task.
|
||||
|
||||
Validates the end user, then dispatches the read to a Celery background task.
|
||||
Returns a task_id for status polling.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryReadRequest(**body)
|
||||
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = await memory_api_service.read_memory(
|
||||
|
||||
result = memory_api_service.read_memory(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
@@ -88,58 +140,95 @@ async def read_memory_api_service(
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
||||
|
||||
logger.info(f"Memory read task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
|
||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read task submitted")
|
||||
|
||||
|
||||
@router.get("/configs")
|
||||
@router.get("/read/status")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def list_memory_configs(
|
||||
async def get_read_task_status(
|
||||
request: Request,
|
||||
task_id: str = Query(..., description="Celery task ID"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List all memory configs for the workspace.
|
||||
|
||||
Returns all available memory configurations associated with the authorized workspace.
|
||||
Check the status of a memory read task.
|
||||
|
||||
Returns the current status and result (if completed) of a previously submitted read task.
|
||||
"""
|
||||
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||
logger.info(f"Read task status check - task_id: {task_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
from app.services.task_service import get_task_memory_read_result
|
||||
result = get_task_memory_read_result(task_id)
|
||||
|
||||
result = memory_api_service.list_memory_configs(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
)
|
||||
|
||||
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||
|
||||
|
||||
@router.post("/end_users")
|
||||
@router.post("/write/sync")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def create_end_user(
|
||||
@check_end_user_quota
|
||||
async def write_memory_sync(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Message content"),
|
||||
):
|
||||
"""
|
||||
Create an end user.
|
||||
|
||||
Creates a new end user for the authorized workspace.
|
||||
If an end user with the same other_id already exists, returns the existing one.
|
||||
Write memory synchronously.
|
||||
|
||||
Blocks until the write completes and returns the result directly.
|
||||
For async processing with task polling, use /write instead.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = CreateEndUserRequest(**body)
|
||||
logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
payload = MemoryWriteRequest(**body)
|
||||
logger.info(f"Memory write (sync) request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = memory_api_service.create_end_user(
|
||||
result = await memory_api_service.write_memory_sync(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
other_id=payload.other_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
config_id=payload.config_id,
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"End user ready: {result['id']}")
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
logger.info(f"Memory write (sync) successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryWriteSyncResponse(**result).model_dump(), msg="Memory written successfully")
|
||||
|
||||
|
||||
@router.post("/read/sync")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_memory_sync(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Query message"),
|
||||
):
|
||||
"""
|
||||
Read memory synchronously.
|
||||
|
||||
Blocks until the read completes and returns the answer directly.
|
||||
For async processing with task polling, use /read instead.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryReadRequest(**body)
|
||||
logger.info(f"Memory read (sync) request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = await memory_api_service.read_memory_sync(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
search_switch=payload.search_switch,
|
||||
config_id=payload.config_id,
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Memory read (sync) successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryReadSyncResponse(**result).model_dump(), msg="Memory read successfully")
|
||||
|
||||
491
api/app/controllers/service/memory_config_api_controller.py
Normal file
491
api/app/controllers/service/memory_config_api_controller.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""Memory Config 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, Query, Request
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import memory_storage_controller
|
||||
from app.controllers import memory_forget_controller
|
||||
from app.controllers import ontology_controller
|
||||
from app.controllers import emotion_config_controller
|
||||
from app.controllers import memory_reflection_controller
|
||||
from app.schemas.memory_storage_schema import ForgettingConfigUpdateRequest
|
||||
from app.controllers.emotion_config_controller import EmotionConfigUpdate
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import (
|
||||
ConfigUpdateExtractedRequest,
|
||||
ConfigUpdateRequest,
|
||||
ListConfigsResponse,
|
||||
ConfigCreateRequest,
|
||||
ConfigUpdateForgettingRequest,
|
||||
EmotionConfigUpdateRequest,
|
||||
ReflectionConfigUpdateRequest,
|
||||
)
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigParamsCreate,
|
||||
)
|
||||
from app.services import api_key_service
|
||||
from app.services.memory_api_service import MemoryAPIService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
router = APIRouter(prefix="/memory_config", tags=["V1 - Memory Config API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
||||
"""Build a current_user object from API key auth
|
||||
|
||||
Args:
|
||||
api_key_auth: Validated API key auth info
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object with current_workspace_id set
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
return current_user
|
||||
|
||||
|
||||
def _verify_config_ownership(config_id:str, workspace_id:uuid.UUID, db:Session):
|
||||
"""Verify that the config belongs to the workspace.
|
||||
|
||||
Args:
|
||||
config_id: The ID of the config to verify
|
||||
workspace_id: The workspace ID tocheck against
|
||||
db: Database session for querying
|
||||
Raises:
|
||||
BusinessException: If the config does not exist or does not belong to the workspace
|
||||
"""
|
||||
try:
|
||||
resolved_id = resolve_config_id(config_id, db)
|
||||
except ValueError as e:
|
||||
raise BusinessException(
|
||||
message=f"Invalid config_id: {e}",
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
)
|
||||
config = MemoryConfigRepository.get_by_id(db, resolved_id)
|
||||
if not config or config.workspace_id != workspace_id:
|
||||
raise BusinessException(
|
||||
message="Config not found or access denied",
|
||||
code=BizCode.MEMORY_CONFIG_NOT_FOUND,
|
||||
)
|
||||
|
||||
# @router.get("/configs")
|
||||
# @require_api_key(scopes=["memory"])
|
||||
# async def list_memory_configs(
|
||||
# request: Request,
|
||||
# api_key_auth: ApiKeyAuth = None,
|
||||
# db: Session = Depends(get_db),
|
||||
# ):
|
||||
# """
|
||||
# List all memory configs for the workspace.
|
||||
|
||||
# Returns all available memory configurations associated with the authorized workspace.
|
||||
# """
|
||||
# logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
# memory_api_service = MemoryAPIService(db)
|
||||
|
||||
# result = memory_api_service.list_memory_configs(
|
||||
# workspace_id=api_key_auth.workspace_id,
|
||||
# )
|
||||
|
||||
# logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||
# return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||
|
||||
@router.get("/read_all_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_all_config(
|
||||
request:Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List all memory configs with full details (enhanced version).
|
||||
|
||||
Returns complete config fields for the authorized workspace.
|
||||
No config_id ownership check needed — results are filtered by workspace.
|
||||
"""
|
||||
logger.info(f"V1 get all configs (full) - workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return memory_storage_controller.read_all_config(
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
@router.get("/scenes/simple")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_ontology_scenes(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get available ontology scenes for the workspace.
|
||||
|
||||
Returns a simple list of scene_id and scene_name for dropdown selection.
|
||||
Used before creating a memory config to choose which ontology scene to associate.
|
||||
"""
|
||||
logger.info(f"V1 get scenes - workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return await ontology_controller.get_scenes_simple(
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
@router.get("/read_config_extracted")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_extracted(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get extraction engine config details for a specific config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read extracted config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return memory_storage_controller.read_config_extracted(
|
||||
config_id = config_id,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
|
||||
@router.get("/read_config_forgetting")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_forgetting(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get forgetting settings for a specific memory config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read forgetting config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
result = await memory_forget_controller.read_forgetting_config(
|
||||
config_id = config_id,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
|
||||
|
||||
@router.get("/read_config_emotion")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_emotion(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get emotion engine config details for a specific config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read emotion config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return jsonable_encoder(emotion_config_controller.get_emotion_config(
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
))
|
||||
|
||||
@router.get("/read_config_reflection")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_reflection(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get reflection engine config details for a specific config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read reflection config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return jsonable_encoder(await memory_reflection_controller.start_reflection_configs(
|
||||
config_id=config_id,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
))
|
||||
|
||||
|
||||
@router.post("/create_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def create_memory_config(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
):
|
||||
"""
|
||||
Create a new memory config for the workspace.
|
||||
|
||||
The config will be associated with the workspace of the API Key.
|
||||
config_name is required, other fields are optional.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigCreateRequest(**body)
|
||||
|
||||
logger.info(f"V1 create config - workspace: {api_key_auth.workspace_id}, config_name: {payload.config_name}")
|
||||
|
||||
# 构造管理端 Schema,workspace_id 从 API Key 注入
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
mgmt_payload = ConfigParamsCreate(
|
||||
config_name=payload.config_name,
|
||||
config_desc=payload.config_desc or "",
|
||||
scene_id=payload.scene_id,
|
||||
llm_id=payload.llm_id,
|
||||
embedding_id=payload.embedding_id,
|
||||
rerank_id=payload.rerank_id,
|
||||
reflection_model_id=payload.reflection_model_id,
|
||||
emotion_model_id=payload.emotion_model_id,
|
||||
)
|
||||
#将返回数据中UUID序列化处理
|
||||
result =memory_storage_controller.create_config(
|
||||
payload=mgmt_payload,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
x_language_type=x_language_type,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@router.put("/update_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_memory_config(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update memory config basic info (name, description, scene).
|
||||
|
||||
Requires API Key with 'memory' scope
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigUpdateRequest(**body)
|
||||
|
||||
logger.info(f"V1 update config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
mgmt_payload = ConfigUpdate(
|
||||
config_id = payload.config_id,
|
||||
config_name = payload.config_name,
|
||||
config_desc = payload.config_desc,
|
||||
scene_id = payload.scene_id,
|
||||
)
|
||||
|
||||
return memory_storage_controller.update_config(
|
||||
payload = mgmt_payload,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
|
||||
@router.put("/update_config_extracted")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_memory_config_extracted(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
update memory config extraction engine config (models, thresholds, chunking, pruning, etc.).
|
||||
|
||||
Requires API Key with 'memory' scope.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigUpdateExtractedRequest(**body)
|
||||
|
||||
logger.info(f"V1 update extracted config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
#校验权限
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = ConfigUpdateExtracted(**update_fields)
|
||||
|
||||
return memory_storage_controller.update_config_extracted(
|
||||
payload = mgmt_payload,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
|
||||
@router.put("/update_config_forgetting")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_memory_config_forgetting(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
update memory config forgetting settings (forgetting strategy, parameters, etc.).
|
||||
|
||||
Requires API Key with 'memory' scope.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigUpdateForgettingRequest(**body)
|
||||
|
||||
logger.info(f"V1 update forgetting config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
#校验权限
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = ForgettingConfigUpdateRequest(**update_fields)
|
||||
|
||||
#将返回数据中UUID序列化处理
|
||||
result = await memory_forget_controller.update_forgetting_config(
|
||||
payload = mgmt_payload,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@router.put("/update_config_emotion")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_config_emotion(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update emotion engine config (full update).
|
||||
|
||||
All fields except emotion_model_id are required.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = EmotionConfigUpdateRequest(**body)
|
||||
|
||||
logger.info(f"V1 update emotion config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = EmotionConfigUpdate(**update_fields)
|
||||
return jsonable_encoder(emotion_config_controller.update_emotion_config(
|
||||
config=mgmt_payload,
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
))
|
||||
|
||||
@router.put("/update_config_reflection")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_config_reflection(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update reflection engine config (full update).
|
||||
|
||||
All fields are required.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ReflectionConfigUpdateRequest(**body)
|
||||
|
||||
logger.info(f"V1 update reflection config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = Memory_Reflection(**update_fields)
|
||||
|
||||
return jsonable_encoder(await memory_reflection_controller.save_reflection_config(
|
||||
request=mgmt_payload,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
))
|
||||
|
||||
@router.delete("/delete_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def delete_memory_config(
|
||||
config_id: str,
|
||||
request: Request,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Delete a memory config.
|
||||
|
||||
- Default configs cannot be deleted.
|
||||
- If end users are connected and force=False, returns a warning.
|
||||
- If force=True, clears end user references and deletes the config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be deleted.
|
||||
"""
|
||||
logger.info(f"V1 delete config - config_id: {config_id}, force: {force}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return memory_storage_controller.delete_config(
|
||||
config_id=config_id,
|
||||
force=force,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
@@ -11,11 +11,13 @@ from app.schemas import skill_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.skill_service import SkillService
|
||||
from app.core.response_utils import success
|
||||
from app.core.quota_stub import check_skill_quota
|
||||
|
||||
router = APIRouter(prefix="/skills", tags=["Skills"])
|
||||
|
||||
|
||||
@router.post("", summary="创建技能")
|
||||
@check_skill_quota
|
||||
def create_skill(
|
||||
data: skill_schema.SkillCreate,
|
||||
db: Session = Depends(get_db),
|
||||
|
||||
173
api/app/controllers/tenant_subscription_controller.py
Normal file
173
api/app/controllers/tenant_subscription_controller.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
租户套餐查询接口(普通用户可访问)
|
||||
"""
|
||||
import datetime
|
||||
from typing import Callable, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.i18n.dependencies import get_translator
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
logger = get_api_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenant", tags=["Tenant"])
|
||||
public_router = APIRouter(tags=["Tenant"])
|
||||
|
||||
|
||||
@router.get("/subscription", response_model=ApiResponse, summary="获取当前用户所属租户的套餐信息")
|
||||
async def get_my_tenant_subscription(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator),
|
||||
):
|
||||
"""
|
||||
获取当前登录用户所属租户的有效套餐订阅信息。
|
||||
包含套餐名称、版本、配额、到期时间等。
|
||||
"""
|
||||
try:
|
||||
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||
|
||||
if not current_user.tenant:
|
||||
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
||||
|
||||
tenant_id = current_user.tenant.id
|
||||
svc = TenantSubscriptionService(db)
|
||||
sub = svc.get_subscription(tenant_id)
|
||||
|
||||
if not sub:
|
||||
# 无订阅记录时,兜底返回免费套餐信息
|
||||
free_plan = svc.plan_repo.get_free_plan()
|
||||
if not free_plan:
|
||||
return success(data=None, msg="暂无有效套餐")
|
||||
return success(data={
|
||||
"subscription_id": None,
|
||||
"tenant_id": str(tenant_id),
|
||||
"package_plan_id": str(free_plan.id),
|
||||
"package_version": free_plan.version,
|
||||
"package_plan": {
|
||||
"id": str(free_plan.id),
|
||||
"name": free_plan.name,
|
||||
"name_en": free_plan.name_en,
|
||||
"version": free_plan.version,
|
||||
"category": free_plan.category,
|
||||
"tier_level": free_plan.tier_level,
|
||||
"price": float(free_plan.price) if free_plan.price is not None else 0.0,
|
||||
"billing_cycle": free_plan.billing_cycle,
|
||||
"core_value": free_plan.core_value,
|
||||
"core_value_en": free_plan.core_value_en,
|
||||
"tech_support": free_plan.tech_support,
|
||||
"tech_support_en": free_plan.tech_support_en,
|
||||
"sla_compliance": free_plan.sla_compliance,
|
||||
"sla_compliance_en": free_plan.sla_compliance_en,
|
||||
"page_customization": free_plan.page_customization,
|
||||
"page_customization_en": free_plan.page_customization_en,
|
||||
"theme_color": free_plan.theme_color,
|
||||
},
|
||||
"started_at": None,
|
||||
"expired_at": None,
|
||||
"status": "active",
|
||||
"quotas": free_plan.quotas or {},
|
||||
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
}, msg="免费套餐")
|
||||
|
||||
return success(data=svc.build_response(sub))
|
||||
|
||||
except ModuleNotFoundError:
|
||||
# 社区版无 premium 模块,从配置文件读取免费套餐
|
||||
if not current_user.tenant:
|
||||
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
||||
|
||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||
|
||||
plan = DEFAULT_FREE_PLAN
|
||||
response_data = {
|
||||
"subscription_id": None,
|
||||
"tenant_id": str(current_user.tenant.id),
|
||||
"package_plan_id": None,
|
||||
"package_version": plan["version"],
|
||||
"package_plan": {
|
||||
"id": None,
|
||||
"name": plan["name"],
|
||||
"name_en": plan.get("name_en"),
|
||||
"version": plan["version"],
|
||||
"category": plan["category"],
|
||||
"tier_level": plan["tier_level"],
|
||||
"price": float(plan["price"]),
|
||||
"billing_cycle": plan["billing_cycle"],
|
||||
"core_value": plan.get("core_value"),
|
||||
"core_value_en": plan.get("core_value_en"),
|
||||
"tech_support": plan.get("tech_support"),
|
||||
"tech_support_en": plan.get("tech_support_en"),
|
||||
"sla_compliance": plan.get("sla_compliance"),
|
||||
"sla_compliance_en": plan.get("sla_compliance_en"),
|
||||
"page_customization": plan.get("page_customization"),
|
||||
"page_customization_en": plan.get("page_customization_en"),
|
||||
"theme_color": plan.get("theme_color"),
|
||||
},
|
||||
"started_at": None,
|
||||
"expired_at": None,
|
||||
"status": "active",
|
||||
"quotas": plan["quotas"],
|
||||
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
}
|
||||
return success(data=response_data, msg="社区版免费套餐")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取租户套餐信息失败: {e}", exc_info=True)
|
||||
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐信息失败"))
|
||||
|
||||
|
||||
@public_router.get("/package-plans", response_model=ApiResponse, summary="获取套餐列表(公开)")
|
||||
async def list_package_plans_public(
|
||||
category: Optional[str] = None,
|
||||
status: Optional[bool] = None,
|
||||
search: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
公开接口,无需鉴权。
|
||||
SaaS 版从数据库读取套餐列表;社区版降级返回 default_free_plan.py 中的免费套餐。
|
||||
"""
|
||||
try:
|
||||
from premium.platform_admin.package_plan_service import PackagePlanService
|
||||
from premium.platform_admin.package_plan_schema import PackagePlanResponse
|
||||
svc = PackagePlanService(db)
|
||||
result = svc.get_list(page=1, size=9999, category=category, status=status, search=search)
|
||||
return success(data=[PackagePlanResponse.model_validate(p).model_dump(mode="json") for p in result["items"]])
|
||||
except ModuleNotFoundError:
|
||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||
plan = DEFAULT_FREE_PLAN
|
||||
return success(data=[{
|
||||
"id": None,
|
||||
"name": plan["name"],
|
||||
"name_en": plan.get("name_en"),
|
||||
"version": plan["version"],
|
||||
"category": plan["category"],
|
||||
"tier_level": plan["tier_level"],
|
||||
"price": float(plan["price"]),
|
||||
"billing_cycle": plan["billing_cycle"],
|
||||
"core_value": plan.get("core_value"),
|
||||
"core_value_en": plan.get("core_value_en"),
|
||||
"tech_support": plan.get("tech_support"),
|
||||
"tech_support_en": plan.get("tech_support_en"),
|
||||
"sla_compliance": plan.get("sla_compliance"),
|
||||
"sla_compliance_en": plan.get("sla_compliance_en"),
|
||||
"page_customization": plan.get("page_customization"),
|
||||
"page_customization_en": plan.get("page_customization_en"),
|
||||
"theme_color": plan.get("theme_color"),
|
||||
"status": plan.get("status", True),
|
||||
"quotas": plan["quotas"],
|
||||
}])
|
||||
except Exception as e:
|
||||
logger.error(f"获取套餐列表失败: {e}", exc_info=True)
|
||||
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐列表失败"))
|
||||
@@ -114,11 +114,14 @@ def get_current_user_info(
|
||||
|
||||
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
||||
if current_user.external_source:
|
||||
from premium.sso.models import SSOSource
|
||||
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||
if source and source.permissions:
|
||||
result_schema.permissions = source.permissions
|
||||
else:
|
||||
try:
|
||||
from premium.sso.models import SSOSource
|
||||
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||
if source and source.permissions:
|
||||
result_schema.permissions = source.permissions
|
||||
else:
|
||||
result_schema.permissions = []
|
||||
except ModuleNotFoundError:
|
||||
result_schema.permissions = []
|
||||
else:
|
||||
result_schema.permissions = ["all"]
|
||||
|
||||
@@ -35,6 +35,7 @@ from app.schemas.workspace_schema import (
|
||||
WorkspaceUpdate,
|
||||
)
|
||||
from app.services import workspace_service
|
||||
from app.core.quota_stub import check_workspace_quota
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -106,6 +107,7 @@ def get_workspaces(
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
@check_workspace_quota
|
||||
def create_workspace(
|
||||
workspace: WorkspaceCreate,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
|
||||
@@ -12,7 +12,7 @@ import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.errors import GraphRecursionError
|
||||
|
||||
@@ -41,6 +41,7 @@ class LangChainAgent:
|
||||
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
|
||||
deep_thinking: bool = False, # 是否启用深度思考模式
|
||||
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
|
||||
json_output: bool = False, # 是否强制 JSON 输出
|
||||
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
@@ -64,7 +65,6 @@ class LangChainAgent:
|
||||
self.streaming = streaming
|
||||
self.is_omni = is_omni
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
self.deep_thinking = deep_thinking and ("thinking" in (capability or []))
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
@@ -80,6 +80,17 @@ class LangChainAgent:
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format
|
||||
# 在 system prompt 中注入 JSON 要求
|
||||
from app.models.models_model import ModelProvider
|
||||
if json_output and (
|
||||
(provider.lower() == ModelProvider.DASHSCOPE and not is_omni)
|
||||
or provider.lower() == ModelProvider.VOLCANO
|
||||
# 有工具时 response_format 会被移除,所有 provider 都需要 system prompt 注入保证 JSON 输出
|
||||
or bool(tools)
|
||||
):
|
||||
self.system_prompt += "\n请以JSON格式输出。"
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
@@ -87,23 +98,17 @@ class LangChainAgent:
|
||||
f"auto_calculated={max_iterations is None}"
|
||||
)
|
||||
|
||||
# 根据 capability 校验是否真正支持深度思考
|
||||
actual_deep_thinking = self.deep_thinking
|
||||
if deep_thinking and not actual_deep_thinking:
|
||||
logger.warning(
|
||||
f"模型 {model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking"
|
||||
)
|
||||
|
||||
# 创建 RedBearLLM(支持多提供商)
|
||||
# 创建 RedBearLLM,capability 校验由 RedBearModelConfig 统一处理
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
deep_thinking=actual_deep_thinking,
|
||||
thinking_budget_tokens=thinking_budget_tokens if actual_deep_thinking else None,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
capability=capability,
|
||||
deep_thinking=deep_thinking,
|
||||
thinking_budget_tokens=thinking_budget_tokens,
|
||||
json_output=json_output,
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -112,6 +117,9 @@ class LangChainAgent:
|
||||
)
|
||||
|
||||
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
||||
# 从经过校验的 config 读取实际生效的能力开关
|
||||
self.deep_thinking = model_config.deep_thinking
|
||||
self.json_output = model_config.json_output
|
||||
|
||||
# 获取底层模型用于真正的流式调用
|
||||
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
|
||||
@@ -237,9 +245,7 @@ class LangChainAgent:
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
"""
|
||||
messages:list = [SystemMessage(content=self.system_prompt)]
|
||||
|
||||
# 添加系统提示词
|
||||
messages: list = []
|
||||
|
||||
# 添加历史消息
|
||||
if history:
|
||||
|
||||
@@ -97,7 +97,7 @@ def require_api_key(
|
||||
)
|
||||
|
||||
rate_limiter = RateLimiterService()
|
||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj, db=db)
|
||||
if not is_allowed:
|
||||
logger.warning("API Key 限流触发", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
@@ -106,10 +106,12 @@ def require_api_key(
|
||||
"error_msg": error_msg
|
||||
})
|
||||
# 根据错误消息判断限流类型
|
||||
if "QPS" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||
elif "Daily" in error_msg:
|
||||
if "Daily" in error_msg:
|
||||
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
||||
elif "Tenant" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED # 租户套餐速率超限,同属 QPS 类
|
||||
elif "QPS" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||
else:
|
||||
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
||||
|
||||
|
||||
@@ -31,6 +31,9 @@ class BizCode(IntEnum):
|
||||
API_KEY_QPS_LIMIT_EXCEEDED = 3014
|
||||
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
|
||||
API_KEY_QUOTA_EXCEEDED = 3016
|
||||
API_KEY_RATE_LIMIT_EXCEEDED = 3017
|
||||
QUOTA_EXCEEDED = 3018
|
||||
RATE_LIMIT_EXCEEDED = 3019
|
||||
# 资源(4xxx)
|
||||
NOT_FOUND = 4000
|
||||
USER_NOT_FOUND = 4001
|
||||
@@ -41,6 +44,7 @@ class BizCode(IntEnum):
|
||||
FILE_NOT_FOUND = 4006
|
||||
APP_NOT_FOUND = 4007
|
||||
RELEASE_NOT_FOUND = 4008
|
||||
USER_NO_ACCESS = 4009
|
||||
|
||||
# 冲突/状态(5xxx)
|
||||
DUPLICATE_NAME = 5001
|
||||
@@ -118,6 +122,7 @@ HTTP_MAPPING = {
|
||||
BizCode.WORKSPACE_ACCESS_DENIED: 403,
|
||||
BizCode.NOT_FOUND: 400,
|
||||
BizCode.USER_NOT_FOUND: 200,
|
||||
BizCode.USER_NO_ACCESS: 401,
|
||||
BizCode.WORKSPACE_NOT_FOUND: 400,
|
||||
BizCode.MODEL_NOT_FOUND: 400,
|
||||
BizCode.KNOWLEDGE_NOT_FOUND: 400,
|
||||
@@ -153,7 +158,8 @@ HTTP_MAPPING = {
|
||||
BizCode.API_KEY_QPS_LIMIT_EXCEEDED: 429,
|
||||
BizCode.API_KEY_DAILY_LIMIT_EXCEEDED: 429,
|
||||
BizCode.API_KEY_QUOTA_EXCEEDED: 429,
|
||||
|
||||
BizCode.QUOTA_EXCEEDED: 402,
|
||||
|
||||
BizCode.MODEL_CONFIG_INVALID: 400,
|
||||
BizCode.API_KEY_MISSING: 400,
|
||||
BizCode.PROVIDER_NOT_SUPPORTED: 400,
|
||||
@@ -182,4 +188,21 @@ HTTP_MAPPING = {
|
||||
BizCode.DB_ERROR: 500,
|
||||
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||
BizCode.RATE_LIMITED: 429,
|
||||
BizCode.RATE_LIMIT_EXCEEDED: 429,
|
||||
}
|
||||
|
||||
ERROR_CODE_TO_BIZ_CODE = {
|
||||
"QUOTA_EXCEEDED": BizCode.QUOTA_EXCEEDED,
|
||||
"RATE_LIMIT_EXCEEDED": BizCode.RATE_LIMIT_EXCEEDED,
|
||||
"API_KEY_NOT_FOUND": BizCode.API_KEY_NOT_FOUND,
|
||||
"API_KEY_INVALID": BizCode.API_KEY_INVALID,
|
||||
"API_KEY_EXPIRED": BizCode.API_KEY_EXPIRED,
|
||||
"WORKSPACE_NOT_FOUND": BizCode.WORKSPACE_NOT_FOUND,
|
||||
"WORKSPACE_NO_ACCESS": BizCode.WORKSPACE_NO_ACCESS,
|
||||
"PERMISSION_DENIED": BizCode.PERMISSION_DENIED,
|
||||
"TOKEN_EXPIRED": BizCode.TOKEN_EXPIRED,
|
||||
"TOKEN_INVALID": BizCode.TOKEN_INVALID,
|
||||
"VALIDATION_FAILED": BizCode.VALIDATION_FAILED,
|
||||
"INVALID_PARAMETER": BizCode.INVALID_PARAMETER,
|
||||
"MISSING_PARAMETER": BizCode.MISSING_PARAMETER,
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ class PerceptualSearchService:
|
||||
return []
|
||||
try:
|
||||
r = await search_perceptual(
|
||||
connector=connector, q=escaped,
|
||||
connector=connector, query=escaped,
|
||||
end_user_id=self.end_user_id,
|
||||
limit=limit * 5, # 多查一些以提高命中率
|
||||
)
|
||||
@@ -178,7 +178,7 @@ class PerceptualSearchService:
|
||||
if not escaped.strip():
|
||||
return []
|
||||
r = await search_perceptual(
|
||||
connector=connector, q=escaped,
|
||||
connector=connector, query=escaped,
|
||||
end_user_id=self.end_user_id, limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
|
||||
@@ -14,6 +14,7 @@ from dotenv import load_dotenv
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||
memory_summary_generation
|
||||
@@ -191,15 +192,37 @@ async def write(
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
|
||||
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||
if all_entity_nodes:
|
||||
end_user_id = all_entity_nodes[0].end_user_id
|
||||
|
||||
# Neo4j 写入完成后,用 PgSQL 权威 aliases 覆盖 Neo4j 用户实体
|
||||
try:
|
||||
from app.repositories.end_user_info_repository import EndUserInfoRepository
|
||||
if end_user_id:
|
||||
with get_db_context() as db_session:
|
||||
info = EndUserInfoRepository(db_session).get_by_end_user_id(uuid.UUID(end_user_id))
|
||||
pg_aliases = info.aliases if info and info.aliases else []
|
||||
if info is not None:
|
||||
# 将 Python 侧占位名集合作为参数传入,避免 Cypher 硬编码
|
||||
placeholder_names = list(_USER_PLACEHOLDER_NAMES)
|
||||
await neo4j_connector.execute_query(
|
||||
"""
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $placeholder_names
|
||||
SET e.aliases = $aliases
|
||||
""",
|
||||
end_user_id=end_user_id, aliases=pg_aliases,
|
||||
placeholder_names=placeholder_names,
|
||||
)
|
||||
logger.info(f"[AliasSync] Neo4j 用户实体 aliases 已用 PgSQL 权威源覆盖: {pg_aliases}")
|
||||
except Exception as sync_err:
|
||||
logger.warning(f"[AliasSync] PgSQL→Neo4j aliases 同步失败(不影响主流程): {sync_err}")
|
||||
|
||||
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||
try:
|
||||
from app.tasks import run_incremental_clustering
|
||||
|
||||
end_user_id = all_entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in all_entity_nodes]
|
||||
|
||||
# 异步提交 Celery 任务
|
||||
task = run_incremental_clustering.apply_async(
|
||||
kwargs={
|
||||
"end_user_id": end_user_id,
|
||||
@@ -207,7 +230,6 @@ async def write(
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
},
|
||||
# 设置任务优先级(低优先级,不影响主业务)
|
||||
priority=3,
|
||||
)
|
||||
logger.info(
|
||||
@@ -215,7 +237,6 @@ async def write(
|
||||
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
|
||||
)
|
||||
except Exception as e:
|
||||
# 聚类任务提交失败不影响主流程
|
||||
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
|
||||
|
||||
break
|
||||
|
||||
@@ -58,6 +58,14 @@ from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
|
||||
# User metadata models
|
||||
from app.core.memory.models.metadata_models import (
|
||||
UserMetadata,
|
||||
UserMetadataProfile,
|
||||
MetadataExtractionResponse,
|
||||
MetadataFieldChange,
|
||||
)
|
||||
|
||||
# Ontology scenario models (LLM extracted from scenarios)
|
||||
from app.core.memory.models.ontology_scenario_models import (
|
||||
OntologyClass,
|
||||
@@ -124,6 +132,10 @@ __all__ = [
|
||||
"Entity",
|
||||
"Triplet",
|
||||
"TripletExtractionResponse",
|
||||
"UserMetadata",
|
||||
"UserMetadataProfile",
|
||||
"MetadataExtractionResponse",
|
||||
"MetadataFieldChange",
|
||||
# Ontology models
|
||||
"OntologyClass",
|
||||
"OntologyExtractionResponse",
|
||||
|
||||
@@ -364,12 +364,14 @@ class ChunkNode(Node):
|
||||
Attributes:
|
||||
dialog_id: ID of the parent dialog
|
||||
content: The text content of the chunk
|
||||
speaker: Speaker identifier ('user' or 'assistant')
|
||||
chunk_embedding: Optional embedding vector for the chunk
|
||||
sequence_number: Order of this chunk within the dialog
|
||||
metadata: Additional chunk metadata as key-value pairs
|
||||
"""
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
content: str = Field(..., description="The text content of the chunk")
|
||||
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
sequence_number: int = Field(..., description="Order of this chunk within the dialog")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional chunk metadata")
|
||||
|
||||
63
api/app/core/memory/models/metadata_models.py
Normal file
63
api/app/core/memory/models/metadata_models.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Models for user metadata extraction.
|
||||
|
||||
Independent from triplet_models.py - these models are used by the
|
||||
standalone metadata extraction pipeline (post-dedup async Celery task).
|
||||
"""
|
||||
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class UserMetadataProfile(BaseModel):
|
||||
"""用户画像信息"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
role: List[str] = Field(default_factory=list, description="用户职业或角色")
|
||||
domain: List[str] = Field(default_factory=list, description="用户所在领域")
|
||||
expertise: List[str] = Field(
|
||||
default_factory=list, description="用户擅长的技能或工具"
|
||||
)
|
||||
interests: List[str] = Field(
|
||||
default_factory=list, description="用户关注的话题或领域标签"
|
||||
)
|
||||
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
"""用户元数据顶层结构"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
profile: UserMetadataProfile = Field(default_factory=UserMetadataProfile)
|
||||
|
||||
|
||||
class MetadataFieldChange(BaseModel):
|
||||
"""单个元数据字段的变更操作"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
field_path: str = Field(
|
||||
description="字段路径,用点号分隔,如 'profile.role'、'profile.expertise'"
|
||||
)
|
||||
action: Literal["set", "remove"] = Field(
|
||||
description="操作类型:'set' 表示新增或修改,'remove' 表示移除"
|
||||
)
|
||||
value: Optional[str] = Field(
|
||||
default=None,
|
||||
description="字段的新值(action='set' 时必填)。标量字段直接填值,列表字段填单个要新增的元素"
|
||||
)
|
||||
|
||||
|
||||
class MetadataExtractionResponse(BaseModel):
|
||||
"""元数据提取 LLM 响应结构(增量模式)"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
metadata_changes: List[MetadataFieldChange] = Field(
|
||||
default_factory=list,
|
||||
description="元数据的增量变更列表,每项描述一个字段的新增、修改或移除操作",
|
||||
)
|
||||
aliases_to_add: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="本次新发现的用户别名(用户自我介绍或他人对用户的称呼)",
|
||||
)
|
||||
aliases_to_remove: List[str] = Field(
|
||||
default_factory=list, description="用户明确否认的别名(如'我不叫XX了')"
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
@@ -6,7 +5,6 @@ import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -23,7 +21,7 @@ from app.core.memory.utils.config.config_utils import (
|
||||
)
|
||||
from app.core.memory.utils.data.text_utils import extract_plain_query
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
# from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
@@ -748,11 +746,10 @@ async def run_hybrid_search(
|
||||
if search_type in ["keyword", "hybrid"]:
|
||||
# Keyword-based search
|
||||
logger.info("[PERF] Starting keyword search...")
|
||||
keyword_start = time.time()
|
||||
keyword_task = asyncio.create_task(
|
||||
search_graph(
|
||||
connector=connector,
|
||||
q=query_text,
|
||||
query=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include
|
||||
@@ -762,7 +759,6 @@ async def run_hybrid_search(
|
||||
if search_type in ["embedding", "hybrid"]:
|
||||
# Embedding-based search
|
||||
logger.info("[PERF] Starting embedding search...")
|
||||
embedding_start = time.time()
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
@@ -904,10 +900,10 @@ async def run_hybrid_search(
|
||||
else:
|
||||
results["latency_metrics"] = latency_metrics
|
||||
|
||||
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
logger.info("[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
|
||||
logger.info(f"[PERF] =========================================")
|
||||
logger.info("[PERF] =========================================")
|
||||
|
||||
# Sanitize results: drop large/unused fields
|
||||
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs
|
||||
|
||||
@@ -82,51 +82,38 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
canonical.connect_strength = next(iter(pair))
|
||||
|
||||
# 别名合并(去重保序,使用标准化工具)
|
||||
# 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,去重合并时不修改
|
||||
try:
|
||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||
incoming_name = (getattr(ent, "name", "") or "").strip()
|
||||
|
||||
# 收集所有需要合并的别名
|
||||
all_aliases = []
|
||||
|
||||
# 1. 添加canonical现有的别名
|
||||
existing = getattr(canonical, "aliases", []) or []
|
||||
all_aliases.extend(existing)
|
||||
|
||||
# 2. 添加incoming实体的名称(如果不同于canonical的名称)
|
||||
if incoming_name and incoming_name != canonical_name:
|
||||
all_aliases.append(incoming_name)
|
||||
|
||||
# 3. 添加incoming实体的所有别名
|
||||
incoming = getattr(ent, "aliases", []) or []
|
||||
all_aliases.extend(incoming)
|
||||
|
||||
# 4. 标准化并去重(优先使用alias_utils工具函数)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
# 如果导入失败,使用增强的去重逻辑
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
if canonical_name.lower() not in _USER_PLACEHOLDER_NAMES:
|
||||
incoming_name = (getattr(ent, "name", "") or "").strip()
|
||||
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
|
||||
# 标准化:转小写用于去重判断
|
||||
alias_normalized = alias_stripped.lower()
|
||||
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
# 收集所有需要合并的别名,过滤掉用户占位名避免污染非用户实体
|
||||
all_aliases = list(getattr(canonical, "aliases", []) or [])
|
||||
if incoming_name and incoming_name != canonical_name and incoming_name.lower() not in _USER_PLACEHOLDER_NAMES:
|
||||
all_aliases.append(incoming_name)
|
||||
all_aliases.extend(
|
||||
a for a in (getattr(ent, "aliases", []) or [])
|
||||
if a and a.strip().lower() not in _USER_PLACEHOLDER_NAMES
|
||||
)
|
||||
|
||||
# 排序并赋值
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
alias_normalized = alias_stripped.lower()
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -733,66 +720,37 @@ def fuzzy_match(
|
||||
|
||||
|
||||
def _merge_entities_with_aliases(canonical: ExtractedEntityNode, losing: ExtractedEntityNode):
|
||||
""" 模糊匹配中的实体合并。
|
||||
"""模糊匹配中的实体合并(别名部分)。
|
||||
|
||||
合并策略:
|
||||
1. 保留canonical的主名称不变
|
||||
2. 将losing的主名称添加为alias(如果不同)
|
||||
3. 合并两个实体的所有aliases
|
||||
4. 自动去重(case-insensitive)并排序
|
||||
|
||||
Args:
|
||||
canonical: 规范实体(保留)
|
||||
losing: 被合并实体(删除)
|
||||
|
||||
Note:
|
||||
使用alias_utils.normalize_aliases进行标准化去重
|
||||
用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,跳过合并。
|
||||
"""
|
||||
# 获取规范实体的名称
|
||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||
if canonical_name.lower() in _USER_PLACEHOLDER_NAMES:
|
||||
return
|
||||
|
||||
losing_name = (getattr(losing, "name", "") or "").strip()
|
||||
|
||||
# 收集所有需要合并的别名
|
||||
all_aliases = []
|
||||
|
||||
# 1. 添加canonical现有的别名
|
||||
current_aliases = getattr(canonical, "aliases", []) or []
|
||||
all_aliases.extend(current_aliases)
|
||||
|
||||
# 2. 添加losing实体的名称(如果不同于canonical的名称)
|
||||
all_aliases = list(getattr(canonical, "aliases", []) or [])
|
||||
if losing_name and losing_name != canonical_name:
|
||||
all_aliases.append(losing_name)
|
||||
all_aliases.extend(getattr(losing, "aliases", []) or [])
|
||||
|
||||
# 3. 添加losing实体的所有别名
|
||||
losing_aliases = getattr(losing, "aliases", []) or []
|
||||
all_aliases.extend(losing_aliases)
|
||||
|
||||
# 4. 标准化并去重(使用标准化后的字符串进行去重)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
# 如果导入失败,使用增强的去重逻辑
|
||||
# 使用标准化后的字符串作为key进行去重
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
|
||||
# 标准化:转小写用于去重判断
|
||||
alias_normalized = alias_stripped.lower()
|
||||
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
|
||||
# 排序并赋值
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
|
||||
# ========== 主循环:遍历所有实体对进行模糊匹配 ==========
|
||||
|
||||
@@ -311,10 +311,53 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
)
|
||||
|
||||
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
|
||||
# 步骤 7: 触发异步元数据和别名提取(仅正式模式)
|
||||
if not is_pilot_run:
|
||||
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
|
||||
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import (
|
||||
MetadataExtractor,
|
||||
)
|
||||
|
||||
metadata_extractor = MetadataExtractor(
|
||||
llm_client=self.llm_client, language=self.language
|
||||
)
|
||||
user_statements = (
|
||||
metadata_extractor.collect_user_related_statements(
|
||||
entity_nodes, statement_nodes, statement_entity_edges
|
||||
)
|
||||
)
|
||||
if user_statements:
|
||||
end_user_id = (
|
||||
dialog_data_list[0].end_user_id
|
||||
if dialog_data_list
|
||||
else None
|
||||
)
|
||||
config_id = (
|
||||
dialog_data_list[0].config_id
|
||||
if dialog_data_list
|
||||
and hasattr(dialog_data_list[0], "config_id")
|
||||
else None
|
||||
)
|
||||
if end_user_id:
|
||||
from app.tasks import extract_user_metadata_task
|
||||
|
||||
extract_user_metadata_task.delay(
|
||||
end_user_id=str(end_user_id),
|
||||
statements=user_statements,
|
||||
config_id=str(config_id) if config_id else None,
|
||||
language=self.language,
|
||||
)
|
||||
logger.info(
|
||||
f"已触发异步元数据提取任务,共 {len(user_statements)} 条用户相关 statement"
|
||||
)
|
||||
else:
|
||||
logger.info("未找到用户相关 statement,跳过元数据提取")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"触发元数据提取任务失败(不影响主流程): {e}", exc_info=True
|
||||
)
|
||||
|
||||
# 别名同步已迁移到 Celery 元数据提取任务中,不再在此处执行
|
||||
|
||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||
return (
|
||||
@@ -1107,6 +1150,7 @@ class ExtractionOrchestrator:
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
content=chunk.content,
|
||||
speaker=getattr(chunk, 'speaker', None),
|
||||
chunk_embedding=chunk.chunk_embedding,
|
||||
sequence_number=chunk_idx, # 添加必需的 sequence_number 字段
|
||||
created_at=dialog_data.created_at,
|
||||
@@ -1342,23 +1386,23 @@ class ExtractionOrchestrator:
|
||||
async def _update_end_user_other_name(
|
||||
self,
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
dialog_data_list: List[DialogData]
|
||||
dialog_data_list: List[DialogData],
|
||||
) -> None:
|
||||
"""
|
||||
将本轮提取的用户别名同步到 end_user 和 end_user_info 表。
|
||||
|
||||
注意:此方法在 Neo4j 写入之前调用,因此不能依赖 Neo4j 作为别名的权威数据源。
|
||||
改为直接使用内存中去重后的 entity_nodes 的 aliases,与 PgSQL 已有的 aliases 合并。
|
||||
PgSQL end_user_info.aliases 是用户别名的唯一权威源。
|
||||
此方法仅将本轮 LLM 从对话中新提取的别名增量追加到 PgSQL,
|
||||
不再从 Neo4j 二层去重合并历史别名,避免脏数据反向污染 PgSQL。
|
||||
|
||||
策略:
|
||||
1. 从内存中的 entity_nodes 提取本轮用户别名(current_aliases)
|
||||
2. 从去重后的 entity_nodes 中提取完整别名(含 Neo4j 二层去重合并的历史别名)
|
||||
3. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases)
|
||||
4. 合并 db_aliases + deduped_aliases + current_aliases,去重保序
|
||||
5. 写回 PgSQL
|
||||
1. 从本轮对话原始发言中提取用户别名(current_aliases)
|
||||
2. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases)
|
||||
3. 合并 db_aliases + current_aliases,去重保序
|
||||
4. 写回 PgSQL
|
||||
|
||||
Args:
|
||||
entity_nodes: 去重后的实体节点列表(内存中,含二层去重合并结果)
|
||||
entity_nodes: 去重后的实体节点列表(内存中)
|
||||
dialog_data_list: 对话数据列表
|
||||
"""
|
||||
try:
|
||||
@@ -1374,11 +1418,6 @@ class ExtractionOrchestrator:
|
||||
# 1. 提取本轮对话的用户别名(保持 LLM 提取的原始顺序,不排序)
|
||||
current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list)
|
||||
|
||||
# 1.5 从去重后的 entity_nodes 中提取完整别名
|
||||
# 二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 中,
|
||||
# 这里提取出来确保 PgSQL 与 Neo4j 的别名保持同步
|
||||
deduped_aliases = self._extract_deduped_entity_aliases(entity_nodes)
|
||||
|
||||
# 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源
|
||||
# (防止 LLM 未提取出 AI 助手实体时,AI 别名泄漏到用户别名中)
|
||||
neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id)
|
||||
@@ -1390,19 +1429,12 @@ class ExtractionOrchestrator:
|
||||
]
|
||||
if len(current_aliases) < before_count:
|
||||
logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名")
|
||||
# 同样过滤 deduped_aliases
|
||||
deduped_aliases = [
|
||||
a for a in deduped_aliases
|
||||
if a.strip().lower() not in neo4j_assistant_aliases
|
||||
]
|
||||
|
||||
if not current_aliases and not deduped_aliases:
|
||||
if not current_aliases:
|
||||
logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}")
|
||||
return
|
||||
|
||||
logger.info(f"本轮对话提取的 aliases: {current_aliases}")
|
||||
if deduped_aliases:
|
||||
logger.info(f"去重后实体的完整 aliases(含历史): {deduped_aliases}")
|
||||
|
||||
# 2. 同步到数据库
|
||||
end_user_uuid = uuid.UUID(end_user_id)
|
||||
@@ -1413,21 +1445,15 @@ class ExtractionOrchestrator:
|
||||
logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录")
|
||||
return
|
||||
|
||||
# 3. 从 PgSQL 读取已有 aliases 并与本轮合并
|
||||
# 3. 从 PgSQL 读取已有 aliases 并与本轮新增合并
|
||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||
db_aliases = (info.aliases if info and info.aliases else [])
|
||||
# 过滤掉占位名称
|
||||
db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
|
||||
|
||||
# 合并:已有 + 去重后完整别名 + 本轮新增,去重保序
|
||||
# 合并:PgSQL 已有 + 本轮新增,去重保序(不再合并 Neo4j 历史别名)
|
||||
merged_aliases = list(db_aliases)
|
||||
seen_lower = {a.strip().lower() for a in merged_aliases}
|
||||
# 先合并去重后实体的完整别名(含 Neo4j 历史别名)
|
||||
for alias in deduped_aliases:
|
||||
if alias.strip().lower() not in seen_lower:
|
||||
merged_aliases.append(alias)
|
||||
seen_lower.add(alias.strip().lower())
|
||||
# 再合并本轮新提取的别名
|
||||
for alias in current_aliases:
|
||||
if alias.strip().lower() not in seen_lower:
|
||||
merged_aliases.append(alias)
|
||||
@@ -1461,16 +1487,13 @@ class ExtractionOrchestrator:
|
||||
info.aliases = merged_aliases
|
||||
logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}")
|
||||
else:
|
||||
first_alias = current_aliases[0].strip() if current_aliases else (
|
||||
deduped_aliases[0].strip() if deduped_aliases else ""
|
||||
)
|
||||
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||
# 确保 first_alias 不是占位名称
|
||||
if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES:
|
||||
db.add(EndUserInfo(
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=first_alias,
|
||||
aliases=merged_aliases,
|
||||
meta_data={}
|
||||
))
|
||||
logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={merged_aliases}")
|
||||
|
||||
@@ -1478,9 +1501,6 @@ class ExtractionOrchestrator:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新 end_user other_name 失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
|
||||
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
|
||||
# 复用 deduped_and_disamb 模块级常量,避免重复维护
|
||||
USER_PLACEHOLDER_NAMES = _USER_PLACEHOLDER_NAMES
|
||||
@@ -1587,7 +1607,6 @@ class ExtractionOrchestrator:
|
||||
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
async def _run_dedup_and_write_summary(
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Metadata extractor module.
|
||||
|
||||
Collects user-related statements from post-dedup graph data and
|
||||
extracts user metadata via an independent LLM call.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
ExtractedEntityNode,
|
||||
StatementEntityEdge,
|
||||
StatementNode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reuse the same user-entity detection logic from dedup module
|
||||
_USER_NAMES = {"用户", "我", "user", "i"}
|
||||
_CANONICAL_USER_TYPE = "用户"
|
||||
|
||||
|
||||
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为用户实体"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _USER_NAMES or etype == _CANONICAL_USER_TYPE
|
||||
|
||||
|
||||
class MetadataExtractor:
|
||||
"""Extracts user metadata from post-dedup graph data via independent LLM call."""
|
||||
|
||||
def __init__(self, llm_client, language: Optional[str] = None):
|
||||
self.llm_client = llm_client
|
||||
self.language = language
|
||||
|
||||
@staticmethod
|
||||
def detect_language(statements: List[str]) -> str:
|
||||
"""根据 statement 文本内容检测语言。
|
||||
如果文本中包含中文字符则返回 "zh",否则返回 "en"。
|
||||
"""
|
||||
import re
|
||||
|
||||
combined = " ".join(statements)
|
||||
if re.search(r"[\u4e00-\u9fff]", combined):
|
||||
return "zh"
|
||||
return "en"
|
||||
|
||||
def collect_user_related_statements(
|
||||
self,
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
) -> List[str]:
|
||||
"""
|
||||
从去重后的数据中筛选与用户直接相关且由用户发言的 statement 文本。
|
||||
|
||||
筛选逻辑:
|
||||
1. 用户实体 → StatementEntityEdge → statement(直接关联)
|
||||
2. 只保留 speaker="user" 的 statement(过滤 assistant 回复的噪声)
|
||||
|
||||
Returns:
|
||||
用户发言的 statement 文本列表
|
||||
"""
|
||||
# Find user entity IDs
|
||||
user_entity_ids = set()
|
||||
for ent in entity_nodes:
|
||||
if _is_user_entity(ent):
|
||||
user_entity_ids.add(ent.id)
|
||||
|
||||
if not user_entity_ids:
|
||||
logger.debug("未找到用户实体节点,跳过 statement 收集")
|
||||
return []
|
||||
|
||||
# 用户实体 → StatementEntityEdge → statement
|
||||
target_stmt_ids = set()
|
||||
for edge in statement_entity_edges:
|
||||
if edge.target in user_entity_ids:
|
||||
target_stmt_ids.add(edge.source)
|
||||
|
||||
# Collect: only speaker="user" statements, preserving order
|
||||
result = []
|
||||
seen = set()
|
||||
total_associated = 0
|
||||
skipped_non_user = 0
|
||||
for stmt_node in statement_nodes:
|
||||
if stmt_node.id in target_stmt_ids and stmt_node.id not in seen:
|
||||
total_associated += 1
|
||||
speaker = getattr(stmt_node, "speaker", None) or "unknown"
|
||||
if speaker == "user":
|
||||
text = (stmt_node.statement or "").strip()
|
||||
if text:
|
||||
result.append(text)
|
||||
else:
|
||||
skipped_non_user += 1
|
||||
seen.add(stmt_node.id)
|
||||
|
||||
logger.info(
|
||||
f"收集到 {len(result)} 条用户发言 statement "
|
||||
f"(直接关联: {total_associated}, speaker=user: {len(result)}, "
|
||||
f"跳过非user: {skipped_non_user})"
|
||||
)
|
||||
if result:
|
||||
for i, text in enumerate(result):
|
||||
logger.info(f" [user statement {i + 1}] {text}")
|
||||
if total_associated > 0 and len(result) == 0:
|
||||
logger.warning(
|
||||
f"有 {total_associated} 条直接关联 statement 但全部被 speaker 过滤,"
|
||||
f"可能本次写入不包含 user 消息"
|
||||
)
|
||||
return result
|
||||
|
||||
async def extract_metadata(
|
||||
self,
|
||||
statements: List[str],
|
||||
existing_metadata: Optional[dict] = None,
|
||||
existing_aliases: Optional[List[str]] = None,
|
||||
) -> Optional[tuple]:
|
||||
"""
|
||||
对筛选后的 statement 列表调用 LLM 提取元数据增量变更和用户别名。
|
||||
|
||||
Args:
|
||||
statements: 用户发言的 statement 文本列表
|
||||
existing_metadata: 数据库已有的元数据(可选)
|
||||
existing_aliases: 数据库已有的用户别名列表(可选)
|
||||
|
||||
Returns:
|
||||
(List[MetadataFieldChange], List[str], List[str]) tuple:
|
||||
(metadata_changes, aliases_to_add, aliases_to_remove) on success, None on failure
|
||||
"""
|
||||
if not statements:
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.prompt.prompt_utils import prompt_env
|
||||
|
||||
if self.language:
|
||||
detected_language = self.language
|
||||
logger.info(f"元数据提取使用显式指定语言: {detected_language}")
|
||||
else:
|
||||
detected_language = self.detect_language(statements)
|
||||
logger.info(f"元数据提取语言自动检测结果: {detected_language}")
|
||||
|
||||
template = prompt_env.get_template("extract_user_metadata.jinja2")
|
||||
prompt = template.render(
|
||||
statements=statements,
|
||||
language=detected_language,
|
||||
existing_metadata=existing_metadata,
|
||||
existing_aliases=existing_aliases,
|
||||
json_schema="",
|
||||
)
|
||||
|
||||
from app.core.memory.models.metadata_models import (
|
||||
MetadataExtractionResponse,
|
||||
)
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_model=MetadataExtractionResponse,
|
||||
)
|
||||
|
||||
if response:
|
||||
changes = response.metadata_changes if response.metadata_changes else []
|
||||
to_add = response.aliases_to_add if response.aliases_to_add else []
|
||||
to_remove = (
|
||||
response.aliases_to_remove if response.aliases_to_remove else []
|
||||
)
|
||||
return changes, to_add, to_remove
|
||||
|
||||
logger.warning("LLM 返回的响应为空")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"元数据提取 LLM 调用失败: {e}", exc_info=True)
|
||||
return None
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -82,6 +81,7 @@ class StatementExtractor:
|
||||
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
|
||||
return None
|
||||
|
||||
|
||||
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
"""Process a single chunk and return extracted statements
|
||||
|
||||
@@ -94,7 +94,8 @@ class StatementExtractor:
|
||||
List of ExtractedStatement objects extracted from the chunk
|
||||
"""
|
||||
chunk_content = chunk.content
|
||||
|
||||
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||
|
||||
if not chunk_content or len(chunk_content.strip()) < 5:
|
||||
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
|
||||
return []
|
||||
@@ -149,8 +150,6 @@ class StatementExtractor:
|
||||
relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT
|
||||
except (KeyError, ValueError):
|
||||
relevence_info = RelevenceInfo.RELEVANT
|
||||
|
||||
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||
|
||||
chunk_statement = Statement(
|
||||
statement=extracted_stmt.statement,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
|
||||
@@ -42,22 +42,21 @@ class AccessHistoryManager:
|
||||
- access_count: 访问次数
|
||||
|
||||
特性:
|
||||
- 原子性更新:使用Neo4j事务确保所有字段同时更新或回滚
|
||||
- 并发安全:使用乐观锁机制防止并发冲突
|
||||
- 原子性更新:使用 APOC 原子操作确保并发安全
|
||||
- 批次内合并:同一批次中对同一节点的多次访问合并为一次更新
|
||||
- 一致性保证:提供一致性检查和自动修复功能
|
||||
- 智能修剪:自动修剪过长的访问历史
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
actr_calculator: ACT-R激活值计算器实例
|
||||
max_retries: 并发冲突时的最大重试次数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
actr_calculator: ACTRCalculator,
|
||||
max_retries: int = 3
|
||||
max_retries: int = 5
|
||||
):
|
||||
"""
|
||||
初始化访问历史管理器
|
||||
@@ -65,47 +64,35 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
actr_calculator: ACT-R激活值计算器实例
|
||||
max_retries: 并发冲突时的最大重试次数(默认3次)
|
||||
max_retries: 已废弃,保留参数兼容性(APOC 原子操作无需重试)
|
||||
"""
|
||||
self.connector = connector
|
||||
self.actr_calculator = actr_calculator
|
||||
self.max_retries = max_retries
|
||||
|
||||
|
||||
async def record_access(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
current_time: Optional[datetime] = None,
|
||||
access_times: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
记录节点访问并原子性更新所有相关字段
|
||||
|
||||
这是核心方法,实现了:
|
||||
1. 首次访问:初始化access_history,计算初始激活值
|
||||
2. 后续访问:追加访问历史,重新计算激活值
|
||||
3. 历史修剪:当历史过长时自动修剪
|
||||
4. 原子性:所有字段在单个事务中更新
|
||||
5. 并发安全:使用乐观锁重试机制
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
end_user_id: 组ID(可选,用于过滤)
|
||||
current_time: 当前时间(可选,默认使用系统时间)
|
||||
access_times: 本次访问次数(默认1,批量合并时可能大于1)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新后的节点数据,包含:
|
||||
- id: 节点ID
|
||||
- activation_value: 更新后的激活值
|
||||
- access_history: 更新后的访问历史
|
||||
- last_access_time: 最后访问时间
|
||||
- access_count: 访问次数
|
||||
- importance_score: 重要性分数
|
||||
Dict[str, Any]: 更新后的节点数据
|
||||
|
||||
Raises:
|
||||
ValueError: 如果节点不存在或节点标签无效
|
||||
RuntimeError: 如果重试次数耗尽仍然失败
|
||||
RuntimeError: 如果更新失败
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
@@ -119,55 +106,48 @@ class AccessHistoryManager:
|
||||
f"Invalid node_label: {node_label}. Must be one of {valid_labels}"
|
||||
)
|
||||
|
||||
# 使用乐观锁重试机制处理并发冲突
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# 步骤1:读取当前节点状态
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
raise ValueError(
|
||||
f"Node not found: {node_label} with id={node_id}"
|
||||
)
|
||||
|
||||
# 步骤2:计算新的访问历史和激活值
|
||||
update_data = await self._calculate_update(
|
||||
node_data=node_data,
|
||||
current_time=current_time,
|
||||
current_time_iso=current_time_iso
|
||||
try:
|
||||
# 步骤1:读取当前节点状态
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
raise ValueError(
|
||||
f"Node not found: {node_label} with id={node_id}"
|
||||
)
|
||||
|
||||
# 步骤3:原子性更新节点(使用事务)
|
||||
updated_node = await self._atomic_update(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"成功记录访问: {node_label}[{node_id}], "
|
||||
f"activation={update_data['activation_value']:.4f}, "
|
||||
f"access_count={update_data['access_count']}"
|
||||
)
|
||||
|
||||
return updated_node
|
||||
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.warning(
|
||||
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], "
|
||||
f"错误: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to record access after {self.max_retries} attempts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# 步骤2:计算新的访问历史和激活值
|
||||
update_data = await self._calculate_update(
|
||||
node_data=node_data,
|
||||
current_time=current_time,
|
||||
current_time_iso=current_time_iso,
|
||||
access_times=access_times
|
||||
)
|
||||
|
||||
# 步骤3:使用 APOC 原子操作更新节点(无需重试)
|
||||
updated_node = await self._atomic_update(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"成功记录访问: {node_label}[{node_id}], "
|
||||
f"activation={update_data['activation_value']:.4f}, "
|
||||
f"access_count={update_data['access_count']}"
|
||||
f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}"
|
||||
)
|
||||
|
||||
return updated_node
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"访问记录失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to record access: {str(e)}"
|
||||
) from e
|
||||
|
||||
async def record_batch_access(
|
||||
self,
|
||||
node_ids: List[str],
|
||||
@@ -178,11 +158,10 @@ class AccessHistoryManager:
|
||||
"""
|
||||
批量记录多个节点的访问
|
||||
|
||||
为提高性能,批量更新多个节点的访问历史。
|
||||
每个节点独立更新,失败的节点不影响其他节点。
|
||||
对同一个节点的多次访问会先在内存中合并,只发起一次更新。
|
||||
|
||||
Args:
|
||||
node_ids: 节点ID列表
|
||||
node_ids: 节点ID列表(可包含重复ID)
|
||||
node_label: 节点标签(所有节点必须是同一类型)
|
||||
end_user_id: 组ID(可选)
|
||||
current_time: 当前时间(可选)
|
||||
@@ -196,25 +175,38 @@ class AccessHistoryManager:
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially
|
||||
tasks = []
|
||||
# 合并同一节点的访问次数,避免对同一节点并发写入
|
||||
access_count_map: Dict[str, int] = {}
|
||||
for node_id in node_ids:
|
||||
access_count_map[node_id] = access_count_map.get(node_id, 0) + 1
|
||||
|
||||
merged_count = len(node_ids) - len(access_count_map)
|
||||
if merged_count > 0:
|
||||
logger.info(
|
||||
f"批量访问合并: 原始={len(node_ids)}, "
|
||||
f"去重后={len(access_count_map)}, 合并={merged_count}"
|
||||
)
|
||||
|
||||
# 对去重后的节点并行发起更新
|
||||
tasks = []
|
||||
for node_id, access_times in access_count_map.items():
|
||||
task = self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
end_user_id=end_user_id,
|
||||
current_time=current_time
|
||||
current_time=current_time,
|
||||
access_times=access_times
|
||||
)
|
||||
tasks.append(task)
|
||||
tasks.append((node_id, task))
|
||||
|
||||
# Execute all tasks in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
task_results = await asyncio.gather(
|
||||
*[t for _, t in tasks], return_exceptions=True
|
||||
)
|
||||
|
||||
# Collect successful results and count failures
|
||||
results = []
|
||||
failed_count = 0
|
||||
|
||||
for node_id, result in zip(node_ids, task_results):
|
||||
for (node_id, _), result in zip(tasks, task_results):
|
||||
if isinstance(result, Exception):
|
||||
failed_count += 1
|
||||
logger.warning(
|
||||
@@ -225,12 +217,12 @@ class AccessHistoryManager:
|
||||
|
||||
batch_duration = time.time() - batch_start
|
||||
logger.info(
|
||||
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
|
||||
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(access_count_map)}, "
|
||||
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def check_consistency(
|
||||
self,
|
||||
node_id: str,
|
||||
@@ -239,22 +231,6 @@ class AccessHistoryManager:
|
||||
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
"""
|
||||
检查节点数据的一致性
|
||||
|
||||
验证以下一致性规则:
|
||||
1. access_history[-1] == last_access_time
|
||||
2. len(access_history) == access_count
|
||||
3. 如果有访问历史,必须有激活值
|
||||
4. 激活值必须在有效范围内 [offset, 1.0]
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
- 一致性检查结果枚举
|
||||
- 错误描述(如果不一致)
|
||||
"""
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
@@ -266,7 +242,6 @@ class AccessHistoryManager:
|
||||
access_count = node_data.get('access_count', 0)
|
||||
activation_value = node_data.get('activation_value')
|
||||
|
||||
# 检查1:access_history[-1] == last_access_time
|
||||
if access_history and last_access_time:
|
||||
if access_history[-1] != last_access_time:
|
||||
return (
|
||||
@@ -275,7 +250,6 @@ class AccessHistoryManager:
|
||||
f"last_access_time={last_access_time}"
|
||||
)
|
||||
|
||||
# 检查2:len(access_history) == access_count
|
||||
if len(access_history) != access_count:
|
||||
return (
|
||||
ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT,
|
||||
@@ -283,14 +257,12 @@ class AccessHistoryManager:
|
||||
f"access_count={access_count}"
|
||||
)
|
||||
|
||||
# 检查3:有访问历史必须有激活值
|
||||
if access_history and activation_value is None:
|
||||
return (
|
||||
ConsistencyCheckResult.MISSING_ACTIVATION,
|
||||
"Node has access_history but activation_value is None"
|
||||
)
|
||||
|
||||
# 检查4:激活值范围
|
||||
if activation_value is not None:
|
||||
offset = self.actr_calculator.offset
|
||||
if not (offset <= activation_value <= 1.0):
|
||||
@@ -301,30 +273,14 @@ class AccessHistoryManager:
|
||||
)
|
||||
|
||||
return ConsistencyCheckResult.CONSISTENT, None
|
||||
|
||||
|
||||
async def check_batch_consistency(
|
||||
self,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
批量检查多个节点的一致性
|
||||
|
||||
Args:
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
limit: 检查的最大节点数
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 一致性检查报告,包含:
|
||||
- total_checked: 检查的节点总数
|
||||
- consistent_count: 一致的节点数
|
||||
- inconsistent_count: 不一致的节点数
|
||||
- inconsistencies: 不一致节点的详细信息列表
|
||||
- consistency_rate: 一致性率(0-1)
|
||||
"""
|
||||
# 查询所有相关节点
|
||||
"""批量检查多个节点的一致性"""
|
||||
query = f"""
|
||||
MATCH (n:{node_label})
|
||||
WHERE n.access_history IS NOT NULL
|
||||
@@ -343,7 +299,6 @@ class AccessHistoryManager:
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
node_ids = [r['id'] for r in results]
|
||||
|
||||
# 检查每个节点
|
||||
inconsistencies = []
|
||||
consistent_count = 0
|
||||
|
||||
@@ -382,32 +337,15 @@ class AccessHistoryManager:
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
async def repair_inconsistency(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
自动修复节点的数据不一致问题
|
||||
|
||||
修复策略:
|
||||
1. 如果access_history[-1] != last_access_time:使用access_history[-1]
|
||||
2. 如果len(access_history) != access_count:使用len(access_history)
|
||||
3. 如果有历史但无激活值:重新计算激活值
|
||||
4. 如果激活值超出范围:重新计算激活值
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
bool: 修复成功返回True,否则返回False
|
||||
"""
|
||||
"""自动修复节点的数据不一致问题"""
|
||||
try:
|
||||
# 检查一致性
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
@@ -418,7 +356,6 @@ class AccessHistoryManager:
|
||||
logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]")
|
||||
return True
|
||||
|
||||
# 获取节点数据
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
if not node_data:
|
||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||
@@ -427,17 +364,13 @@ class AccessHistoryManager:
|
||||
access_history = node_data.get('access_history') or []
|
||||
importance_score = node_data.get('importance_score', 0.5)
|
||||
|
||||
# 准备修复数据
|
||||
repair_data = {}
|
||||
|
||||
# 修复last_access_time
|
||||
if access_history:
|
||||
repair_data['last_access_time'] = access_history[-1]
|
||||
|
||||
# 修复access_count
|
||||
repair_data['access_count'] = len(access_history)
|
||||
|
||||
# 修复activation_value
|
||||
if access_history:
|
||||
current_time = datetime.now()
|
||||
last_access_dt = datetime.fromisoformat(access_history[-1])
|
||||
@@ -453,7 +386,6 @@ class AccessHistoryManager:
|
||||
)
|
||||
repair_data['activation_value'] = activation_value
|
||||
|
||||
# 执行修复
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
@@ -484,26 +416,16 @@ class AccessHistoryManager:
|
||||
f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
# ==================== 私有辅助方法 ====================
|
||||
|
||||
|
||||
async def _fetch_node(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取节点数据
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
||||
"""
|
||||
"""获取节点数据"""
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
@@ -527,12 +449,13 @@ class AccessHistoryManager:
|
||||
if results:
|
||||
return results[0]
|
||||
return None
|
||||
|
||||
|
||||
async def _calculate_update(
|
||||
self,
|
||||
node_data: Dict[str, Any],
|
||||
current_time: datetime,
|
||||
current_time_iso: str
|
||||
current_time_iso: str,
|
||||
access_times: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算更新数据
|
||||
@@ -541,45 +464,40 @@ class AccessHistoryManager:
|
||||
node_data: 当前节点数据
|
||||
current_time: 当前时间(datetime对象)
|
||||
current_time_iso: 当前时间(ISO格式字符串)
|
||||
access_times: 本次访问次数(合并后可能大于1)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新数据,包含所有需要更新的字段
|
||||
Dict[str, Any]: 更新数据
|
||||
"""
|
||||
access_history = node_data.get('access_history') or []
|
||||
# Handle None importance_score - default to 0.5
|
||||
importance_score = node_data.get('importance_score')
|
||||
if importance_score is None:
|
||||
importance_score = 0.5
|
||||
|
||||
# 追加新的访问时间
|
||||
new_access_history = access_history + [current_time_iso]
|
||||
# 本次新增的时间戳
|
||||
new_timestamps = [current_time_iso] * access_times
|
||||
|
||||
# 修剪访问历史(如果过长)
|
||||
access_history_dt = [
|
||||
datetime.fromisoformat(ts) for ts in new_access_history
|
||||
]
|
||||
# 仅用本次新增的访问记录计算激活值
|
||||
new_history_dt = [current_time] * access_times
|
||||
trimmed_history_dt = self.actr_calculator.trim_access_history(
|
||||
access_history=access_history_dt,
|
||||
access_history=new_history_dt,
|
||||
current_time=current_time
|
||||
)
|
||||
trimmed_history = [ts.isoformat() for ts in trimmed_history_dt]
|
||||
|
||||
# 计算新的激活值
|
||||
activation_value = self.actr_calculator.calculate_memory_activation(
|
||||
access_history=trimmed_history_dt,
|
||||
current_time=current_time,
|
||||
last_access_time=current_time, # 最后访问时间就是当前时间
|
||||
last_access_time=current_time,
|
||||
importance_score=importance_score
|
||||
)
|
||||
|
||||
# 返回所有需要更新的字段
|
||||
return {
|
||||
'activation_value': activation_value,
|
||||
'access_history': trimmed_history,
|
||||
'new_timestamps': new_timestamps,
|
||||
'access_count_delta': access_times,
|
||||
'access_count': len(trimmed_history_dt),
|
||||
'last_access_time': current_time_iso,
|
||||
'access_count': len(trimmed_history)
|
||||
}
|
||||
|
||||
|
||||
async def _atomic_update(
|
||||
self,
|
||||
node_id: str,
|
||||
@@ -588,10 +506,10 @@ class AccessHistoryManager:
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
原子性更新节点(使用乐观锁)
|
||||
原子性更新节点(使用 APOC 原子操作)
|
||||
|
||||
使用Neo4j事务和版本号确保所有字段同时更新或回滚。
|
||||
实现乐观锁机制防止并发冲突。
|
||||
使用 apoc.atomic.add 和 apoc.atomic.insert 保证并发安全,
|
||||
无需 version 字段和乐观锁,数据库层面保证原子性。
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
@@ -603,126 +521,68 @@ class AccessHistoryManager:
|
||||
Dict[str, Any]: 更新后的节点数据
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果更新失败或发生版本冲突
|
||||
RuntimeError: 如果更新失败
|
||||
"""
|
||||
# 定义事务函数
|
||||
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
|
||||
# 步骤1:读取当前节点并获取版本号
|
||||
read_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if end_user_id:
|
||||
read_query += " WHERE n.end_user_id = $end_user_id"
|
||||
read_query += """
|
||||
RETURN n.id as id,
|
||||
n.version as version,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score
|
||||
"""
|
||||
content_field_map = {
|
||||
'Statement': 'n.statement as statement',
|
||||
'MemorySummary': 'n.content as content',
|
||||
'ExtractedEntity': 'null as content_placeholder',
|
||||
'Community': 'n.summary as summary'
|
||||
}
|
||||
|
||||
if node_label not in content_field_map:
|
||||
raise ValueError(
|
||||
f"Unsupported node_label: {node_label}. "
|
||||
f"Supported labels are: {list(content_field_map.keys())}"
|
||||
)
|
||||
|
||||
content_field = content_field_map[node_label]
|
||||
|
||||
where_clause = ""
|
||||
if end_user_id:
|
||||
where_clause = " AND n.end_user_id = $end_user_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
WHERE true{where_clause}
|
||||
CALL apoc.atomic.add(n, 'access_count', $access_count_delta, 5) YIELD oldValue AS old_count
|
||||
WITH n
|
||||
CALL (n) {{
|
||||
UNWIND $new_timestamps AS ts
|
||||
CALL apoc.atomic.insert(n, 'access_history', size(n.access_history), ts, 5) YIELD oldValue
|
||||
RETURN count(*) AS inserted
|
||||
}}
|
||||
SET n.activation_value = $activation_value,
|
||||
n.last_access_time = $last_access_time
|
||||
RETURN n.id as id,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score,
|
||||
{content_field}
|
||||
"""
|
||||
|
||||
params = {
|
||||
'node_id': node_id,
|
||||
'access_count_delta': update_data['access_count_delta'],
|
||||
'new_timestamps': update_data['new_timestamps'],
|
||||
'activation_value': update_data['activation_value'],
|
||||
'last_access_time': update_data['last_access_time'],
|
||||
}
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
try:
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
read_params = {'node_id': node_id}
|
||||
if end_user_id:
|
||||
read_params['end_user_id'] = end_user_id
|
||||
|
||||
read_result = await tx.run(read_query, **read_params)
|
||||
current_node = await read_result.single()
|
||||
|
||||
if not current_node:
|
||||
if not results:
|
||||
raise RuntimeError(f"Node not found: {node_label}[{node_id}]")
|
||||
|
||||
# 获取当前版本号(如果不存在则为0)
|
||||
current_version = current_node.get('version', 0) or 0
|
||||
new_version = current_version + 1
|
||||
|
||||
# 步骤2:使用乐观锁更新节点
|
||||
# 根据节点类型构建完整的查询语句
|
||||
content_field_map = {
|
||||
'Statement': 'n.statement as statement',
|
||||
'MemorySummary': 'n.content as content',
|
||||
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤
|
||||
}
|
||||
|
||||
# 显式检查节点类型,不支持的类型抛出错误
|
||||
if node_label not in content_field_map:
|
||||
raise ValueError(
|
||||
f"Unsupported node_label: {node_label}. "
|
||||
f"Supported labels are: {list(content_field_map.keys())}"
|
||||
)
|
||||
|
||||
content_field = content_field_map[node_label]
|
||||
|
||||
# 构建 WHERE 子句
|
||||
where_conditions = []
|
||||
if end_user_id:
|
||||
where_conditions.append("n.end_user_id = $end_user_id")
|
||||
|
||||
# 添加版本检查
|
||||
if current_version > 0:
|
||||
where_conditions.append("n.version = $current_version")
|
||||
else:
|
||||
where_conditions.append("(n.version IS NULL OR n.version = 0)")
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "true"
|
||||
|
||||
# 构建完整的更新查询
|
||||
update_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
WHERE {where_clause}
|
||||
SET n.activation_value = $activation_value,
|
||||
n.access_history = $access_history,
|
||||
n.last_access_time = $last_access_time,
|
||||
n.access_count = $access_count,
|
||||
n.version = $new_version
|
||||
RETURN n.id as id,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score,
|
||||
n.version as version,
|
||||
{content_field}
|
||||
"""
|
||||
|
||||
update_params = {
|
||||
'node_id': node_id,
|
||||
'current_version': current_version,
|
||||
'new_version': new_version,
|
||||
'activation_value': update_data['activation_value'],
|
||||
'access_history': update_data['access_history'],
|
||||
'last_access_time': update_data['last_access_time'],
|
||||
'access_count': update_data['access_count']
|
||||
}
|
||||
if end_user_id:
|
||||
update_params['end_user_id'] = end_user_id
|
||||
|
||||
update_result = await tx.run(update_query, **update_params)
|
||||
updated_node = await update_result.single()
|
||||
|
||||
if not updated_node:
|
||||
raise RuntimeError(
|
||||
f"Version conflict detected for {node_label}[{node_id}]. "
|
||||
f"Expected version {current_version}, but node was modified by another transaction."
|
||||
)
|
||||
|
||||
# 转换为字典并移除占位符字段
|
||||
result_dict = dict(updated_node)
|
||||
result_dict = dict(results[0])
|
||||
result_dict.pop('content_placeholder', None)
|
||||
|
||||
return result_dict
|
||||
|
||||
# 执行事务
|
||||
try:
|
||||
result = await self.connector.execute_write_transaction(
|
||||
update_transaction,
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
|
||||
@@ -4,11 +4,6 @@
|
||||
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.search_strategy import (
|
||||
@@ -29,115 +24,87 @@ __all__ = [
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 向后兼容的函数式API
|
||||
# 向后兼容的函数式API (DEPRECATED - 未被使用)
|
||||
# ============================================================================
|
||||
# 为了兼容旧代码,提供与 src/search.py 相同的函数式接口
|
||||
# 所有调用方均直接使用 app.core.memory.src.search.run_hybrid_search
|
||||
# 保留注释以备参考
|
||||
|
||||
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str = "hybrid",
|
||||
end_user_id: str | None = None,
|
||||
apply_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
include: list[str] | None = None,
|
||||
alpha: float = 0.6,
|
||||
use_forgetting_curve: bool = False,
|
||||
memory_config: "MemoryConfig" = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""运行混合搜索(向后兼容的函数式API)
|
||||
|
||||
这是一个向后兼容的包装函数,将旧的函数式API转换为新的基于类的API。
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
search_type: 搜索类型("hybrid", "keyword", "semantic")
|
||||
end_user_id: 组ID过滤
|
||||
apply_id: 应用ID过滤
|
||||
user_id: 用户ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
alpha: BM25分数权重(0.0-1.0)
|
||||
use_forgetting_curve: 是否使用遗忘曲线
|
||||
memory_config: MemoryConfig object containing embedding_model_id
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
dict: 搜索结果字典,格式与旧API兼容
|
||||
"""
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
if not memory_config:
|
||||
raise ValueError("memory_config is required for search")
|
||||
|
||||
# 初始化客户端
|
||||
connector = Neo4jConnector()
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
try:
|
||||
# 根据搜索类型选择策略
|
||||
if search_type == "keyword":
|
||||
strategy = KeywordSearchStrategy(connector=connector)
|
||||
elif search_type == "semantic":
|
||||
strategy = SemanticSearchStrategy(
|
||||
connector=connector,
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
else: # hybrid
|
||||
strategy = HybridSearchStrategy(
|
||||
connector=connector,
|
||||
embedder_client=embedder_client,
|
||||
alpha=alpha,
|
||||
use_forgetting_curve=use_forgetting_curve
|
||||
)
|
||||
|
||||
# 执行搜索
|
||||
result = await strategy.search(
|
||||
query_text=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
alpha=alpha,
|
||||
use_forgetting_curve=use_forgetting_curve,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 转换为旧格式
|
||||
result_dict = result.to_dict()
|
||||
|
||||
# 保存到文件(如果指定了output_path)
|
||||
output_path = kwargs.get('output_path', 'search_results.json')
|
||||
if output_path:
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
# 确保目录存在
|
||||
out_dir = os.path.dirname(output_path)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# 保存结果
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str)
|
||||
print(f"Search results saved to {output_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving search results: {e}")
|
||||
return result_dict
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
__all__.append("run_hybrid_search")
|
||||
# async def run_hybrid_search(
|
||||
# query_text: str,
|
||||
# search_type: str = "hybrid",
|
||||
# end_user_id: str | None = None,
|
||||
# apply_id: str | None = None,
|
||||
# user_id: str | None = None,
|
||||
# limit: int = 50,
|
||||
# include: list[str] | None = None,
|
||||
# alpha: float = 0.6,
|
||||
# use_forgetting_curve: bool = False,
|
||||
# memory_config: "MemoryConfig" = None,
|
||||
# **kwargs
|
||||
# ) -> dict:
|
||||
# """运行混合搜索(向后兼容的函数式API)"""
|
||||
# from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
# from app.core.models.base import RedBearModelConfig
|
||||
# from app.db import get_db_context
|
||||
# from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
#
|
||||
# if not memory_config:
|
||||
# raise ValueError("memory_config is required for search")
|
||||
#
|
||||
# connector = Neo4jConnector()
|
||||
# with get_db_context() as db:
|
||||
# config_service = MemoryConfigService(db)
|
||||
# embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
# embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
# embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
#
|
||||
# try:
|
||||
# if search_type == "keyword":
|
||||
# strategy = KeywordSearchStrategy(connector=connector)
|
||||
# elif search_type == "semantic":
|
||||
# strategy = SemanticSearchStrategy(
|
||||
# connector=connector,
|
||||
# embedder_client=embedder_client
|
||||
# )
|
||||
# else:
|
||||
# strategy = HybridSearchStrategy(
|
||||
# connector=connector,
|
||||
# embedder_client=embedder_client,
|
||||
# alpha=alpha,
|
||||
# use_forgetting_curve=use_forgetting_curve
|
||||
# )
|
||||
#
|
||||
# result = await strategy.search(
|
||||
# query_text=query_text,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include,
|
||||
# alpha=alpha,
|
||||
# use_forgetting_curve=use_forgetting_curve,
|
||||
# **kwargs
|
||||
# )
|
||||
#
|
||||
# result_dict = result.to_dict()
|
||||
#
|
||||
# output_path = kwargs.get('output_path', 'search_results.json')
|
||||
# if output_path:
|
||||
# import json
|
||||
# import os
|
||||
# from datetime import datetime
|
||||
#
|
||||
# try:
|
||||
# out_dir = os.path.dirname(output_path)
|
||||
# if out_dir:
|
||||
# os.makedirs(out_dir, exist_ok=True)
|
||||
# with open(output_path, "w", encoding="utf-8") as f:
|
||||
# json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str)
|
||||
# print(f"Search results saved to {output_path}")
|
||||
# except Exception as e:
|
||||
# print(f"Error saving search results: {e}")
|
||||
# return result_dict
|
||||
#
|
||||
# finally:
|
||||
# await connector.close()
|
||||
#
|
||||
# __all__.append("run_hybrid_search")
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
使用Neo4j的全文索引进行高效的文本匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
@@ -74,7 +74,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
# 调用底层的关键词搜索函数
|
||||
results_dict = await search_graph(
|
||||
connector=self.connector,
|
||||
q=query_text,
|
||||
query=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
|
||||
@@ -22,7 +22,9 @@ def escape_lucene_query(query: str) -> str:
|
||||
s = s.replace("\r", " ").replace("\n", " ").strip()
|
||||
|
||||
# Lucene reserved tokens/special characters
|
||||
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':']
|
||||
# NOTE: '/' is the regex delimiter in Lucene — must be escaped to prevent
|
||||
# TokenMgrError when the query contains unmatched slashes.
|
||||
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':', '/']
|
||||
# Replace longer tokens first to avoid partial double-escaping
|
||||
for token in sorted(specials, key=len, reverse=True):
|
||||
s = s.replace(token, f"\\{token}")
|
||||
|
||||
@@ -43,8 +43,9 @@ Each statement must be labeled as per the criteria mentioned below.
|
||||
|
||||
对话上下文和共指消解:
|
||||
- 将每个陈述句归属于说出它的参与者。
|
||||
- 如果参与者列表为说话者提供了名称(例如,"李雪(用户)"),请在提取的陈述句中使用具体名称("李雪"),而不是通用角色("用户")。
|
||||
- 将所有代词解析为对话上下文中的具体人物或实体。
|
||||
- **对于用户的发言:必须使用"用户"作为主语**,禁止将"用户"或"我"替换为用户的真实姓名或别名。例如,用户说"我叫张三"应提取为"用户叫张三",而不是"张三叫张三"。
|
||||
- 对于 AI 助手的发言:使用"助手"或"AI助手"作为主语。
|
||||
- 将所有代词解析为对话上下文中的具体人物或实体,但"我"必须解析为"用户"。
|
||||
- 识别并将抽象引用解析为其具体名称(如果提到)。
|
||||
- 将缩写和首字母缩略词扩展为其完整形式。
|
||||
{% else %}
|
||||
@@ -68,8 +69,9 @@ Context Resolution Requirements:
|
||||
|
||||
Conversational Context & Co-reference Resolution:
|
||||
- Attribute every statement to the participant who uttered it.
|
||||
- If the participant list provides a name for a speaker (e.g., "李雪 (用户)"), use the specific name ("李雪") in the extracted statement, not the generic role ("用户").
|
||||
- Resolve all pronouns to the specific person or entity from the conversation's context.
|
||||
- **For user's statements: always use "用户" (User) as the subject**. Do NOT replace "用户" or "I" with the user's real name or alias. For example, if the user says "I'm John", extract as "用户 is John", not "John is John".
|
||||
- For AI assistant's statements: use "助手" or "AI助手" as the subject.
|
||||
- Resolve all pronouns to the specific person or entity from the conversation's context, but "I"/"我" must always resolve to "用户".
|
||||
- Identify and resolve abstract references to their specific names if mentioned.
|
||||
- Expand abbreviations and acronyms to their full form.
|
||||
{% endif %}
|
||||
@@ -139,13 +141,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合
|
||||
示例输出: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "Sarah Chen 最近一直在尝试水彩画。",
|
||||
"statement": "用户最近一直在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 画了一些花朵。",
|
||||
"statement": "用户画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -157,13 +159,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 认为她的水彩画中的色彩组合可以改进。",
|
||||
"statement": "用户认为她的水彩画中的色彩组合可以改进。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 真的很喜欢玫瑰和百合。",
|
||||
"statement": "用户真的很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -186,13 +188,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
|
||||
示例输出: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "张曼婷最近在尝试水彩画。",
|
||||
"statement": "用户最近在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷画了一些花朵。",
|
||||
"statement": "用户画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -204,13 +206,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement": "用户觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷很喜欢玫瑰和百合。",
|
||||
"statement": "用户很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -233,13 +235,13 @@ User: "I think the color combinations could use some improvement, but I really l
|
||||
Example Output: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "Sarah Chen has been trying watercolor painting recently.",
|
||||
"statement": "用户 has been trying watercolor painting recently.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen painted some flowers.",
|
||||
"statement": "用户 painted some flowers.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -251,13 +253,13 @@ Example Output: {
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen thinks the color combinations in her watercolor paintings could use some improvement.",
|
||||
"statement": "用户 thinks the color combinations in her watercolor paintings could use some improvement.",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen really likes roses and lilies.",
|
||||
"statement": "用户 really likes roses and lilies.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -280,13 +282,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
|
||||
Example Output: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "张曼婷最近在尝试水彩画。",
|
||||
"statement": "用户最近在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷画了一些花朵。",
|
||||
"statement": "用户画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -298,13 +300,13 @@ Example Output: {
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement": "用户觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷很喜欢玫瑰和百合。",
|
||||
"statement": "用户很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
|
||||
@@ -406,4 +406,12 @@ Output:
|
||||
- **⚠️ ALIASES ORDER: preserve temporal order of appearance**
|
||||
- **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []**
|
||||
|
||||
**Output JSON structure:**
|
||||
```json
|
||||
{
|
||||
"triplets": [...],
|
||||
"entities": [...]
|
||||
}
|
||||
```
|
||||
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
===Task===
|
||||
Extract user metadata changes from the following conversation statements spoken by the user.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**"三度原则"判断标准:**
|
||||
- 复用度:该信息是否会被多个功能模块使用?
|
||||
- 约束度:该信息是否会影响系统行为?
|
||||
- 时效性:该信息是长期稳定的还是临时的?仅提取长期稳定信息。
|
||||
|
||||
**提取规则:**
|
||||
- **只提取关于"用户本人"的画像信息**,忽略用户提到的第三方人物(如朋友、同事、家人)的信息
|
||||
- 仅提取文本中明确提到的信息,不要推测
|
||||
- **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值)
|
||||
|
||||
**增量模式(重要):**
|
||||
你只需要输出**本次对话引起的变更操作**,不要输出完整的元数据。每个变更是一个对象,包含:
|
||||
- `field_path`:字段路径,用点号分隔(如 `profile.role`、`profile.expertise`)
|
||||
- `action`:操作类型
|
||||
* `set`:新增或修改一个字段的值
|
||||
* `remove`:移除一个字段的值
|
||||
- `value`:字段的新值(`action="set"` 时必填,`action="remove"` 时填要移除的元素值)
|
||||
* 所有字段均为列表类型,每个元素一条变更记录
|
||||
|
||||
**判断规则:**
|
||||
- 用户提到新信息 → `action="set"`,填入新值
|
||||
- 用户明确否定已有信息(如"我不再做老师了"、"我已经不学Python了")→ `action="remove"`,`value` 填要移除的元素值
|
||||
- 如果本次对话没有任何可提取的变更,返回空的 `metadata_changes` 数组 `[]`
|
||||
- **不要为未被提及的字段生成任何变更操作**
|
||||
|
||||
{% if existing_metadata %}
|
||||
**已有元数据(仅供参考,用于判断是否需要变更):**
|
||||
请对比已有数据和用户最新发言,只输出差异部分的变更操作。
|
||||
- 如果用户说的信息和已有数据一致,不需要输出变更
|
||||
- 如果用户否定了已有数据中的某个值,输出 `remove` 操作
|
||||
- 如果用户提到了新信息,输出 `set` 操作
|
||||
{% endif %}
|
||||
|
||||
**字段说明:**
|
||||
- profile.role:用户的职业或角色(列表),如 教师、医生、后端工程师,一个人可以有多个角色
|
||||
- profile.domain:用户所在领域(列表),如 教育、医疗、软件开发,一个人可以涉及多个领域
|
||||
- profile.expertise:用户擅长的技能或工具(列表),如 Python、心理咨询、高中物理
|
||||
- profile.interests:用户主动表达兴趣的话题或领域标签(列表)
|
||||
|
||||
**用户别名变更(增量模式):**
|
||||
- **aliases_to_add**:本次新发现的用户别名,包括:
|
||||
* 用户主动自我介绍:如"我叫张三"、"我的名字是XX"、"我的网名是XX"
|
||||
* 他人对用户的称呼:如"同事叫我陈哥"、"大家叫我小张"、"领导叫我老陈"
|
||||
* 只提取原文中逐字出现的名字,严禁推测或创造
|
||||
* 禁止提取:用户给 AI 取的名字、第三方人物自身的名字、"用户"/"我" 等占位词
|
||||
* 如果没有新别名,返回空数组 `[]`
|
||||
- **aliases_to_remove**:用户明确否认的别名,包括:
|
||||
* 用户说"我不叫XX了"、"别叫我XX"、"我改名了,不叫XX" → 将 XX 放入此数组
|
||||
* **严格限制**:只将用户原文中**逐字提到**的被否认名字放入,不要推断关联的其他别名
|
||||
* 如果没有要移除的别名,返回空数组 `[]`
|
||||
{% if existing_aliases %}
|
||||
- 已有别名:{{ existing_aliases | tojson }}(仅供参考,不需要在输出中重复)
|
||||
{% endif %}
|
||||
{% else %}
|
||||
**"Three-Degree Principle" criteria:**
|
||||
- Reusability: Will this information be used by multiple functional modules?
|
||||
- Constraint: Will this information affect system behavior?
|
||||
- Timeliness: Is this information long-term stable or temporary? Only extract long-term stable information.
|
||||
|
||||
**Extraction rules:**
|
||||
- **Only extract profile information about the user themselves**, ignore information about third parties (friends, colleagues, family) mentioned by the user
|
||||
- Only extract information explicitly mentioned in the text, do not speculate
|
||||
- **Output language must match the input text language**
|
||||
|
||||
**Incremental mode (important):**
|
||||
You should only output **the change operations caused by this conversation**, not the complete metadata. Each change is an object containing:
|
||||
- `field_path`: Field path separated by dots (e.g. `profile.role`, `profile.expertise`)
|
||||
- `action`: Operation type
|
||||
* `set`: Add or update a field value
|
||||
* `remove`: Remove a field value
|
||||
- `value`: The new value for the field (required when `action="set"`, for `action="remove"` fill in the element value to remove)
|
||||
* All fields are list types, one change record per element
|
||||
|
||||
**Decision rules:**
|
||||
- User mentions new information → `action="set"`, fill in the new value
|
||||
- User explicitly negates existing info (e.g. "I'm no longer a teacher", "I stopped learning Python") → `action="remove"`, `value` is the element to remove
|
||||
- If this conversation has no extractable changes, return an empty `metadata_changes` array `[]`
|
||||
- **Do NOT generate any change operations for fields not mentioned in the conversation**
|
||||
|
||||
{% if existing_metadata %}
|
||||
**Existing metadata (for reference only, to determine if changes are needed):**
|
||||
Compare existing data with the user's latest statements, and only output change operations for the differences.
|
||||
- If the user's statement matches existing data, no change is needed
|
||||
- If the user negates a value in existing data, output a `remove` operation
|
||||
- If the user mentions new information, output a `set` operation
|
||||
{% endif %}
|
||||
|
||||
**Field descriptions:**
|
||||
- profile.role: User's occupation or role (list), e.g. teacher, doctor, software engineer. A person can have multiple roles
|
||||
- profile.domain: User's domain (list), e.g. education, healthcare, software development. A person can span multiple domains
|
||||
- profile.expertise: User's skills or tools (list), e.g. Python, counseling, physics
|
||||
- profile.interests: Topics or domain tags the user actively expressed interest in (list)
|
||||
|
||||
**User alias changes (incremental mode):**
|
||||
- **aliases_to_add**: Newly discovered user aliases from this conversation, including:
|
||||
* User self-introductions: e.g. "I'm John", "My name is XX", "My username is XX"
|
||||
* How others address the user: e.g. "My colleagues call me Johnny", "People call me Mike"
|
||||
* Only extract names that appear VERBATIM in the text — never infer or fabricate
|
||||
* Do NOT extract: names the user gives to the AI, third-party people's own names, placeholder words like "User"/"I"
|
||||
* If no new aliases, return empty array `[]`
|
||||
- **aliases_to_remove**: Aliases the user explicitly denies, including:
|
||||
* User says "Don't call me XX anymore", "I'm not called XX", "I changed my name from XX" → put XX in this array
|
||||
* **Strict rule**: Only include the exact name the user **verbatim mentions** as denied. Do NOT infer or remove related aliases
|
||||
* If no aliases to remove, return empty array `[]`
|
||||
{% if existing_aliases %}
|
||||
- Existing aliases: {{ existing_aliases | tojson }} (for reference only, do not repeat in output)
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
===User Statements===
|
||||
{% for stmt in statements %}
|
||||
- {{ stmt }}
|
||||
{% endfor %}
|
||||
|
||||
{% if existing_metadata %}
|
||||
===Existing User Metadata===
|
||||
```json
|
||||
{{ existing_metadata | tojson }}
|
||||
```
|
||||
{% endif %}
|
||||
|
||||
===Output Format===
|
||||
Return a JSON object with the following structure:
|
||||
```json
|
||||
{
|
||||
"metadata_changes": [
|
||||
{"field_path": "profile.role", "action": "set", "value": "后端工程师"},
|
||||
{"field_path": "profile.expertise", "action": "set", "value": "Python"},
|
||||
{"field_path": "profile.expertise", "action": "remove", "value": "Java"}
|
||||
],
|
||||
"aliases_to_add": [],
|
||||
"aliases_to_remove": []
|
||||
}
|
||||
```
|
||||
|
||||
{{ json_schema }}
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
from typing import Any, Dict, List, Optional, TypeVar
|
||||
|
||||
from langchain_aws import ChatBedrock
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
@@ -9,12 +9,12 @@ from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.core.models.volcano_chat import VolcanoChatOpenAI
|
||||
from app.core.models.compatible_chat import CompatibleChatOpenAI
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -25,10 +25,11 @@ class RedBearModelConfig(BaseModel):
|
||||
provider: str
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
capability: List[str] = Field(default_factory=list) # 模型能力列表,驱动所有能力开关
|
||||
is_omni: bool = False # 是否为 Omni 模型
|
||||
deep_thinking: bool = False # 是否启用深度思考模式
|
||||
thinking_budget_tokens: Optional[int] = None # 深度思考 token 预算
|
||||
support_thinking: bool = False # 模型是否支持 enable_thinking 参数(capability 含 thinking)
|
||||
json_output: bool = False # 是否强制 JSON 输出
|
||||
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||
@@ -36,6 +37,23 @@ class RedBearModelConfig(BaseModel):
|
||||
concurrency: int = 5 # 并发限流
|
||||
extra_params: Dict[str, Any] = {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _resolve_capabilities(self) -> "RedBearModelConfig":
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
if self.deep_thinking and "thinking" not in self.capability:
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking"
|
||||
)
|
||||
self.deep_thinking = False
|
||||
self.thinking_budget_tokens = None
|
||||
if self.json_output and "json_output" not in self.capability:
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 不支持 JSON 输出(capability 中无 'json_output'),已自动关闭 json_output"
|
||||
)
|
||||
self.json_output = False
|
||||
return self
|
||||
|
||||
|
||||
class RedBearModelFactory:
|
||||
"""模型工厂类"""
|
||||
@@ -74,18 +92,19 @@ class RedBearModelFactory:
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
if is_streaming:
|
||||
params["stream_usage"] = True
|
||||
# 只有支持 thinking 的模型才传 enable_thinking
|
||||
if config.support_thinking:
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
if is_streaming:
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking:
|
||||
model_kwargs["incremental_output"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
else:
|
||||
model_kwargs["enable_thinking"] = False
|
||||
params["model_kwargs"] = model_kwargs
|
||||
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||
if "thinking" in config.capability:
|
||||
extra_body = params.setdefault("extra_body", {})
|
||||
if config.deep_thinking:
|
||||
extra_body["enable_thinking"] = False
|
||||
if is_streaming:
|
||||
extra_body["enable_thinking"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
||||
# JSON 输出模式
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||
@@ -108,26 +127,31 @@ class RedBearModelFactory:
|
||||
**config.extra_params
|
||||
}
|
||||
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||
if config.extra_params.get("streaming"):
|
||||
params["stream_usage"] = True
|
||||
# 深度思考模式
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
if is_streaming and not config.is_omni:
|
||||
if is_streaming:
|
||||
params["stream_usage"] = True
|
||||
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||
if "thinking" in config.capability:
|
||||
# VOLCANO 深度思考仅流式支持
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
# 火山引擎深度思考仅流式调用支持,非流式时不传 thinking 参数
|
||||
thinking_config: Dict[str, Any] = {
|
||||
"type": "enabled" if config.deep_thinking else "disabled"
|
||||
}
|
||||
thinking_config: Dict[str, Any] = {"type": "enabled" if config.deep_thinking else "disabled"}
|
||||
if config.deep_thinking and config.thinking_budget_tokens:
|
||||
thinking_config["budget_tokens"] = config.thinking_budget_tokens
|
||||
params["extra_body"] = {"thinking": thinking_config}
|
||||
else:
|
||||
# 始终显式传递 enable_thinking,不支持该参数的模型(如 DeepSeek-R1)会直接忽略
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking and config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
params["model_kwargs"] = model_kwargs
|
||||
extra_body = params.setdefault("extra_body", {})
|
||||
if config.deep_thinking:
|
||||
extra_body["enable_thinking"] = False
|
||||
if is_streaming:
|
||||
extra_body["enable_thinking"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
||||
# JSON 输出模式
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
# VOLCANO 模型不支持 response_format,JSON 输出由 system prompt 注入实现
|
||||
if provider != ModelProvider.VOLCANO:
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
params = {
|
||||
@@ -136,19 +160,20 @@ class RedBearModelFactory:
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 只有支持 thinking 的模型才传 enable_thinking
|
||||
if config.support_thinking:
|
||||
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||
if "thinking" in config.capability:
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
if is_streaming:
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking:
|
||||
model_kwargs["incremental_output"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
else:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
if config.deep_thinking:
|
||||
model_kwargs["enable_thinking"] = False
|
||||
params["model_kwargs"] = model_kwargs
|
||||
if is_streaming:
|
||||
model_kwargs["enable_thinking"] = True
|
||||
model_kwargs["incremental_output"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
# Bedrock 使用 AWS 凭证
|
||||
@@ -195,6 +220,10 @@ class RedBearModelFactory:
|
||||
params["additional_model_request_fields"] = {
|
||||
"thinking": {"type": "enabled", "budget_tokens": budget}
|
||||
}
|
||||
# JSON 输出模式
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
@@ -206,10 +235,15 @@ class RedBearModelFactory:
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
return {
|
||||
"model": config.model_name,
|
||||
# "base_url": config.base_url,
|
||||
"jina_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
return {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@@ -218,18 +252,19 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
# dashscope的omni模型 和 volcano模型使用
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
return ChatOpenAI
|
||||
return CompatibleChatOpenAI
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
return VolcanoChatOpenAI
|
||||
return CompatibleChatOpenAI
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
if type == ModelType.LLM:
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
return ChatOpenAI
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
return CompatibleChatOpenAI
|
||||
# if type == ModelType.LLM:
|
||||
# return OpenAI
|
||||
# elif type == ModelType.CHAT:
|
||||
# return CompatibleChatOpenAI
|
||||
# else:
|
||||
# raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
@@ -265,6 +300,9 @@ def get_provider_rerank_class(provider: str):
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
return JinaRerank
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank
|
||||
return DashScopeRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
# from langchain_ollama import OllamaEmbeddings
|
||||
# return OllamaEmbeddings
|
||||
|
||||
@@ -8,12 +8,33 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class VolcanoChatOpenAI(ChatOpenAI):
|
||||
"""火山引擎 Chat 模型,支持深度思考内容(reasoning_content)的流式和非流式透传。"""
|
||||
class CompatibleChatOpenAI(ChatOpenAI):
|
||||
"""火山和千问的omni兼容模型,支持深度思考内容(reasoning_content)的流式和非流式透传。
|
||||
|
||||
同时修复 json_output + tools 同时使用时 langchain_openai 强制走 .parse()/.stream()
|
||||
导致 strict 校验报错的问题:有工具时从 payload 中移除 response_format,
|
||||
让父类走普通 .create()/.astream() 路径,JSON 输出由 system prompt 指令保证。
|
||||
"""
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: list[BaseMessage],
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
# 有工具时 langchain_openai 检测到 response_format 会切换到 .parse()/.stream()
|
||||
# 接口,OpenAI SDK 要求此时所有工具必须 strict=True,动态生成的工具不满足。
|
||||
# 移除 response_format,让父类走普通路径,JSON 输出由 system prompt 指令保证。
|
||||
if payload.get("tools") and "response_format" in payload:
|
||||
payload.pop("response_format")
|
||||
return payload
|
||||
|
||||
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
|
||||
result = super()._create_chat_result(response, generation_info)
|
||||
@@ -36,9 +36,7 @@ class RedBearEmbeddings(Embeddings):
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
|
||||
"max_retries": config.max_retries,
|
||||
"check_embedding_ctx_length": False,
|
||||
"encoding_format": "float"
|
||||
"max_retries": config.max_retries
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
params = {
|
||||
|
||||
@@ -76,5 +76,9 @@ class RedBearRerank(BaseDocumentCompressor):
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
model_instance: JinaRerank = self._model
|
||||
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank
|
||||
model_instance: DashScopeRerank = self._model
|
||||
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型提供商: {provider}")
|
||||
|
||||
@@ -6,7 +6,8 @@ models:
|
||||
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -20,6 +21,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -38,6 +40,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -54,7 +57,8 @@ models:
|
||||
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -72,6 +76,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -87,7 +92,8 @@ models:
|
||||
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -101,7 +107,8 @@ models:
|
||||
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -115,7 +122,8 @@ models:
|
||||
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -130,7 +138,8 @@ models:
|
||||
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
|
||||
@@ -8,6 +8,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -22,6 +23,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -36,6 +38,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -48,7 +51,8 @@ models:
|
||||
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -61,7 +65,8 @@ models:
|
||||
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -74,7 +79,8 @@ models:
|
||||
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -87,7 +93,8 @@ models:
|
||||
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -100,7 +107,8 @@ models:
|
||||
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -115,7 +123,8 @@ models:
|
||||
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -133,6 +142,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -150,6 +160,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -180,6 +191,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -210,7 +222,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -376,6 +388,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -448,6 +461,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -466,6 +480,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -481,7 +496,8 @@ models:
|
||||
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -498,6 +514,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -513,7 +530,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -530,6 +547,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -546,6 +564,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -561,7 +580,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -578,6 +597,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -594,6 +614,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -610,6 +631,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -626,6 +648,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -641,7 +664,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -656,7 +679,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -672,6 +695,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -687,6 +711,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -702,6 +727,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -719,6 +745,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -736,6 +763,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -752,6 +780,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -768,7 +797,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -785,6 +814,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -803,6 +833,8 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- audio
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -822,7 +854,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -844,6 +876,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -864,7 +897,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -886,6 +919,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -907,6 +941,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -928,6 +963,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -947,6 +983,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -964,6 +1001,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -979,6 +1017,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -994,6 +1033,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
|
||||
@@ -10,6 +10,7 @@ models:
|
||||
- vision
|
||||
- audio
|
||||
- video
|
||||
- json_output
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -27,7 +28,8 @@ models:
|
||||
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -42,7 +44,8 @@ models:
|
||||
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -57,7 +60,8 @@ models:
|
||||
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -84,7 +88,8 @@ models:
|
||||
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -99,7 +104,8 @@ models:
|
||||
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -114,7 +120,8 @@ models:
|
||||
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -131,6 +138,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -146,7 +154,8 @@ models:
|
||||
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -163,6 +172,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -194,6 +204,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -213,6 +224,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -231,6 +243,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -248,6 +261,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -266,6 +280,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -284,6 +299,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -302,6 +318,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -321,6 +338,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -340,6 +358,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
|
||||
@@ -11,6 +11,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -26,6 +27,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -41,6 +43,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -56,6 +59,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -72,6 +76,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -87,6 +92,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -102,6 +108,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -117,6 +124,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -132,6 +140,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -148,6 +157,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -175,7 +185,8 @@ models:
|
||||
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -187,7 +198,8 @@ models:
|
||||
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
|
||||
791
api/app/core/quota_manager.py
Normal file
791
api/app/core/quota_manager.py
Normal file
@@ -0,0 +1,791 @@
|
||||
"""
|
||||
统一配额管理器 - 社区版和 SaaS 版共用
|
||||
|
||||
配额来源策略:
|
||||
1. 优先从 premium 模块的 tenant_subscriptions 表读取(SaaS 版)
|
||||
2. 降级到 default_free_plan.py 配置文件(社区版兜底)
|
||||
"""
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import Optional, Callable, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_auth_logger
|
||||
from app.i18n.exceptions import QuotaExceededError, InternalServerError
|
||||
|
||||
logger = get_auth_logger()
|
||||
|
||||
# Redis key 格式常量,与 RateLimiterService.check_qps 保持一致(per api_key 独立计数)
|
||||
API_KEY_QPS_REDIS_KEY = "rate_limit:qps:{api_key_id}"
|
||||
|
||||
|
||||
def _get_user_from_kwargs(kwargs: dict):
|
||||
"""从 kwargs 中获取 user 对象"""
|
||||
for key in ["user", "current_user"]:
|
||||
if key in kwargs:
|
||||
return kwargs[key]
|
||||
return None
|
||||
|
||||
|
||||
def _get_workspace_id_from_kwargs(kwargs: dict):
|
||||
"""从 kwargs 中获取 workspace_id"""
|
||||
# 优先从 kwargs['workspace_id'] 获取
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
|
||||
# 从 api_key_auth.workspace_id 获取(API Key 认证场景)
|
||||
api_key_auth = kwargs.get("api_key_auth")
|
||||
if api_key_auth and hasattr(api_key_auth, 'workspace_id'):
|
||||
return api_key_auth.workspace_id
|
||||
|
||||
# 从 user.current_workspace_id 获取
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if user:
|
||||
ws_id = getattr(user, 'current_workspace_id', None)
|
||||
if ws_id:
|
||||
return ws_id
|
||||
|
||||
logger.warning(f"无法获取 workspace_id, kwargs keys: {list(kwargs.keys())}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_tenant_id_from_kwargs(db: Session, kwargs: dict):
|
||||
"""从 kwargs 中获取 tenant_id"""
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if user and hasattr(user, 'tenant_id'):
|
||||
return user.tenant_id
|
||||
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
if workspace_id:
|
||||
from app.models.workspace_model import Workspace
|
||||
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if workspace:
|
||||
return workspace.tenant_id
|
||||
|
||||
api_key_auth = kwargs.get("api_key_auth")
|
||||
if api_key_auth and hasattr(api_key_auth, 'workspace_id'):
|
||||
from app.models.workspace_model import Workspace
|
||||
workspace = db.query(Workspace).filter(Workspace.id == api_key_auth.workspace_id).first()
|
||||
if workspace:
|
||||
return workspace.tenant_id
|
||||
|
||||
data = kwargs.get("data") or kwargs.get("body") or kwargs.get("payload")
|
||||
if data and hasattr(data, "workspace_id"):
|
||||
from app.models.workspace_model import Workspace
|
||||
workspace = db.query(Workspace).filter(Workspace.id == data.workspace_id).first()
|
||||
if workspace:
|
||||
return workspace.tenant_id
|
||||
|
||||
share_data = kwargs.get("share_data")
|
||||
if share_data and hasattr(share_data, 'share_token'):
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.models.app_model import App
|
||||
share_token = share_data.share_token
|
||||
from app.models.release_share_model import ReleaseShare
|
||||
share_record = db.query(ReleaseShare).filter(ReleaseShare.share_token == share_token).first()
|
||||
if share_record:
|
||||
app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first()
|
||||
if app:
|
||||
workspace = db.query(Workspace).filter(Workspace.id == app.workspace_id).first()
|
||||
if workspace:
|
||||
return workspace.tenant_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取租户的配额配置
|
||||
|
||||
优先级:
|
||||
1. premium 模块的 tenant_subscriptions(SaaS 版)
|
||||
2. default_free_plan.py 配置文件(社区版兜底)
|
||||
"""
|
||||
# 尝试从 premium 模块获取(SaaS 版)
|
||||
try:
|
||||
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||
# premium 模块存在,运行时错误不应被静默降级,直接抛出
|
||||
quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id)
|
||||
if quota_config:
|
||||
logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置")
|
||||
return quota_config
|
||||
# premium 存在但该租户无订阅记录,降级到免费套餐
|
||||
logger.debug(f"租户 {tenant_id} 无 premium 订阅,降级到免费套餐")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# 社区版:premium 包不存在,正常降级
|
||||
logger.debug("premium 模块不存在,使用社区版免费套餐配额")
|
||||
|
||||
# 降级到社区版配置文件
|
||||
try:
|
||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||
logger.debug(f"使用社区版免费套餐配额: tenant={tenant_id}")
|
||||
return DEFAULT_FREE_PLAN.get("quotas")
|
||||
except Exception as e:
|
||||
logger.error(f"无法从配置文件获取配额: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_api_ops_rate_limit(db: Session, tenant_id: UUID) -> Optional[int]:
|
||||
"""
|
||||
获取租户套餐的 API 操作速率限制(QPS 上限)
|
||||
|
||||
该函数兼容社区版和 SaaS 版:
|
||||
- SaaS 版:从 premium 模块的套餐配额读取
|
||||
- 社区版:从 default_free_plan.py 配置文件读取
|
||||
|
||||
Returns:
|
||||
int: api_ops_rate_limit 值,如果未配置则返回 None
|
||||
"""
|
||||
quota_config = _get_quota_config(db, tenant_id)
|
||||
if quota_config:
|
||||
return quota_config.get("api_ops_rate_limit")
|
||||
return None
|
||||
|
||||
|
||||
class QuotaUsageRepository:
|
||||
"""配额使用量数据访问层"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def count_workspaces(self, tenant_id: UUID) -> int:
|
||||
from app.models.workspace_model import Workspace
|
||||
return self.db.query(Workspace).filter(
|
||||
Workspace.tenant_id == tenant_id,
|
||||
Workspace.is_active.is_(True)
|
||||
).count()
|
||||
|
||||
def count_apps(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
|
||||
from app.models.app_model import App
|
||||
from app.models.workspace_model import Workspace
|
||||
query = self.db.query(App).join(
|
||||
Workspace, App.workspace_id == Workspace.id
|
||||
).filter(
|
||||
App.is_active.is_(True)
|
||||
)
|
||||
if workspace_id:
|
||||
query = query.filter(App.workspace_id == workspace_id)
|
||||
else:
|
||||
query = query.filter(Workspace.tenant_id == tenant_id)
|
||||
return query.count()
|
||||
|
||||
def count_skills(self, tenant_id: UUID) -> int:
|
||||
from app.models.skill_model import Skill
|
||||
return self.db.query(Skill).filter(
|
||||
Skill.tenant_id == tenant_id,
|
||||
Skill.is_active.is_(True)
|
||||
).count()
|
||||
|
||||
def sum_knowledge_capacity_gb(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> float:
|
||||
from app.models.document_model import Document
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.models.workspace_model import Workspace
|
||||
query = self.db.query(func.coalesce(func.sum(Document.file_size), 0)).join(
|
||||
Knowledge, Document.kb_id == Knowledge.id
|
||||
).join(
|
||||
Workspace, Knowledge.workspace_id == Workspace.id
|
||||
).filter(
|
||||
Document.status == 1,
|
||||
)
|
||||
if workspace_id:
|
||||
query = query.filter(Knowledge.workspace_id == workspace_id)
|
||||
else:
|
||||
query = query.filter(Workspace.tenant_id == tenant_id)
|
||||
result = query.scalar()
|
||||
return float(result) / (1024 ** 3) if result else 0.0
|
||||
|
||||
def count_memory_engines(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from app.models.workspace_model import Workspace
|
||||
query = self.db.query(MemoryConfig).join(
|
||||
Workspace, MemoryConfig.workspace_id == Workspace.id
|
||||
)
|
||||
if workspace_id:
|
||||
query = query.filter(MemoryConfig.workspace_id == workspace_id)
|
||||
else:
|
||||
query = query.filter(Workspace.tenant_id == tenant_id)
|
||||
return query.count()
|
||||
|
||||
def count_end_users(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.models.user_model import User
|
||||
query = self.db.query(EndUser).join(
|
||||
Workspace, EndUser.workspace_id == Workspace.id
|
||||
)
|
||||
if workspace_id:
|
||||
query = query.filter(EndUser.workspace_id == workspace_id)
|
||||
else:
|
||||
query = query.filter(Workspace.tenant_id == tenant_id)
|
||||
trial_user_ids = [
|
||||
str(u.id) for u in self.db.query(User.id).filter(User.tenant_id == tenant_id).all()
|
||||
]
|
||||
if trial_user_ids:
|
||||
query = query.filter(~EndUser.other_id.in_(trial_user_ids))
|
||||
return query.count()
|
||||
|
||||
def count_models(self, tenant_id: UUID) -> int:
|
||||
from app.models.models_model import ModelConfig
|
||||
return self.db.query(ModelConfig).filter(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_active == True,
|
||||
ModelConfig.is_composite == True
|
||||
).count()
|
||||
|
||||
def count_ontology_projects(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
from app.models.workspace_model import Workspace
|
||||
if workspace_id:
|
||||
return self.db.query(OntologyScene).filter(
|
||||
OntologyScene.workspace_id == workspace_id
|
||||
).count()
|
||||
return self.db.query(OntologyScene).join(
|
||||
Workspace, OntologyScene.workspace_id == Workspace.id
|
||||
).filter(
|
||||
Workspace.tenant_id == tenant_id
|
||||
).count()
|
||||
|
||||
def get_usage_by_quota_type(self, tenant_id: UUID, quota_type: str, workspace_id: Optional[UUID] = None):
|
||||
"""按配额类型分发,返回当前使用量"""
|
||||
dispatch = {
|
||||
"workspace_quota": self.count_workspaces,
|
||||
"app_quota": self.count_apps,
|
||||
"skill_quota": self.count_skills,
|
||||
"knowledge_capacity_quota": self.sum_knowledge_capacity_gb,
|
||||
"memory_engine_quota": self.count_memory_engines,
|
||||
"end_user_quota": self.count_end_users,
|
||||
"model_quota": self.count_models,
|
||||
"ontology_project_quota": self.count_ontology_projects,
|
||||
}
|
||||
fn = dispatch.get(quota_type)
|
||||
if workspace_id:
|
||||
return fn(tenant_id, workspace_id) if fn else 0
|
||||
return fn(tenant_id) if fn else 0
|
||||
|
||||
|
||||
def _check_quota(
|
||||
db: Session,
|
||||
tenant_id: UUID,
|
||||
quota_type: str,
|
||||
resource_name: str,
|
||||
usage_func: Optional[Callable] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
) -> None:
|
||||
"""核心配额检查逻辑:对比使用量和配额限制"""
|
||||
try:
|
||||
quota_config = _get_quota_config(db, tenant_id)
|
||||
if not quota_config:
|
||||
logger.warning(f"租户 {tenant_id} 无有效配额配置,跳过配额检查")
|
||||
return
|
||||
|
||||
quota_limit = quota_config.get(quota_type)
|
||||
if quota_limit is None:
|
||||
logger.warning(f"配额配置未包含 {quota_type},跳过配额检查")
|
||||
return
|
||||
|
||||
if usage_func:
|
||||
current_usage = usage_func(db, tenant_id, workspace_id) if workspace_id else usage_func(db, tenant_id)
|
||||
else:
|
||||
current_usage = QuotaUsageRepository(db).get_usage_by_quota_type(tenant_id, quota_type, workspace_id)
|
||||
|
||||
if current_usage >= quota_limit:
|
||||
logger.warning(
|
||||
f"配额不足: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, "
|
||||
f"usage={current_usage}, limit={quota_limit}"
|
||||
)
|
||||
raise QuotaExceededError(
|
||||
resource=resource_name,
|
||||
current_usage=current_usage,
|
||||
quota_limit=quota_limit,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"配额检查通过: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, "
|
||||
f"usage={current_usage}, limit={quota_limit}"
|
||||
)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"配额检查异常: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, "
|
||||
f"error_type={type(e).__name__}, error={str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# ─── 具名装饰器 ────────────────────────────────────────────────────────────
|
||||
|
||||
def check_workspace_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_skill_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "skill_quota", "skill")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "skill_quota", "skill")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_app_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "app_quota", "app", workspace_id=workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "app_quota", "app", workspace_id=workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_knowledge_capacity_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
if not db:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||
if not tenant_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity", workspace_id=workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity", workspace_id=workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_memory_engine_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
logger.debug(f"check_memory_engine_quota async_wrapper: db={db is not None}, user={user}, kwargs_keys={list(kwargs.keys())}")
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine", workspace_id=workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
logger.debug(f"check_memory_engine_quota sync_wrapper: db={db is not None}, user={user}, kwargs_keys={list(kwargs.keys())}")
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine", workspace_id=workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_end_user_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
if not db:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||
if not tenant_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
if not db:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||
if not tenant_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_ontology_project_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project", workspace_id=workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project", workspace_id=workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_model_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_model_activation_quota(func: Callable) -> Callable:
|
||||
"""模型激活时的配额检查装饰器"""
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
|
||||
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
|
||||
model_data = kwargs.get("model_data")
|
||||
|
||||
if not model_id or not model_data:
|
||||
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
if model_data.is_active:
|
||||
try:
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
existing_model = ModelConfigService.get_model_by_id(
|
||||
db=db,
|
||||
model_id=model_id,
|
||||
tenant_id=user.tenant_id
|
||||
)
|
||||
|
||||
if not existing_model.is_active:
|
||||
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
|
||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||
except Exception as e:
|
||||
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
|
||||
raise
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
|
||||
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
|
||||
model_data = kwargs.get("model_data")
|
||||
|
||||
if not model_id or not model_data:
|
||||
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if model_data.is_active:
|
||||
try:
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
existing_model = ModelConfigService.get_model_by_id(
|
||||
db=db,
|
||||
model_id=model_id,
|
||||
tenant_id=user.tenant_id
|
||||
)
|
||||
|
||||
if not existing_model.is_active:
|
||||
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
|
||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||
except Exception as e:
|
||||
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
|
||||
raise
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None):
|
||||
"""通用配额检查装饰器,支持自定义使用量获取函数"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# ─── 配额使用统计 ────────────────────────────────────────────────────────────
|
||||
|
||||
async def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
|
||||
"""获取租户所有配额的使用情况
|
||||
|
||||
对于 workspace 级别的配额(app/knowledge_capacity/memory_engine/end_user):
|
||||
- used: 租户汇总(所有空间加总)
|
||||
- limit: quota × 活跃工作区数(有效总限额,使汇总数据自洽)
|
||||
- per_workspace: 各空间明细,包含 workspace_id、workspace_name、used、limit、percentage
|
||||
- 配额检查逻辑不变:仍按单个空间独立检查
|
||||
"""
|
||||
quota_config = _get_quota_config(db, tenant_id)
|
||||
if not quota_config:
|
||||
return {}
|
||||
|
||||
repo = QuotaUsageRepository(db)
|
||||
|
||||
def pct(used, limit):
|
||||
return round(used / limit * 100, 1) if limit else None
|
||||
|
||||
workspace_count = repo.count_workspaces(tenant_id)
|
||||
skill_count = repo.count_skills(tenant_id)
|
||||
app_count = repo.count_apps(tenant_id)
|
||||
knowledge_gb = repo.sum_knowledge_capacity_gb(tenant_id)
|
||||
memory_count = repo.count_memory_engines(tenant_id)
|
||||
end_user_count = repo.count_end_users(tenant_id)
|
||||
model_count = repo.count_models(tenant_id)
|
||||
ontology_count = repo.count_ontology_projects(tenant_id)
|
||||
|
||||
# 获取租户下所有活跃工作区,用于按空间拆分明细
|
||||
from app.models.workspace_model import Workspace
|
||||
active_workspaces = db.query(Workspace).filter(
|
||||
Workspace.tenant_id == tenant_id,
|
||||
Workspace.is_active.is_(True)
|
||||
).all()
|
||||
|
||||
# 构建各空间的 workspace 级配额明细
|
||||
def _build_per_workspace_detail(count_func, per_unit_limit):
|
||||
"""为 workspace 级配额构建 per_workspace 明细列表"""
|
||||
if not per_unit_limit or not active_workspaces:
|
||||
return []
|
||||
details = []
|
||||
for ws in active_workspaces:
|
||||
ws_used = count_func(tenant_id, ws.id)
|
||||
details.append({
|
||||
"workspace_id": str(ws.id),
|
||||
"workspace_name": ws.name,
|
||||
"used": ws_used,
|
||||
"limit": per_unit_limit,
|
||||
"percentage": pct(ws_used, per_unit_limit),
|
||||
})
|
||||
return details
|
||||
|
||||
# workspace 级配额的每空间限额
|
||||
app_quota_per_ws = quota_config.get("app_quota")
|
||||
knowledge_quota_per_ws = quota_config.get("knowledge_capacity_quota")
|
||||
memory_quota_per_ws = quota_config.get("memory_engine_quota")
|
||||
end_user_quota_per_ws = quota_config.get("end_user_quota")
|
||||
ontology_quota_per_ws = quota_config.get("ontology_project_quota")
|
||||
|
||||
# workspace 级配额的有效总限额 = 每空间限额 × 活跃工作区数
|
||||
app_effective_limit = app_quota_per_ws * workspace_count if app_quota_per_ws is not None and workspace_count > 0 else app_quota_per_ws
|
||||
knowledge_effective_limit = knowledge_quota_per_ws * workspace_count if knowledge_quota_per_ws is not None and workspace_count > 0 else knowledge_quota_per_ws
|
||||
memory_effective_limit = memory_quota_per_ws * workspace_count if memory_quota_per_ws is not None and workspace_count > 0 else memory_quota_per_ws
|
||||
end_user_effective_limit = end_user_quota_per_ws * workspace_count if end_user_quota_per_ws is not None and workspace_count > 0 else end_user_quota_per_ws
|
||||
ontology_effective_limit = ontology_quota_per_ws * workspace_count if ontology_quota_per_ws is not None and workspace_count > 0 else ontology_quota_per_ws
|
||||
|
||||
api_ops_current = 0
|
||||
try:
|
||||
from app.aioRedis import aio_redis as _aio_redis
|
||||
from app.models.api_key_model import ApiKey
|
||||
# api_ops_rate_limit 限的是每个 api_key 每秒最高限额
|
||||
# 展示当前最接近触发限流的 key 的 QPS(取最大值)
|
||||
api_key_ids = db.query(ApiKey.id).join(
|
||||
Workspace, ApiKey.workspace_id == Workspace.id
|
||||
).filter(
|
||||
Workspace.tenant_id == tenant_id,
|
||||
ApiKey.is_active.is_(True)
|
||||
).all()
|
||||
for (key_id,) in api_key_ids:
|
||||
_rk = API_KEY_QPS_REDIS_KEY.format(api_key_id=key_id)
|
||||
val = await _aio_redis.get(_rk)
|
||||
count = int(val) if val else 0
|
||||
if count > api_ops_current:
|
||||
api_ops_current = count
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 api_ops_current 失败,返回 0: {type(e).__name__}: {e}")
|
||||
|
||||
return {
|
||||
"workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))},
|
||||
"skill": {"used": skill_count, "limit": quota_config.get("skill_quota"), "percentage": pct(skill_count, quota_config.get("skill_quota"))},
|
||||
"app": {
|
||||
"used": app_count,
|
||||
"limit": app_effective_limit,
|
||||
"percentage": pct(app_count, app_effective_limit),
|
||||
"per_workspace": _build_per_workspace_detail(repo.count_apps, app_quota_per_ws),
|
||||
},
|
||||
"knowledge_capacity": {
|
||||
"used": round(knowledge_gb, 2),
|
||||
"limit": knowledge_effective_limit,
|
||||
"percentage": pct(knowledge_gb, knowledge_effective_limit),
|
||||
"unit": "GB",
|
||||
"per_workspace": _build_per_workspace_detail(repo.sum_knowledge_capacity_gb, knowledge_quota_per_ws),
|
||||
},
|
||||
"memory_engine": {
|
||||
"used": memory_count,
|
||||
"limit": memory_effective_limit,
|
||||
"percentage": pct(memory_count, memory_effective_limit),
|
||||
"per_workspace": _build_per_workspace_detail(repo.count_memory_engines, memory_quota_per_ws),
|
||||
},
|
||||
"end_user": {
|
||||
"used": end_user_count,
|
||||
"limit": end_user_effective_limit,
|
||||
"percentage": pct(end_user_count, end_user_effective_limit),
|
||||
"per_workspace": _build_per_workspace_detail(repo.count_end_users, end_user_quota_per_ws),
|
||||
},
|
||||
"ontology_project": {
|
||||
"used": ontology_count,
|
||||
"limit": ontology_effective_limit,
|
||||
"percentage": pct(ontology_count, ontology_effective_limit),
|
||||
"per_workspace": _build_per_workspace_detail(repo.count_ontology_projects, ontology_quota_per_ws),
|
||||
},
|
||||
"model": {"used": model_count, "limit": quota_config.get("model_quota"), "percentage": pct(model_count, quota_config.get("model_quota"))},
|
||||
"api_ops_rate_limit": {"current": api_ops_current, "limit": quota_config.get("api_ops_rate_limit"), "percentage": None, "unit": "次/秒"},
|
||||
}
|
||||
38
api/app/core/quota_stub.py
Normal file
38
api/app/core/quota_stub.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
配额检查 stub - 社区版和 SaaS 版统一使用 core.quota_manager 实现
|
||||
|
||||
所有配额检查逻辑统一在 core 层实现,两个版本共用:
|
||||
- 社区版:从 default_free_plan.py 读取配额限制
|
||||
- SaaS 版:优先从 tenant_subscriptions 表读取,降级到配置文件
|
||||
"""
|
||||
from app.core.quota_manager import (
|
||||
check_workspace_quota,
|
||||
check_skill_quota,
|
||||
check_app_quota,
|
||||
check_knowledge_capacity_quota,
|
||||
check_memory_engine_quota,
|
||||
check_end_user_quota,
|
||||
check_ontology_project_quota,
|
||||
check_model_quota,
|
||||
check_model_activation_quota,
|
||||
get_quota_usage,
|
||||
_check_quota,
|
||||
QuotaUsageRepository,
|
||||
API_KEY_QPS_REDIS_KEY,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"check_workspace_quota",
|
||||
"check_skill_quota",
|
||||
"check_app_quota",
|
||||
"check_knowledge_capacity_quota",
|
||||
"check_memory_engine_quota",
|
||||
"check_end_user_quota",
|
||||
"check_ontology_project_quota",
|
||||
"check_model_quota",
|
||||
"check_model_activation_quota",
|
||||
"get_quota_usage",
|
||||
"_check_quota",
|
||||
"QuotaUsageRepository",
|
||||
"API_KEY_QPS_REDIS_KEY",
|
||||
]
|
||||
@@ -672,10 +672,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
excel_parser = ExcelParser()
|
||||
if parser_config.get("html4excel") and parser_config.get("html4excel").lower() == "true":
|
||||
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
|
||||
parser_config["chunk_token_num"] = 0
|
||||
else:
|
||||
sections = [(_, "") for _ in excel_parser(binary) if _]
|
||||
parser_config["chunk_token_num"] = 12800
|
||||
callback(0.8, "Finish parsing.")
|
||||
# Excel 每行直接作为一个 chunk,不经过 naive_merge 避免被 delimiter 拆分
|
||||
chunks = [s for s, _ in sections]
|
||||
res.extend(tokenize_chunks(chunks, doc, is_english, None))
|
||||
res.extend(embed_res)
|
||||
res.extend(url_res)
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
@@ -33,18 +33,16 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
effective_timeout = seconds if seconds else 120 # 默认 120 秒超时
|
||||
for a in range(attempts):
|
||||
try:
|
||||
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
||||
result = result_queue.get(timeout=seconds)
|
||||
else:
|
||||
result = result_queue.get()
|
||||
result = result_queue.get(timeout=effective_timeout)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
except queue.Empty:
|
||||
pass
|
||||
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
|
||||
raise TimeoutError(f"Function '{func.__name__}' timed out after {effective_timeout} seconds and {attempts} attempts.")
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs) -> Any:
|
||||
|
||||
@@ -232,14 +232,14 @@ class RAGExcelParser:
|
||||
t = str(ti[i].value) if i < len(ti) else ""
|
||||
t += (":" if t else "") + str(c.value)
|
||||
fields.append(t)
|
||||
line = "; ".join(fields)
|
||||
line = "\n".join(fields)
|
||||
if sheetname.lower().find("sheet") < 0:
|
||||
line += " ——" + sheetname
|
||||
line += "\n——" + sheetname
|
||||
res.append(line)
|
||||
else:
|
||||
# 只有表头的情况
|
||||
if header_fields:
|
||||
line = "; ".join(header_fields)
|
||||
line = "\n".join(header_fields)
|
||||
if sheetname.lower().find("sheet") < 0:
|
||||
line += " ——" + sheetname
|
||||
res.append(line)
|
||||
|
||||
@@ -292,9 +292,10 @@ class MinerUParser(RAGPdfParser):
|
||||
self.page_from = page_from
|
||||
self.page_to = page_to
|
||||
try:
|
||||
with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
|
||||
self.pdf = pdf
|
||||
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])]
|
||||
with sys.modules[LOCK_KEY_pdfplumber]: # ← 加这一行,获取全局锁
|
||||
with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
|
||||
self.pdf = pdf
|
||||
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])]
|
||||
except Exception as e:
|
||||
self.page_images = None
|
||||
self.total_page = 0
|
||||
|
||||
@@ -50,7 +50,9 @@ class OpenAIEmbed(Base):
|
||||
def encode(self, texts: list):
|
||||
# OpenAI requires batch size <=16
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8191) for t in texts]
|
||||
# Use 8000 instead of 8191 to leave safety margin for tokenizer differences
|
||||
# between cl100k_base (used by truncate) and the actual embedding model
|
||||
texts = [truncate(t, 8000) for t in texts]
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
@@ -63,7 +65,7 @@ class OpenAIEmbed(Base):
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
|
||||
|
||||
@@ -79,6 +81,7 @@ class LocalAIEmbed(Base):
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8000) for t in texts]
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
@@ -173,6 +176,7 @@ class XinferenceEmbed(Base):
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8000) for t in texts]
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
@@ -188,7 +192,7 @@ class XinferenceEmbed(Base):
|
||||
def encode_queries(self, text):
|
||||
res = None
|
||||
try:
|
||||
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
@@ -28,6 +28,7 @@ from app.core.rag.common.float_utils import get_float
|
||||
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -112,11 +113,10 @@ def knowledge_retrieval(
|
||||
continue
|
||||
|
||||
# Use the specified reranker for re-ranking
|
||||
if reranker_id:
|
||||
if reranker_id and all_results:
|
||||
try:
|
||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
all_results = rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
except Exception as rerank_error:
|
||||
# If reranker fails, log warning and continue with original results
|
||||
logger.warning(
|
||||
"Reranker failed, falling back to original results",
|
||||
extra={
|
||||
@@ -132,7 +132,10 @@ def knowledge_retrieval(
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
all_results.insert(0, doc)
|
||||
all_results.insert(0, DocumentChunk(
|
||||
page_content=doc.get("page_content", ""),
|
||||
metadata=doc.get("metadata", {})
|
||||
))
|
||||
except Exception as graph_error:
|
||||
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
|
||||
|
||||
@@ -198,16 +201,18 @@ def _retrieve_for_knowledge(
|
||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||
|
||||
if not chat_model:
|
||||
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||
key=llm_key.api_key,
|
||||
model_name=llm_key.model_name,
|
||||
base_url=llm_key.api_base,
|
||||
)
|
||||
if not embedding_model:
|
||||
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base,
|
||||
key=emb_key.api_key,
|
||||
model_name=emb_key.model_name,
|
||||
base_url=emb_key.api_base,
|
||||
)
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
@@ -248,6 +253,29 @@ def _retrieve_for_knowledge(
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = unique_rs
|
||||
if unique_rs:
|
||||
rs = vector_service.rerank(
|
||||
query=kb_config["query"],
|
||||
docs=unique_rs,
|
||||
top_k=kb_config["top_k"]
|
||||
)
|
||||
if kb_config["retrieve_type"] == "graph":
|
||||
try:
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
graph_doc = kg_retriever.retrieval(
|
||||
question=kb_config["query"],
|
||||
workspace_ids=[str(db_knowledge.workspace_id)],
|
||||
kb_ids=[str(db_knowledge.id)],
|
||||
emb_mdl=embedding_model,
|
||||
llm=chat_model,
|
||||
)
|
||||
if graph_doc:
|
||||
rs.insert(0, DocumentChunk(
|
||||
page_content=graph_doc.get("page_content", ""),
|
||||
metadata=graph_doc.get("metadata", {})
|
||||
))
|
||||
except Exception as graph_error:
|
||||
logger.warning(f"Graph retrieval failed for kb {db_knowledge.id}: {graph_error}")
|
||||
|
||||
results.extend(rs)
|
||||
return results, chat_model, embedding_model
|
||||
|
||||
@@ -68,9 +68,9 @@ class ESConnection(DocStoreConnection):
|
||||
client_config = {
|
||||
"hosts": [hosts],
|
||||
"basic_auth": (os.getenv("ELASTICSEARCH_USERNAME", "elastic"), os.getenv("ELASTICSEARCH_PASSWORD", "elastic")),
|
||||
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)),
|
||||
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 30)),
|
||||
"retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true",
|
||||
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)),
|
||||
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 3)),
|
||||
}
|
||||
|
||||
# Only add SSL settings if using HTTPS
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
import threading
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
from elasticsearch import Elasticsearch, helpers
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from packaging.version import parse as parse_version
|
||||
from pydantic import BaseModel, model_validator
|
||||
from abc import ABC
|
||||
# langchain-community
|
||||
# langchain-xinference
|
||||
# from langchain_community.embeddings import XinferenceEmbeddings
|
||||
# from langchain_xinference import XinferenceRerank
|
||||
from langchain_core.documents import Document
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models import RedBearLLM, RedBearRerank
|
||||
from app.core.models import RedBearRerank
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
from app.models.models_model import ModelConfig, ModelApiKey
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.models.models_model import ModelApiKey
|
||||
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.core.rag.vdb.field import Field
|
||||
@@ -29,37 +26,9 @@ from app.core.rag.models.chunk import DocumentChunk
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticSearchConfig(BaseModel):
|
||||
# Regular Elasticsearch config
|
||||
host: str | None = None
|
||||
port: int | None = None
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
|
||||
# Common config
|
||||
ca_certs: str | None = None
|
||||
verify_certs: bool = False
|
||||
request_timeout: int = 100000
|
||||
retry_on_timeout: bool = True
|
||||
max_retries: int = 10000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
# Regular Elasticsearch validation
|
||||
if not values.get("host"):
|
||||
raise ValueError("config HOST is required for regular Elasticsearch")
|
||||
if not values.get("port"):
|
||||
raise ValueError("config PORT is required for regular Elasticsearch")
|
||||
if not values.get("username"):
|
||||
raise ValueError("config USERNAME is required for regular Elasticsearch")
|
||||
if not values.get("password"):
|
||||
raise ValueError("config PASSWORD is required for regular Elasticsearch")
|
||||
return values
|
||||
|
||||
|
||||
class ElasticSearchVector(BaseVector):
|
||||
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
||||
def __init__(self, index_name: str, client: Elasticsearch,
|
||||
embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
||||
super().__init__(index_name.lower())
|
||||
|
||||
# 初始化 Embedding 模型(自动支持火山引擎多模态)
|
||||
@@ -77,58 +46,8 @@ class ElasticSearchVector(BaseVector):
|
||||
api_key=reranker_config.api_key,
|
||||
base_url=reranker_config.api_base
|
||||
))
|
||||
self._client = self._init_client(config)
|
||||
self._version = self._get_version()
|
||||
self._check_version()
|
||||
|
||||
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
||||
"""
|
||||
Initialize Elasticsearch client for regular Elasticsearch.
|
||||
"""
|
||||
try:
|
||||
# Regular Elasticsearch configuration
|
||||
parsed_url = urlparse(config.host or "")
|
||||
if parsed_url.scheme in {"http", "https"}:
|
||||
hosts = f"{config.host}:{config.port}"
|
||||
use_https = parsed_url.scheme == "https"
|
||||
else:
|
||||
hosts = f"https://{config.host}:{config.port}"
|
||||
use_https = False
|
||||
|
||||
client_config = {
|
||||
"hosts": [hosts],
|
||||
"basic_auth": (config.username, config.password),
|
||||
"request_timeout": config.request_timeout,
|
||||
"retry_on_timeout": config.retry_on_timeout,
|
||||
"max_retries": config.max_retries,
|
||||
}
|
||||
|
||||
# Only add SSL settings if using HTTPS
|
||||
if use_https:
|
||||
client_config["verify_certs"] = config.verify_certs
|
||||
if config.ca_certs:
|
||||
client_config["ca_certs"] = config.ca_certs
|
||||
|
||||
client = Elasticsearch(**client_config)
|
||||
|
||||
# Test connection
|
||||
if not client.ping():
|
||||
raise ConnectionError("Failed to connect to Elasticsearch")
|
||||
|
||||
except requests.ConnectionError as e:
|
||||
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
||||
|
||||
return client
|
||||
|
||||
def _get_version(self) -> str:
|
||||
info = self._client.info()
|
||||
return cast(str, info["version"]["number"])
|
||||
|
||||
def _check_version(self):
|
||||
if parse_version(self._version) < parse_version("8.0.0"):
|
||||
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
|
||||
# 使用外部传入的共享客户端
|
||||
self._client = client
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "elasticsearch"
|
||||
@@ -745,29 +664,79 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
|
||||
class ElasticSearchVectorFactory:
|
||||
@staticmethod
|
||||
def init_vector(knowledge: Knowledge) -> ElasticSearchVector:
|
||||
"""ES 向量服务工厂 - 单例共享连接"""
|
||||
|
||||
_client: Elasticsearch | None = None
|
||||
_lock = threading.Lock()
|
||||
_version_checked = False
|
||||
|
||||
@classmethod
|
||||
def _get_shared_client(cls) -> Elasticsearch:
|
||||
"""获取共享的 ES 客户端(线程安全的懒加载单例)"""
|
||||
if cls._client is not None:
|
||||
return cls._client
|
||||
|
||||
with cls._lock:
|
||||
# 双重检查,防止并发时重复创建
|
||||
if cls._client is not None:
|
||||
return cls._client
|
||||
|
||||
try:
|
||||
parsed_url = urlparse(os.getenv("ELASTICSEARCH_HOST", "127.0.0.1") or "")
|
||||
if parsed_url.scheme in {"http", "https"}:
|
||||
hosts = f'{os.getenv("ELASTICSEARCH_HOST")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}'
|
||||
use_https = parsed_url.scheme == "https"
|
||||
else:
|
||||
hosts = f'https://{os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}'
|
||||
use_https = False
|
||||
|
||||
client_config = {
|
||||
"hosts": [hosts],
|
||||
"basic_auth": (
|
||||
os.getenv("ELASTICSEARCH_USERNAME", "elastic"),
|
||||
os.getenv("ELASTICSEARCH_PASSWORD", "elastic"),
|
||||
),
|
||||
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 30)),
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 3)),
|
||||
"connections_per_node": int(os.getenv("ELASTICSEARCH_CONNECTIONS_PER_NODE", 10)),
|
||||
}
|
||||
|
||||
if use_https:
|
||||
client_config["verify_certs"] = os.getenv("ELASTICSEARCH_VERIFY_CERTS", "false") == "true"
|
||||
ca_certs = os.getenv("ELASTICSEARCH_CA_CERTS")
|
||||
if ca_certs:
|
||||
client_config["ca_certs"] = str(ca_certs)
|
||||
|
||||
client = Elasticsearch(**client_config)
|
||||
|
||||
if not client.ping():
|
||||
raise ConnectionError("Failed to connect to Elasticsearch")
|
||||
|
||||
# 版本检查只做一次
|
||||
if not cls._version_checked:
|
||||
info = client.info()
|
||||
version = info["version"]["number"]
|
||||
if parse_version(version) < parse_version("8.0.0"):
|
||||
raise ValueError(f"Elasticsearch version must be >= 8.0.0, got {version}")
|
||||
cls._version_checked = True
|
||||
logger.info(f"Elasticsearch shared client initialized, version: {version}")
|
||||
|
||||
cls._client = client
|
||||
|
||||
except requests.ConnectionError as e:
|
||||
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
||||
|
||||
return cls._client
|
||||
|
||||
@classmethod
|
||||
def init_vector(cls, knowledge: Knowledge) -> ElasticSearchVector:
|
||||
"""创建向量服务实例(共享 ES 连接)"""
|
||||
client = cls._get_shared_client()
|
||||
collection_name = f"Vector_index_{knowledge.id}_Node"
|
||||
|
||||
# Use regular Elasticsearch with config values
|
||||
config_dict = {
|
||||
"host": os.getenv("ELASTICSEARCH_HOST", "127.0.0.1"),
|
||||
"port": os.getenv("ELASTICSEARCH_PORT", 9200),
|
||||
"username": os.getenv("ELASTICSEARCH_USERNAME", "elastic"),
|
||||
"password": os.getenv("ELASTICSEARCH_PASSWORD", "elastic"),
|
||||
}
|
||||
|
||||
# Common configuration
|
||||
config_dict.update(
|
||||
{
|
||||
"ca_certs": str(os.getenv("ELASTICSEARCH_CA_CERTS")) if os.getenv("ELASTICSEARCH_CA_CERTS") else None,
|
||||
"verify_certs": os.getenv("ELASTICSEARCH_VERIFY_CERTS", False) == "true",
|
||||
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)),
|
||||
"retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true",
|
||||
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)),
|
||||
}
|
||||
)
|
||||
|
||||
if knowledge.embedding is None:
|
||||
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
|
||||
if knowledge.reranker is None:
|
||||
@@ -775,9 +744,9 @@ class ElasticSearchVectorFactory:
|
||||
|
||||
return ElasticSearchVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(**config_dict),
|
||||
client=client,
|
||||
embedding_config=knowledge.embedding.api_keys[0],
|
||||
reranker_config=knowledge.reranker.api_keys[0]
|
||||
reranker_config=knowledge.reranker.api_keys[0],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -230,7 +230,7 @@ class DateTimeTool(BuiltinTool):
|
||||
@staticmethod
|
||||
def _datetime_to_timestamp(kwargs) -> dict:
|
||||
"""日期时间转时间戳"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_value = kwargs.get("input_value").strip()
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
timezone_str = kwargs.get("from_timezone", "Asia/Shanghai")
|
||||
|
||||
@@ -253,9 +253,9 @@ class DateTimeTool(BuiltinTool):
|
||||
return {
|
||||
"datetime": input_value,
|
||||
"timezone": timezone_str,
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"timestamp": int(dt.timestamp() * 1000),
|
||||
"iso_format": dt.isoformat(),
|
||||
"result_data": int(dt.timestamp())
|
||||
"result_data": int(dt.timestamp() * 1000)
|
||||
}
|
||||
|
||||
def _calculate_datetime(self, kwargs) -> dict:
|
||||
|
||||
300
api/app/core/tools/builtin/openclaw_tool.py
Normal file
300
api/app/core/tools/builtin/openclaw_tool.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""OpenClaw 远程 Agent 内置工具"""
|
||||
import time
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import List, Dict, Any, Optional
|
||||
import aiohttp
|
||||
|
||||
from app.core.tools.builtin.base import BuiltinTool
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class OpenClawTool(BuiltinTool):
|
||||
"""OpenClaw 远程 Agent 工具 — 支持文本和图片多模态输入"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
super().__init__(tool_id, config)
|
||||
params = self.parameters_config
|
||||
|
||||
# 用户配置项(前端表单填写)
|
||||
self._server_url = params.get("server_url", "")
|
||||
self._api_key = params.get("api_key", "")
|
||||
self._agent_id = params.get("agent_id", "main")
|
||||
|
||||
# 内部默认值
|
||||
self._model = "openclaw"
|
||||
self._session_strategy = "by_user"
|
||||
self._timeout = 120
|
||||
|
||||
# 运行时上下文(通过 set_runtime_context 注入)
|
||||
self._user_id = "anonymous"
|
||||
self._conversation_id = None
|
||||
self._uploaded_files = []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "openclaw_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"OpenClaw 远程 Agent:将任务委托给远程 OpenClaw Agent。"
|
||||
"具备 3D 模型生成与打印控制、设备管理、文件处理、浏览器自动化、"
|
||||
"Shell 命令执行、网络搜索等能力。支持文本和图片多模态交互。"
|
||||
)
|
||||
|
||||
def get_required_config_parameters(self) -> List[str]:
|
||||
return ["server_url", "api_key"]
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="operation",
|
||||
type=ParameterType.STRING,
|
||||
description="任务类型",
|
||||
required=True,
|
||||
enum= ["print_task", "device_query", "image_understand", "general"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="message",
|
||||
type=ParameterType.STRING,
|
||||
description="发送给 OpenClaw Agent 的文本请求内容",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_url",
|
||||
type=ParameterType.STRING,
|
||||
description="可选,附带的图片 URL 或 base64 data URI(OpenClaw 支持图片输入)",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
|
||||
# ---------- 运行时上下文注入 ----------
|
||||
def set_runtime_context(
|
||||
self,
|
||||
user_id: str = "anonymous",
|
||||
conversation_id: Optional[str] = None,
|
||||
uploaded_files: Optional[list] = None
|
||||
):
|
||||
"""注入运行时上下文(由 chat service 调用)"""
|
||||
self._user_id = user_id
|
||||
self._conversation_id = conversation_id
|
||||
self._uploaded_files = uploaded_files or []
|
||||
|
||||
# ---------- 连接测试 ----------
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试 OpenClaw Gateway 连接"""
|
||||
if not self._server_url:
|
||||
return {"success": False, "message": "未配置 server_url"}
|
||||
if not self._api_key:
|
||||
return {"success": False, "message": "未配置 api_key"}
|
||||
|
||||
url = f"{self._server_url.rstrip('/')}/v1/responses"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"x-openclaw-agent-id": self._agent_id
|
||||
}
|
||||
body = {
|
||||
"model": self._model,
|
||||
"user": "connection-test",
|
||||
"input": "hi",
|
||||
"stream": False
|
||||
}
|
||||
try:
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=30)
|
||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
||||
async with session.post(url, json=body, headers=headers) as resp:
|
||||
if resp.status < 400:
|
||||
return {"success": True, "message": "OpenClaw 连接成功"}
|
||||
error_text = await resp.text()
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"OpenClaw HTTP {resp.status}: {error_text[:200]}"
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"OpenClaw 连接失败: {str(e)}"}
|
||||
|
||||
# ---------- 执行 ----------
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行 OpenClaw 调用"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
message = kwargs.get("message", "")
|
||||
if not message:
|
||||
return ToolResult.error_result(
|
||||
error="message 参数不能为空",
|
||||
error_code="OPENCLAW_INVALID_INPUT",
|
||||
execution_time=time.time() - start_time
|
||||
)
|
||||
|
||||
# 提取图片:优先从用户上传文件中获取,LLM 传的 image_url 作为兜底
|
||||
image_url = self._extract_image_from_uploads()
|
||||
if not image_url:
|
||||
image_url = kwargs.get("image_url")
|
||||
if image_url and not image_url.startswith("data:"):
|
||||
image_url = await self._download_and_encode_image(image_url)
|
||||
|
||||
# 构建请求
|
||||
url = f"{self._server_url.rstrip('/')}/v1/responses"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"x-openclaw-agent-id": self._agent_id
|
||||
}
|
||||
user_field = (
|
||||
f"conv-{self._conversation_id}"
|
||||
if self._session_strategy == "by_conversation" and self._conversation_id
|
||||
else f"user-{self._user_id}"
|
||||
)
|
||||
input_field = self._build_input(message, image_url)
|
||||
body = {
|
||||
"model": self._model,
|
||||
"user": user_field,
|
||||
"input": input_field,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=self._timeout)
|
||||
# 打印请求日志(截断 base64 避免日志过大)
|
||||
log_body = {**body}
|
||||
if isinstance(log_body.get("input"), list):
|
||||
log_body["input"] = "[multimodal input, truncated]"
|
||||
elif isinstance(log_body.get("input"), str) and len(log_body["input"]) > 500:
|
||||
log_body["input"] = log_body["input"][:500] + "..."
|
||||
logger.info(
|
||||
f"OpenClaw 请求: url={url}, agent_id={self._agent_id}, "
|
||||
f"has_image={bool(image_url)}, body={log_body}"
|
||||
)
|
||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
||||
async with session.post(url, json=body, headers=headers) as resp:
|
||||
execution_time = time.time() - start_time
|
||||
if resp.status >= 400:
|
||||
error_text = await resp.text()
|
||||
return ToolResult.error_result(
|
||||
error=f"OpenClaw HTTP {resp.status}: {error_text[:500]}",
|
||||
error_code="OPENCLAW_HTTP_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
data = await resp.json()
|
||||
text = self._extract_response(data)
|
||||
display_text = self._format_result(text)
|
||||
return ToolResult.success_result(
|
||||
data=display_text,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return ToolResult.error_result(
|
||||
error=f"OpenClaw 网络连接失败: {str(e)}",
|
||||
error_code="OPENCLAW_NETWORK_ERROR",
|
||||
execution_time=time.time() - start_time
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult.error_result(
|
||||
error=f"OpenClaw 调用失败: {str(e)}",
|
||||
error_code="OPENCLAW_EXECUTION_ERROR",
|
||||
execution_time=time.time() - start_time
|
||||
)
|
||||
|
||||
# ---------- 私有方法 ----------
|
||||
def _extract_image_from_uploads(self) -> Optional[str]:
|
||||
"""从用户上传文件中提取图片 URL"""
|
||||
for f in self._uploaded_files:
|
||||
f_type = f.get("type", "")
|
||||
if f_type == "image":
|
||||
source = f.get("source", {})
|
||||
if source.get("type") == "base64":
|
||||
media_type = source.get("media_type", "image/jpeg")
|
||||
data = source.get("data", "")
|
||||
return f"data:{media_type};base64,{data}"
|
||||
elif f.get("image"):
|
||||
return f.get("image")
|
||||
elif f.get("url"):
|
||||
return f.get("url")
|
||||
elif f_type == "image_url":
|
||||
return f.get("image_url", {}).get("url", "")
|
||||
return None
|
||||
|
||||
async def _download_and_encode_image(self, image_url: str) -> str:
|
||||
"""下载图片并转为 base64 data URI"""
|
||||
try:
|
||||
from PIL import Image
|
||||
MAX_RAW_SIZE = 4 * 1024 * 1024
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
image_url, allow_redirects=True,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
return image_url
|
||||
content_type = resp.headers.get("Content-Type", "image/jpeg")
|
||||
if not content_type.startswith("image/"):
|
||||
return image_url
|
||||
img_bytes = await resp.read()
|
||||
|
||||
if len(img_bytes) > MAX_RAW_SIZE:
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
if img.mode in ("RGBA", "P", "LA"):
|
||||
img = img.convert("RGB")
|
||||
if max(img.size) > 2048:
|
||||
img.thumbnail((2048, 2048), Image.LANCZOS)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format="JPEG", quality=75, optimize=True)
|
||||
img_bytes = buf.getvalue()
|
||||
content_type = "image/jpeg"
|
||||
|
||||
b64 = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return f"data:{content_type};base64,{b64}"
|
||||
except Exception as e:
|
||||
logger.warning(f"OpenClaw 下载图片失败,使用原始 URL: {e}")
|
||||
return image_url
|
||||
|
||||
def _build_input(self, message: str, image_url: Optional[str] = None):
|
||||
"""构造请求 input 字段:有图片则构造多模态结构,否则纯文本"""
|
||||
if not image_url:
|
||||
return message
|
||||
|
||||
content_parts = [{"type": "input_text", "text": message}]
|
||||
if image_url.startswith("data:"):
|
||||
try:
|
||||
header, data = image_url.split(",", 1)
|
||||
media_type = header.split(":")[1].split(";")[0]
|
||||
content_parts.append({
|
||||
"type": "input_image",
|
||||
"source": {"type": "base64", "media_type": media_type, "data": data}
|
||||
})
|
||||
except (ValueError, IndexError):
|
||||
return message
|
||||
else:
|
||||
content_parts.append({
|
||||
"type": "input_image",
|
||||
"source": {"type": "url", "url": image_url}
|
||||
})
|
||||
|
||||
return [{"type": "message", "role": "user", "content": content_parts}]
|
||||
|
||||
def _extract_response(self, response_data: Dict[str, Any]) -> str:
|
||||
"""从 OpenClaw 响应中提取文本内容
|
||||
|
||||
OpenClaw /v1/responses 只返回 output_text 类型的内容。
|
||||
图片信息(如有)由 OpenClaw Skill 以 Markdown 链接形式嵌入文本中返回。
|
||||
"""
|
||||
output = response_data.get("output", [])
|
||||
texts = []
|
||||
for item in output:
|
||||
if item.get("type") == "message":
|
||||
for content in item.get("content", []):
|
||||
if content.get("type") == "output_text" and content.get("text"):
|
||||
texts.append(content["text"])
|
||||
return "\n".join(texts) if texts else str(response_data)
|
||||
|
||||
@staticmethod
|
||||
def _format_result(text: str) -> str:
|
||||
"""格式化结果为 LLM 可读字符串"""
|
||||
return text or "(OpenClaw 返回了空内容)"
|
||||
@@ -11,6 +11,11 @@ class OperationTool(BaseTool):
|
||||
self.base_tool = base_tool
|
||||
self.operation = operation
|
||||
super().__init__(base_tool.tool_id, base_tool.config)
|
||||
|
||||
def set_runtime_context(self, **kwargs):
|
||||
"""转发运行时上下文到 base_tool"""
|
||||
if hasattr(self.base_tool, 'set_runtime_context'):
|
||||
self.base_tool.set_runtime_context(**kwargs)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -32,6 +37,8 @@ class OperationTool(BaseTool):
|
||||
return self._get_datetime_params()
|
||||
elif self.base_tool.name == 'json_tool':
|
||||
return self._get_json_params()
|
||||
elif self.base_tool.name == 'openclaw_tool':
|
||||
return self._get_openclaw_params()
|
||||
else:
|
||||
# 默认返回除operation外的所有参数
|
||||
return [p for p in self.base_tool.parameters if p.name != "operation"]
|
||||
@@ -138,6 +145,29 @@ class OperationTool(BaseTool):
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
elif self.operation == "datetime_to_timestamp":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串,如:2026-04-07 10:30:25)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="from_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="源时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
@@ -209,6 +239,64 @@ class OperationTool(BaseTool):
|
||||
else:
|
||||
return base_params
|
||||
|
||||
def _get_openclaw_params(self) -> List[ToolParameter]:
|
||||
"""获取 openclaw_tool 特定操作的参数"""
|
||||
if self.operation == "print_task":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="message",
|
||||
type=ParameterType.STRING,
|
||||
description="发送给 OpenClaw 的打印任务描述,将用户的原始消息原封不动地传递给 OpenClaw,禁止改写、补充或润色用户的原文",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_url",
|
||||
type=ParameterType.STRING,
|
||||
description="可选,附带的设计图片或参考图,OpenClaw 可据此生成 3D 模型",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
elif self.operation == "device_query":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="message",
|
||||
type=ParameterType.STRING,
|
||||
description="发送给 OpenClaw 的设备查询指令",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "image_understand":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="message",
|
||||
type=ParameterType.STRING,
|
||||
description="发送给 OpenClaw 的图片理解任务,应描述需要对图片做什么(如描述内容、提取文字、分析信息)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_url",
|
||||
type=ParameterType.STRING,
|
||||
description="要分析的图片 URL 或 base64 data URI",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
else:
|
||||
# general 及其他
|
||||
return [
|
||||
ToolParameter(
|
||||
name="message",
|
||||
type=ParameterType.STRING,
|
||||
description="发送给 OpenClaw Agent 的任务描述,应包含完整的任务需求",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_url",
|
||||
type=ParameterType.STRING,
|
||||
description="可选,附带的图片 URL 或 base64 data URI",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行特定操作"""
|
||||
# 添加operation参数
|
||||
|
||||
15
api/app/core/tools/configs/builtin/openclaw_tool.json
Normal file
15
api/app/core/tools/configs/builtin/openclaw_tool.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"name": "openclaw_tool",
|
||||
"description": "调用OpenClaw Agent远程服务",
|
||||
"tool_class": "OpenClawTool",
|
||||
"category": "agent",
|
||||
"requires_config": true,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"server_url": "",
|
||||
"api_key": "",
|
||||
"agent_id": "main"
|
||||
},
|
||||
"tags": ["agent", "openclaw", "multimodal", "3d-printing", "builtin"]
|
||||
}
|
||||
@@ -30,5 +30,18 @@
|
||||
"parameters": {
|
||||
"api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true}
|
||||
}
|
||||
},
|
||||
"openclaw": {
|
||||
"name": "OpenClaw远程Agent",
|
||||
"description": "OpenClaw Agent远程服务",
|
||||
"tool_class": "OpenClawTool",
|
||||
"category": "agent",
|
||||
"requires_config": true,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"server_url": {"type": "string", "description": "OpenClaw Gateway 地址", "required": true},
|
||||
"api_key": {"type": "string", "description": "OpenClaw API Key", "sensitive": true, "required": true}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -30,7 +30,7 @@ class CustomTool(BaseTool):
|
||||
self.auth_config = config.get("auth_config", {})
|
||||
self.base_url = config.get("base_url", "")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
|
||||
|
||||
# 解析schema
|
||||
self._parsed_operations = self._parse_openapi_schema()
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ class LangchainAdapter:
|
||||
def _tool_supports_operations(tool: BaseTool) -> bool:
|
||||
"""检查工具是否支持多操作"""
|
||||
# 内置工具中支持操作的工具
|
||||
builtin_operation_tools = ['datetime_tool', 'json_tool']
|
||||
builtin_operation_tools = ['datetime_tool', 'json_tool', 'openclaw_tool']
|
||||
|
||||
# 检查内置工具
|
||||
if tool.tool_type.value == "builtin" and tool.name in builtin_operation_tools:
|
||||
|
||||
@@ -40,6 +40,7 @@ class WorkflowParserResult(BaseModel):
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
features: dict[str, Any] = Field(default_factory=dict)
|
||||
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
|
||||
@@ -51,6 +52,7 @@ class WorkflowImportResult(BaseModel):
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
features: dict[str, Any] = Field(default_factory=dict)
|
||||
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from app.core.workflow.adapters.errors import (
|
||||
ExceptionType
|
||||
)
|
||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
|
||||
from app.core.workflow.nodes.base_config import VariableDefinition as NodeVariableDefinition, BaseNodeConfig
|
||||
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
||||
from app.core.workflow.nodes.configs import (
|
||||
StartNodeConfig,
|
||||
@@ -36,6 +36,7 @@ from app.core.workflow.nodes.configs import (
|
||||
ListOperatorNodeConfig,
|
||||
DocExtractorNodeConfig,
|
||||
)
|
||||
from app.schemas.workflow_schema import VariableDefinition as SchemaVariableDefinition
|
||||
from app.core.workflow.nodes.cycle_graph.config import (
|
||||
ConditionDetail as LoopConditionDetail,
|
||||
ConditionsConfig,
|
||||
@@ -98,6 +99,7 @@ class DifyConverter(BaseConverter):
|
||||
NodeType.CYCLE_START: lambda x: {},
|
||||
NodeType.BREAK: lambda x: {},
|
||||
}
|
||||
self._file_vars_to_conv: list[SchemaVariableDefinition] = []
|
||||
|
||||
def get_node_convert(self, node_type):
|
||||
func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {})
|
||||
@@ -286,19 +288,25 @@ class DifyConverter(BaseConverter):
|
||||
)
|
||||
continue
|
||||
|
||||
if var_type in ["file", "array[file]"]:
|
||||
self.errors.append(
|
||||
ExceptionDefinition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
name=var["variable"],
|
||||
detail=f"Unsupported Variable type for start node: {var_type}"
|
||||
)
|
||||
)
|
||||
if var_type in [VariableType.FILE, VariableType.ARRAY_FILE]:
|
||||
# 开始节点不支持文件变量,转为会话变量
|
||||
self._file_vars_to_conv.append(SchemaVariableDefinition(
|
||||
name=var["variable"],
|
||||
type=var_type.value,
|
||||
required=var.get("required", False),
|
||||
default=None,
|
||||
description=var.get("label", ""),
|
||||
))
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
name=var["variable"],
|
||||
detail=f"File variable '{var['variable']}' is not supported in start node, moved to conversation variables"
|
||||
))
|
||||
continue
|
||||
|
||||
var_def = VariableDefinition(
|
||||
var_def = NodeVariableDefinition(
|
||||
name=var["variable"],
|
||||
type=var_type,
|
||||
required=var["required"],
|
||||
@@ -837,3 +845,76 @@ class DifyConverter(BaseConverter):
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], DocExtractorNodeConfig, result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def convert_features(features: dict) -> dict:
|
||||
"""Convert Dify features to MemoryBear FeaturesConfigForm format."""
|
||||
if not features:
|
||||
return {}
|
||||
|
||||
result: dict = {}
|
||||
|
||||
# opening_statement
|
||||
opening = features.get("opening_statement", "")
|
||||
suggested = features.get("suggested_questions", [])
|
||||
result["opening_statement"] = {
|
||||
"enabled": bool(opening),
|
||||
"statement": opening or None,
|
||||
"suggested_questions": suggested,
|
||||
}
|
||||
|
||||
# citation (对应 Dify retriever_resource)
|
||||
retriever = features.get("retriever_resource", {})
|
||||
result["citation"] = {
|
||||
"enabled": retriever.get("enabled", False) if isinstance(retriever, dict) else False,
|
||||
}
|
||||
|
||||
# file_upload: Dify allowed_file_types 数组 -> 前端扁平字段
|
||||
file_upload = features.get("file_upload", {})
|
||||
allowed_types = file_upload.get("allowed_file_types", []) if file_upload else []
|
||||
allowed_methods = file_upload.get("allowed_file_upload_methods", ["local_file", "remote_url"])
|
||||
if isinstance(allowed_methods, list):
|
||||
if len(allowed_methods) >= 2:
|
||||
transfer_method = "both"
|
||||
elif allowed_methods:
|
||||
transfer_method = allowed_methods[0]
|
||||
else:
|
||||
transfer_method = "both"
|
||||
else:
|
||||
transfer_method = allowed_methods or "both"
|
||||
|
||||
file_config = file_upload.get("fileUploadConfig", {})
|
||||
result["file_upload"] = {
|
||||
"enabled": file_upload.get("enabled", False) if file_upload else False,
|
||||
"image_enabled": "image" in allowed_types,
|
||||
"image_max_size_mb": file_config.get("image_file_size_limit", 10) if file_config else 10,
|
||||
"image_allowed_extensions": ["png", "jpg", "jpeg"],
|
||||
"audio_enabled": "audio" in allowed_types,
|
||||
"audio_max_size_mb": file_config.get("audio_file_size_limit", 50) if file_config else 50,
|
||||
"audio_allowed_extensions": ["mp3", "wav", "m4a"],
|
||||
"document_enabled": "document" in allowed_types,
|
||||
"document_max_size_mb": file_config.get("file_size_limit", 100) if file_config else 100,
|
||||
"document_allowed_extensions": ["pdf", "docx", "doc", "xlsx", "xls", "txt", "csv", "json", "md"],
|
||||
"video_enabled": "video" in allowed_types,
|
||||
"video_max_size_mb": file_config.get("video_file_size_limit", 100) if file_config else 100,
|
||||
"video_allowed_extensions": ["mp4", "mov"],
|
||||
"max_file_count": file_upload.get("number_limits", 1) if file_upload else 1,
|
||||
"allowed_transfer_methods": transfer_method,
|
||||
}
|
||||
|
||||
# text_to_speech
|
||||
tts = features.get("text_to_speech", {})
|
||||
result["text_to_speech"] = {
|
||||
"enabled": tts.get("enabled", False) if isinstance(tts, dict) else False,
|
||||
"voice": tts.get("voice") if isinstance(tts, dict) else None,
|
||||
"language": tts.get("language") if isinstance(tts, dict) else None,
|
||||
"autoplay": False,
|
||||
}
|
||||
|
||||
# suggested_questions_after_answer
|
||||
sqa = features.get("suggested_questions_after_answer", {})
|
||||
result["suggested_questions_after_answer"] = {
|
||||
"enabled": sqa.get("enabled", False) if isinstance(sqa, dict) else False,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -119,9 +119,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
if variable:
|
||||
self.conv_variables.append(con_var)
|
||||
|
||||
# for variables in config.get("workflow").get("environment_variables"):
|
||||
# variable = self._convert_variable(variables)
|
||||
# conv_variables.append(variable)
|
||||
# 开始节点的文件变量合并到会话变量
|
||||
self.conv_variables.extend(self._file_vars_to_conv)
|
||||
|
||||
features = self.convert_features(
|
||||
self.config.get("workflow", {}).get("features", {})
|
||||
)
|
||||
|
||||
trigger = self._convert_trigger({})
|
||||
execution_config = self._convert_execution({})
|
||||
@@ -135,6 +138,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
edges=self.edges,
|
||||
nodes=self.nodes,
|
||||
variables=self.conv_variables,
|
||||
features=features,
|
||||
warnings=self.warnings,
|
||||
errors=self.errors
|
||||
)
|
||||
|
||||
@@ -31,9 +31,9 @@ logger = logging.getLogger(__name__)
|
||||
# Example:
|
||||
# "Hello {{user.name}}!" ->
|
||||
# ["Hello ", "{{user.name}}", "!"]
|
||||
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
|
||||
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{]+|{')
|
||||
# Strict variable format: {{ node_id.field_name }}
|
||||
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
|
||||
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)?\s*}}')
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
logger = get_logger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}"
|
||||
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)?\s*}}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -34,19 +34,22 @@ class LazyVariableDict:
|
||||
return self._cache[key]
|
||||
var_struct = self._source.get(key)
|
||||
if var_struct is None:
|
||||
raise KeyError(key)
|
||||
value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
|
||||
return None
|
||||
raw = var_struct.instance.get_value()
|
||||
# literal 模式下 dict/list 保留结构,让 Jinja2 能继续访问子字段(如 .type)
|
||||
value = raw if (not self._literal or isinstance(raw, (dict, list))) else var_struct.instance.to_literal()
|
||||
self._cache[key] = value
|
||||
return value
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self._resolve(key)
|
||||
except KeyError:
|
||||
return default
|
||||
value = self._resolve(key)
|
||||
return default if value is None else value
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._resolve(key)
|
||||
value = self._resolve(key)
|
||||
if value is None:
|
||||
raise KeyError(key)
|
||||
return value
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key.startswith('_'):
|
||||
@@ -164,7 +167,7 @@ class VariablePool:
|
||||
def transform_selector(selector):
|
||||
variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
|
||||
selector = VariableSelector.from_string(variable_literal).path
|
||||
if len(selector) != 2:
|
||||
if len(selector) not in (2, 3):
|
||||
raise ValueError(f"Selector not valid - {selector}")
|
||||
return selector
|
||||
|
||||
@@ -196,6 +199,19 @@ class VariablePool:
|
||||
return None
|
||||
return var_instance
|
||||
|
||||
@staticmethod
|
||||
def _extract_field(struct: "VariableStruct", field: str | None) -> Any:
|
||||
"""If field is given, drill into a dict/object/array[file] variable's value."""
|
||||
if field is None:
|
||||
return struct.instance.get_value()
|
||||
value = struct.instance.get_value()
|
||||
# array[file]: extract the field from every element, return a list
|
||||
if isinstance(value, list):
|
||||
return [item.get(field) if isinstance(item, dict) else getattr(item, field, None) for item in value]
|
||||
if not isinstance(value, dict):
|
||||
raise KeyError(f"Variable is not an object or array, cannot access field '{field}'")
|
||||
return value.get(field)
|
||||
|
||||
def get_instance(
|
||||
self,
|
||||
selector: str,
|
||||
@@ -250,12 +266,14 @@ class VariablePool:
|
||||
Raises:
|
||||
KeyError: If strict is True and the variable does not exist.
|
||||
"""
|
||||
path = self.transform_selector(selector)
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
if strict:
|
||||
raise KeyError(f"{selector} not exist")
|
||||
return default
|
||||
|
||||
if len(path) == 3:
|
||||
return self._extract_field(variable_struct, path[2])
|
||||
return variable_struct.instance.get_value()
|
||||
|
||||
def get_literal(
|
||||
@@ -282,12 +300,15 @@ class VariablePool:
|
||||
Raises:
|
||||
KeyError: If strict is True and the variable does not exist.
|
||||
"""
|
||||
path = self.transform_selector(selector)
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
if strict:
|
||||
raise KeyError(f"{selector} not exist")
|
||||
return default
|
||||
|
||||
if len(path) == 3:
|
||||
value = self._extract_field(variable_struct, path[2])
|
||||
return str(value) if value is not None else ""
|
||||
return variable_struct.instance.to_literal()
|
||||
|
||||
async def set(
|
||||
@@ -345,7 +366,14 @@ class VariablePool:
|
||||
Returns:
|
||||
变量是否存在
|
||||
"""
|
||||
return self._get_variable_struct(selector) is not None
|
||||
path = self.transform_selector(selector)
|
||||
struct = self._get_variable_struct(selector)
|
||||
if struct is None:
|
||||
return False
|
||||
if len(path) == 3:
|
||||
value = struct.instance.get_value()
|
||||
return isinstance(value, dict) and path[2] in value
|
||||
return True
|
||||
|
||||
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
|
||||
return LazyVariableDict(self.variables.get(namespace, {}), literal)
|
||||
|
||||
@@ -28,86 +28,135 @@ class IterationRuntime:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_id: str,
|
||||
stream: bool,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool,
|
||||
child_variable_pool: VariablePool,
|
||||
cycle_nodes: list,
|
||||
cycle_edges: list,
|
||||
):
|
||||
"""
|
||||
Initialize the iteration runtime.
|
||||
|
||||
Args:
|
||||
graph: Compiled workflow graph capable of async invocation.
|
||||
node_id: Unique identifier of the loop node.
|
||||
config: Dictionary containing iteration node configuration.
|
||||
state: Current workflow state at the point of iteration.
|
||||
stream: Whether to run in streaming mode. When True, each iteration
|
||||
uses graph.astream and emits cycle_item events in real time.
|
||||
When False, graph.ainvoke is used instead.
|
||||
node_id: The unique identifier of the iteration node in the workflow.
|
||||
Also used as the variable namespace for item/index inside
|
||||
the subgraph (e.g. {{ node_id.item }}).
|
||||
config: Raw configuration dict for the iteration node, parsed into
|
||||
IterationNodeConfig. Controls input/output variable selectors,
|
||||
parallel execution settings, and output flattening.
|
||||
state: The parent workflow state at the point the iteration node is
|
||||
entered. Each task receives a copy of this state as its
|
||||
starting point.
|
||||
variable_pool: The parent VariablePool containing all variables available
|
||||
at the time the iteration node executes, including sys.*,
|
||||
conv.*, and outputs from upstream nodes. Used as the source
|
||||
for deep-copying into each task's independent child pool.
|
||||
cycle_nodes: List of node config dicts belonging to this iteration's
|
||||
subgraph (i.e. nodes whose cycle field equals node_id).
|
||||
Passed to GraphBuilder when constructing each task's subgraph.
|
||||
cycle_edges: List of edge config dicts connecting nodes within the subgraph.
|
||||
Passed to GraphBuilder alongside cycle_nodes.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.stream = stream
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
self.typed_config = IterationNodeConfig(**config)
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
self.cycle_nodes = cycle_nodes
|
||||
self.cycle_edges = cycle_edges
|
||||
self.event_write = get_stream_writer()
|
||||
self.checkpoint = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4()
|
||||
}
|
||||
)
|
||||
|
||||
self.output_value = None
|
||||
self.result: list = []
|
||||
|
||||
async def _init_iteration_state(self, item, idx):
|
||||
def _build_child_graph(self) -> tuple[CompiledStateGraph, VariablePool, str]:
|
||||
"""
|
||||
Initialize a per-iteration copy of the workflow state.
|
||||
Build an independent compiled subgraph for a single iteration task.
|
||||
|
||||
Args:
|
||||
item: Current element from the input array for this iteration.
|
||||
idx: Index of the element in the input array.
|
||||
Each call creates a brand-new VariablePool by deep-copying the parent pool,
|
||||
then passes it to GraphBuilder. GraphBuilder binds this pool to every node's
|
||||
execution closure at build time, so the pool and the subgraph always reference
|
||||
the same object. This is the key design invariant: item/index written into the
|
||||
pool after build will be visible to all nodes inside the subgraph.
|
||||
|
||||
Returns:
|
||||
A copy of the workflow state with iteration-specific variables set.
|
||||
graph: The compiled LangGraph subgraph ready for invocation.
|
||||
child_pool: The VariablePool bound to this subgraph's node closures.
|
||||
Callers must write item/index into this pool before invoking
|
||||
the graph, and read output from it after invocation.
|
||||
start_node_id: The ID of the CYCLE_START node inside the subgraph,
|
||||
used to set the initial activation signal in workflow state.
|
||||
"""
|
||||
loopstate = WorkflowState(
|
||||
**self.state
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
child_pool = VariablePool()
|
||||
child_pool.copy(self.variable_pool)
|
||||
builder = GraphBuilder(
|
||||
{"nodes": self.cycle_nodes, "edges": self.cycle_edges},
|
||||
stream=self.stream,
|
||||
variable_pool=child_pool,
|
||||
cycle=self.node_id,
|
||||
)
|
||||
self.child_variable_pool.copy(self.variable_pool)
|
||||
await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
|
||||
await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True)
|
||||
loopstate["node_outputs"][self.node_id] = {
|
||||
"item": item,
|
||||
"index": idx,
|
||||
}
|
||||
graph = builder.build()
|
||||
return graph, builder.variable_pool, builder.start_node_id
|
||||
|
||||
async def _init_iteration_state(self, item, idx, child_pool: VariablePool, start_id: str):
|
||||
"""
|
||||
Initialize the workflow state for a single iteration.
|
||||
|
||||
Writes the current item and its index into child_pool under the iteration
|
||||
node's namespace (e.g. iteration_xxx.item, iteration_xxx.index), making them
|
||||
accessible to downstream nodes inside the subgraph via variable selectors.
|
||||
|
||||
Also prepares a copy of the parent workflow state with:
|
||||
- node_outputs[node_id] set to {item, index} so the state snapshot is consistent
|
||||
with the pool values.
|
||||
- looping flag set to 1 (active) to signal the subgraph is inside a cycle.
|
||||
- activate[start_id] set to True to trigger the CYCLE_START node.
|
||||
|
||||
Args:
|
||||
item: The current element from the input array.
|
||||
idx: The zero-based index of this element in the input array.
|
||||
child_pool: The VariablePool bound to this iteration's subgraph.
|
||||
Must be the same object returned by _build_child_graph.
|
||||
start_id: The ID of the CYCLE_START node inside the subgraph.
|
||||
|
||||
Returns:
|
||||
A WorkflowState instance ready to be passed to graph.ainvoke or graph.astream.
|
||||
"""
|
||||
loopstate = WorkflowState(**self.state)
|
||||
await child_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
|
||||
await child_pool.new(self.node_id, "index", idx, VariableType.type_map(idx), mut=True)
|
||||
loopstate["node_outputs"][self.node_id] = {"item": item, "index": idx}
|
||||
loopstate["looping"] = 1
|
||||
loopstate["activate"][self.start_id] = True
|
||||
loopstate["activate"][start_id] = True
|
||||
return loopstate
|
||||
|
||||
def merge_conv_vars(self):
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables["conv"]
|
||||
)
|
||||
def _merge_conv_vars(self, child_pool: VariablePool):
|
||||
self.variable_pool.variables["conv"].update(child_pool.variables["conv"])
|
||||
|
||||
async def run_task(self, item, idx):
|
||||
"""
|
||||
Execute a single iteration asynchronously.
|
||||
Each task builds its own subgraph so the variable pool closure is independent.
|
||||
|
||||
Args:
|
||||
item: The input element for this iteration.
|
||||
idx: The index of this iteration.
|
||||
Returns:
|
||||
Tuple of (idx, output, result, child_pool, stopped)
|
||||
"""
|
||||
graph, child_pool, start_id = self._build_child_graph()
|
||||
checkpoint = RunnableConfig(configurable={"thread_id": uuid.uuid4()})
|
||||
init_state = await self._init_iteration_state(item, idx, child_pool, start_id)
|
||||
|
||||
if self.stream:
|
||||
async for event in self.graph.astream(
|
||||
await self._init_iteration_state(item, idx),
|
||||
async for event in graph.astream(
|
||||
init_state,
|
||||
stream_mode=["debug"],
|
||||
config=self.checkpoint
|
||||
config=checkpoint
|
||||
):
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
@@ -117,7 +166,6 @@ class IterationRuntime:
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if node_name and node_name.startswith("nop"):
|
||||
continue
|
||||
if event_type == "task_result":
|
||||
@@ -140,17 +188,13 @@ class IterationRuntime:
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
})
|
||||
result = self.graph.get_state(config=self.checkpoint).values
|
||||
result = graph.get_state(config=checkpoint).values
|
||||
else:
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
return result
|
||||
result = await graph.ainvoke(init_state)
|
||||
|
||||
output = child_pool.get_value(self.output_value)
|
||||
stopped = result["looping"] == 2
|
||||
return idx, output, result, child_pool, stopped
|
||||
|
||||
def _create_iteration_tasks(self, array_obj, idx):
|
||||
"""
|
||||
@@ -196,16 +240,32 @@ class IterationRuntime:
|
||||
tasks = self._create_iteration_tasks(array_obj, idx)
|
||||
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
|
||||
idx += self.typed_config.parallel_count
|
||||
child_state.extend(await asyncio.gather(*tasks))
|
||||
self.merge_conv_vars()
|
||||
batch = await asyncio.gather(*tasks)
|
||||
# Sort by idx to preserve order, then collect results
|
||||
batch_sorted = sorted(batch, key=lambda x: x[0])
|
||||
for _, output, result, child_pool, stopped in batch_sorted:
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
child_state.append(result)
|
||||
self._merge_conv_vars(child_pool)
|
||||
if stopped:
|
||||
self.looping = False
|
||||
else:
|
||||
# Execute iterations sequentially
|
||||
while idx < len(array_obj) and self.looping:
|
||||
logger.info(f"Iteration node {self.node_id}: running")
|
||||
item = array_obj[idx]
|
||||
result = await self.run_task(item, idx)
|
||||
self.merge_conv_vars()
|
||||
_, output, result, child_pool, stopped = await self.run_task(item, idx)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
self._merge_conv_vars(child_pool)
|
||||
child_state.append(result)
|
||||
if stopped:
|
||||
self.looping = False
|
||||
idx += 1
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return {
|
||||
|
||||
@@ -55,9 +55,9 @@ class CycleGraphNode(BaseNode):
|
||||
if config.output_type in [
|
||||
VariableType.ARRAY_FILE,
|
||||
VariableType.ARRAY_STRING,
|
||||
VariableType.NUMBER,
|
||||
VariableType.ARRAY_NUMBER,
|
||||
VariableType.ARRAY_OBJECT,
|
||||
VariableType.BOOLEAN
|
||||
VariableType.ARRAY_BOOLEAN
|
||||
]:
|
||||
if config.flatten:
|
||||
outputs['output'] = config.output_type
|
||||
@@ -123,7 +123,7 @@ class CycleGraphNode(BaseNode):
|
||||
|
||||
return cycle_nodes, cycle_edges
|
||||
|
||||
def build_graph(self):
|
||||
def build_graph(self, variable_pool: VariablePool):
|
||||
"""
|
||||
Build and compile the internal subgraph for this cycle node.
|
||||
|
||||
@@ -135,6 +135,7 @@ class CycleGraphNode(BaseNode):
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
|
||||
self.child_variable_pool = VariablePool()
|
||||
self.child_variable_pool.copy(variable_pool)
|
||||
builder = GraphBuilder(
|
||||
{
|
||||
"nodes": self.cycle_nodes,
|
||||
@@ -165,8 +166,8 @@ class CycleGraphNode(BaseNode):
|
||||
Raises:
|
||||
RuntimeError: If the node type is unsupported.
|
||||
"""
|
||||
self.build_graph()
|
||||
if self.node_type == NodeType.LOOP:
|
||||
self.build_graph(variable_pool)
|
||||
return await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=False,
|
||||
@@ -179,20 +180,19 @@ class CycleGraphNode(BaseNode):
|
||||
).run()
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
return await IterationRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=False,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool
|
||||
cycle_nodes=self.cycle_nodes,
|
||||
cycle_edges=self.cycle_edges,
|
||||
).run()
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
self.build_graph()
|
||||
if self.node_type == NodeType.LOOP:
|
||||
self.build_graph(variable_pool)
|
||||
yield {
|
||||
"__final__": True,
|
||||
"result": await LoopRuntime(
|
||||
@@ -211,14 +211,13 @@ class CycleGraphNode(BaseNode):
|
||||
yield {
|
||||
"__final__": True,
|
||||
"result": await IterationRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=True,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool
|
||||
cycle_nodes=self.cycle_nodes,
|
||||
cycle_edges=self.cycle_edges,
|
||||
).run()
|
||||
}
|
||||
return
|
||||
|
||||
@@ -72,8 +72,9 @@ class HttpContentTypeConfig(BaseModel):
|
||||
@classmethod
|
||||
def validate_data(cls, v, info):
|
||||
content_type = info.data.get("content_type")
|
||||
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
|
||||
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
|
||||
if content_type == HttpContentType.FROM_DATA and (
|
||||
not isinstance(v, list) or not all(isinstance(item, HttpFormData) for item in v)):
|
||||
raise ValueError("When content_type is 'form-data', data must be a list of HttpFormData")
|
||||
elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
|
||||
raise ValueError("When content_type is JSON, data must be of type str")
|
||||
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):
|
||||
|
||||
@@ -260,17 +260,22 @@ class HttpRequestNode(BaseNode):
|
||||
))
|
||||
case HttpContentType.FROM_DATA:
|
||||
data = {}
|
||||
content["files"] = {}
|
||||
files = []
|
||||
for item in self.typed_config.body.data:
|
||||
key = self._render_template(item.key, variable_pool)
|
||||
if item.type == "text":
|
||||
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
|
||||
variable_pool)
|
||||
data[key] = self._render_template(item.value, variable_pool)
|
||||
elif item.type == "file":
|
||||
content["files"][self._render_template(item.key, variable_pool)] = (
|
||||
uuid.uuid4().hex,
|
||||
await variable_pool.get_instance(item.value).get_content()
|
||||
)
|
||||
file_instance = variable_pool.get_instance(item.value)
|
||||
if isinstance(file_instance, ArrayVariable):
|
||||
for v in file_instance.value:
|
||||
if isinstance(v, FileVariable):
|
||||
files.append((key, (uuid.uuid4().hex, await v.get_content())))
|
||||
elif isinstance(file_instance, FileVariable):
|
||||
files.append((key, (uuid.uuid4().hex, await file_instance.get_content())))
|
||||
content["data"] = data
|
||||
if files:
|
||||
content["files"] = files
|
||||
case HttpContentType.BINARY:
|
||||
content["files"] = []
|
||||
file_instence = variable_pool.get_instance(self.typed_config.body.data)
|
||||
|
||||
@@ -6,6 +6,30 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||
|
||||
|
||||
class SubVariableConditionItem(BaseModel):
|
||||
"""A single condition on a file object's field, used inside sub_variable_condition."""
|
||||
key: str = Field(..., description="Field name of the file object, e.g. type, size, name")
|
||||
operator: ComparisonOperator = Field(..., description="Comparison operator")
|
||||
value: Any = Field(default=None, description="Value to compare with, or variable selector when input_type=variable")
|
||||
input_type: ValueInputType = Field(default=ValueInputType.CONSTANT, description="constant or variable")
|
||||
|
||||
@field_validator("input_type", mode="before")
|
||||
@classmethod
|
||||
def lower_input_type(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return ValueInputType(v.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid input_type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class SubVariableCondition(BaseModel):
|
||||
"""Sub-conditions applied to each file element in an array[file] variable."""
|
||||
logical_operator: LogicOperator = Field(default=LogicOperator.AND)
|
||||
conditions: list[SubVariableConditionItem] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ConditionDetail(BaseModel):
|
||||
operator: ComparisonOperator = Field(
|
||||
...,
|
||||
@@ -14,12 +38,12 @@ class ConditionDetail(BaseModel):
|
||||
|
||||
left: str = Field(
|
||||
...,
|
||||
description="Value to compare against"
|
||||
description="Variable selector, e.g. {{sys.files}}"
|
||||
)
|
||||
|
||||
right: Any = Field(
|
||||
default=None,
|
||||
description="Value to compare with"
|
||||
description="Value to compare with (unused when sub_variable_condition is set)"
|
||||
)
|
||||
|
||||
input_type: ValueInputType = Field(
|
||||
@@ -27,6 +51,11 @@ class ConditionDetail(BaseModel):
|
||||
description="Value input type for comparison"
|
||||
)
|
||||
|
||||
sub_variable_condition: SubVariableCondition | None = Field(
|
||||
default=None,
|
||||
description="Sub-conditions for array[file] fields. When set, operator must be contains/not_contains."
|
||||
)
|
||||
|
||||
@field_validator("input_type", mode="before")
|
||||
@classmethod
|
||||
def lower_input_type(cls, v):
|
||||
@@ -39,16 +68,19 @@ class ConditionDetail(BaseModel):
|
||||
|
||||
|
||||
class ConditionBranchConfig(BaseModel):
|
||||
"""Configuration for a conditional branch"""
|
||||
"""Configuration for a conditional branch.
|
||||
|
||||
logical_operator controls how all expressions are combined (AND/OR).
|
||||
"""
|
||||
|
||||
logical_operator: LogicOperator = Field(
|
||||
default=LogicOperator.AND,
|
||||
description="Logical operator used to combine multiple condition expressions"
|
||||
description="Logical operator used to combine all conditions"
|
||||
)
|
||||
|
||||
expressions: list[ConditionDetail] = Field(
|
||||
...,
|
||||
description="List of condition expressions within this branch"
|
||||
default_factory=list,
|
||||
description="List of conditions within this branch"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance, ArrayFileContainsOperator
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -90,11 +90,9 @@ class IfElseNode(BaseNode):
|
||||
list[str]: A list of Python boolean expression strings,
|
||||
ordered by branch priority.
|
||||
"""
|
||||
branch_index = 0
|
||||
conditions = []
|
||||
|
||||
for case_branch in self.typed_config.cases:
|
||||
branch_index += 1
|
||||
branch_result = []
|
||||
for expression in case_branch.expressions:
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
@@ -103,13 +101,18 @@ class IfElseNode(BaseNode):
|
||||
left_value = self.get_variable(left_string, variable_pool)
|
||||
except KeyError:
|
||||
left_value = None
|
||||
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||
variable_pool,
|
||||
expression.left,
|
||||
expression.right,
|
||||
expression.input_type
|
||||
)
|
||||
|
||||
if expression.sub_variable_condition is not None and isinstance(left_value, list):
|
||||
evaluator = ArrayFileContainsOperator(left_value, expression.sub_variable_condition, variable_pool)
|
||||
else:
|
||||
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||
variable_pool,
|
||||
expression.left,
|
||||
expression.right,
|
||||
expression.input_type
|
||||
)
|
||||
branch_result.append(self._evaluate(expression.operator, evaluator))
|
||||
|
||||
if case_branch.logical_operator == LogicOperator.AND:
|
||||
conditions.append(all(branch_result))
|
||||
else:
|
||||
|
||||
@@ -8,6 +8,8 @@ from langchain_core.documents import Document
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
@@ -39,8 +41,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
if isinstance(business_result, dict) and "chunks" in business_result:
|
||||
return business_result["chunks"]
|
||||
return business_result
|
||||
|
||||
def _extract_citations(self, business_result: Any) -> list:
|
||||
|
||||
@staticmethod
|
||||
def _extract_citations(business_result: Any) -> list:
|
||||
if isinstance(business_result, dict):
|
||||
return business_result.get("citations", [])
|
||||
return []
|
||||
@@ -230,23 +233,23 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
}
|
||||
)
|
||||
)
|
||||
case RetrieveType.HYBRID:
|
||||
case retrieve_type if retrieve_type in (RetrieveType.HYBRID, RetrieveType.Graph):
|
||||
rs1_task = asyncio.to_thread(
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
rs2_task = asyncio.to_thread(
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
|
||||
|
||||
# Deduplicate hybrid retrieval results
|
||||
@@ -266,6 +269,33 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
key=lambda d: d.metadata.get("score", 0),
|
||||
reverse=True
|
||||
)[:kb_config.top_k])
|
||||
if kb_config.retrieve_type == RetrieveType.Graph:
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
llm_key = self.model_balance(db_knowledge.llm)
|
||||
emb_key = self.model_balance(db_knowledge.embedding)
|
||||
chat_model = Base(
|
||||
key=llm_key.api_key,
|
||||
model_name=llm_key.model_name,
|
||||
base_url=llm_key.api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=emb_key.api_key,
|
||||
model_name=emb_key.model_name,
|
||||
base_url=emb_key.api_base
|
||||
)
|
||||
doc = await asyncio.to_thread(
|
||||
kg_retriever.retrieval,
|
||||
question=query,
|
||||
workspace_ids=[str(db_knowledge.workspace_id)],
|
||||
kb_ids=[str(kb_config.kb_id)],
|
||||
emb_mdl=embedding_model,
|
||||
llm=chat_model
|
||||
)
|
||||
if doc:
|
||||
rs.insert(0, DocumentChunk(
|
||||
page_content=doc.get("page_content", ""),
|
||||
metadata=doc.get("metadata", {})
|
||||
))
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
return rs
|
||||
|
||||
@@ -116,6 +116,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
description="Top-p 采样参数"
|
||||
)
|
||||
|
||||
json_output: bool = Field(
|
||||
default=False,
|
||||
description="是否以 JSON 格式输出"
|
||||
)
|
||||
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
|
||||
@@ -5,7 +5,6 @@ LLM 节点实现
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
@@ -22,6 +21,7 @@ from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.models.models_model import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -80,7 +80,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
def _render_context(self, message: str, variable_pool: VariablePool):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
||||
return re.sub(r"{{context}}", context, message)
|
||||
return message.replace("{{context}}", context)
|
||||
|
||||
async def _prepare_llm(
|
||||
self,
|
||||
@@ -126,7 +126,11 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||
extra_params = {"streaming": stream} if stream else {}
|
||||
extra_params: dict[str, Any] = {"streaming": stream} if stream else {}
|
||||
if self.typed_config.temperature is not None:
|
||||
extra_params["temperature"] = self.typed_config.temperature
|
||||
if self.typed_config.max_tokens is not None:
|
||||
extra_params["max_tokens"] = self.typed_config.max_tokens
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
@@ -135,7 +139,9 @@ class LLMNode(BaseNode):
|
||||
api_key=model_info.api_key,
|
||||
base_url=model_info.api_base,
|
||||
extra_params=extra_params,
|
||||
is_omni=model_info.is_omni
|
||||
is_omni=model_info.is_omni,
|
||||
capability=model_info.capability,
|
||||
json_output=self.typed_config.json_output,
|
||||
),
|
||||
type=model_info.model_type
|
||||
)
|
||||
@@ -218,6 +224,19 @@ class LLMNode(BaseNode):
|
||||
rendered = self._render_template(prompt_template, variable_pool)
|
||||
self.messages = [{"role": "user", "content": rendered}]
|
||||
|
||||
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format,在 system prompt 中注入
|
||||
# VOLCANO 模型不支持 response_format,同样需要 system prompt 注入
|
||||
need_json_prompt = self.typed_config.json_output and (
|
||||
(model_info.provider.lower() == ModelProvider.DASHSCOPE and not model_info.is_omni)
|
||||
or model_info.provider.lower() == ModelProvider.VOLCANO
|
||||
)
|
||||
if need_json_prompt:
|
||||
system_msg = next((m for m in self.messages if m["role"] == "system"), None)
|
||||
if system_msg:
|
||||
system_msg["content"] += "\n请以JSON格式输出。"
|
||||
else:
|
||||
self.messages.insert(0, {"role": "system", "content": "请以JSON格式输出。"})
|
||||
|
||||
return llm
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
||||
|
||||
@@ -395,11 +395,73 @@ class NoneObjectComparisonOperator:
|
||||
return lambda *args, **kwargs: False
|
||||
|
||||
|
||||
class ArrayFileContainsOperator:
|
||||
"""Handles contains/not_contains on array[file] with sub_variable_condition."""
|
||||
|
||||
def __init__(self, left_value: list[dict], sub_variable_condition: Any, pool: VariablePool | None = None):
|
||||
self.left_value = left_value
|
||||
self.sub_variable_condition = sub_variable_condition
|
||||
self.pool = pool
|
||||
|
||||
def _resolve_value(self, cond: Any) -> Any:
|
||||
if cond.input_type == ValueInputType.VARIABLE and self.pool is not None:
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
selector = re.sub(pattern, r"\1", str(cond.value)).strip()
|
||||
return self.pool.get_value(selector, default=None, strict=False)
|
||||
return cond.value
|
||||
|
||||
def _match_item(self, file_item: dict) -> bool:
|
||||
results = []
|
||||
for cond in self.sub_variable_condition.conditions:
|
||||
field_val = file_item.get(cond.key)
|
||||
expected = self._resolve_value(cond)
|
||||
result = self._eval_sub(field_val, cond.operator.value, expected)
|
||||
results.append(result)
|
||||
if self.sub_variable_condition.logical_operator.value == "and":
|
||||
return all(results)
|
||||
return any(results)
|
||||
|
||||
@staticmethod
|
||||
def _eval_sub(field_val: Any, op: str, expected: Any) -> bool:
|
||||
if field_val is None:
|
||||
return op == "empty"
|
||||
match op:
|
||||
case "eq": return str(field_val) == str(expected)
|
||||
case "ne": return str(field_val) != str(expected)
|
||||
case "contains": return isinstance(field_val, str) and str(expected) in field_val
|
||||
case "not_contains": return isinstance(field_val, str) and str(expected) not in field_val
|
||||
case "in": return field_val in (expected if isinstance(expected, list) else [expected])
|
||||
case "not_in": return field_val not in (expected if isinstance(expected, list) else [expected])
|
||||
case "gt": return isinstance(field_val, (int, float)) and field_val > float(expected)
|
||||
case "ge": return isinstance(field_val, (int, float)) and field_val >= float(expected)
|
||||
case "lt": return isinstance(field_val, (int, float)) and field_val < float(expected)
|
||||
case "le": return isinstance(field_val, (int, float)) and field_val <= float(expected)
|
||||
case "empty": return field_val in (None, "", 0)
|
||||
case "not_empty": return field_val not in (None, "", 0)
|
||||
case _: return False
|
||||
|
||||
def contains(self) -> bool:
|
||||
return any(self._match_item(f) for f in self.left_value if isinstance(f, dict))
|
||||
|
||||
def not_contains(self) -> bool:
|
||||
return not self.contains()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return not self.left_value
|
||||
|
||||
def not_empty(self) -> bool:
|
||||
return bool(self.left_value)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: False
|
||||
|
||||
|
||||
CompareOperatorInstance = Union[
|
||||
StringComparisonOperator,
|
||||
NumberComparisonOperator,
|
||||
BooleanComparisonOperator,
|
||||
ArrayComparisonOperator,
|
||||
ArrayFileContainsOperator,
|
||||
ObjectComparisonOperator
|
||||
]
|
||||
CompareOperatorType = Type[CompareOperatorInstance]
|
||||
|
||||
@@ -11,10 +11,12 @@ from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db_read
|
||||
from app.services.tool_service import ToolService
|
||||
from app.models.tool_model import ToolType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
||||
PURE_VARIABLE_PATTERN = re.compile(r"^\{\{\s*([\w.]+)\s*}}$")
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
@@ -52,13 +54,21 @@ class ToolNode(BaseNode):
|
||||
# 渲染工具参数
|
||||
rendered_parameters = {}
|
||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
|
||||
try:
|
||||
rendered_value = self._render_template(param_template, variable_pool)
|
||||
except Exception as e:
|
||||
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||
if isinstance(param_template, str):
|
||||
pure_match = PURE_VARIABLE_PATTERN.match(param_template)
|
||||
if pure_match:
|
||||
# 纯单变量引用直接取原始值,保留 int/bool/float 等类型
|
||||
rendered_value = self.get_variable(pure_match.group(1), variable_pool, strict=False)
|
||||
if rendered_value is None:
|
||||
rendered_value = self._render_template(param_template, variable_pool)
|
||||
elif TEMPLATE_PATTERN.search(param_template):
|
||||
try:
|
||||
rendered_value = self._render_template(param_template, variable_pool)
|
||||
except Exception as e:
|
||||
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||
else:
|
||||
rendered_value = param_template
|
||||
else:
|
||||
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
||||
rendered_value = param_template
|
||||
rendered_parameters[param_name] = rendered_value
|
||||
|
||||
@@ -67,6 +77,18 @@ class ToolNode(BaseNode):
|
||||
# 执行工具
|
||||
with get_db_read() as db:
|
||||
tool_service = ToolService(db)
|
||||
|
||||
# MCP 工具:将 operation 映射为 tool_name,其余参数包装进 arguments
|
||||
tool_instance = tool_service.get_tool_instance(self.typed_config.tool_id, tenant_id)
|
||||
if tool_instance and tool_instance.tool_type == ToolType.MCP:
|
||||
operation = rendered_parameters.pop("operation", None)
|
||||
if operation:
|
||||
old_params = rendered_parameters
|
||||
rendered_parameters = {
|
||||
"tool_name": operation,
|
||||
"arguments": old_params
|
||||
}
|
||||
|
||||
result = await tool_service.execute_tool(
|
||||
tool_id=self.typed_config.tool_id,
|
||||
parameters=rendered_parameters,
|
||||
|
||||
@@ -84,7 +84,7 @@ class FileVariable(BaseVariable):
|
||||
total_bytes = 0
|
||||
chunks = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
async with client.stream("GET", self.value.url) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes(8192):
|
||||
|
||||
@@ -6,12 +6,14 @@ error messages based on the current request's language.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from app.i18n.service import get_translation_service
|
||||
from app.core.error_codes import ERROR_CODE_TO_BIZ_CODE, BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -118,15 +120,24 @@ class I18nException(HTTPException):
|
||||
**params
|
||||
)
|
||||
|
||||
# Build error detail
|
||||
detail = {
|
||||
"error_code": self.error_code,
|
||||
"message": message,
|
||||
}
|
||||
# Convert error_code string to BizCode value
|
||||
biz_code = ERROR_CODE_TO_BIZ_CODE.get(
|
||||
self.error_code,
|
||||
BizCode.BAD_REQUEST
|
||||
)
|
||||
|
||||
# Add parameters to detail if provided
|
||||
if params:
|
||||
detail["params"] = params
|
||||
# Build error detail in standard format for compatibility
|
||||
# main.py handler expects "message" and "error_code" fields for filtering
|
||||
# but we also include standard format fields
|
||||
detail = {
|
||||
"code": biz_code.value,
|
||||
"msg": message,
|
||||
"message": message,
|
||||
"error_code": self.error_code,
|
||||
"data": params if params else {},
|
||||
"error": message,
|
||||
"time": int(time.time() * 1000),
|
||||
}
|
||||
|
||||
# Initialize HTTPException
|
||||
super().__init__(
|
||||
@@ -482,14 +493,39 @@ class RateLimitExceededError(I18nException):
|
||||
)
|
||||
|
||||
|
||||
class QuotaExceededError(ForbiddenError):
|
||||
"""Quota exceeded error."""
|
||||
class QuotaExceededError(I18nException):
|
||||
"""Quota exceeded error (402)."""
|
||||
|
||||
# resource key -> i18n display key
|
||||
_RESOURCE_KEY_MAP = {
|
||||
"workspace": "errors.quota_resources.workspace",
|
||||
"app": "errors.quota_resources.app",
|
||||
"skill": "errors.quota_resources.skill",
|
||||
"knowledge_capacity": "errors.quota_resources.knowledge_capacity",
|
||||
"memory_engine": "errors.quota_resources.memory_engine",
|
||||
"end_user": "errors.quota_resources.end_user",
|
||||
"model": "errors.quota_resources.model",
|
||||
"ontology_project": "errors.quota_resources.ontology_project",
|
||||
"api_ops_rate_limit": "errors.quota_resources.api_ops_rate_limit",
|
||||
}
|
||||
|
||||
def __init__(self, resource: Optional[str] = None, **params):
|
||||
# Translate resource key to a localized display name before calling super()
|
||||
if resource:
|
||||
params["resource"] = resource
|
||||
resource_i18n_key = self._RESOURCE_KEY_MAP.get(resource)
|
||||
if resource_i18n_key:
|
||||
try:
|
||||
from app.i18n.service import get_translation_service
|
||||
from app.core.config import settings
|
||||
_locale = _current_locale.get() or settings.I18N_DEFAULT_LANGUAGE
|
||||
params["resource"] = get_translation_service().translate(resource_i18n_key, _locale)
|
||||
except Exception:
|
||||
params["resource"] = resource
|
||||
else:
|
||||
params["resource"] = resource
|
||||
super().__init__(
|
||||
error_key="errors.api.quota_exceeded",
|
||||
status_code=402,
|
||||
error_code="QUOTA_EXCEEDED",
|
||||
**params
|
||||
)
|
||||
|
||||
@@ -106,7 +106,7 @@
|
||||
},
|
||||
"api": {
|
||||
"rate_limit_exceeded": "API rate limit exceeded",
|
||||
"quota_exceeded": "API quota exceeded",
|
||||
"quota_exceeded": "{resource} quota exceeded",
|
||||
"invalid_api_key": "Invalid API key",
|
||||
"api_key_expired": "API key has expired",
|
||||
"api_key_revoked": "API key has been revoked",
|
||||
@@ -114,7 +114,8 @@
|
||||
"method_not_allowed": "Method not allowed",
|
||||
"invalid_request": "Invalid request",
|
||||
"missing_parameter": "Missing required parameter: {param}",
|
||||
"invalid_parameter": "Invalid parameter: {param}"
|
||||
"invalid_parameter": "Invalid parameter: {param}",
|
||||
"api_key_rate_limit_exceeded": "API Key rate limit ({rate_limit}) exceeds tenant plan limit ({limit})"
|
||||
},
|
||||
"database": {
|
||||
"connection_failed": "Database connection failed",
|
||||
@@ -134,5 +135,16 @@
|
||||
"invalid_format": "Invalid format: {field}",
|
||||
"invalid_value": "Invalid value: {field}",
|
||||
"out_of_range": "Value out of range: {field}"
|
||||
},
|
||||
"quota_resources": {
|
||||
"workspace": "Workspace",
|
||||
"app": "App",
|
||||
"skill": "Skill",
|
||||
"knowledge_capacity": "Knowledge capacity",
|
||||
"memory_engine": "Memory engine",
|
||||
"end_user": "End user",
|
||||
"model": "Model",
|
||||
"ontology_project": "Ontology project",
|
||||
"api_ops_rate_limit": "API ops rate limit"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@
|
||||
},
|
||||
"api": {
|
||||
"rate_limit_exceeded": "API调用频率超限",
|
||||
"quota_exceeded": "API调用配额已用完",
|
||||
"quota_exceeded": "{resource} 配额已超限",
|
||||
"invalid_api_key": "无效的API密钥",
|
||||
"api_key_expired": "API密钥已过期",
|
||||
"api_key_revoked": "API密钥已被撤销",
|
||||
@@ -114,7 +114,8 @@
|
||||
"method_not_allowed": "不支持的请求方法",
|
||||
"invalid_request": "无效的请求",
|
||||
"missing_parameter": "缺少必需参数:{param}",
|
||||
"invalid_parameter": "参数无效:{param}"
|
||||
"invalid_parameter": "参数无效:{param}",
|
||||
"api_key_rate_limit_exceeded": "API Key 的 QPS 限制({rate_limit})超过租户套餐上限({limit})"
|
||||
},
|
||||
"database": {
|
||||
"connection_failed": "数据库连接失败",
|
||||
@@ -134,5 +135,16 @@
|
||||
"invalid_format": "格式不正确:{field}",
|
||||
"invalid_value": "值无效:{field}",
|
||||
"out_of_range": "值超出范围:{field}"
|
||||
},
|
||||
"quota_resources": {
|
||||
"workspace": "工作空间",
|
||||
"app": "应用",
|
||||
"skill": "技能",
|
||||
"knowledge_capacity": "知识库容量",
|
||||
"memory_engine": "记忆引擎",
|
||||
"end_user": "终端用户",
|
||||
"model": "模型",
|
||||
"ontology_project": "本体工程",
|
||||
"api_ops_rate_limit": "API 操作速率"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,11 +29,8 @@ class Tenants(Base):
|
||||
contact_email = Column(String(255), nullable=True) # 联系人邮箱
|
||||
contact_phone = Column(String(50), nullable=True) # 联系人电话
|
||||
|
||||
# 租户套餐信息
|
||||
plan = Column(String(50), nullable=True) # 套餐类型
|
||||
plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间
|
||||
api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制
|
||||
status = Column(String(50), nullable=True, default='active') # 租户状态
|
||||
# 租户套餐信息(只读,从 tenant_subscriptions 动态获取)
|
||||
status = Column(String(50), nullable=True, default='active', server_default='active') # 租户状态
|
||||
|
||||
# Relationship to users - one tenant has many users
|
||||
users = relationship("User", back_populates="tenant")
|
||||
|
||||
@@ -61,3 +61,15 @@ def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
|
||||
"""根据工作空间ID查询应用"""
|
||||
repo = AppRepository(db)
|
||||
return repo.get_apps_by_id(app_id)
|
||||
|
||||
|
||||
def get_release_by_id(db: Session, app_id: uuid.UUID, release_id: uuid.UUID):
|
||||
"""根据发布版本ID查询发布快照(仅返回激活状态)"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
return db.scalars(
|
||||
select(AppRelease).where(
|
||||
AppRelease.app_id == app_id,
|
||||
AppRelease.id == release_id,
|
||||
AppRelease.is_active.is_(True),
|
||||
)
|
||||
).first()
|
||||
|
||||
@@ -66,6 +66,17 @@ class EndUserRepository:
|
||||
db_logger.error(f"查询宿主 {end_user_id} 时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_end_user_by_other_id(self, workspace_id: uuid.UUID, other_id: str) -> Optional["EndUser"]:
|
||||
"""按 workspace_id + other_id 查找终端用户,不存在返回 None"""
|
||||
return (
|
||||
self.db.query(EndUser)
|
||||
.filter(
|
||||
EndUser.workspace_id == workspace_id,
|
||||
EndUser.other_id == other_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_or_create_end_user(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
|
||||
@@ -5,16 +5,9 @@ Implicit Emotions Storage Repository
|
||||
事务由调用方控制,仓储层只使用 flush/refresh
|
||||
"""
|
||||
import logging
|
||||
from datetime import date, datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Generator, Optional
|
||||
|
||||
|
||||
class TimeFilterUnavailableError(Exception):
|
||||
"""redis_client 不可用,无法执行时间轴筛选。
|
||||
|
||||
调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。
|
||||
"""
|
||||
|
||||
import redis
|
||||
from sqlalchemy import exists, not_, select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -25,6 +18,13 @@ from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimeFilterUnavailableError(Exception):
|
||||
"""redis_client 不可用,无法执行时间轴筛选。
|
||||
|
||||
调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。
|
||||
"""
|
||||
|
||||
|
||||
class ImplicitEmotionsStorageRepository:
|
||||
"""隐性记忆和情绪存储仓储类"""
|
||||
|
||||
@@ -216,9 +216,7 @@ class ImplicitEmotionsStorageRepository:
|
||||
"""
|
||||
from sqlalchemy import String as SAString
|
||||
from sqlalchemy import cast
|
||||
CST = timezone(timedelta(hours=8))
|
||||
now_cst = datetime.now(CST)
|
||||
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)
|
||||
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
tomorrow_start = today_start + timedelta(days=1)
|
||||
offset = 0
|
||||
while True:
|
||||
|
||||
@@ -328,7 +328,7 @@ class MemoryConfigRepository:
|
||||
if not db_config:
|
||||
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
#TODO:部分更新没有用patch请求,是在Repository层中用先查再部分更新的方式实现的,后续可以考虑改成patch请求更符合RESTful设计原则
|
||||
update_data = update.model_dump(exclude_unset=True)
|
||||
update_data.pop("config_id", None)
|
||||
|
||||
|
||||
@@ -263,16 +263,15 @@ class ModelConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_by_type(db: Session, model_type: ModelType, tenant_id: uuid.UUID | None = None, is_active: bool = True) -> List[ModelConfig]:
|
||||
"""根据类型获取模型配置"""
|
||||
db_logger.debug(f"根据类型查询模型配置: type={model_type}, tenant_id={tenant_id}, is_active={is_active}")
|
||||
|
||||
def get_by_type(db: Session, model_types: List[ModelType], tenant_id: uuid.UUID | None = None, is_active: bool = True) -> List[ModelConfig]:
|
||||
"""根据类型获取模型配置,支持多类型查询"""
|
||||
db_logger.debug(f"根据类型查询模型配置: types={[t.value for t in model_types]}, tenant_id={tenant_id}, is_active={is_active}")
|
||||
|
||||
try:
|
||||
query = db.query(ModelConfig).options(
|
||||
joinedload(ModelConfig.api_keys)
|
||||
).filter(ModelConfig.type == model_type)
|
||||
|
||||
# 添加租户过滤
|
||||
).filter(ModelConfig.type.in_([t.value for t in model_types]))
|
||||
|
||||
if tenant_id:
|
||||
query = query.filter(
|
||||
or_(
|
||||
@@ -280,16 +279,18 @@ class ModelConfigRepository:
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if is_active:
|
||||
query = query.filter(ModelConfig.is_active)
|
||||
|
||||
models = query.order_by(ModelConfig.name).all()
|
||||
|
||||
query = query.filter(ModelConfig.is_composite == False)
|
||||
|
||||
models = query.order_by(ModelConfig.created_at.desc()).all()
|
||||
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
|
||||
return models
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据类型查询模型配置失败: type={model_type} - {str(e)}")
|
||||
db_logger.error(f"根据类型查询模型配置失败: types={model_types} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -23,6 +23,7 @@ SET s += {
|
||||
end_user_id: statement.end_user_id,
|
||||
stmt_type: statement.stmt_type,
|
||||
statement: statement.statement,
|
||||
speaker: statement.speaker,
|
||||
emotion_intensity: statement.emotion_intensity,
|
||||
emotion_target: statement.emotion_target,
|
||||
emotion_subject: statement.emotion_subject,
|
||||
@@ -56,6 +57,7 @@ SET c += {
|
||||
expired_at: chunk.expired_at,
|
||||
dialog_id: chunk.dialog_id,
|
||||
content: chunk.content,
|
||||
speaker: chunk.speaker,
|
||||
chunk_embedding: chunk.chunk_embedding,
|
||||
sequence_number: chunk.sequence_number,
|
||||
start_index: chunk.start_index,
|
||||
@@ -91,6 +93,8 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
||||
END,
|
||||
e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END,
|
||||
e.aliases = CASE
|
||||
// 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,知识抽取完全不写入
|
||||
WHEN entity.name IN ['用户', '我', 'User', 'I'] THEN e.aliases
|
||||
WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0
|
||||
THEN CASE
|
||||
WHEN e.aliases IS NULL THEN entity.aliases
|
||||
@@ -283,7 +287,7 @@ LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score
|
||||
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
@@ -307,7 +311,7 @@ LIMIT $limit
|
||||
"""
|
||||
# 查询实体名称包含指定字符串的实体
|
||||
SEARCH_ENTITIES_BY_NAME = """
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
|
||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
@@ -337,21 +341,21 @@ LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
|
||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
WITH e, score
|
||||
WITH collect({entity: e, score: score}) AS fulltextResults
|
||||
With collect({entity: e, score: score}) AS fulltextResults
|
||||
|
||||
OPTIONAL MATCH (ae:ExtractedEntity)
|
||||
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
|
||||
AND ae.aliases IS NOT NULL
|
||||
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q))
|
||||
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($query))
|
||||
WITH fulltextResults, collect(ae) AS aliasEntities
|
||||
|
||||
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
|
||||
CASE
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($query)) THEN 1.0
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($query)) THEN 0.9
|
||||
ELSE 0.8
|
||||
END
|
||||
}]) AS row
|
||||
@@ -384,7 +388,7 @@ LIMIT $limit
|
||||
|
||||
|
||||
SEARCH_CHUNKS_BY_CONTENT = """
|
||||
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
|
||||
CALL db.index.fulltext.queryNodes("chunksFulltext", $query) YIELD node AS c, score
|
||||
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
@@ -501,7 +505,7 @@ LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score
|
||||
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
|
||||
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
|
||||
@@ -677,7 +681,7 @@ SET n.invalid_at = $new_invalid_at
|
||||
|
||||
# MemorySummary keyword search using fulltext index
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
|
||||
CALL db.index.fulltext.queryNodes("summariesFulltext", $query) YIELD node AS m, score
|
||||
WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
|
||||
RETURN m.id AS id,
|
||||
@@ -1363,7 +1367,7 @@ RETURN c.community_id AS community_id
|
||||
|
||||
# Community keyword search: matches name or summary via fulltext index
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score
|
||||
CALL db.index.fulltext.queryNodes("communitiesFulltext", $query) YIELD node AS c, score
|
||||
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||
RETURN c.community_id AS id,
|
||||
c.name AS name,
|
||||
@@ -1451,7 +1455,7 @@ RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
|
||||
CALL db.index.fulltext.queryNodes("perceptualFulltext", $query) YIELD node AS p, score
|
||||
WHERE p.end_user_id = $end_user_id
|
||||
RETURN p.id AS id,
|
||||
p.end_user_id AS end_user_id,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user