Compare commits
592 Commits
hotfix/v0.
...
release/v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09393b2326 | ||
|
|
eaa66ba71a | ||
|
|
c59a97afba | ||
|
|
9480a61229 | ||
|
|
7ffd250b08 | ||
|
|
52bccfaede | ||
|
|
9233e74f36 | ||
|
|
46dfd92a9f | ||
|
|
5f33cec8ad | ||
|
|
334502f06b | ||
|
|
b0bb5e883c | ||
|
|
b9cfc47e1e | ||
|
|
4a4391a19c | ||
|
|
7193eed9e3 | ||
|
|
2a03f70287 | ||
|
|
124e8d0639 | ||
|
|
7dc35bb3fb | ||
|
|
b488590537 | ||
|
|
aa56ad15f9 | ||
|
|
d6af459ca8 | ||
|
|
2f7fd85ab1 | ||
|
|
398aebd0c5 | ||
|
|
eaa4058c56 | ||
|
|
21b25bfef7 | ||
|
|
a61acbef93 | ||
|
|
a90757745d | ||
|
|
b882863907 | ||
|
|
9159d5cbb0 | ||
|
|
537f6a1812 | ||
|
|
1ea0f308ba | ||
|
|
77c023102e | ||
|
|
ad24119b2d | ||
|
|
ea6fa154e0 | ||
|
|
158507cf8e | ||
|
|
5e0d30dde8 | ||
|
|
363d775270 | ||
|
|
ad4121b0d8 | ||
|
|
671df83bcd | ||
|
|
8bb5a66401 | ||
|
|
4c9f327833 | ||
|
|
6bd528eace | ||
|
|
2b5bece9b6 | ||
|
|
ea0e65f1ec | ||
|
|
cb2a7aa60a | ||
|
|
402c8aef5d | ||
|
|
eb98a69a84 | ||
|
|
152a84aff3 | ||
|
|
c5c8be89ed | ||
|
|
30aed72b74 | ||
|
|
35c2d9d0d3 | ||
|
|
27275eee43 | ||
|
|
7eb21f677f | ||
|
|
6de5d413c4 | ||
|
|
aecb0f6497 | ||
|
|
83b7c6870d | ||
|
|
74157adb12 | ||
|
|
8011610acc | ||
|
|
f1dc507b5c | ||
|
|
f3ac7e084d | ||
|
|
ba3743f9f1 | ||
|
|
20ddc76a4d | ||
|
|
84ca98555d | ||
|
|
7e6d17e4e3 | ||
|
|
7f3c48ce2a | ||
|
|
e5c16a2a24 | ||
|
|
8887600f7d | ||
|
|
df6eb74b28 | ||
|
|
b4b9974064 | ||
|
|
ff65dee754 | ||
|
|
2c2ed0ebf3 | ||
|
|
d60f838fb8 | ||
|
|
817aa78d03 | ||
|
|
4c73887a48 | ||
|
|
94d2d975ee | ||
|
|
d59990d326 | ||
|
|
3227c25b07 | ||
|
|
08b5c7bc8a | ||
|
|
475e573891 | ||
|
|
b03300c804 | ||
|
|
a5d07ee66d | ||
|
|
10a655772f | ||
|
|
aeeb18581d | ||
|
|
fb1160e833 | ||
|
|
c448cf0660 | ||
|
|
5289b3a2cb | ||
|
|
48f3d9b105 | ||
|
|
559b4bef6b | ||
|
|
4a39fd5f46 | ||
|
|
b22c15cccc | ||
|
|
a2f85b3d98 | ||
|
|
7f1cf13b23 | ||
|
|
d4129edcf5 | ||
|
|
ab2a58d68e | ||
|
|
a28b62763e | ||
|
|
86540a81d1 | ||
|
|
dcd874fecd | ||
|
|
bbd85733b8 | ||
|
|
22c5f12657 | ||
|
|
7b5d7696cb | ||
|
|
cb33724673 | ||
|
|
48b56a3d88 | ||
|
|
83d0fb9387 | ||
|
|
bb964c1ed8 | ||
|
|
81d58b001f | ||
|
|
99bc84a9f2 | ||
|
|
37dbe0f95b | ||
|
|
d4a1904b19 | ||
|
|
ecdad19f54 | ||
|
|
fb93c509f4 | ||
|
|
f597139913 | ||
|
|
113ae59f84 | ||
|
|
62c721bdf6 | ||
|
|
4cbb0cee2f | ||
|
|
8c586935a8 | ||
|
|
d5272af76f | ||
|
|
cf8912e929 | ||
|
|
327c1904b1 | ||
|
|
58c13aaeb4 | ||
|
|
377ddd2b9b | ||
|
|
52f7ea7456 | ||
|
|
b02baedd2c | ||
|
|
f3c3b6255e | ||
|
|
b659e2a6e1 | ||
|
|
e15e32cc7b | ||
|
|
04d20dc094 | ||
|
|
b8123fc84c | ||
|
|
5a17b7fd0d | ||
|
|
e3d0602850 | ||
|
|
696b2d2417 | ||
|
|
a5613314b8 | ||
|
|
e87529876c | ||
|
|
7bb3e65fb7 | ||
|
|
5ada7e77fc | ||
|
|
79b7da44e2 | ||
|
|
26a3d8a41b | ||
|
|
2380cd55ef | ||
|
|
a105df33ab | ||
|
|
0dd8cc5d43 | ||
|
|
fd90a4c2ad | ||
|
|
b302a94620 | ||
|
|
c96dc53534 | ||
|
|
f883c1469d | ||
|
|
ddfd81259a | ||
|
|
e015455fb8 | ||
|
|
915cb54f21 | ||
|
|
cada860a16 | ||
|
|
e1f8ad871b | ||
|
|
e205aaa6e6 | ||
|
|
62edafcebe | ||
|
|
ccdf7ae81d | ||
|
|
643f69bb90 | ||
|
|
73fbc19747 | ||
|
|
7ba0726473 | ||
|
|
8c6b65db12 | ||
|
|
5ce0bdb0f5 | ||
|
|
b59e2b5bcd | ||
|
|
5a2fe738dc | ||
|
|
f04412c455 | ||
|
|
db6fc5d2db | ||
|
|
b6aca0b1e7 | ||
|
|
4fd7395464 | ||
|
|
78ba313262 | ||
|
|
d35bc3a2cf | ||
|
|
d5c8d16e64 | ||
|
|
09496bd7b9 | ||
|
|
171f25a350 | ||
|
|
c7230659e3 | ||
|
|
502d87e88d | ||
|
|
1faa258e23 | ||
|
|
bef6a50deb | ||
|
|
cc12ec3fa8 | ||
|
|
466864afe3 | ||
|
|
e0d7a5a91f | ||
|
|
5ac2d5602e | ||
|
|
f4c3974956 | ||
|
|
71e5b6586a | ||
|
|
bfb723a468 | ||
|
|
61f2e44bd5 | ||
|
|
ed765b7c26 | ||
|
|
3018d186f7 | ||
|
|
2e1470cb52 | ||
|
|
737858731b | ||
|
|
d072eb1af7 | ||
|
|
daaee63bd5 | ||
|
|
e3c643b659 | ||
|
|
017efdc320 | ||
|
|
29aef4527c | ||
|
|
d9cb2b511b | ||
|
|
18be1a9f89 | ||
|
|
49e0801d15 | ||
|
|
dde7ea9039 | ||
|
|
5262aedab9 | ||
|
|
441b21774d | ||
|
|
d6dd038167 | ||
|
|
47c242e513 | ||
|
|
811193dd75 | ||
|
|
797780824c | ||
|
|
75e95bab01 | ||
|
|
e7a400bb96 | ||
|
|
28ca4d1734 | ||
|
|
5e6490213d | ||
|
|
3b359df02f | ||
|
|
fcf3071cb0 | ||
|
|
1294aabbcc | ||
|
|
3c2a78a449 | ||
|
|
4f0e5d0866 | ||
|
|
7a84ee33c6 | ||
|
|
e3265e4ba3 | ||
|
|
3e7a004599 | ||
|
|
fa1e5ee43c | ||
|
|
c72a6fd724 | ||
|
|
0965008210 | ||
|
|
bcadd2a6f1 | ||
|
|
e4f306dabb | ||
|
|
b5ec5c2cea | ||
|
|
e539b3eeb7 | ||
|
|
7f8765b815 | ||
|
|
72b39c6fa3 | ||
|
|
9032f50a19 | ||
|
|
aa683efaa0 | ||
|
|
2d9986f902 | ||
|
|
06075ffef5 | ||
|
|
a7336b0829 | ||
|
|
0d16e168e7 | ||
|
|
a882e5e5c4 | ||
|
|
c614bb5be7 | ||
|
|
1ff0f3ebfd | ||
|
|
bafcb5c545 | ||
|
|
f8d27fada6 | ||
|
|
90365cd026 | ||
|
|
d96c7b88f0 | ||
|
|
99559621c5 | ||
|
|
926f65a1ff | ||
|
|
b20971dc95 | ||
|
|
1ff0274027 | ||
|
|
8495aa5dde | ||
|
|
d8ef7a8e02 | ||
|
|
7a4a02b2bb | ||
|
|
8f623a66c8 | ||
|
|
77ed9faea1 | ||
|
|
1ff3748935 | ||
|
|
f023c43f80 | ||
|
|
60124e3232 | ||
|
|
70d4e79de1 | ||
|
|
59b5a1bcf2 | ||
|
|
62f345b3de | ||
|
|
a3f0415cd3 | ||
|
|
2450fe3afe | ||
|
|
52e726eabc | ||
|
|
7ca80b5d01 | ||
|
|
9470dd2f1e | ||
|
|
10f1089198 | ||
|
|
095f4e3001 | ||
|
|
ef8c7093b5 | ||
|
|
05ea372776 | ||
|
|
2b067ce08a | ||
|
|
b63cff2993 | ||
|
|
5bb9ce9018 | ||
|
|
aa581a9083 | ||
|
|
ac51ccaf1f | ||
|
|
5eaedaad77 | ||
|
|
bd955569b3 | ||
|
|
7a2a941ac4 | ||
|
|
19fa8314e4 | ||
|
|
cba24e58db | ||
|
|
62355186ef | ||
|
|
82faedc972 | ||
|
|
11ea486f82 | ||
|
|
efdee32f85 | ||
|
|
988d101e93 | ||
|
|
418f9f4dba | ||
|
|
520ee7c132 | ||
|
|
72be9f75f9 | ||
|
|
2b52b32b96 | ||
|
|
a96f20ee05 | ||
|
|
b8acc0a32f | ||
|
|
e1cf3bb3d2 | ||
|
|
6f66c9727f | ||
|
|
3beca641e1 | ||
|
|
b8507a1df6 | ||
|
|
0f28d54c43 | ||
|
|
0afc38e7ef | ||
|
|
07fd85c342 | ||
|
|
4c2a1e6d1d | ||
|
|
7cfb6ace22 | ||
|
|
91cc20d589 | ||
|
|
f01ca51896 | ||
|
|
f4a63f7d55 | ||
|
|
0019f3acfd | ||
|
|
3fe90a5e13 | ||
|
|
bc14c94407 | ||
|
|
a21dad70ed | ||
|
|
807a4e715d | ||
|
|
58d18b476c | ||
|
|
5e5927a0b9 | ||
|
|
7869121382 | ||
|
|
7c0fb624d9 | ||
|
|
af83980f99 | ||
|
|
cf0d11208c | ||
|
|
87d1630230 | ||
|
|
50392384e7 | ||
|
|
9a926a8398 | ||
|
|
e5e6699168 | ||
|
|
068e2bfb7e | ||
|
|
4ce6fede67 | ||
|
|
8497c955f9 | ||
|
|
72fe3962cf | ||
|
|
c253968aa8 | ||
|
|
d517bceda2 | ||
|
|
412183c359 | ||
|
|
90e8e90528 | ||
|
|
fd05c000f6 | ||
|
|
627d6a0381 | ||
|
|
807dee8460 | ||
|
|
ac7d39524e | ||
|
|
cd018814fe | ||
|
|
e0b7e95af6 | ||
|
|
3a62d50048 | ||
|
|
0e60da6d8a | ||
|
|
39e94eb3ea | ||
|
|
3e0f59adc6 | ||
|
|
660cd2fadb | ||
|
|
6f1bb43eab | ||
|
|
61b5627505 | ||
|
|
af6392fb09 | ||
|
|
84b1a95313 | ||
|
|
8b21dab255 | ||
|
|
fc5ce63e44 | ||
|
|
15a863b41a | ||
|
|
5226c5b79d | ||
|
|
27e9f9968d | ||
|
|
d38612a10d | ||
|
|
32c71dcd89 | ||
|
|
428e7ebaa5 | ||
|
|
57833689d9 | ||
|
|
384a67482c | ||
|
|
7842435321 | ||
|
|
33c4c5d31b | ||
|
|
ca4f7aa65d | ||
|
|
b875626f18 | ||
|
|
130684cac0 | ||
|
|
5adff38bda | ||
|
|
62e0b2730b | ||
|
|
55b2e05ba8 | ||
|
|
562ca6c1f1 | ||
|
|
e298b38de9 | ||
|
|
a7b8ba0c66 | ||
|
|
460c86cd94 | ||
|
|
33a1c178ff | ||
|
|
c81612e6d3 | ||
|
|
9f9ac69f97 | ||
|
|
0516822d42 | ||
|
|
b598171a3d | ||
|
|
a4ea7f0385 | ||
|
|
32ae60fc65 | ||
|
|
6b272c5b44 | ||
|
|
2782d0661f | ||
|
|
ea2f5e61c9 | ||
|
|
5975d70bf9 | ||
|
|
e0546e01ef | ||
|
|
70aab94fc3 | ||
|
|
0f50537d7d | ||
|
|
b7c1ce261b | ||
|
|
edac6a164e | ||
|
|
1503b242ea | ||
|
|
3ff44f0108 | ||
|
|
18fd48505d | ||
|
|
807ddce5cd | ||
|
|
62fb6c79a0 | ||
|
|
cc373b2864 | ||
|
|
f2d7479229 | ||
|
|
ae1909b7e9 | ||
|
|
8e397b83b6 | ||
|
|
b0aaa12340 | ||
|
|
5eb65e7ad8 | ||
|
|
cb5610e8b1 | ||
|
|
6bb01119d0 | ||
|
|
c16e832081 | ||
|
|
e3d50c5c55 | ||
|
|
e64aadce95 | ||
|
|
bad6087c25 | ||
|
|
b04c05f4a4 | ||
|
|
5e372627f7 | ||
|
|
29611738ce | ||
|
|
de846c05ab | ||
|
|
6475387af8 | ||
|
|
b330bdba29 | ||
|
|
bed279c604 | ||
|
|
9eaf779e67 | ||
|
|
1fbccd98a7 | ||
|
|
931b800bb6 | ||
|
|
d76bb36b9f | ||
|
|
3c93409f7f | ||
|
|
9451a08e7f | ||
|
|
bc49bd2a43 | ||
|
|
4bfc9ca991 | ||
|
|
1ba60401af | ||
|
|
236e8973ac | ||
|
|
ca6cc8ae63 | ||
|
|
dd2cc89c62 | ||
|
|
a87bba93c2 | ||
|
|
a153fdb7cb | ||
|
|
4eed393db5 | ||
|
|
dc8e432719 | ||
|
|
16a0099e27 | ||
|
|
c4cf639bbc | ||
|
|
f91431a70d | ||
|
|
0102ad3a30 | ||
|
|
ebe298b71d | ||
|
|
a486ca7857 | ||
|
|
dd40e5df5f | ||
|
|
8e1ec1bae6 | ||
|
|
1f9c4919be | ||
|
|
065182ad5c | ||
|
|
90ec8db0d8 | ||
|
|
78baf1c60a | ||
|
|
072c94cccb | ||
|
|
a66030d1b3 | ||
|
|
a90ceaf5a2 | ||
|
|
725f2f5146 | ||
|
|
a2c3357e80 | ||
|
|
acbc954e6f | ||
|
|
77032583ab | ||
|
|
ef533d27ac | ||
|
|
ca1a2c7b9e | ||
|
|
145fa398dd | ||
|
|
274430b2c9 | ||
|
|
e9972834fe | ||
|
|
1ecc04fee7 | ||
|
|
78cd1f69a3 | ||
|
|
aabd9a1b57 | ||
|
|
b9439b337a | ||
|
|
eb9f4f39f1 | ||
|
|
baa4b56426 | ||
|
|
49bcc6131b | ||
|
|
0d3f6f1e14 | ||
|
|
84536925c6 | ||
|
|
b22a5a9f12 | ||
|
|
b8825a83dd | ||
|
|
08b4d5c1cf | ||
|
|
a5dfc472d3 | ||
|
|
48d29bcc63 | ||
|
|
856c6f6d78 | ||
|
|
bfc47ad738 | ||
|
|
841b6abb33 | ||
|
|
8a1114a1a7 | ||
|
|
be8c481d6d | ||
|
|
5d439346a1 | ||
|
|
ed753caaf7 | ||
|
|
9a931389ea | ||
|
|
b9d469b6e3 | ||
|
|
e817cfd292 | ||
|
|
af86cb3556 | ||
|
|
e48b146e60 | ||
|
|
07b66a9801 | ||
|
|
c3ee3c4af9 | ||
|
|
cd8229f370 | ||
|
|
9a4a614fc8 | ||
|
|
0b5a030e46 | ||
|
|
675d0fc5ef | ||
|
|
6291c28f0a | ||
|
|
30b512e554 | ||
|
|
33c73c6c6f | ||
|
|
072d118935 | ||
|
|
2e7ebf174b | ||
|
|
3ece83d419 | ||
|
|
9c1c232b2e | ||
|
|
bfc98efc9d | ||
|
|
cfbf83f71e | ||
|
|
a43e8fa594 | ||
|
|
6c8d0d9d64 | ||
|
|
bd2a3bd7ef | ||
|
|
1f72b8aa70 | ||
|
|
9bb32888a2 | ||
|
|
caee5d214e | ||
|
|
38f3455bab | ||
|
|
d60cb423a4 | ||
|
|
b20a65ce29 | ||
|
|
99862db7a0 | ||
|
|
00a8099857 | ||
|
|
117e29fbe3 | ||
|
|
32740e8159 | ||
|
|
bc5ea2d421 | ||
|
|
d34bf4bc89 | ||
|
|
c4ff1a325b | ||
|
|
d1f0258065 | ||
|
|
5db59bc9cf | ||
|
|
a711635694 | ||
|
|
15b3ce3dd5 | ||
|
|
9cc19047b4 | ||
|
|
2e8e63878e | ||
|
|
38955d7d45 | ||
|
|
b6167d4e94 | ||
|
|
7890970a39 | ||
|
|
203732de1d | ||
|
|
4961e7df79 | ||
|
|
1b52850526 | ||
|
|
1732fc7af5 | ||
|
|
a52e2137b7 | ||
|
|
377f79773d | ||
|
|
cae87de6ef | ||
|
|
63235de42b | ||
|
|
106a32bc3a | ||
|
|
2f0bb793d8 | ||
|
|
010eff17cf | ||
|
|
0b47194f12 | ||
|
|
9ff3a3d5f7 | ||
|
|
abbd92b74c | ||
|
|
960ee9f2df | ||
|
|
1c133d3d6c | ||
|
|
d270d25a99 | ||
|
|
8abd59b26e | ||
|
|
bd48b4fdbe | ||
|
|
aad6955709 | ||
|
|
18703919a8 | ||
|
|
9f2cd6afae | ||
|
|
d1beb9e5d5 | ||
|
|
2c7aaebdd5 | ||
|
|
be38c9e385 | ||
|
|
1aec7115a5 | ||
|
|
9facb513b2 | ||
|
|
9bce14be4e | ||
|
|
12f3a3ed77 | ||
|
|
8b9eb81d36 | ||
|
|
4fb3d6992c | ||
|
|
370a668ead | ||
|
|
daaad51357 | ||
|
|
6eca5f6cdf | ||
|
|
f61f86f8fe | ||
|
|
57eb5aa967 | ||
|
|
cf519738f4 | ||
|
|
cdebe014cf | ||
|
|
853ce6f4e1 | ||
|
|
9cbe9d5edc | ||
|
|
767f9ab17c | ||
|
|
7b5b2ab31a | ||
|
|
924d10ac5b | ||
|
|
0470a71d03 | ||
|
|
378b110d91 | ||
|
|
5f7db778b5 | ||
|
|
0d15457299 | ||
|
|
ad4ddea977 | ||
|
|
75bb96d4e7 | ||
|
|
68fdf5d76f | ||
|
|
258c19f9e0 | ||
|
|
386ed2b914 | ||
|
|
264183cec2 | ||
|
|
9561578a2a | ||
|
|
7ce29019f7 | ||
|
|
99ff07ccac | ||
|
|
e77a1a92fd | ||
|
|
d3cd66fc6e | ||
|
|
b95a627424 | ||
|
|
c9ca5df05c | ||
|
|
70c3c7dd74 | ||
|
|
b482822629 | ||
|
|
8f609ba29c | ||
|
|
a1ef5146d7 | ||
|
|
8b997b422a | ||
|
|
6d6338eb06 | ||
|
|
b5c5863b39 | ||
|
|
ab45b7abac | ||
|
|
2dfc3b25d8 | ||
|
|
3ea42ac27f | ||
|
|
fff5e0e8b8 | ||
|
|
ef626951bc | ||
|
|
4533644e13 | ||
|
|
ca255304d9 | ||
|
|
b40f4829cb | ||
|
|
52ae914e17 | ||
|
|
87c2419186 | ||
|
|
2ad25c48d2 | ||
|
|
75e8caf441 | ||
|
|
02660c7c97 | ||
|
|
3ea57d1cb0 | ||
|
|
4a71484151 | ||
|
|
db8b3416a6 | ||
|
|
876c39b1b0 | ||
|
|
3cca35a74f | ||
|
|
ed90405439 | ||
|
|
533000030f | ||
|
|
a58ac385b1 | ||
|
|
891cfc2704 | ||
|
|
e9ad13504a | ||
|
|
13e35ed122 | ||
|
|
7acb7045f0 | ||
|
|
f9f302dd2a | ||
|
|
bca43fcc75 | ||
|
|
7fd00009a2 | ||
|
|
4534b65d6a | ||
|
|
a5bce221bd | ||
|
|
6056952936 |
164
.github/workflows/release-notify-wechat.yml
vendored
Normal file
164
.github/workflows/release-notify-wechat.yml
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
name: Release Notify Workflow
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
|
||||
jobs:
|
||||
notify:
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
startsWith(github.event.pull_request.base.ref, 'release')
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# 防止 GitHub HEAD 未同步
|
||||
- run: sleep 3
|
||||
|
||||
# 1️⃣ 获取分支 HEAD
|
||||
- name: Get HEAD
|
||||
id: head
|
||||
run: |
|
||||
HEAD_SHA=$(curl -s \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
https://api.github.com/repos/${{ github.repository }}/git/ref/heads/${{ github.event.pull_request.base.ref }} \
|
||||
| jq -r '.object.sha')
|
||||
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
|
||||
|
||||
# 2️⃣ 判断是否最终PR
|
||||
- name: Check Latest
|
||||
id: check
|
||||
run: |
|
||||
if [ "${{ github.event.pull_request.merge_commit_sha }}" = "${{ steps.head.outputs.head_sha }}" ]; then
|
||||
echo "ok=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "ok=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# 3️⃣ 尝试从 PR body 提取 Sourcery 摘要
|
||||
- name: Extract Sourcery Summary
|
||||
if: steps.check.outputs.ok == 'true'
|
||||
id: sourcery
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import os, re
|
||||
|
||||
body = os.environ.get("PR_BODY", "") or ""
|
||||
match = re.search(
|
||||
r"## Summary by Sourcery\s*\n(.*?)(?=\n## |\Z)",
|
||||
body,
|
||||
re.DOTALL
|
||||
)
|
||||
|
||||
if match:
|
||||
summary = match.group(1).strip()
|
||||
found = "true"
|
||||
else:
|
||||
summary = ""
|
||||
found = "false"
|
||||
|
||||
with open("sourcery_summary.txt", "w", encoding="utf-8") as f:
|
||||
f.write(summary)
|
||||
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||
gh.write(f"found={found}\n")
|
||||
gh.write("summary<<EOF\n")
|
||||
gh.write(summary + "\n")
|
||||
gh.write("EOF\n")
|
||||
PYEOF
|
||||
|
||||
# 4️⃣ Fallback: 获取 commits + 通义千问总结
|
||||
- name: Get Commits
|
||||
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||
run: |
|
||||
curl -s \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
${{ github.event.pull_request.commits_url }} \
|
||||
| jq -r '.[].commit.message' | head -n 20 > commits.txt
|
||||
|
||||
- name: AI Summary (Qwen Fallback)
|
||||
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||
id: qwen
|
||||
env:
|
||||
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import json, os, urllib.request
|
||||
|
||||
with open("commits.txt", "r") as f:
|
||||
commits = f.read().strip()
|
||||
|
||||
prompt = "请用中文总结以下代码提交,输出3-5条要点,面向测试人员。直接输出编号列表,不要输出标题或前言:\n" + commits
|
||||
payload = {"model": "qwen-plus", "input": {"prompt": prompt}}
|
||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
|
||||
data=data,
|
||||
headers={
|
||||
"Authorization": "Bearer " + os.environ["DASHSCOPE_API_KEY"],
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
resp = urllib.request.urlopen(req)
|
||||
result = json.loads(resp.read().decode())
|
||||
summary = result.get("output", {}).get("text", "AI 摘要生成失败")
|
||||
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||
gh.write("summary<<EOF\n")
|
||||
gh.write(summary + "\n")
|
||||
gh.write("EOF\n")
|
||||
PYEOF
|
||||
|
||||
# 5️⃣ 企业微信通知(Markdown)
|
||||
- name: Notify WeChat
|
||||
if: steps.check.outputs.ok == 'true'
|
||||
env:
|
||||
WECHAT_WEBHOOK: ${{ secrets.WECHAT_WEBHOOK }}
|
||||
BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||
AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SOURCERY_FOUND: ${{ steps.sourcery.outputs.found }}
|
||||
SOURCERY_SUMMARY: ${{ steps.sourcery.outputs.summary }}
|
||||
QWEN_SUMMARY: ${{ steps.qwen.outputs.summary }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import json, os, urllib.request
|
||||
|
||||
if os.environ.get("SOURCERY_FOUND") == "true":
|
||||
label = "Summary by Sourcery"
|
||||
summary = os.environ.get("SOURCERY_SUMMARY", "")
|
||||
else:
|
||||
label = "AI变更摘要"
|
||||
summary = os.environ.get("QWEN_SUMMARY", "AI 摘要生成失败")
|
||||
|
||||
pr_number = os.environ.get("PR_NUMBER", "")
|
||||
short_sha = os.environ.get("MERGE_SHA", "")[:7]
|
||||
|
||||
content = (
|
||||
"## 🚀 Release 发布通知\n"
|
||||
"> <20> **分支**: " + os.environ["BRANCH"] + "\n"
|
||||
"> 👤 **提交人**: " + os.environ["AUTHOR"] + "\n"
|
||||
"> 📝 **标题**: " + os.environ["PR_TITLE"] + "\n"
|
||||
"> 🔢 **PR编号**: #" + pr_number + "\n"
|
||||
"> 🔖 **Commit**: " + short_sha + "\n\n"
|
||||
"### 🧠 " + label + "\n" +
|
||||
summary + "\n\n"
|
||||
"---\n"
|
||||
"🔗 [查看PR详情](" + os.environ["PR_URL"] + ")"
|
||||
)
|
||||
payload = {"msgtype": "markdown", "markdown": {"content": content}}
|
||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
os.environ["WECHAT_WEBHOOK"],
|
||||
data=data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
resp = urllib.request.urlopen(req)
|
||||
print(resp.read().decode())
|
||||
PYEOF
|
||||
36
.github/workflows/sync-to-gitee.yml
vendored
Normal file
36
.github/workflows/sync-to-gitee.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Sync to Gitee
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main # Production
|
||||
- develop # Integration
|
||||
- 'release/*' # Release preparation
|
||||
- 'hotfix/*' # Urgent fixes
|
||||
tags:
|
||||
- '*' # All version tags (v1.0.0, etc.)
|
||||
|
||||
jobs:
|
||||
sync:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout Source Code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Sync to Gitee
|
||||
run: |
|
||||
GITEE_URL="https://${{ secrets.GITEE_USERNAME }}:${{ secrets.GITEE_TOKEN }}@gitee.com/hangzhou-hongxiong-intelligent_1/MemoryBear.git"
|
||||
git remote add gitee "$GITEE_URL"
|
||||
|
||||
# 遍历并推送所有分支
|
||||
for branch in $(git branch -r | grep -v HEAD | sed 's/origin\///'); do
|
||||
echo "Syncing branch: $branch"
|
||||
git push -f gitee "origin/$branch:refs/heads/$branch"
|
||||
done
|
||||
|
||||
# 推送所有标签
|
||||
echo "Syncing tags..."
|
||||
git push gitee --tags --force
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -18,6 +18,7 @@ examples/
|
||||
.kiro
|
||||
.vscode
|
||||
.idea
|
||||
.claude
|
||||
|
||||
# Temporary outputs
|
||||
.DS_Store
|
||||
@@ -26,6 +27,7 @@ time.log
|
||||
celerybeat-schedule.db
|
||||
search_results.json
|
||||
redbear-mem-metrics/
|
||||
redbear-mem-benchmark/
|
||||
pitch-deck/
|
||||
|
||||
api/migrations/versions
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
# MemoryBear empowers AI with human-like memory capabilities
|
||||
|
||||
[](LICENSE)
|
||||
[](https://www.python.org/)
|
||||
[](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
|
||||
|
||||
[中文](./README_CN.md) | English
|
||||
|
||||
### [Installation Guide](#memorybear-installation-guide)
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
# MemoryBear 让AI拥有如同人类一样的记忆
|
||||
|
||||
[](LICENSE)
|
||||
[](https://www.python.org/)
|
||||
[](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
|
||||
|
||||
中文 | [English](./README.md)
|
||||
|
||||
### [安装教程](#memorybear安装教程)
|
||||
|
||||
@@ -111,11 +111,17 @@ celery_app.conf.update(
|
||||
# Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题)
|
||||
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
|
||||
|
||||
# Metadata extraction → memory_tasks queue
|
||||
'app.tasks.extract_user_metadata': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# GraphRAG tasks → graphrag_tasks queue (独立队列,避免阻塞文档解析)
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'graphrag_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_document': {'queue': 'graphrag_tasks'},
|
||||
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
Celery Worker 入口点
|
||||
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
|
||||
"""
|
||||
from celery.signals import worker_process_init
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import LoggingConfig, get_logger
|
||||
|
||||
@@ -13,4 +15,39 @@ logger.info("Celery worker logging initialized")
|
||||
# 导入任务模块以注册任务
|
||||
import app.tasks
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def _reinit_db_pool(**kwargs):
|
||||
"""
|
||||
prefork 子进程启动时重建被 fork 污染的资源。
|
||||
|
||||
fork() 后子进程继承了父进程的:
|
||||
1. SQLAlchemy 连接池 — 多进程共享 TCP socket 导致 DB 连接损坏
|
||||
2. ThreadPoolExecutor — fork 后线程状态不确定,第二个任务会死锁
|
||||
"""
|
||||
# 重建 DB 连接池
|
||||
from app.db import engine
|
||||
engine.dispose()
|
||||
logger.info("DB connection pool disposed for forked worker process")
|
||||
|
||||
# 重建模块级 ThreadPoolExecutor(fork 后线程池不可用)
|
||||
try:
|
||||
from app.core.rag.deepdoc.parser import figure_parser
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
figure_parser.shared_executor = ThreadPoolExecutor(max_workers=10)
|
||||
logger.info("figure_parser.shared_executor recreated")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to recreate figure_parser.shared_executor: {e}")
|
||||
|
||||
try:
|
||||
from app.core.rag.utils import libre_office
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import os
|
||||
max_workers = os.cpu_count() * 2 if os.cpu_count() else 4
|
||||
libre_office.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
logger.info("libre_office.executor recreated")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to recreate libre_office.executor: {e}")
|
||||
|
||||
|
||||
__all__ = ['celery_app']
|
||||
|
||||
77
api/app/config/default_free_plan.py
Normal file
77
api/app/config/default_free_plan.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
社区版默认免费套餐配置
|
||||
当无法从 SaaS 版获取 premium 模块时,使用此配置作为兜底
|
||||
|
||||
可通过环境变量覆盖配额配置,格式:QUOTA_<QUOTA_NAME>
|
||||
例如:QUOTA_END_USER_QUOTA=100
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def _get_quota_from_env():
|
||||
"""从环境变量获取配额配置"""
|
||||
quota_keys = [
|
||||
"workspace_quota",
|
||||
"skill_quota",
|
||||
"app_quota",
|
||||
"knowledge_capacity_quota",
|
||||
"memory_engine_quota",
|
||||
"end_user_quota",
|
||||
"ontology_project_quota",
|
||||
"model_quota",
|
||||
"api_ops_rate_limit",
|
||||
]
|
||||
quotas = {}
|
||||
for key in quota_keys:
|
||||
env_key = f"QUOTA_{key.upper()}"
|
||||
env_value = os.getenv(env_key)
|
||||
if env_value is not None:
|
||||
try:
|
||||
quotas[key] = float(env_value) if '.' in env_value else int(env_value)
|
||||
except ValueError:
|
||||
pass
|
||||
return quotas
|
||||
|
||||
|
||||
def _build_default_free_plan():
|
||||
"""构建默认免费套餐配置"""
|
||||
base = {
|
||||
"name": "记忆体验版",
|
||||
"name_en": "Memory Experience",
|
||||
"category": "saas_personal",
|
||||
"tier_level": 0,
|
||||
"version": "1.0",
|
||||
"status": True,
|
||||
"price": 0,
|
||||
"billing_cycle": "permanent_free",
|
||||
"core_value": "感受永久记忆",
|
||||
"core_value_en": "Experience Permanent Memory",
|
||||
"tech_support": "社群交流",
|
||||
"tech_support_en": "Community Support",
|
||||
"sla_compliance": "无",
|
||||
"sla_compliance_en": "None",
|
||||
"page_customization": "无",
|
||||
"page_customization_en": "None",
|
||||
"theme_color": "#64748B",
|
||||
"quotas": {
|
||||
"workspace_quota": 1,
|
||||
"skill_quota": 5,
|
||||
"app_quota": 2,
|
||||
"knowledge_capacity_quota": 0.3,
|
||||
"memory_engine_quota": 1,
|
||||
"end_user_quota": 10,
|
||||
"ontology_project_quota": 3,
|
||||
"model_quota": 1,
|
||||
"api_ops_rate_limit": 50,
|
||||
},
|
||||
}
|
||||
|
||||
env_quotas = _get_quota_from_env()
|
||||
if env_quotas:
|
||||
base["quotas"].update(env_quotas)
|
||||
|
||||
return base
|
||||
|
||||
|
||||
DEFAULT_FREE_PLAN = _build_default_free_plan()
|
||||
@@ -14,7 +14,6 @@ from . import (
|
||||
document_controller,
|
||||
emotion_config_controller,
|
||||
emotion_controller,
|
||||
end_user_controller,
|
||||
file_controller,
|
||||
file_storage_controller,
|
||||
home_page_controller,
|
||||
@@ -48,7 +47,8 @@ from . import (
|
||||
user_memory_controllers,
|
||||
workspace_controller,
|
||||
ontology_controller,
|
||||
skill_controller
|
||||
skill_controller,
|
||||
tenant_subscription_controller,
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -99,6 +99,7 @@ manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
manager_router.include_router(i18n_controller.router)
|
||||
manager_router.include_router(end_user_controller.router)
|
||||
manager_router.include_router(tenant_subscription_controller.router)
|
||||
manager_router.include_router(tenant_subscription_controller.public_router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -167,6 +167,8 @@ def update_api_key(
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
||||
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
|
||||
@@ -28,6 +28,7 @@ from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.services.workflow_import_service import WorkflowImportService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_dsl_service import AppDslService
|
||||
from app.core.quota_stub import check_app_quota
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -35,6 +36,7 @@ logger = get_business_logger()
|
||||
|
||||
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
||||
@cur_workspace_access_guard()
|
||||
@check_app_quota
|
||||
def create_app(
|
||||
payload: app_schema.AppCreate,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -217,6 +219,7 @@ def delete_app(
|
||||
|
||||
@router.post("/{app_id}/copy", summary="复制应用")
|
||||
@cur_workspace_access_guard()
|
||||
@check_app_quota
|
||||
def copy_app(
|
||||
app_id: uuid.UUID,
|
||||
new_name: Optional[str] = None,
|
||||
@@ -269,6 +272,19 @@ def update_agent_config(
|
||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/model/parameters/default", summary="获取 Agent 模型参数默认配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_agent_model_parameters(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = AppService(db)
|
||||
model_parameters = service.get_default_model_parameters(app_id=app_id)
|
||||
return success(data=model_parameters, msg="获取 Agent 模型参数默认配置")
|
||||
|
||||
|
||||
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_agent_config(
|
||||
@@ -292,10 +308,19 @@ def get_opening(
|
||||
):
|
||||
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
|
||||
# 根据应用类型获取 features
|
||||
from app.models.app_model import App as AppModel
|
||||
app = db.get(AppModel, app_id)
|
||||
if app and app.type == "workflow":
|
||||
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
else:
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
|
||||
opening = features.get("opening_statement", {})
|
||||
return success(data=app_schema.OpeningResponse(
|
||||
enabled=opening.get("enabled", False),
|
||||
@@ -1070,6 +1095,14 @@ async def update_workflow_config(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if payload.variables:
|
||||
from app.services.workflow_service import WorkflowService
|
||||
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
|
||||
[v.model_dump() for v in payload.variables]
|
||||
)
|
||||
# Patch default values back into VariableDefinition objects
|
||||
for var_def, resolved_def in zip(payload.variables, resolved):
|
||||
var_def.default = resolved_def.get("default", var_def.default)
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
@@ -1112,6 +1145,7 @@ async def import_workflow_config(
|
||||
|
||||
@router.post("/workflow/import/save")
|
||||
@cur_workspace_access_guard()
|
||||
@check_app_quota
|
||||
async def save_workflow_import(
|
||||
data: WorkflowImportSave,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -1233,9 +1267,11 @@ async def export_app(
|
||||
async def import_app(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
app_id: Optional[str] = Form(None),
|
||||
):
|
||||
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
||||
传入 app_id 时覆盖该应用的配置(类型必须一致),否则创建新应用。
|
||||
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
||||
"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
@@ -1246,13 +1282,19 @@ async def import_app(
|
||||
if not dsl or "app" not in dsl:
|
||||
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
||||
|
||||
new_app, warnings = AppDslService(db).import_dsl(
|
||||
target_app_id = uuid.UUID(app_id) if app_id else None
|
||||
# 仅新建应用时检查配额,覆盖已有应用时跳过
|
||||
if target_app_id is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
_check_quota(db, current_user.tenant_id, "app_quota", "app", workspace_id=current_user.current_workspace_id)
|
||||
result_app, warnings = AppDslService(db).import_dsl(
|
||||
dsl=dsl,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
tenant_id=current_user.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=target_app_id,
|
||||
)
|
||||
return success(
|
||||
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
|
||||
data={"app": app_schema.App.model_validate(result_app), "warnings": warnings},
|
||||
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
||||
)
|
||||
|
||||
@@ -53,22 +53,24 @@ async def login_for_access_token(
|
||||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||
if form_data.invite:
|
||||
auth_service.bind_workspace_with_invite(db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id)
|
||||
auth_service.bind_workspace_with_invite(
|
||||
db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
except BusinessException as e:
|
||||
# 用户不存在且有邀请码,尝试注册
|
||||
if e.code == BizCode.USER_NOT_FOUND:
|
||||
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
|
||||
user = auth_service.register_user_with_invite(
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
elif e.code == BizCode.PASSWORD_ERROR:
|
||||
# 用户存在但密码错误
|
||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||
@@ -134,7 +136,7 @@ async def refresh_token(
|
||||
# 检查用户是否存在
|
||||
user = auth_service.get_user_by_id(db, userId)
|
||||
if not user:
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NO_ACCESS)
|
||||
|
||||
# 检查 refresh token 黑名单
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
|
||||
@@ -23,6 +23,7 @@ from app.models.user_model import User
|
||||
from app.schemas import chunk_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -442,10 +443,10 @@ async def retrieve_chunks(
|
||||
match retrieve_data.retrieve_type:
|
||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
case chunk_schema.RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
@@ -456,22 +457,24 @@ async def retrieve_chunks(
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
|
||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) if unique_rs else []
|
||||
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
||||
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
||||
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
||||
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
|
||||
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
|
||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
key=llm_key.api_key,
|
||||
model_name=llm_key.model_name,
|
||||
base_url=llm_key.api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
key=emb_key.api_key,
|
||||
model_name=emb_key.model_name,
|
||||
base_url=emb_key.api_base
|
||||
)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
rs.insert(0, doc)
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
@@ -314,8 +314,10 @@ async def parse_documents(
|
||||
)
|
||||
|
||||
# 4. Check if the file exists
|
||||
api_logger.debug(f"Constructed file path: {file_path}")
|
||||
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
||||
if not os.path.exists(file_path):
|
||||
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
|
||||
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
"""End User 管理接口 - 无需认证"""
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.memory_api_schema import (
|
||||
CreateEndUserRequest,
|
||||
CreateEndUserResponse,
|
||||
)
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter(prefix="/end_users", tags=["End Users"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_end_user(
|
||||
data: CreateEndUserRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create an end user.
|
||||
|
||||
Creates a new end user for the given workspace.
|
||||
If an end user with the same other_id already exists in the workspace,
|
||||
returns the existing one.
|
||||
"""
|
||||
logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}")
|
||||
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=None,
|
||||
workspace_id=data.workspace_id,
|
||||
other_id=data.other_id,
|
||||
)
|
||||
|
||||
logger.info(f"End user ready: {end_user.id}")
|
||||
|
||||
result = {
|
||||
"id": str(end_user.id),
|
||||
"other_id": end_user.other_id or "",
|
||||
"other_name": end_user.other_name or "",
|
||||
"workspace_id": str(end_user.workspace_id),
|
||||
}
|
||||
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
@@ -19,6 +19,7 @@ from app.models.user_model import User
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import file_service, document_service
|
||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
@@ -131,6 +132,7 @@ async def create_folder(
|
||||
|
||||
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
@check_knowledge_capacity_quota
|
||||
async def upload_file(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
|
||||
@@ -3,9 +3,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.db import get_db, SessionLocal
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.repositories.home_page_repository import HomePageRepository
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.home_page_service import HomePageService
|
||||
|
||||
@@ -31,9 +32,32 @@ def get_workspace_list(
|
||||
|
||||
@router.get("/version", response_model=ApiResponse)
|
||||
def get_system_version():
|
||||
"""获取系统版本号+说明"""
|
||||
current_version = settings.SYSTEM_VERSION
|
||||
version_info = HomePageService.load_version_introduction(current_version)
|
||||
"""获取系统版本号 + 说明"""
|
||||
current_version = None
|
||||
version_info = None
|
||||
|
||||
# 1️⃣ 优先从数据库获取最新已发布的版本
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
current_version, version_info = HomePageRepository.get_latest_version_introduction(db)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# 2️⃣ 降级:使用环境变量中的版本号
|
||||
if not current_version:
|
||||
current_version = settings.SYSTEM_VERSION
|
||||
version_info = HomePageService.load_version_introduction(current_version)
|
||||
|
||||
# 3️⃣ 如果数据库和 JSON 都没有,返回基本信息
|
||||
if not version_info:
|
||||
version_info = {
|
||||
"introduction": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []},
|
||||
"introduction_en": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []}
|
||||
}
|
||||
|
||||
return success(
|
||||
data={
|
||||
"version": current_version,
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.schemas import knowledge_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import knowledge_service, document_service
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -179,6 +180,7 @@ async def get_knowledges(
|
||||
|
||||
|
||||
@router.post("/knowledge", response_model=ApiResponse)
|
||||
@check_knowledge_capacity_quota
|
||||
async def create_knowledge(
|
||||
create_data: knowledge_schema.KnowledgeCreate,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -352,6 +354,7 @@ async def delete_knowledge(
|
||||
# 2. Soft-delete knowledge base
|
||||
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
|
||||
db_knowledge.status = 2
|
||||
db_knowledge.updated_at = datetime.datetime.now()
|
||||
db.commit()
|
||||
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
|
||||
return success(msg="The knowledge base has been successfully deleted")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -47,64 +49,64 @@ def get_workspace_total_end_users(
|
||||
|
||||
@router.get("/end_users", response_model=ApiResponse)
|
||||
async def get_workspace_end_users(
|
||||
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"),
|
||||
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"),
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
pagesize: int = Query(10, ge=1, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取工作空间的宿主列表(高性能优化版本 v2)
|
||||
|
||||
优化策略:
|
||||
1. 批量查询 end_users(一次查询而非循环)
|
||||
2. 并发查询所有用户的记忆数量(Neo4j)
|
||||
3. RAG 模式使用批量查询(一次 SQL)
|
||||
4. 只返回必要字段减少数据传输
|
||||
5. 添加短期缓存减少重复查询
|
||||
6. 并发执行配置查询和记忆数量查询
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"end_user": {"id": "uuid", "other_name": "名称"},
|
||||
"memory_num": {"total": 数量},
|
||||
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
||||
}
|
||||
获取工作空间的宿主列表(分页查询,支持模糊搜索)
|
||||
|
||||
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
|
||||
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含宿主列表和分页信息
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 尝试从缓存获取(30秒缓存)
|
||||
cache_key = f"end_users:workspace:{workspace_id}"
|
||||
try:
|
||||
cached_data = await aio_redis_get(cache_key)
|
||||
if cached_data:
|
||||
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
||||
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
||||
|
||||
# 如果未提供 workspace_id,使用当前用户的工作空间
|
||||
if workspace_id is None:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 获取当前空间类型
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
|
||||
# 获取 end_users(已优化为批量查询)
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
|
||||
|
||||
# 获取分页的 end_users
|
||||
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
current_user=current_user,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keyword=keyword
|
||||
)
|
||||
|
||||
end_users = end_users_result.get("items", [])
|
||||
total = end_users_result.get("total", 0)
|
||||
|
||||
if not end_users:
|
||||
api_logger.info("工作空间下没有宿主")
|
||||
# 缓存空结果,避免重复查询
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
return success(data=[], msg="宿主列表获取成功")
|
||||
|
||||
api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}")
|
||||
return success(data={
|
||||
"items": [],
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"hasnext": (page * pagesize) < total
|
||||
}
|
||||
}, msg="宿主列表获取成功")
|
||||
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
|
||||
|
||||
# 并发执行两个独立的查询任务
|
||||
async def get_memory_configs():
|
||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||
@@ -116,7 +118,7 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
return {}
|
||||
|
||||
|
||||
async def get_memory_nums():
|
||||
"""获取记忆数量"""
|
||||
if current_workspace_type == "rag":
|
||||
@@ -130,26 +132,18 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
|
||||
elif current_workspace_type == "neo4j":
|
||||
# Neo4j 模式:并发查询(带并发限制)
|
||||
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
||||
MAX_CONCURRENT_QUERIES = 10
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
||||
|
||||
async def get_neo4j_memory_num(end_user_id: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
return await memory_storage_service.search_all(end_user_id)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
||||
return {"total": 0}
|
||||
|
||||
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
||||
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
||||
|
||||
# Neo4j 模式:批量查询(简化版本,只返回total)
|
||||
try:
|
||||
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
||||
return {uid: {"total": count} for uid, count in batch_result.items()}
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
|
||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||
try:
|
||||
from app.celery_app import celery_app as _celery_app
|
||||
@@ -170,13 +164,13 @@ async def get_workspace_end_users(
|
||||
get_memory_configs(),
|
||||
get_memory_nums()
|
||||
)
|
||||
|
||||
# 构建结果(优化:使用列表推导式)
|
||||
result = []
|
||||
|
||||
# 构建结果列表
|
||||
items = []
|
||||
for end_user in end_users:
|
||||
user_id = str(end_user.id)
|
||||
config_info = memory_configs_map.get(user_id, {})
|
||||
result.append({
|
||||
items.append({
|
||||
'end_user': {
|
||||
'id': user_id,
|
||||
'other_name': end_user.other_name
|
||||
@@ -187,12 +181,6 @@ async def get_workspace_end_users(
|
||||
"memory_config_name": config_info.get("memory_config_name")
|
||||
}
|
||||
})
|
||||
|
||||
# 写入缓存(30秒过期)
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||
try:
|
||||
@@ -202,7 +190,18 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
# 构建分页响应
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"hasnext": (page * pagesize) < total
|
||||
}
|
||||
}
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
|
||||
@@ -592,7 +591,7 @@ async def dashboard_data(
|
||||
"total_api_call": None
|
||||
}
|
||||
|
||||
# 1. 获取记忆总量(total_memory)
|
||||
# 1. 获取记忆总量(total_memory)—— neo4j 独有逻辑:查询 neo4j 存储节点
|
||||
try:
|
||||
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
|
||||
db=db,
|
||||
@@ -601,49 +600,33 @@ async def dashboard_data(
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
# 包含自有app + 被分享给本工作空间的app
|
||||
from app.services import app_service as _app_svc
|
||||
_, total_app = _app_svc.AppService(db).list_apps(
|
||||
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||
)
|
||||
neo4j_data["total_app"] = total_app
|
||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||
|
||||
# 2. 获取知识库类型统计(total_knowledge)
|
||||
try:
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
memory_agent_service = MemoryAgentService()
|
||||
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
only_active=True,
|
||||
current_workspace_id=workspace_id,
|
||||
db=db
|
||||
)
|
||||
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
|
||||
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
||||
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
||||
neo4j_data.update(common_stats)
|
||||
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
||||
|
||||
# 3. 获取API调用统计(total_api_call)
|
||||
# 计算昨日对比
|
||||
try:
|
||||
# 使用 AppStatisticsService 获取真实的API调用统计
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
storage_type=storage_type,
|
||||
today_data=neo4j_data
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
neo4j_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
|
||||
neo4j_data.update(changes)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取API调用统计失败: {str(e)}")
|
||||
neo4j_data["total_api_call"] = 0
|
||||
|
||||
api_logger.warning(f"计算neo4j昨日对比失败: {str(e)}")
|
||||
neo4j_data.update({
|
||||
"total_memory_change": None,
|
||||
"total_app_change": None,
|
||||
"total_knowledge_change": None,
|
||||
"total_api_call_change": None,
|
||||
})
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
|
||||
@@ -656,44 +639,37 @@ async def dashboard_data(
|
||||
"total_api_call": None
|
||||
}
|
||||
|
||||
# 获取RAG相关数据
|
||||
# 1. 获取记忆总量(total_memory)—— rag 独有逻辑:查询 document 表的 chunk_num
|
||||
try:
|
||||
# total_memory: 只统计用户知识库(permission_id='Memory')的chunk数
|
||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
# 包含自有app + 被分享给本工作空间的app
|
||||
from app.services import app_service as _app_svc
|
||||
_, total_app = _app_svc.AppService(db).list_apps(
|
||||
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||
)
|
||||
rag_data["total_app"] = total_app
|
||||
|
||||
# total_knowledge: 使用 total_kb(总知识库数)
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
rag_data["total_knowledge"] = total_kb
|
||||
|
||||
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
|
||||
try:
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
rag_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
||||
rag_data["total_api_call"] = 0
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||
api_logger.info(f"成功获取RAG记忆总量: {total_chunk}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
api_logger.warning(f"获取RAG记忆总量失败: {str(e)}")
|
||||
|
||||
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
||||
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
||||
rag_data.update(common_stats)
|
||||
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
||||
|
||||
# 计算昨日对比
|
||||
try:
|
||||
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
storage_type=storage_type,
|
||||
today_data=rag_data
|
||||
)
|
||||
rag_data.update(changes)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"计算RAG昨日对比失败: {str(e)}")
|
||||
rag_data.update({
|
||||
"total_memory_change": None,
|
||||
"total_app_change": None,
|
||||
"total_knowledge_change": None,
|
||||
"total_api_call_change": None,
|
||||
})
|
||||
|
||||
result["rag_data"] = rag_data
|
||||
api_logger.info("成功获取rag_data")
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from app.services.memory_storage_service import (
|
||||
analytics_hot_memory_tags,
|
||||
analytics_recent_activity_stats,
|
||||
kb_type_distribution,
|
||||
search_all,
|
||||
search_all_batch,
|
||||
search_chunk,
|
||||
search_detials,
|
||||
search_dialogue,
|
||||
@@ -34,6 +34,7 @@ from app.services.memory_storage_service import (
|
||||
search_entity,
|
||||
search_statement,
|
||||
)
|
||||
from app.core.quota_stub import check_memory_engine_quota
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -76,6 +77,7 @@ async def get_storage_info(
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@check_memory_engine_quota
|
||||
def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -409,7 +411,10 @@ async def search_all_num(
|
||||
) -> dict:
|
||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_all(end_user_id)
|
||||
if not end_user_id:
|
||||
return success(data={"total": 0}, msg="查询成功")
|
||||
batch_result = await search_all_batch([end_user_id])
|
||||
result = {"total": batch_result.get(end_user_id, 0)}
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search all failed: {str(e)}")
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.quota_stub import check_model_quota, check_model_activation_quota
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -303,6 +304,7 @@ async def create_model(
|
||||
|
||||
|
||||
@router.post("/composite", response_model=ApiResponse)
|
||||
@check_model_quota
|
||||
async def create_composite_model(
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -329,6 +331,7 @@ async def create_composite_model(
|
||||
|
||||
|
||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||
@check_model_activation_quota
|
||||
async def update_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
@@ -370,6 +373,7 @@ def delete_composite_model(
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=ApiResponse)
|
||||
@check_model_activation_quota
|
||||
def update_model(
|
||||
model_id: uuid.UUID,
|
||||
model_data: model_schema.ModelConfigUpdate,
|
||||
|
||||
@@ -28,6 +28,8 @@ from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, H
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.quota_stub import check_ontology_project_quota
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -163,6 +165,7 @@ def _get_ontology_service(
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
capability=api_key_config.capability,
|
||||
max_retries=3,
|
||||
timeout=60.0
|
||||
)
|
||||
@@ -286,6 +289,7 @@ async def extract_ontology(
|
||||
# ==================== 本体场景管理接口 ====================
|
||||
|
||||
@router.post("/scene", response_model=ApiResponse)
|
||||
@check_ontology_project_quota
|
||||
async def create_scene(
|
||||
request: SceneCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
|
||||
@@ -124,10 +124,11 @@ async def get_prompt_opt(
|
||||
skill=data.skill
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
yield f"event:message\ndata: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"event:error\ndata: {json.dumps(
|
||||
{"error": str(e)}
|
||||
{"error": str(e)},
|
||||
ensure_ascii=False
|
||||
)}\n\n"
|
||||
yield "event:end\ndata: {}\n\n"
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.quota_manager import check_end_user_quota
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db, get_db_read
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
@@ -218,9 +219,20 @@ def list_conversations(
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
# 仅在新建终端用户时检查配额
|
||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||
if existing_end_user is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
from app.models.workspace_model import Workspace
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if ws:
|
||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
@@ -348,6 +360,18 @@ async def chat(
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
# 仅在新建终端用户时检查配额,已有用户复用不受限制
|
||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||
logger.info(f"终端用户配额检查: workspace_id={workspace_id}, other_id={other_id}, existing={existing_end_user is not None}")
|
||||
if existing_end_user is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
from app.models.workspace_model import Workspace
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if ws:
|
||||
logger.info(f"新终端用户,执行配额检查: tenant_id={ws.tenant_id}")
|
||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=workspace_id,
|
||||
@@ -453,31 +477,10 @@ async def chat(
|
||||
# 流式返回
|
||||
agent_config = agent_config_4_app_release(release)
|
||||
|
||||
if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
||||
agent_config.model_parameters["deep_thinking"] = False
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
@@ -503,20 +506,6 @@ async def chat(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
# 非流式返回
|
||||
# result = await service.chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
result = await app_chat_service.agnet_chat(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
@@ -575,48 +564,6 @@ async def chat(
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
# 多 Agent 流式返回
|
||||
# if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.multi_agent_chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
|
||||
# # 多 Agent 非流式返回
|
||||
# result = await service.multi_agent_chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
config = workflow_config_4_app_release(release)
|
||||
if not config.id:
|
||||
@@ -714,7 +661,8 @@ async def config_query(
|
||||
"app_type": release.app.type,
|
||||
"variables": release.config.get("variables"),
|
||||
"memory": release.config.get("memory", {}).get("enabled"),
|
||||
"features": release.config.get("features")
|
||||
"features": release.config.get("features"),
|
||||
"model_parameters": release.config.get("model_parameters")
|
||||
}
|
||||
elif release.app.type == AppType.MULTI_AGENT:
|
||||
content = {
|
||||
|
||||
@@ -4,7 +4,17 @@
|
||||
认证方式: API Key
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
|
||||
|
||||
from . import (
|
||||
app_api_controller,
|
||||
end_user_api_controller,
|
||||
memory_api_controller,
|
||||
memory_config_api_controller,
|
||||
rag_api_chunk_controller,
|
||||
rag_api_document_controller,
|
||||
rag_api_file_controller,
|
||||
rag_api_knowledge_controller,
|
||||
)
|
||||
|
||||
# 创建 V1 API 路由器
|
||||
service_router = APIRouter()
|
||||
@@ -16,5 +26,7 @@ service_router.include_router(rag_api_document_controller.router)
|
||||
service_router.include_router(rag_api_file_controller.router)
|
||||
service_router.include_router(rag_api_chunk_controller.router)
|
||||
service_router.include_router(memory_api_controller.router)
|
||||
service_router.include_router(end_user_api_controller.router)
|
||||
service_router.include_router(memory_config_api_controller.router)
|
||||
|
||||
__all__ = ["service_router"]
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas import AppChatRequest, conversation_schema
|
||||
@@ -61,18 +62,18 @@ async def list_apps():
|
||||
# return success(data={"received": True}, msg="消息已接收")
|
||||
|
||||
|
||||
def _checkAppConfig(app: App):
|
||||
if app.type == AppType.AGENT:
|
||||
if not app.current_release.config:
|
||||
def _checkAppConfig(release: AppRelease):
|
||||
if release.type == AppType.AGENT:
|
||||
if not release.config:
|
||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app.type == AppType.MULTI_AGENT:
|
||||
if not app.current_release.config:
|
||||
elif release.type == AppType.MULTI_AGENT:
|
||||
if not release.config:
|
||||
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app.type == AppType.WORKFLOW:
|
||||
if not app.current_release.config:
|
||||
elif release.type == AppType.WORKFLOW:
|
||||
if not release.config:
|
||||
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
else:
|
||||
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
|
||||
raise BusinessException("不支持的应用类型", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@@ -86,13 +87,35 @@ async def chat(
|
||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
):
|
||||
"""
|
||||
Agent/Workflow 聊天接口
|
||||
|
||||
- 不传 version:使用当前生效版本(current_release,回滚后为回滚目标版本)
|
||||
- 传 version=release_id:使用指定版本uuid的历史快照,例如 {"version": "{{release_id}}"}
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = AppChatRequest(**body)
|
||||
|
||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||
|
||||
# 版本切换:指定 release_id 时查找对应历史快照,否则使用当前激活版本
|
||||
if payload.version is not None:
|
||||
active_release = app_service.get_release_by_id(app.id, payload.version)
|
||||
else:
|
||||
active_release = app.current_release
|
||||
other_id = payload.user_id
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
|
||||
# 仅在新建终端用户时检查配额,已有用户复用不受限制
|
||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||
if existing_end_user is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
from app.models.workspace_model import Workspace
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if ws:
|
||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
@@ -127,7 +150,7 @@ async def chat(
|
||||
storage_type = 'neo4j'
|
||||
app_type = app.type
|
||||
# check app config
|
||||
_checkAppConfig(app)
|
||||
_checkAppConfig(active_release)
|
||||
|
||||
# 获取或创建会话(提前验证)
|
||||
conversation = conversation_service.create_or_get_conversation(
|
||||
@@ -142,8 +165,13 @@ async def chat(
|
||||
|
||||
# print("="*50)
|
||||
# print(app.current_release.default_model_config_id)
|
||||
agent_config = agent_config_4_app_release(app.current_release)
|
||||
agent_config = agent_config_4_app_release(active_release)
|
||||
# print(agent_config.default_model_config_id)
|
||||
|
||||
# thinking 开关:仅当 agent 配置了 deep_thinking 且请求 thinking=True 时才启用
|
||||
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
||||
agent_config.model_parameters["deep_thinking"] = False
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
@@ -189,7 +217,7 @@ async def chat(
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# 多 Agent 流式返回
|
||||
config = multi_agent_config_4_app_release(app.current_release)
|
||||
config = multi_agent_config_4_app_release(active_release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
@@ -232,7 +260,7 @@ async def chat(
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
# 多 Agent 流式返回
|
||||
config = workflow_config_4_app_release(app.current_release)
|
||||
config = workflow_config_4_app_release(active_release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
@@ -248,7 +276,7 @@ async def chat(
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
release_id=active_release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
@@ -283,7 +311,7 @@ async def chat(
|
||||
files=payload.files,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id
|
||||
release_id=active_release.id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
@@ -297,6 +325,4 @@ async def chat(
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
173
api/app/controllers/service/end_user_api_controller.py
Normal file
173
api/app/controllers/service/end_user_api_controller.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""End User 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import user_memory_controllers
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.quota_stub import check_end_user_quota
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.end_user_info_schema import EndUserInfoUpdate
|
||||
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
|
||||
from app.services import api_key_service
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
||||
"""Build a current_user object from API key auth
|
||||
|
||||
Args:
|
||||
api_key_auth: Validated API key auth info
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object with current_workspace_id set
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
return current_user
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
@require_api_key(scopes=["memory"])
|
||||
@check_end_user_quota
|
||||
async def create_end_user(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Create or retrieve an end user for the workspace.
|
||||
|
||||
Creates a new end user and connects it to a memory configuration.
|
||||
If an end user with the same other_id already exists in the workspace,
|
||||
returns the existing one.
|
||||
|
||||
Optionally accepts a memory_config_id to connect the end user to a specific
|
||||
memory configuration. If not provided, falls back to the workspace default config.
|
||||
Optionally accepts an app_id to bind the end user to a specific app.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = CreateEndUserRequest(**body)
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
|
||||
logger.info("Create end user request - other_id: %s, workspace_id: %s", payload.other_id, workspace_id)
|
||||
|
||||
# Resolve memory_config_id: explicit > workspace default
|
||||
memory_config_id = None
|
||||
config_service = MemoryConfigService(db)
|
||||
|
||||
if payload.memory_config_id:
|
||||
try:
|
||||
memory_config_id = uuid.UUID(payload.memory_config_id)
|
||||
except ValueError:
|
||||
raise BusinessException(
|
||||
f"Invalid memory_config_id format: {payload.memory_config_id}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
config = config_service.get_config_with_fallback(memory_config_id, workspace_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
f"Memory config not found: {payload.memory_config_id}",
|
||||
BizCode.MEMORY_CONFIG_NOT_FOUND
|
||||
)
|
||||
memory_config_id = config.config_id
|
||||
else:
|
||||
default_config = config_service.get_workspace_default_config(workspace_id)
|
||||
if default_config:
|
||||
memory_config_id = default_config.config_id
|
||||
logger.info(f"Using workspace default memory config: {memory_config_id}")
|
||||
else:
|
||||
logger.warning(f"No default memory config found for workspace: {workspace_id}")
|
||||
|
||||
# Resolve app_id: explicit from payload, otherwise None
|
||||
app_id = None
|
||||
if payload.app_id:
|
||||
try:
|
||||
app_id = uuid.UUID(payload.app_id)
|
||||
except ValueError:
|
||||
raise BusinessException(
|
||||
f"Invalid app_id format: {payload.app_id}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_or_create_end_user_with_config(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=payload.other_id,
|
||||
memory_config_id=memory_config_id,
|
||||
other_name=payload.other_name,
|
||||
)
|
||||
end_user.other_name = payload.other_name
|
||||
logger.info(f"End user ready: {end_user.id}")
|
||||
|
||||
result = {
|
||||
"id": str(end_user.id),
|
||||
"other_id": end_user.other_id or "",
|
||||
"other_name": end_user.other_name or "",
|
||||
"workspace_id": str(end_user.workspace_id),
|
||||
"memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None,
|
||||
}
|
||||
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_end_user_info(
|
||||
request: Request,
|
||||
end_user_id: str,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get end user info.
|
||||
|
||||
Retrieves the info record (aliases, meta_data, etc.) for the specified end user.
|
||||
Delegates to the manager-side controller for shared logic.
|
||||
"""
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
return await user_memory_controllers.get_end_user_info(
|
||||
end_user_id=end_user_id,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/info/update")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_end_user_info(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update end user info.
|
||||
|
||||
Updates the info record (other_name, aliases, meta_data) for the specified end user.
|
||||
Delegates to the manager-side controller for shared logic.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = EndUserInfoUpdate(**body)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
return await user_memory_controllers.update_end_user_info(
|
||||
info_update=payload,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
@@ -1,51 +1,83 @@
|
||||
"""Memory 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.quota_stub import check_end_user_quota
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import (
|
||||
ListConfigsResponse,
|
||||
MemoryReadRequest,
|
||||
MemoryReadResponse,
|
||||
MemoryReadSyncResponse,
|
||||
MemoryWriteRequest,
|
||||
MemoryWriteResponse,
|
||||
MemoryWriteSyncResponse,
|
||||
)
|
||||
from app.services.memory_api_service import MemoryAPIService
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def _sanitize_task_result(result: dict) -> dict:
|
||||
"""Make Celery task result JSON-serializable.
|
||||
|
||||
Converts UUID and other non-serializable values to strings.
|
||||
|
||||
Args:
|
||||
result: Raw task result dict from task_service
|
||||
|
||||
Returns:
|
||||
JSON-safe dict
|
||||
"""
|
||||
import uuid as _uuid
|
||||
from datetime import datetime
|
||||
|
||||
def _convert(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: _convert(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_convert(i) for i in obj]
|
||||
if isinstance(obj, _uuid.UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return obj
|
||||
|
||||
return _convert(result)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_memory_info():
|
||||
"""获取记忆服务信息(占位)"""
|
||||
return success(data={}, msg="Memory API - Coming Soon")
|
||||
|
||||
|
||||
@router.post("/write_api_service")
|
||||
@router.post("/write")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def write_memory_api_service(
|
||||
async def write_memory(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Message content"),
|
||||
):
|
||||
"""
|
||||
Write memory to storage.
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
Submit a memory write task.
|
||||
|
||||
Validates the end user, then dispatches the write to a Celery background task
|
||||
with per-user fair locking. Returns a task_id for status polling.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryWriteRequest(**body)
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = await memory_api_service.write_memory(
|
||||
|
||||
result = memory_api_service.write_memory(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
@@ -53,31 +85,53 @@ async def write_memory_api_service(
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Memory write successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory written successfully")
|
||||
|
||||
logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
|
||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
||||
|
||||
|
||||
@router.post("/read_api_service")
|
||||
@router.get("/write/status")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_memory_api_service(
|
||||
async def get_write_task_status(
|
||||
request: Request,
|
||||
task_id: str = Query(..., description="Celery task ID"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Check the status of a memory write task.
|
||||
|
||||
Returns the current status and result (if completed) of a previously submitted write task.
|
||||
"""
|
||||
logger.info(f"Write task status check - task_id: {task_id}")
|
||||
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
result = get_task_memory_write_result(task_id)
|
||||
|
||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||
|
||||
|
||||
@router.post("/read")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_memory(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Query message"),
|
||||
):
|
||||
"""
|
||||
Read memory from storage.
|
||||
|
||||
Queries and retrieves memories for the specified end user with context-aware responses.
|
||||
Submit a memory read task.
|
||||
|
||||
Validates the end user, then dispatches the read to a Celery background task.
|
||||
Returns a task_id for status polling.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryReadRequest(**body)
|
||||
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = await memory_api_service.read_memory(
|
||||
|
||||
result = memory_api_service.read_memory(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
@@ -86,30 +140,95 @@ async def read_memory_api_service(
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
||||
|
||||
logger.info(f"Memory read task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
|
||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read task submitted")
|
||||
|
||||
|
||||
@router.get("/configs")
|
||||
@router.get("/read/status")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def list_memory_configs(
|
||||
async def get_read_task_status(
|
||||
request: Request,
|
||||
task_id: str = Query(..., description="Celery task ID"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List all memory configs for the workspace.
|
||||
|
||||
Returns all available memory configurations associated with the authorized workspace.
|
||||
Check the status of a memory read task.
|
||||
|
||||
Returns the current status and result (if completed) of a previously submitted read task.
|
||||
"""
|
||||
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||
logger.info(f"Read task status check - task_id: {task_id}")
|
||||
|
||||
from app.services.task_service import get_task_memory_read_result
|
||||
result = get_task_memory_read_result(task_id)
|
||||
|
||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||
|
||||
|
||||
@router.post("/write/sync")
|
||||
@require_api_key(scopes=["memory"])
|
||||
@check_end_user_quota
|
||||
async def write_memory_sync(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Message content"),
|
||||
):
|
||||
"""
|
||||
Write memory synchronously.
|
||||
|
||||
Blocks until the write completes and returns the result directly.
|
||||
For async processing with task polling, use /write instead.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryWriteRequest(**body)
|
||||
logger.info(f"Memory write (sync) request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = memory_api_service.list_memory_configs(
|
||||
result = await memory_api_service.write_memory_sync(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
config_id=payload.config_id,
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||
logger.info(f"Memory write (sync) successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryWriteSyncResponse(**result).model_dump(), msg="Memory written successfully")
|
||||
|
||||
|
||||
@router.post("/read/sync")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_memory_sync(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Query message"),
|
||||
):
|
||||
"""
|
||||
Read memory synchronously.
|
||||
|
||||
Blocks until the read completes and returns the answer directly.
|
||||
For async processing with task polling, use /read instead.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryReadRequest(**body)
|
||||
logger.info(f"Memory read (sync) request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = await memory_api_service.read_memory_sync(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
search_switch=payload.search_switch,
|
||||
config_id=payload.config_id,
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Memory read (sync) successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryReadSyncResponse(**result).model_dump(), msg="Memory read successfully")
|
||||
|
||||
491
api/app/controllers/service/memory_config_api_controller.py
Normal file
491
api/app/controllers/service/memory_config_api_controller.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""Memory Config 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, Query, Request
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import memory_storage_controller
|
||||
from app.controllers import memory_forget_controller
|
||||
from app.controllers import ontology_controller
|
||||
from app.controllers import emotion_config_controller
|
||||
from app.controllers import memory_reflection_controller
|
||||
from app.schemas.memory_storage_schema import ForgettingConfigUpdateRequest
|
||||
from app.controllers.emotion_config_controller import EmotionConfigUpdate
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import (
|
||||
ConfigUpdateExtractedRequest,
|
||||
ConfigUpdateRequest,
|
||||
ListConfigsResponse,
|
||||
ConfigCreateRequest,
|
||||
ConfigUpdateForgettingRequest,
|
||||
EmotionConfigUpdateRequest,
|
||||
ReflectionConfigUpdateRequest,
|
||||
)
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigParamsCreate,
|
||||
)
|
||||
from app.services import api_key_service
|
||||
from app.services.memory_api_service import MemoryAPIService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
router = APIRouter(prefix="/memory_config", tags=["V1 - Memory Config API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
||||
"""Build a current_user object from API key auth
|
||||
|
||||
Args:
|
||||
api_key_auth: Validated API key auth info
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object with current_workspace_id set
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
return current_user
|
||||
|
||||
|
||||
def _verify_config_ownership(config_id:str, workspace_id:uuid.UUID, db:Session):
|
||||
"""Verify that the config belongs to the workspace.
|
||||
|
||||
Args:
|
||||
config_id: The ID of the config to verify
|
||||
workspace_id: The workspace ID tocheck against
|
||||
db: Database session for querying
|
||||
Raises:
|
||||
BusinessException: If the config does not exist or does not belong to the workspace
|
||||
"""
|
||||
try:
|
||||
resolved_id = resolve_config_id(config_id, db)
|
||||
except ValueError as e:
|
||||
raise BusinessException(
|
||||
message=f"Invalid config_id: {e}",
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
)
|
||||
config = MemoryConfigRepository.get_by_id(db, resolved_id)
|
||||
if not config or config.workspace_id != workspace_id:
|
||||
raise BusinessException(
|
||||
message="Config not found or access denied",
|
||||
code=BizCode.MEMORY_CONFIG_NOT_FOUND,
|
||||
)
|
||||
|
||||
# @router.get("/configs")
|
||||
# @require_api_key(scopes=["memory"])
|
||||
# async def list_memory_configs(
|
||||
# request: Request,
|
||||
# api_key_auth: ApiKeyAuth = None,
|
||||
# db: Session = Depends(get_db),
|
||||
# ):
|
||||
# """
|
||||
# List all memory configs for the workspace.
|
||||
|
||||
# Returns all available memory configurations associated with the authorized workspace.
|
||||
# """
|
||||
# logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
# memory_api_service = MemoryAPIService(db)
|
||||
|
||||
# result = memory_api_service.list_memory_configs(
|
||||
# workspace_id=api_key_auth.workspace_id,
|
||||
# )
|
||||
|
||||
# logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||
# return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||
|
||||
@router.get("/read_all_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_all_config(
|
||||
request:Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List all memory configs with full details (enhanced version).
|
||||
|
||||
Returns complete config fields for the authorized workspace.
|
||||
No config_id ownership check needed — results are filtered by workspace.
|
||||
"""
|
||||
logger.info(f"V1 get all configs (full) - workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return memory_storage_controller.read_all_config(
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
@router.get("/scenes/simple")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_ontology_scenes(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get available ontology scenes for the workspace.
|
||||
|
||||
Returns a simple list of scene_id and scene_name for dropdown selection.
|
||||
Used before creating a memory config to choose which ontology scene to associate.
|
||||
"""
|
||||
logger.info(f"V1 get scenes - workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return await ontology_controller.get_scenes_simple(
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
@router.get("/read_config_extracted")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_extracted(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get extraction engine config details for a specific config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read extracted config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return memory_storage_controller.read_config_extracted(
|
||||
config_id = config_id,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
|
||||
@router.get("/read_config_forgetting")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_forgetting(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get forgetting settings for a specific memory config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read forgetting config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
result = await memory_forget_controller.read_forgetting_config(
|
||||
config_id = config_id,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
|
||||
|
||||
@router.get("/read_config_emotion")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_emotion(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get emotion engine config details for a specific config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read emotion config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return jsonable_encoder(emotion_config_controller.get_emotion_config(
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
))
|
||||
|
||||
@router.get("/read_config_reflection")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_reflection(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get reflection engine config details for a specific config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read reflection config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return jsonable_encoder(await memory_reflection_controller.start_reflection_configs(
|
||||
config_id=config_id,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
))
|
||||
|
||||
|
||||
@router.post("/create_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def create_memory_config(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
):
|
||||
"""
|
||||
Create a new memory config for the workspace.
|
||||
|
||||
The config will be associated with the workspace of the API Key.
|
||||
config_name is required, other fields are optional.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigCreateRequest(**body)
|
||||
|
||||
logger.info(f"V1 create config - workspace: {api_key_auth.workspace_id}, config_name: {payload.config_name}")
|
||||
|
||||
# 构造管理端 Schema,workspace_id 从 API Key 注入
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
mgmt_payload = ConfigParamsCreate(
|
||||
config_name=payload.config_name,
|
||||
config_desc=payload.config_desc or "",
|
||||
scene_id=payload.scene_id,
|
||||
llm_id=payload.llm_id,
|
||||
embedding_id=payload.embedding_id,
|
||||
rerank_id=payload.rerank_id,
|
||||
reflection_model_id=payload.reflection_model_id,
|
||||
emotion_model_id=payload.emotion_model_id,
|
||||
)
|
||||
#将返回数据中UUID序列化处理
|
||||
result =memory_storage_controller.create_config(
|
||||
payload=mgmt_payload,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
x_language_type=x_language_type,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@router.put("/update_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_memory_config(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update memory config basic info (name, description, scene).
|
||||
|
||||
Requires API Key with 'memory' scope
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigUpdateRequest(**body)
|
||||
|
||||
logger.info(f"V1 update config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
mgmt_payload = ConfigUpdate(
|
||||
config_id = payload.config_id,
|
||||
config_name = payload.config_name,
|
||||
config_desc = payload.config_desc,
|
||||
scene_id = payload.scene_id,
|
||||
)
|
||||
|
||||
return memory_storage_controller.update_config(
|
||||
payload = mgmt_payload,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
|
||||
@router.put("/update_config_extracted")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_memory_config_extracted(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
update memory config extraction engine config (models, thresholds, chunking, pruning, etc.).
|
||||
|
||||
Requires API Key with 'memory' scope.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigUpdateExtractedRequest(**body)
|
||||
|
||||
logger.info(f"V1 update extracted config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
#校验权限
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = ConfigUpdateExtracted(**update_fields)
|
||||
|
||||
return memory_storage_controller.update_config_extracted(
|
||||
payload = mgmt_payload,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
|
||||
@router.put("/update_config_forgetting")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_memory_config_forgetting(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
update memory config forgetting settings (forgetting strategy, parameters, etc.).
|
||||
|
||||
Requires API Key with 'memory' scope.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigUpdateForgettingRequest(**body)
|
||||
|
||||
logger.info(f"V1 update forgetting config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
#校验权限
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = ForgettingConfigUpdateRequest(**update_fields)
|
||||
|
||||
#将返回数据中UUID序列化处理
|
||||
result = await memory_forget_controller.update_forgetting_config(
|
||||
payload = mgmt_payload,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@router.put("/update_config_emotion")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_config_emotion(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update emotion engine config (full update).
|
||||
|
||||
All fields except emotion_model_id are required.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = EmotionConfigUpdateRequest(**body)
|
||||
|
||||
logger.info(f"V1 update emotion config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = EmotionConfigUpdate(**update_fields)
|
||||
return jsonable_encoder(emotion_config_controller.update_emotion_config(
|
||||
config=mgmt_payload,
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
))
|
||||
|
||||
@router.put("/update_config_reflection")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_config_reflection(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update reflection engine config (full update).
|
||||
|
||||
All fields are required.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ReflectionConfigUpdateRequest(**body)
|
||||
|
||||
logger.info(f"V1 update reflection config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = Memory_Reflection(**update_fields)
|
||||
|
||||
return jsonable_encoder(await memory_reflection_controller.save_reflection_config(
|
||||
request=mgmt_payload,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
))
|
||||
|
||||
@router.delete("/delete_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def delete_memory_config(
|
||||
config_id: str,
|
||||
request: Request,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Delete a memory config.
|
||||
|
||||
- Default configs cannot be deleted.
|
||||
- If end users are connected and force=False, returns a warning.
|
||||
- If force=True, clears end user references and deletes the config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be deleted.
|
||||
"""
|
||||
logger.info(f"V1 delete config - config_id: {config_id}, force: {force}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return memory_storage_controller.delete_config(
|
||||
config_id=config_id,
|
||||
force=force,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
@@ -11,11 +11,13 @@ from app.schemas import skill_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.skill_service import SkillService
|
||||
from app.core.response_utils import success
|
||||
from app.core.quota_stub import check_skill_quota
|
||||
|
||||
router = APIRouter(prefix="/skills", tags=["Skills"])
|
||||
|
||||
|
||||
@router.post("", summary="创建技能")
|
||||
@check_skill_quota
|
||||
def create_skill(
|
||||
data: skill_schema.SkillCreate,
|
||||
db: Session = Depends(get_db),
|
||||
|
||||
173
api/app/controllers/tenant_subscription_controller.py
Normal file
173
api/app/controllers/tenant_subscription_controller.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
租户套餐查询接口(普通用户可访问)
|
||||
"""
|
||||
import datetime
|
||||
from typing import Callable, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.i18n.dependencies import get_translator
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
logger = get_api_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenant", tags=["Tenant"])
|
||||
public_router = APIRouter(tags=["Tenant"])
|
||||
|
||||
|
||||
@router.get("/subscription", response_model=ApiResponse, summary="获取当前用户所属租户的套餐信息")
|
||||
async def get_my_tenant_subscription(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator),
|
||||
):
|
||||
"""
|
||||
获取当前登录用户所属租户的有效套餐订阅信息。
|
||||
包含套餐名称、版本、配额、到期时间等。
|
||||
"""
|
||||
try:
|
||||
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||
|
||||
if not current_user.tenant:
|
||||
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
||||
|
||||
tenant_id = current_user.tenant.id
|
||||
svc = TenantSubscriptionService(db)
|
||||
sub = svc.get_subscription(tenant_id)
|
||||
|
||||
if not sub:
|
||||
# 无订阅记录时,兜底返回免费套餐信息
|
||||
free_plan = svc.plan_repo.get_free_plan()
|
||||
if not free_plan:
|
||||
return success(data=None, msg="暂无有效套餐")
|
||||
return success(data={
|
||||
"subscription_id": None,
|
||||
"tenant_id": str(tenant_id),
|
||||
"package_plan_id": str(free_plan.id),
|
||||
"package_version": free_plan.version,
|
||||
"package_plan": {
|
||||
"id": str(free_plan.id),
|
||||
"name": free_plan.name,
|
||||
"name_en": free_plan.name_en,
|
||||
"version": free_plan.version,
|
||||
"category": free_plan.category,
|
||||
"tier_level": free_plan.tier_level,
|
||||
"price": float(free_plan.price) if free_plan.price is not None else 0.0,
|
||||
"billing_cycle": free_plan.billing_cycle,
|
||||
"core_value": free_plan.core_value,
|
||||
"core_value_en": free_plan.core_value_en,
|
||||
"tech_support": free_plan.tech_support,
|
||||
"tech_support_en": free_plan.tech_support_en,
|
||||
"sla_compliance": free_plan.sla_compliance,
|
||||
"sla_compliance_en": free_plan.sla_compliance_en,
|
||||
"page_customization": free_plan.page_customization,
|
||||
"page_customization_en": free_plan.page_customization_en,
|
||||
"theme_color": free_plan.theme_color,
|
||||
},
|
||||
"started_at": None,
|
||||
"expired_at": None,
|
||||
"status": "active",
|
||||
"quotas": free_plan.quotas or {},
|
||||
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
}, msg="免费套餐")
|
||||
|
||||
return success(data=svc.build_response(sub))
|
||||
|
||||
except ModuleNotFoundError:
|
||||
# 社区版无 premium 模块,从配置文件读取免费套餐
|
||||
if not current_user.tenant:
|
||||
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
||||
|
||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||
|
||||
plan = DEFAULT_FREE_PLAN
|
||||
response_data = {
|
||||
"subscription_id": None,
|
||||
"tenant_id": str(current_user.tenant.id),
|
||||
"package_plan_id": None,
|
||||
"package_version": plan["version"],
|
||||
"package_plan": {
|
||||
"id": None,
|
||||
"name": plan["name"],
|
||||
"name_en": plan.get("name_en"),
|
||||
"version": plan["version"],
|
||||
"category": plan["category"],
|
||||
"tier_level": plan["tier_level"],
|
||||
"price": float(plan["price"]),
|
||||
"billing_cycle": plan["billing_cycle"],
|
||||
"core_value": plan.get("core_value"),
|
||||
"core_value_en": plan.get("core_value_en"),
|
||||
"tech_support": plan.get("tech_support"),
|
||||
"tech_support_en": plan.get("tech_support_en"),
|
||||
"sla_compliance": plan.get("sla_compliance"),
|
||||
"sla_compliance_en": plan.get("sla_compliance_en"),
|
||||
"page_customization": plan.get("page_customization"),
|
||||
"page_customization_en": plan.get("page_customization_en"),
|
||||
"theme_color": plan.get("theme_color"),
|
||||
},
|
||||
"started_at": None,
|
||||
"expired_at": None,
|
||||
"status": "active",
|
||||
"quotas": plan["quotas"],
|
||||
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
}
|
||||
return success(data=response_data, msg="社区版免费套餐")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取租户套餐信息失败: {e}", exc_info=True)
|
||||
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐信息失败"))
|
||||
|
||||
|
||||
@public_router.get("/package-plans", response_model=ApiResponse, summary="获取套餐列表(公开)")
|
||||
async def list_package_plans_public(
|
||||
category: Optional[str] = None,
|
||||
status: Optional[bool] = None,
|
||||
search: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
公开接口,无需鉴权。
|
||||
SaaS 版从数据库读取套餐列表;社区版降级返回 default_free_plan.py 中的免费套餐。
|
||||
"""
|
||||
try:
|
||||
from premium.platform_admin.package_plan_service import PackagePlanService
|
||||
from premium.platform_admin.package_plan_schema import PackagePlanResponse
|
||||
svc = PackagePlanService(db)
|
||||
result = svc.get_list(page=1, size=9999, category=category, status=status, search=search)
|
||||
return success(data=[PackagePlanResponse.model_validate(p).model_dump(mode="json") for p in result["items"]])
|
||||
except ModuleNotFoundError:
|
||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||
plan = DEFAULT_FREE_PLAN
|
||||
return success(data=[{
|
||||
"id": None,
|
||||
"name": plan["name"],
|
||||
"name_en": plan.get("name_en"),
|
||||
"version": plan["version"],
|
||||
"category": plan["category"],
|
||||
"tier_level": plan["tier_level"],
|
||||
"price": float(plan["price"]),
|
||||
"billing_cycle": plan["billing_cycle"],
|
||||
"core_value": plan.get("core_value"),
|
||||
"core_value_en": plan.get("core_value_en"),
|
||||
"tech_support": plan.get("tech_support"),
|
||||
"tech_support_en": plan.get("tech_support_en"),
|
||||
"sla_compliance": plan.get("sla_compliance"),
|
||||
"sla_compliance_en": plan.get("sla_compliance_en"),
|
||||
"page_customization": plan.get("page_customization"),
|
||||
"page_customization_en": plan.get("page_customization_en"),
|
||||
"theme_color": plan.get("theme_color"),
|
||||
"status": plan.get("status", True),
|
||||
"quotas": plan["quotas"],
|
||||
}])
|
||||
except Exception as e:
|
||||
logger.error(f"获取套餐列表失败: {e}", exc_info=True)
|
||||
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐列表失败"))
|
||||
@@ -114,11 +114,14 @@ def get_current_user_info(
|
||||
|
||||
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
||||
if current_user.external_source:
|
||||
from premium.sso.models import SSOSource
|
||||
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||
if source and source.permissions:
|
||||
result_schema.permissions = source.permissions
|
||||
else:
|
||||
try:
|
||||
from premium.sso.models import SSOSource
|
||||
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||
if source and source.permissions:
|
||||
result_schema.permissions = source.permissions
|
||||
else:
|
||||
result_schema.permissions = []
|
||||
except ModuleNotFoundError:
|
||||
result_schema.permissions = []
|
||||
else:
|
||||
result_schema.permissions = ["all"]
|
||||
|
||||
@@ -35,6 +35,7 @@ from app.schemas.workspace_schema import (
|
||||
WorkspaceUpdate,
|
||||
)
|
||||
from app.services import workspace_service
|
||||
from app.core.quota_stub import check_workspace_quota
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -106,6 +107,7 @@ def get_workspaces(
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
@check_workspace_quota
|
||||
def create_workspace(
|
||||
workspace: WorkspaceCreate,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
|
||||
@@ -11,17 +11,14 @@ LangChain Agent 封装
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.errors import GraphRecursionError
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType, ModelProvider
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from app.models.models_model import ModelType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -41,7 +38,11 @@ class LangChainAgent:
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
|
||||
deep_thinking: bool = False, # 是否启用深度思考模式
|
||||
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
|
||||
json_output: bool = False, # 是否强制 JSON 输出
|
||||
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -79,6 +80,17 @@ class LangChainAgent:
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format
|
||||
# 在 system prompt 中注入 JSON 要求
|
||||
from app.models.models_model import ModelProvider
|
||||
if json_output and (
|
||||
(provider.lower() == ModelProvider.DASHSCOPE and not is_omni)
|
||||
or provider.lower() == ModelProvider.VOLCANO
|
||||
# 有工具时 response_format 会被移除,所有 provider 都需要 system prompt 注入保证 JSON 输出
|
||||
or bool(tools)
|
||||
):
|
||||
self.system_prompt += "\n请以JSON格式输出。"
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
@@ -86,21 +98,28 @@ class LangChainAgent:
|
||||
f"auto_calculated={max_iterations is None}"
|
||||
)
|
||||
|
||||
# 创建 RedBearLLM(支持多提供商)
|
||||
# 创建 RedBearLLM,capability 校验由 RedBearModelConfig 统一处理
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
capability=capability,
|
||||
deep_thinking=deep_thinking,
|
||||
thinking_budget_tokens=thinking_budget_tokens,
|
||||
json_output=json_output,
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"streaming": streaming # 使用参数控制流式
|
||||
"streaming": streaming
|
||||
}
|
||||
)
|
||||
|
||||
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
||||
# 从经过校验的 config 读取实际生效的能力开关
|
||||
self.deep_thinking = model_config.deep_thinking
|
||||
self.json_output = model_config.json_output
|
||||
|
||||
# 获取底层模型用于真正的流式调用
|
||||
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
|
||||
@@ -226,10 +245,7 @@ class LangChainAgent:
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# 添加系统提示词
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
messages: list = []
|
||||
|
||||
# 添加历史消息
|
||||
if history:
|
||||
@@ -254,6 +270,33 @@ class LangChainAgent:
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _extract_tokens_from_message(msg) -> int:
|
||||
"""从 AIMessage 或类似对象中提取 total_tokens,兼容多种 provider 格式
|
||||
|
||||
支持的格式:
|
||||
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
|
||||
- response_metadata.usage.total_tokens (部分 provider)
|
||||
- usage_metadata.total_tokens (LangChain 新版)
|
||||
"""
|
||||
total = 0
|
||||
# 1. response_metadata
|
||||
response_meta = getattr(msg, "response_metadata", None)
|
||||
if response_meta and isinstance(response_meta, dict):
|
||||
# 尝试 token_usage 路径
|
||||
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
|
||||
if isinstance(token_usage, dict):
|
||||
total = token_usage.get("total_tokens", 0)
|
||||
# 2. usage_metadata(LangChain 新版 AIMessage 属性)
|
||||
if not total:
|
||||
usage_meta = getattr(msg, "usage_metadata", None)
|
||||
if usage_meta:
|
||||
if isinstance(usage_meta, dict):
|
||||
total = usage_meta.get("total_tokens", 0)
|
||||
else:
|
||||
total = getattr(usage_meta, "total_tokens", 0)
|
||||
return total or 0
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建多模态消息内容
|
||||
@@ -288,17 +331,23 @@ class LangChainAgent:
|
||||
|
||||
return content_parts
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning_content(msg) -> str:
|
||||
"""从 AIMessage 中提取深度思考内容(reasoning_content)
|
||||
|
||||
所有 provider 统一通过 additional_kwargs.reasoning_content 传递:
|
||||
- DeepSeek-R1 / QwQ: 原生字段
|
||||
- Volcano (Doubao-thinking): 由 VolcanoChatOpenAI 从 delta.reasoning_content 注入
|
||||
"""
|
||||
additional = getattr(msg, "additional_kwargs", None) or {}
|
||||
return additional.get("reasoning_content") or additional.get("reasoning", "")
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -306,31 +355,12 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||
context: 上下文信息(如知识库检索结果)
|
||||
files: 多模态文件
|
||||
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat = message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -354,7 +384,7 @@ class LangChainAgent:
|
||||
{"messages": messages},
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
)
|
||||
except RecursionError as e:
|
||||
except (RecursionError, GraphRecursionError) as e:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
||||
extra={"error": str(e)}
|
||||
@@ -377,6 +407,7 @@ class LangChainAgent:
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
reasoning_content = ""
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
@@ -411,16 +442,13 @@ class LangChainAgent:
|
||||
else:
|
||||
content = str(msg.content)
|
||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
total_tokens = self._extract_tokens_from_message(msg)
|
||||
reasoning_content = self._extract_reasoning_content(msg) if self.deep_thinking else ""
|
||||
break
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -431,6 +459,8 @@ class LangChainAgent:
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
if reasoning_content:
|
||||
response["reasoning_content"] = reasoning_content
|
||||
|
||||
logger.debug(
|
||||
"Agent 调用完成",
|
||||
@@ -451,22 +481,20 @@ class LangChainAgent:
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> AsyncGenerator[str | int | dict[str, str], None]:
|
||||
"""执行流式对话
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
history: 历史消息列表
|
||||
context: 上下文信息
|
||||
files: 多模态文件
|
||||
|
||||
Yields:
|
||||
str: 消息内容块
|
||||
int: token 统计
|
||||
Dict: 深度思考内容 {"type": "reasoning", "content": "..."}
|
||||
"""
|
||||
logger.info("=" * 80)
|
||||
logger.info(" chat_stream 方法开始执行")
|
||||
@@ -474,23 +502,6 @@ class LangChainAgent:
|
||||
logger.info(f" Has tools: {bool(self.tools)}")
|
||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||
logger.info("=" * 80)
|
||||
message_chat = message
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -500,17 +511,19 @@ class LangChainAgent:
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
yielded_content = False
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
full_content = ''
|
||||
full_reasoning = ''
|
||||
try:
|
||||
last_event = {}
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
last_event = event
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
@@ -519,12 +532,18 @@ class LangChainAgent:
|
||||
# LLM 流式输出
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
# 提取深度思考内容(仅在启用深度思考时)
|
||||
if self.deep_thinking:
|
||||
reasoning_chunk = self._extract_reasoning_content(chunk)
|
||||
if reasoning_chunk:
|
||||
full_reasoning += reasoning_chunk
|
||||
yield {"type": "reasoning", "content": reasoning_chunk}
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
@@ -535,29 +554,32 @@ class LangChainAgent:
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk:
|
||||
if hasattr(chunk, "content"):
|
||||
# 提取深度思考内容(仅在启用深度思考时)
|
||||
if self.deep_thinking:
|
||||
reasoning_chunk = self._extract_reasoning_content(chunk)
|
||||
if reasoning_chunk:
|
||||
full_reasoning += reasoning_chunk
|
||||
yield {"type": "reasoning", "content": reasoning_chunk}
|
||||
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
@@ -568,22 +590,18 @@ class LangChainAgent:
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
elif isinstance(chunk, str):
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
@@ -593,19 +611,20 @@ class LangChainAgent:
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get(
|
||||
"total_tokens",
|
||||
0
|
||||
) if response_meta else 0
|
||||
yield total_tokens
|
||||
stream_total_tokens = self._extract_tokens_from_message(msg)
|
||||
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
|
||||
yield stream_total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
|
||||
except GraphRecursionError:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),模型可能不支持正确的工具调用停止判断"
|
||||
)
|
||||
if not full_content:
|
||||
yield "抱歉,我在处理您的请求时遇到了问题(已达最大处理步骤限制)。请尝试简化问题或更换模型后重试。"
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -97,7 +97,7 @@ def require_api_key(
|
||||
)
|
||||
|
||||
rate_limiter = RateLimiterService()
|
||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj, db=db)
|
||||
if not is_allowed:
|
||||
logger.warning("API Key 限流触发", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
@@ -106,10 +106,12 @@ def require_api_key(
|
||||
"error_msg": error_msg
|
||||
})
|
||||
# 根据错误消息判断限流类型
|
||||
if "QPS" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||
elif "Daily" in error_msg:
|
||||
if "Daily" in error_msg:
|
||||
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
||||
elif "Tenant" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED # 租户套餐速率超限,同属 QPS 类
|
||||
elif "QPS" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||
else:
|
||||
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ class BizCode(IntEnum):
|
||||
TENANT_NOT_FOUND = 3002
|
||||
WORKSPACE_NO_ACCESS = 3003
|
||||
WORKSPACE_INVITE_NOT_FOUND = 3004
|
||||
WORKSPACE_ACCESS_DENIED = 3005
|
||||
# API Key 管理(3xxx)
|
||||
API_KEY_NOT_FOUND = 3007
|
||||
API_KEY_DUPLICATE_NAME = 3008
|
||||
@@ -30,6 +31,9 @@ class BizCode(IntEnum):
|
||||
API_KEY_QPS_LIMIT_EXCEEDED = 3014
|
||||
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
|
||||
API_KEY_QUOTA_EXCEEDED = 3016
|
||||
API_KEY_RATE_LIMIT_EXCEEDED = 3017
|
||||
QUOTA_EXCEEDED = 3018
|
||||
RATE_LIMIT_EXCEEDED = 3019
|
||||
# 资源(4xxx)
|
||||
NOT_FOUND = 4000
|
||||
USER_NOT_FOUND = 4001
|
||||
@@ -40,6 +44,7 @@ class BizCode(IntEnum):
|
||||
FILE_NOT_FOUND = 4006
|
||||
APP_NOT_FOUND = 4007
|
||||
RELEASE_NOT_FOUND = 4008
|
||||
USER_NO_ACCESS = 4009
|
||||
|
||||
# 冲突/状态(5xxx)
|
||||
DUPLICATE_NAME = 5001
|
||||
@@ -113,8 +118,11 @@ HTTP_MAPPING = {
|
||||
BizCode.FORBIDDEN: 403,
|
||||
BizCode.TENANT_NOT_FOUND: 400,
|
||||
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||
BizCode.WORKSPACE_INVITE_NOT_FOUND: 400,
|
||||
BizCode.WORKSPACE_ACCESS_DENIED: 403,
|
||||
BizCode.NOT_FOUND: 400,
|
||||
BizCode.USER_NOT_FOUND: 200,
|
||||
BizCode.USER_NO_ACCESS: 401,
|
||||
BizCode.WORKSPACE_NOT_FOUND: 400,
|
||||
BizCode.MODEL_NOT_FOUND: 400,
|
||||
BizCode.KNOWLEDGE_NOT_FOUND: 400,
|
||||
@@ -150,7 +158,8 @@ HTTP_MAPPING = {
|
||||
BizCode.API_KEY_QPS_LIMIT_EXCEEDED: 429,
|
||||
BizCode.API_KEY_DAILY_LIMIT_EXCEEDED: 429,
|
||||
BizCode.API_KEY_QUOTA_EXCEEDED: 429,
|
||||
|
||||
BizCode.QUOTA_EXCEEDED: 402,
|
||||
|
||||
BizCode.MODEL_CONFIG_INVALID: 400,
|
||||
BizCode.API_KEY_MISSING: 400,
|
||||
BizCode.PROVIDER_NOT_SUPPORTED: 400,
|
||||
@@ -179,4 +188,21 @@ HTTP_MAPPING = {
|
||||
BizCode.DB_ERROR: 500,
|
||||
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||
BizCode.RATE_LIMITED: 429,
|
||||
BizCode.RATE_LIMIT_EXCEEDED: 429,
|
||||
}
|
||||
|
||||
ERROR_CODE_TO_BIZ_CODE = {
|
||||
"QUOTA_EXCEEDED": BizCode.QUOTA_EXCEEDED,
|
||||
"RATE_LIMIT_EXCEEDED": BizCode.RATE_LIMIT_EXCEEDED,
|
||||
"API_KEY_NOT_FOUND": BizCode.API_KEY_NOT_FOUND,
|
||||
"API_KEY_INVALID": BizCode.API_KEY_INVALID,
|
||||
"API_KEY_EXPIRED": BizCode.API_KEY_EXPIRED,
|
||||
"WORKSPACE_NOT_FOUND": BizCode.WORKSPACE_NOT_FOUND,
|
||||
"WORKSPACE_NO_ACCESS": BizCode.WORKSPACE_NO_ACCESS,
|
||||
"PERMISSION_DENIED": BizCode.PERMISSION_DENIED,
|
||||
"TOKEN_EXPIRED": BizCode.TOKEN_EXPIRED,
|
||||
"TOKEN_INVALID": BizCode.TOKEN_INVALID,
|
||||
"VALIDATION_FAILED": BizCode.VALIDATION_FAILED,
|
||||
"INVALID_PARAMETER": BizCode.INVALID_PARAMETER,
|
||||
"MISSING_PARAMETER": BizCode.MISSING_PARAMETER,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Perceptual Memory Retrieval Node & Service
|
||||
|
||||
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
|
||||
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
|
||||
with BM25+embedding fusion reranking.
|
||||
|
||||
Also provides the perceptual_retrieve_node for use as a LangGraph node.
|
||||
"""
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_perceptual,
|
||||
search_perceptual_by_embedding,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class PerceptualSearchService:
|
||||
"""
|
||||
感知记忆检索服务。
|
||||
|
||||
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
|
||||
调用方只需提供 query / keywords、end_user_id、memory_config,即可获得
|
||||
格式化并排序后的感知记忆列表和拼接文本。
|
||||
|
||||
Usage:
|
||||
service = PerceptualSearchService(end_user_id=..., memory_config=...)
|
||||
results = await service.search(query="...", keywords=[...], limit=10)
|
||||
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
|
||||
"""
|
||||
|
||||
DEFAULT_ALPHA = 0.6
|
||||
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
end_user_id: str,
|
||||
memory_config: Any,
|
||||
alpha: float = DEFAULT_ALPHA,
|
||||
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
|
||||
):
|
||||
self.end_user_id = end_user_id
|
||||
self.memory_config = memory_config
|
||||
self.alpha = alpha
|
||||
self.content_score_threshold = content_score_threshold
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
keywords: Optional[List[str]] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
|
||||
|
||||
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
|
||||
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
|
||||
|
||||
Args:
|
||||
query: 原始用户查询(用于向量检索和 BM25 补查)
|
||||
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
|
||||
limit: 最大返回数量
|
||||
|
||||
Returns:
|
||||
{
|
||||
"memories": [格式化后的记忆 dict, ...],
|
||||
"content": "拼接的纯文本摘要",
|
||||
"keyword_raw": int,
|
||||
"embedding_raw": int,
|
||||
}
|
||||
"""
|
||||
if keywords is None:
|
||||
keywords = [query] if query else []
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
kw_task = self._keyword_search(connector, keywords, limit)
|
||||
emb_task = self._embedding_search(connector, query, limit)
|
||||
|
||||
kw_results, emb_results = await asyncio.gather(
|
||||
kw_task, emb_task, return_exceptions=True
|
||||
)
|
||||
if isinstance(kw_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
|
||||
kw_results = []
|
||||
if isinstance(emb_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
|
||||
emb_results = []
|
||||
|
||||
# 补查 BM25:找出 embedding 命中但 keyword 未命中的 id,
|
||||
# 用原始 query 对这些节点补查全文索引拿 BM25 score
|
||||
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
|
||||
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
|
||||
|
||||
if emb_only_ids and query:
|
||||
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
|
||||
# 把补查到的 BM25 score 注入到 embedding 结果中
|
||||
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
|
||||
for r in emb_results:
|
||||
rid = r.get("id", "")
|
||||
if rid in backfill_map:
|
||||
r["bm25_backfill_score"] = backfill_map[rid]
|
||||
logger.info(
|
||||
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
|
||||
f"{len(backfill_map)} got BM25 scores"
|
||||
)
|
||||
|
||||
reranked = self._rerank(kw_results, emb_results, limit)
|
||||
|
||||
memories = []
|
||||
content_parts = []
|
||||
for record in reranked:
|
||||
fmt = self._format_result(record)
|
||||
fmt["score"] = round(record.get("content_score", 0), 4)
|
||||
memories.append(fmt)
|
||||
content_parts.append(self._build_content_text(fmt))
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] {len(memories)} results after rerank "
|
||||
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
|
||||
)
|
||||
return {
|
||||
"memories": memories,
|
||||
"content": "\n\n".join(content_parts),
|
||||
"keyword_raw": len(kw_results),
|
||||
"embedding_raw": len(emb_results),
|
||||
}
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
async def _bm25_backfill(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
query: str,
|
||||
target_ids: set,
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
对指定 id 集合补查全文索引 BM25 score。
|
||||
|
||||
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
|
||||
"""
|
||||
escaped = escape_lucene_query(query)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
try:
|
||||
r = await search_perceptual(
|
||||
connector=connector, query=escaped,
|
||||
end_user_id=self.end_user_id,
|
||||
limit=limit * 5, # 多查一些以提高命中率
|
||||
)
|
||||
all_hits = r.get("perceptuals", [])
|
||||
return [h for h in all_hits if h.get("id") in target_ids]
|
||||
except Exception as e:
|
||||
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
|
||||
return []
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
keywords: List[str],
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
|
||||
seen_ids: set = set()
|
||||
all_results: List[dict] = []
|
||||
|
||||
async def _one(kw: str):
|
||||
escaped = escape_lucene_query(kw)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
r = await search_perceptual(
|
||||
connector=connector, query=escaped,
|
||||
end_user_id=self.end_user_id, limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
|
||||
tasks = [_one(kw) for kw in keywords[:10]]
|
||||
batch = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in batch:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
|
||||
continue
|
||||
for rec in result:
|
||||
rid = rec.get("id", "")
|
||||
if rid and rid not in seen_ids:
|
||||
seen_ids.add(rid)
|
||||
all_results.append(rec)
|
||||
|
||||
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
|
||||
return all_results[:limit]
|
||||
|
||||
async def _embedding_search(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
|
||||
try:
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
with get_db_context() as db:
|
||||
cfg = MemoryConfigService(db).get_embedder_config(
|
||||
str(self.memory_config.embedding_model_id)
|
||||
)
|
||||
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
|
||||
|
||||
r = await search_perceptual_by_embedding(
|
||||
connector=connector, embedder_client=client,
|
||||
query_text=query_text, end_user_id=self.end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
except Exception as e:
|
||||
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
|
||||
return []
|
||||
|
||||
def _rerank(
|
||||
self,
|
||||
keyword_results: List[dict],
|
||||
embedding_results: List[dict],
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""BM25 + embedding 融合排序。
|
||||
|
||||
对 embedding 结果中带有 bm25_backfill_score 的条目,
|
||||
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
|
||||
"""
|
||||
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
|
||||
emb_backfill_items = []
|
||||
for item in embedding_results:
|
||||
backfill_score = item.get("bm25_backfill_score")
|
||||
if backfill_score is not None and item.get("id"):
|
||||
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
|
||||
|
||||
# 合并后统一归一化 BM25 scores
|
||||
all_bm25_items = keyword_results + emb_backfill_items
|
||||
all_bm25_items = self._normalize_scores(all_bm25_items)
|
||||
|
||||
# 建立 id -> normalized BM25 score 的映射
|
||||
bm25_norm_map: Dict[str, float] = {}
|
||||
for item in all_bm25_items:
|
||||
item_id = item.get("id", "")
|
||||
if item_id:
|
||||
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
|
||||
|
||||
# 归一化 embedding scores
|
||||
embedding_results = self._normalize_scores(embedding_results)
|
||||
|
||||
# 合并
|
||||
combined: Dict[str, dict] = {}
|
||||
for item in keyword_results:
|
||||
item_id = item.get("id", "")
|
||||
if not item_id:
|
||||
continue
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = 0.0
|
||||
|
||||
for item in embedding_results:
|
||||
item_id = item.get("id", "")
|
||||
if not item_id:
|
||||
continue
|
||||
if item_id in combined:
|
||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
for item in combined.values():
|
||||
bm25 = float(item.get("bm25_score", 0) or 0)
|
||||
emb = float(item.get("embedding_score", 0) or 0)
|
||||
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
|
||||
|
||||
results = list(combined.values())
|
||||
before = len(results)
|
||||
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
|
||||
results.sort(key=lambda x: x["content_score"], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
|
||||
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
|
||||
)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
|
||||
"""Z-score + sigmoid 归一化。"""
|
||||
if not items:
|
||||
return items
|
||||
scores = [float(it.get(field, 0) or 0) for it in items]
|
||||
if len(scores) <= 1:
|
||||
for it in items:
|
||||
it[f"normalized_{field}"] = 1.0
|
||||
return items
|
||||
mean = sum(scores) / len(scores)
|
||||
var = sum((s - mean) ** 2 for s in scores) / len(scores)
|
||||
std = math.sqrt(var)
|
||||
if std == 0:
|
||||
for it in items:
|
||||
it[f"normalized_{field}"] = 1.0
|
||||
else:
|
||||
for it, s in zip(items, scores):
|
||||
z = (s - mean) / std
|
||||
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def _format_result(record: dict) -> dict:
|
||||
return {
|
||||
"id": record.get("id", ""),
|
||||
"perceptual_type": record.get("perceptual_type", ""),
|
||||
"file_name": record.get("file_name", ""),
|
||||
"file_path": record.get("file_path", ""),
|
||||
"summary": record.get("summary", ""),
|
||||
"topic": record.get("topic", ""),
|
||||
"domain": record.get("domain", ""),
|
||||
"keywords": record.get("keywords", []),
|
||||
"created_at": str(record.get("created_at", "")),
|
||||
"file_type": record.get("file_type", ""),
|
||||
"score": record.get("score", 0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_content_text(formatted: dict) -> str:
|
||||
parts = []
|
||||
if formatted["summary"]:
|
||||
parts.append(formatted["summary"])
|
||||
if formatted["topic"]:
|
||||
parts.append(f"[主题: {formatted['topic']}]")
|
||||
if formatted["keywords"]:
|
||||
kw_list = formatted["keywords"]
|
||||
if isinstance(kw_list, list):
|
||||
parts.append(f"[关键词: {', '.join(kw_list)}]")
|
||||
if formatted["file_name"]:
|
||||
parts.append(f"[文件: {formatted['file_name']}]")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
|
||||
"""Extract search keywords from problem extension results."""
|
||||
keywords = []
|
||||
context = problem_extension.get("context", {})
|
||||
if isinstance(context, dict):
|
||||
for original_q, extended_qs in context.items():
|
||||
keywords.append(original_q)
|
||||
if isinstance(extended_qs, list):
|
||||
keywords.extend(extended_qs)
|
||||
return keywords
|
||||
|
||||
|
||||
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
|
||||
"""
|
||||
LangGraph node: perceptual memory retrieval.
|
||||
|
||||
Uses PerceptualSearchService to run keyword + embedding search with
|
||||
BM25 fusion reranking, then writes results to state['perceptual_data'].
|
||||
"""
|
||||
end_user_id = state.get("end_user_id", "")
|
||||
problem_extension = state.get("problem_extension", {})
|
||||
original_query = state.get("data", "")
|
||||
memory_config = state.get("memory_config", None)
|
||||
|
||||
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
|
||||
|
||||
keywords = _extract_keywords_from_problems(problem_extension)
|
||||
if not keywords:
|
||||
keywords = [original_query] if original_query else []
|
||||
|
||||
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
|
||||
|
||||
service = PerceptualSearchService(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
search_result = await service.search(
|
||||
query=original_query,
|
||||
keywords=keywords,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
result = {
|
||||
"memories": search_result["memories"],
|
||||
"content": search_result["content"],
|
||||
"_intermediate": {
|
||||
"type": "perceptual_retrieve",
|
||||
"title": "感知记忆检索",
|
||||
"data": search_result["memories"],
|
||||
"query": original_query,
|
||||
"result_count": len(search_result["memories"]),
|
||||
},
|
||||
}
|
||||
return {"perceptual_data": result}
|
||||
@@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
print(time.time() - start)
|
||||
result = {
|
||||
"context": aggregated_dict,
|
||||
"original": data,
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||
PerceptualSearchService,
|
||||
)
|
||||
from app.core.memory.agent.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
@@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
if storage_type != "rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
|
||||
|
||||
async def _perceptual_search():
|
||||
service = PerceptualSearchService(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
return await service.search(query=data, limit=5)
|
||||
|
||||
hybrid_task = SearchService().execute_hybrid_search(
|
||||
**search_params,
|
||||
memory_config=memory_config,
|
||||
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
|
||||
expand_communities=False,
|
||||
)
|
||||
perceptual_task = _perceptual_search()
|
||||
|
||||
gather_results = await asyncio.gather(
|
||||
hybrid_task, perceptual_task, return_exceptions=True
|
||||
)
|
||||
hybrid_result = gather_results[0]
|
||||
perceptual_results = gather_results[1]
|
||||
|
||||
# 处理 hybrid search 异常
|
||||
if isinstance(hybrid_result, Exception):
|
||||
raise hybrid_result
|
||||
retrieve_info, question, raw_results = hybrid_result
|
||||
|
||||
# 处理感知记忆结果
|
||||
if isinstance(perceptual_results, Exception):
|
||||
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
|
||||
perceptual_results = []
|
||||
|
||||
# 拼接感知记忆内容到 retrieve_info
|
||||
if perceptual_results and isinstance(perceptual_results, dict):
|
||||
perceptual_content = perceptual_results.get("content", "")
|
||||
if perceptual_content:
|
||||
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
|
||||
count = len(perceptual_results.get("memories", []))
|
||||
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
|
||||
|
||||
# 调试:打印 community 检索结果数量
|
||||
if raw_results and isinstance(raw_results, dict):
|
||||
reranked = raw_results.get('reranked_results', {})
|
||||
@@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"error": str(e)
|
||||
}
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
duration = end - start
|
||||
log_time('检索', duration)
|
||||
return {"summary": summary}
|
||||
|
||||
@@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
retrieve_info_str = list(set(retrieve_info_str))
|
||||
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages = await summary_llm(state, history, retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
aimessages = await summary_llm(
|
||||
state,
|
||||
history,
|
||||
retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2',
|
||||
'retrieve_summary', RetrieveSummaryResponse,
|
||||
"1"
|
||||
)
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState:
|
||||
retrieve_info_str += i + '\n'
|
||||
history = await summary_history(state)
|
||||
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
@@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str += i + '\n'
|
||||
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
|
||||
@@ -15,7 +15,10 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
||||
Problem_Extension,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
||||
retrieve,
|
||||
retrieve_nodes,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||
perceptual_retrieve_node,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||
Input_Summary,
|
||||
@@ -48,13 +51,14 @@ async def make_read_graph():
|
||||
"""
|
||||
try:
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow.add_node("content_input", content_input_node)
|
||||
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
||||
workflow.add_node("Problem_Extension", Problem_Extension)
|
||||
workflow.add_node("Input_Summary", Input_Summary)
|
||||
# workflow.add_node("Retrieve", retrieve_nodes)
|
||||
workflow.add_node("Retrieve", retrieve)
|
||||
workflow.add_node("Retrieve", retrieve_nodes)
|
||||
# workflow.add_node("Retrieve", retrieve)
|
||||
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
|
||||
workflow.add_node("Verify", Verify)
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||
workflow.add_node("Summary", Summary)
|
||||
@@ -65,14 +69,15 @@ async def make_read_graph():
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
|
||||
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
|
||||
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
# Compile workflow
|
||||
@@ -80,7 +85,5 @@ async def make_read_graph():
|
||||
yield graph
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建工作流失败: {e}")
|
||||
logger.error(f"创建工作流失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
@@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
@@ -21,25 +20,6 @@ logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
"""
|
||||
Write messages to RAG storage system
|
||||
|
||||
Combines user and AI messages into a single string format and stores them
|
||||
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for the conversation
|
||||
user_message: User's input message content
|
||||
ai_message: AI's response message content
|
||||
user_rag_memory_id: RAG memory identifier for storage location
|
||||
"""
|
||||
# RAG mode: combine messages into string format (maintain original logic)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
|
||||
|
||||
async def write(
|
||||
storage_type,
|
||||
end_user_id,
|
||||
@@ -118,7 +98,7 @@ async def write(
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
|
||||
|
||||
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
||||
async def term_memory_save(end_user_id, strategy_type, scope):
|
||||
"""
|
||||
Save long-term memory data to database
|
||||
|
||||
@@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
to long-term memory storage.
|
||||
|
||||
Args:
|
||||
long_term_messages: Long-term message data to be saved
|
||||
actual_config_id: Configuration identifier for memory settings
|
||||
end_user_id: User identifier for memory association
|
||||
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
scope: Scope/window size for memory processing
|
||||
"""
|
||||
with get_db_context() as db_session:
|
||||
@@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
if not result:
|
||||
logger.warning(f"No write data found for user {end_user_id}")
|
||||
return
|
||||
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data) == scope:
|
||||
@@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
"""Window-based dialogue processing"""
|
||||
|
||||
|
||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||
"""
|
||||
Process dialogue based on window size and write to Neo4j
|
||||
@@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
langchain_messages: Original message data list
|
||||
scope: Window size determining when to trigger long-term storage
|
||||
"""
|
||||
scope = scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||
is_end_user_id += 1
|
||||
langchain_messages += redis_messages
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_has_history:
|
||||
end_user_visit_count, redis_messages = is_end_user_has_history
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
return
|
||||
end_user_visit_count += 1
|
||||
if end_user_visit_count < scope:
|
||||
redis_messages.extend(langchain_messages)
|
||||
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
|
||||
else:
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = redis_messages
|
||||
redis_messages.extend(langchain_messages)
|
||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||
end_user_id,
|
||||
"",
|
||||
"",
|
||||
None,
|
||||
end_user_id,
|
||||
config_id,
|
||||
formatted_messages
|
||||
write_message_task.delay(
|
||||
end_user_id, # end_user_id: User ID
|
||||
redis_messages, # message: JSON string format message list
|
||||
config_id, # config_id: Configuration ID string
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
||||
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""Time-based memory processing"""
|
||||
count_store.update_sessions_count(end_user_id, 0, [])
|
||||
|
||||
|
||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||
@@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
|
||||
@@ -1,49 +1,25 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from app.db import get_db, get_db_context
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph():
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
workflow.add_edge("save_neo4j", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
yield graph
|
||||
|
||||
|
||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
||||
end_user_id: str = '', scope: int = 6):
|
||||
async def long_term_storage(
|
||||
long_term_type: str,
|
||||
langchain_messages: list,
|
||||
memory_config_id: str,
|
||||
end_user_id: str,
|
||||
scope: int = 6
|
||||
):
|
||||
"""
|
||||
Handle long-term memory storage with different strategies
|
||||
|
||||
@@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l
|
||||
Args:
|
||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||
langchain_messages: List of messages to store
|
||||
memory_config: Memory configuration identifier
|
||||
memory_config_id: Memory configuration identifier
|
||||
end_user_id: User group identifier
|
||||
scope: Scope parameter for chunk-based storage (default: 6)
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
if langchain_messages is None:
|
||||
langchain_messages = []
|
||||
|
||||
write_store.save_session_write(end_user_id, langchain_messages)
|
||||
# 获取数据库会话
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
config_id=memory_config_id, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
||||
# Dialogue window with 6 rounds of conversation
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
"""Time-based strategy"""
|
||||
# Time-based strategy
|
||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
"""Strategy 3: Aggregate judgment"""
|
||||
# Aggregate judgment
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
||||
async def write_long_term(
|
||||
storage_type: str,
|
||||
end_user_id: str,
|
||||
messages: list[dict],
|
||||
user_rag_memory_id: str,
|
||||
actual_config_id: str
|
||||
):
|
||||
"""
|
||||
Write long-term memory with different storage types
|
||||
|
||||
@@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u
|
||||
Args:
|
||||
storage_type: Type of storage (RAG or traditional)
|
||||
end_user_id: User group identifier
|
||||
message_chat: User message content
|
||||
aimessages: AI response messages
|
||||
messages: message list
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_config_id: Actual configuration ID
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
message_content = []
|
||||
for message in messages:
|
||||
message_content.append(f'{message.get("role")}:{message.get("content")}')
|
||||
messages_string = "\n".join(message_content)
|
||||
await write_rag(end_user_id, messages_string, user_rag_memory_id)
|
||||
else:
|
||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
# langchain_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "今天周五去爬山"
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "好耶"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
await long_term_storage(long_term_type=CHUNK,
|
||||
langchain_messages=messages,
|
||||
memory_config_id=actual_config_id,
|
||||
end_user_id=end_user_id,
|
||||
scope=SCOPE)
|
||||
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
@@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||
@@ -31,10 +30,10 @@ def _clean_expand_fields(obj):
|
||||
|
||||
|
||||
async def expand_communities_to_statements(
|
||||
community_results: List[dict],
|
||||
end_user_id: str,
|
||||
existing_content: str = "",
|
||||
limit: int = 10,
|
||||
community_results: List[dict],
|
||||
end_user_id: str,
|
||||
existing_content: str = "",
|
||||
limit: int = 10,
|
||||
) -> Tuple[List[dict], List[str]]:
|
||||
"""
|
||||
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
||||
@@ -76,17 +75,18 @@ async def expand_communities_to_statements(
|
||||
if s.get("statement") and s["statement"] not in existing_lines
|
||||
]
|
||||
cleaned = _clean_expand_fields(expanded_stmts)
|
||||
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||
logger.info(
|
||||
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||
return cleaned, new_texts
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the search service."""
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
|
||||
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
|
||||
"""
|
||||
Extract only meaningful content from search results, dropping all metadata.
|
||||
@@ -107,19 +107,19 @@ class SearchService:
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return str(result)
|
||||
|
||||
|
||||
content_parts = []
|
||||
|
||||
|
||||
# Statements: extract statement field
|
||||
if 'statement' in result and result['statement']:
|
||||
content_parts.append(result['statement'])
|
||||
|
||||
|
||||
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||
is_community = (
|
||||
node_type == "community"
|
||||
or 'member_count' in result
|
||||
or 'core_entities' in result
|
||||
node_type == "community"
|
||||
or 'member_count' in result
|
||||
or 'core_entities' in result
|
||||
)
|
||||
if is_community:
|
||||
name = result.get('name', '')
|
||||
@@ -130,16 +130,16 @@ class SearchService:
|
||||
elif 'content' in result and result['content']:
|
||||
# Summaries / Chunks
|
||||
content_parts.append(result['content'])
|
||||
|
||||
|
||||
# Entities: extract name and fact_summary (commented out in original)
|
||||
# if 'name' in result and result['name']:
|
||||
# content_parts.append(result['name'])
|
||||
# if result.get('fact_summary'):
|
||||
# content_parts.append(result['fact_summary'])
|
||||
|
||||
|
||||
# Return concatenated content or empty string
|
||||
return '\n'.join(content_parts) if content_parts else ""
|
||||
|
||||
|
||||
def clean_query(self, query: str) -> str:
|
||||
"""
|
||||
Clean and escape query text for Lucene.
|
||||
@@ -155,33 +155,33 @@ class SearchService:
|
||||
Cleaned and escaped query string
|
||||
"""
|
||||
q = str(query).strip()
|
||||
|
||||
|
||||
# Remove wrapping quotes
|
||||
if (q.startswith("'") and q.endswith("'")) or (
|
||||
q.startswith('"') and q.endswith('"')
|
||||
q.startswith('"') and q.endswith('"')
|
||||
):
|
||||
q = q[1:-1]
|
||||
|
||||
|
||||
# Remove newlines and carriage returns
|
||||
q = q.replace('\r', ' ').replace('\n', ' ').strip()
|
||||
|
||||
|
||||
# Apply Lucene escaping
|
||||
q = escape_lucene_query(q)
|
||||
|
||||
|
||||
return q
|
||||
|
||||
|
||||
async def execute_hybrid_search(
|
||||
self,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config = None,
|
||||
expand_communities: bool = True,
|
||||
self,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config=None,
|
||||
expand_communities: bool = True,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
@@ -205,10 +205,10 @@ class SearchService:
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries", "communities"]
|
||||
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
|
||||
try:
|
||||
# Execute search
|
||||
answer = await run_hybrid_search(
|
||||
@@ -221,18 +221,18 @@ class SearchService:
|
||||
memory_config=memory_config,
|
||||
rerank_alpha=rerank_alpha
|
||||
)
|
||||
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
# Prioritize summaries as they contain synthesized contextual information
|
||||
answer_list = []
|
||||
|
||||
|
||||
# For hybrid search, use reranked_results
|
||||
if search_type == "hybrid":
|
||||
reranked_results = answer.get('reranked_results', {})
|
||||
|
||||
|
||||
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in reranked_results:
|
||||
category_results = reranked_results[category]
|
||||
@@ -242,7 +242,7 @@ class SearchService:
|
||||
# For keyword or embedding search, results are directly in answer dict
|
||||
# Apply same priority order
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in answer:
|
||||
category_results = answer[category]
|
||||
@@ -261,7 +261,7 @@ class SearchService:
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
answer_list.extend(cleaned_stmts)
|
||||
|
||||
|
||||
# Extract clean content from all results,按类型传入 node_type 区分 community
|
||||
content_list = []
|
||||
for ans in answer_list:
|
||||
@@ -269,19 +269,18 @@ class SearchService:
|
||||
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||
|
||||
|
||||
# Filter out empty strings and join with newlines
|
||||
clean_content = '\n'.join([c for c in content_list if c])
|
||||
|
||||
|
||||
# Log first 200 chars
|
||||
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
|
||||
|
||||
|
||||
# Return raw results if requested
|
||||
if return_raw_results:
|
||||
return clean_content, cleaned_query, answer
|
||||
else:
|
||||
return clean_content, cleaned_query, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Annotated, TypedDict
|
||||
@@ -52,6 +51,7 @@ class ReadState(TypedDict):
|
||||
embedding_id: str
|
||||
memory_config: object # 新增字段用于传递内存配置对象
|
||||
retrieve: dict
|
||||
perceptual_data: dict
|
||||
RetrieveSummary: dict
|
||||
InputSummary: dict
|
||||
verify: dict
|
||||
|
||||
@@ -3,8 +3,9 @@ import uuid
|
||||
from app.core.config import settings
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.agent.utils.redis_base import (
|
||||
serialize_messages,
|
||||
serialize_messages,
|
||||
deserialize_messages,
|
||||
fix_encoding,
|
||||
format_session_data,
|
||||
@@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import (
|
||||
get_current_timestamp
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisWriteStore:
|
||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -66,10 +67,10 @@ class RedisWriteStore:
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session_write] 保存会话失败: {e}")
|
||||
logger.error(f"[save_session_write] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||
@@ -99,7 +100,7 @@ class RedisWriteStore:
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
@@ -108,16 +109,16 @@ class RedisWriteStore:
|
||||
"sessionid": session_id,
|
||||
"messages": fix_encoding(data.get('messages', ''))
|
||||
})
|
||||
|
||||
|
||||
if not results:
|
||||
return False
|
||||
|
||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
|
||||
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||
logger.error(f"[get_session_by_userid] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
通过 end_user_id 获取所有 write 类型的会话数据
|
||||
@@ -144,7 +145,7 @@ class RedisWriteStore:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
@@ -158,12 +159,12 @@ class RedisWriteStore:
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == end_user_id:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
|
||||
|
||||
# 构建完整的会话信息
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
@@ -173,23 +174,21 @@ class RedisWriteStore:
|
||||
"starttime": data.get('starttime', '')
|
||||
}
|
||||
results.append(session_info)
|
||||
|
||||
|
||||
if not results:
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
return False
|
||||
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
minutes: int = 5) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||
@@ -203,11 +202,11 @@ class RedisWriteStore:
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
@@ -221,7 +220,7 @@ class RedisWriteStore:
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid and data.get('starttime'):
|
||||
# write 类型没有 aimessages,所以 Answer 为空
|
||||
@@ -230,15 +229,14 @@ class RedisWriteStore:
|
||||
"Answer": "",
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
|
||||
|
||||
# 根据时间范围过滤
|
||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||
# 排序并移除时间字段
|
||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||
print(result_items)
|
||||
result_items = sort_and_limit_results(filtered_items)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
@@ -258,7 +256,7 @@ class RedisWriteStore:
|
||||
|
||||
class RedisCountStore:
|
||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -278,7 +276,7 @@ class RedisCountStore:
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
self.uuid = session_id
|
||||
|
||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||
"""
|
||||
@@ -295,26 +293,26 @@ class RedisCountStore:
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"id": self.uuid,
|
||||
"end_user_id": end_user_id,
|
||||
"count": int(count),
|
||||
"messages": serialize_messages(messages),
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||
|
||||
|
||||
# 创建索引:end_user_id -> session_id 映射
|
||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||
|
||||
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
|
||||
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
|
||||
"""
|
||||
通过 end_user_id 查询访问次数统计
|
||||
|
||||
@@ -327,7 +325,7 @@ class RedisCountStore:
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
@@ -335,35 +333,40 @@ class RedisCountStore:
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
|
||||
# 直接获取数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
|
||||
if not data:
|
||||
# 索引存在但数据不存在,清理索引
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
|
||||
|
||||
count = data.get('count')
|
||||
messages_str = data.get('messages')
|
||||
|
||||
|
||||
if count is not None:
|
||||
messages = deserialize_messages(messages_str)
|
||||
return [int(count), messages]
|
||||
|
||||
messages: list[dict] = deserialize_messages(messages_str)
|
||||
return int(count), messages
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[get_sessions_count] 查询失败: {e}")
|
||||
logger.error(f"[get_sessions_count] 查询失败: {e}")
|
||||
return False
|
||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||
messages: Any) -> bool:
|
||||
|
||||
def update_sessions_count(
|
||||
self,
|
||||
end_user_id: str,
|
||||
new_count: int,
|
||||
messages: Any
|
||||
) -> bool:
|
||||
"""
|
||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||
|
||||
@@ -378,39 +381,39 @@ class RedisCountStore:
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
# 索引键类型错误,删除并返回 False
|
||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
self.r.delete(index_key)
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
|
||||
if not session_id:
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
|
||||
|
||||
# 直接更新数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
messages_str = serialize_messages(messages)
|
||||
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, 'count', int(new_count))
|
||||
pipe.hset(key, 'count', str(new_count))
|
||||
pipe.hset(key, 'messages', messages_str)
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
|
||||
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[update_sessions_count] 更新失败: {e}")
|
||||
logger.debug(f"[update_sessions_count] 更新失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_count_sessions(self) -> int:
|
||||
@@ -428,7 +431,7 @@ class RedisCountStore:
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis 会话存储类,用于管理会话数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -451,9 +454,9 @@ class RedisSessionStore:
|
||||
self.uudi = session_id
|
||||
|
||||
# ==================== 写入操作 ====================
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
|
||||
@@ -483,14 +486,14 @@ class RedisSessionStore:
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session] 保存会话失败: {e}")
|
||||
logger.error(f"[save_session] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ==================== 读取操作 ====================
|
||||
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取一条会话数据
|
||||
@@ -520,8 +523,8 @@ class RedisSessionStore:
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||
|
||||
@@ -535,10 +538,10 @@ class RedisSessionStore:
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
@@ -556,21 +559,21 @@ class RedisSessionStore:
|
||||
continue
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配或完全匹配 sessionid
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append(format_session_data(data, include_time=True))
|
||||
|
||||
|
||||
# 排序、限制数量并移除时间字段
|
||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
# ==================== 更新操作 ====================
|
||||
|
||||
|
||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||
"""
|
||||
更新单个字段
|
||||
@@ -591,7 +594,7 @@ class RedisSessionStore:
|
||||
return bool(results[0])
|
||||
|
||||
# ==================== 删除操作 ====================
|
||||
|
||||
|
||||
def delete_session(self, session_id: str) -> int:
|
||||
"""
|
||||
删除单条会话
|
||||
@@ -632,7 +635,7 @@ class RedisSessionStore:
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
logger.debug("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 批量获取所有数据
|
||||
@@ -678,7 +681,7 @@ class RedisSessionStore:
|
||||
deleted_count += len(batch)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from dotenv import load_dotenv
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||
memory_summary_generation
|
||||
@@ -152,6 +153,24 @@ async def write(
|
||||
# Step 3: Save all data to Neo4j database
|
||||
step_start = time.time()
|
||||
|
||||
# Neo4j 写入前:清洗用户/AI助手实体之间的别名交叉污染
|
||||
# 从 Neo4j 查询已有的 AI 助手别名,与本轮实体中的 AI 助手别名合并,
|
||||
# 确保用户实体的 aliases 不包含 AI 助手的名字
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
clean_cross_role_aliases,
|
||||
fetch_neo4j_assistant_aliases,
|
||||
)
|
||||
neo4j_assistant_aliases = set()
|
||||
if all_entity_nodes:
|
||||
_eu_id = all_entity_nodes[0].end_user_id
|
||||
if _eu_id:
|
||||
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(neo4j_connector, _eu_id)
|
||||
clean_cross_role_aliases(all_entity_nodes, external_assistant_aliases=neo4j_assistant_aliases)
|
||||
logger.info(f"Neo4j 写入前别名清洗完成,AI助手别名排除集大小: {len(neo4j_assistant_aliases)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Neo4j 写入前别名清洗失败(不影响主流程): {e}")
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 秒
|
||||
@@ -173,15 +192,37 @@ async def write(
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
|
||||
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||
if all_entity_nodes:
|
||||
end_user_id = all_entity_nodes[0].end_user_id
|
||||
|
||||
# Neo4j 写入完成后,用 PgSQL 权威 aliases 覆盖 Neo4j 用户实体
|
||||
try:
|
||||
from app.repositories.end_user_info_repository import EndUserInfoRepository
|
||||
if end_user_id:
|
||||
with get_db_context() as db_session:
|
||||
info = EndUserInfoRepository(db_session).get_by_end_user_id(uuid.UUID(end_user_id))
|
||||
pg_aliases = info.aliases if info and info.aliases else []
|
||||
if info is not None:
|
||||
# 将 Python 侧占位名集合作为参数传入,避免 Cypher 硬编码
|
||||
placeholder_names = list(_USER_PLACEHOLDER_NAMES)
|
||||
await neo4j_connector.execute_query(
|
||||
"""
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $placeholder_names
|
||||
SET e.aliases = $aliases
|
||||
""",
|
||||
end_user_id=end_user_id, aliases=pg_aliases,
|
||||
placeholder_names=placeholder_names,
|
||||
)
|
||||
logger.info(f"[AliasSync] Neo4j 用户实体 aliases 已用 PgSQL 权威源覆盖: {pg_aliases}")
|
||||
except Exception as sync_err:
|
||||
logger.warning(f"[AliasSync] PgSQL→Neo4j aliases 同步失败(不影响主流程): {sync_err}")
|
||||
|
||||
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||
try:
|
||||
from app.tasks import run_incremental_clustering
|
||||
|
||||
end_user_id = all_entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in all_entity_nodes]
|
||||
|
||||
# 异步提交 Celery 任务
|
||||
task = run_incremental_clustering.apply_async(
|
||||
kwargs={
|
||||
"end_user_id": end_user_id,
|
||||
@@ -189,7 +230,6 @@ async def write(
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
},
|
||||
# 设置任务优先级(低优先级,不影响主业务)
|
||||
priority=3,
|
||||
)
|
||||
logger.info(
|
||||
@@ -197,7 +237,6 @@ async def write(
|
||||
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
|
||||
)
|
||||
except Exception as e:
|
||||
# 聚类任务提交失败不影响主流程
|
||||
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
|
||||
|
||||
break
|
||||
|
||||
@@ -56,7 +56,7 @@ class LLMClient(ABC):
|
||||
self.max_retries = self.config.max_retries
|
||||
self.timeout = self.config.timeout
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"初始化 LLM 客户端: provider={self.provider}, "
|
||||
f"model={self.model_name}, max_retries={self.max_retries}"
|
||||
)
|
||||
|
||||
@@ -58,6 +58,14 @@ from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
|
||||
# User metadata models
|
||||
from app.core.memory.models.metadata_models import (
|
||||
UserMetadata,
|
||||
UserMetadataProfile,
|
||||
MetadataExtractionResponse,
|
||||
MetadataFieldChange,
|
||||
)
|
||||
|
||||
# Ontology scenario models (LLM extracted from scenarios)
|
||||
from app.core.memory.models.ontology_scenario_models import (
|
||||
OntologyClass,
|
||||
@@ -124,6 +132,10 @@ __all__ = [
|
||||
"Entity",
|
||||
"Triplet",
|
||||
"TripletExtractionResponse",
|
||||
"UserMetadata",
|
||||
"UserMetadataProfile",
|
||||
"MetadataExtractionResponse",
|
||||
"MetadataFieldChange",
|
||||
# Ontology models
|
||||
"OntologyClass",
|
||||
"OntologyExtractionResponse",
|
||||
|
||||
@@ -364,12 +364,14 @@ class ChunkNode(Node):
|
||||
Attributes:
|
||||
dialog_id: ID of the parent dialog
|
||||
content: The text content of the chunk
|
||||
speaker: Speaker identifier ('user' or 'assistant')
|
||||
chunk_embedding: Optional embedding vector for the chunk
|
||||
sequence_number: Order of this chunk within the dialog
|
||||
metadata: Additional chunk metadata as key-value pairs
|
||||
"""
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
content: str = Field(..., description="The text content of the chunk")
|
||||
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
sequence_number: int = Field(..., description="Order of this chunk within the dialog")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional chunk metadata")
|
||||
|
||||
63
api/app/core/memory/models/metadata_models.py
Normal file
63
api/app/core/memory/models/metadata_models.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Models for user metadata extraction.
|
||||
|
||||
Independent from triplet_models.py - these models are used by the
|
||||
standalone metadata extraction pipeline (post-dedup async Celery task).
|
||||
"""
|
||||
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class UserMetadataProfile(BaseModel):
|
||||
"""用户画像信息"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
role: List[str] = Field(default_factory=list, description="用户职业或角色")
|
||||
domain: List[str] = Field(default_factory=list, description="用户所在领域")
|
||||
expertise: List[str] = Field(
|
||||
default_factory=list, description="用户擅长的技能或工具"
|
||||
)
|
||||
interests: List[str] = Field(
|
||||
default_factory=list, description="用户关注的话题或领域标签"
|
||||
)
|
||||
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
"""用户元数据顶层结构"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
profile: UserMetadataProfile = Field(default_factory=UserMetadataProfile)
|
||||
|
||||
|
||||
class MetadataFieldChange(BaseModel):
|
||||
"""单个元数据字段的变更操作"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
field_path: str = Field(
|
||||
description="字段路径,用点号分隔,如 'profile.role'、'profile.expertise'"
|
||||
)
|
||||
action: Literal["set", "remove"] = Field(
|
||||
description="操作类型:'set' 表示新增或修改,'remove' 表示移除"
|
||||
)
|
||||
value: Optional[str] = Field(
|
||||
default=None,
|
||||
description="字段的新值(action='set' 时必填)。标量字段直接填值,列表字段填单个要新增的元素"
|
||||
)
|
||||
|
||||
|
||||
class MetadataExtractionResponse(BaseModel):
|
||||
"""元数据提取 LLM 响应结构(增量模式)"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
metadata_changes: List[MetadataFieldChange] = Field(
|
||||
default_factory=list,
|
||||
description="元数据的增量变更列表,每项描述一个字段的新增、修改或移除操作",
|
||||
)
|
||||
aliases_to_add: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="本次新发现的用户别名(用户自我介绍或他人对用户的称呼)",
|
||||
)
|
||||
aliases_to_remove: List[str] = Field(
|
||||
default_factory=list, description="用户明确否认的别名(如'我不叫XX了')"
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
@@ -6,7 +5,6 @@ import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -23,7 +21,7 @@ from app.core.memory.utils.config.config_utils import (
|
||||
)
|
||||
from app.core.memory.utils.data.text_utils import extract_plain_query
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
# from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
@@ -43,6 +41,7 @@ load_dotenv()
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
def _parse_datetime(value: Any) -> Optional[datetime]:
|
||||
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
|
||||
if value is None:
|
||||
@@ -75,7 +74,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
if score_field == "activation_value" and score is None:
|
||||
scores.append(None) # 保持 None,稍后特殊处理
|
||||
continue
|
||||
|
||||
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
scores.append(float(score))
|
||||
else:
|
||||
@@ -83,10 +82,10 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
|
||||
if not scores:
|
||||
return results
|
||||
|
||||
|
||||
# 过滤掉 None 值,只对有效分数进行归一化
|
||||
valid_scores = [s for s in scores if s is not None]
|
||||
|
||||
|
||||
if not valid_scores:
|
||||
# 所有分数都是 None,不进行归一化
|
||||
for item in results:
|
||||
@@ -94,7 +93,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
item[f"normalized_{score_field}"] = None
|
||||
return results
|
||||
|
||||
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
||||
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
||||
for item, score in zip(results, scores):
|
||||
if score_field in item or score_field == "activation_value":
|
||||
if score is None:
|
||||
@@ -132,7 +131,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
return results
|
||||
|
||||
|
||||
|
||||
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Remove duplicate items from search results based on content.
|
||||
@@ -150,52 +148,53 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
seen_ids = set()
|
||||
seen_content = set()
|
||||
deduplicated = []
|
||||
|
||||
|
||||
for item in items:
|
||||
# Try multiple ID fields to identify unique items
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
|
||||
|
||||
# Extract content from various possible fields
|
||||
content = (
|
||||
item.get("text") or
|
||||
item.get("content") or
|
||||
item.get("statement") or
|
||||
item.get("name") or
|
||||
""
|
||||
item.get("text") or
|
||||
item.get("content") or
|
||||
item.get("statement") or
|
||||
item.get("name") or
|
||||
""
|
||||
)
|
||||
|
||||
|
||||
# Normalize content for comparison (strip whitespace and lowercase)
|
||||
normalized_content = str(content).strip().lower() if content else ""
|
||||
|
||||
|
||||
# Check if we've seen this ID or content before
|
||||
is_duplicate = False
|
||||
|
||||
|
||||
if item_id and item_id in seen_ids:
|
||||
is_duplicate = True
|
||||
elif normalized_content and normalized_content in seen_content:
|
||||
# Only check content duplication if content is not empty
|
||||
is_duplicate = True
|
||||
|
||||
|
||||
if not is_duplicate:
|
||||
# Mark as seen
|
||||
if item_id:
|
||||
seen_ids.add(item_id)
|
||||
if normalized_content: # Only track non-empty content
|
||||
seen_content.add(normalized_content)
|
||||
|
||||
|
||||
deduplicated.append(item)
|
||||
|
||||
|
||||
return deduplicated
|
||||
|
||||
|
||||
def rerank_with_activation(
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10,
|
||||
forgetting_config: ForgettingEngineConfig | None = None,
|
||||
activation_boost_factor: float = 0.8,
|
||||
now: datetime | None = None,
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10,
|
||||
forgetting_config: ForgettingEngineConfig | None = None,
|
||||
activation_boost_factor: float = 0.8,
|
||||
now: datetime | None = None,
|
||||
content_score_threshold: float = 0.5,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
两阶段排序:先按内容相关性筛选,再按激活值排序。
|
||||
@@ -222,6 +221,8 @@ def rerank_with_activation(
|
||||
forgetting_config: 遗忘引擎配置(当前未使用)
|
||||
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
|
||||
now: 当前时间(用于遗忘计算)
|
||||
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score),
|
||||
低于此阈值的结果会被过滤。默认 0.5。
|
||||
|
||||
返回:
|
||||
带评分元数据的重排序结果,按 final_score 排序
|
||||
@@ -229,26 +230,26 @@ def rerank_with_activation(
|
||||
# 验证权重范围
|
||||
if not (0 <= alpha <= 1):
|
||||
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
|
||||
|
||||
|
||||
# 初始化遗忘引擎(如果需要)
|
||||
engine = None
|
||||
if forgetting_config:
|
||||
engine = ForgettingEngine(forgetting_config)
|
||||
now_dt = now or datetime.now()
|
||||
|
||||
|
||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
|
||||
# 步骤 1: 归一化分数
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
|
||||
# 步骤 2: 按 ID 合并结果(去重)
|
||||
combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
# 添加关键词结果
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
@@ -257,7 +258,7 @@ def rerank_with_activation(
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
combined_items[item_id]["embedding_score"] = 0 # 默认值
|
||||
|
||||
|
||||
# 添加或更新向量嵌入结果
|
||||
for item in embedding_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
@@ -271,18 +272,18 @@ def rerank_with_activation(
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0 # 默认值
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
|
||||
# 步骤 3: 归一化激活度分数
|
||||
# 为所有项准备激活度值列表
|
||||
items_list = list(combined_items.values())
|
||||
items_list = normalize_scores(items_list, "activation_value")
|
||||
|
||||
|
||||
# 更新 combined_items 中的归一化激活度分数
|
||||
for item in items_list:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id and item_id in combined_items:
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
|
||||
|
||||
|
||||
# 步骤 4: 计算基础分数和最终分数
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
||||
@@ -290,45 +291,45 @@ def rerank_with_activation(
|
||||
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
|
||||
raw_act_norm = item.get("normalized_activation_value")
|
||||
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
|
||||
|
||||
|
||||
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
||||
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
||||
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
||||
base_score = content_score # 第一阶段用内容分数
|
||||
|
||||
|
||||
# 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序)
|
||||
item["activation_score"] = act_norm # 可能为 None
|
||||
item["content_score"] = content_score
|
||||
item["base_score"] = base_score
|
||||
|
||||
|
||||
# 步骤 5: 应用遗忘曲线(可选)
|
||||
if engine:
|
||||
# 计算受激活度影响的记忆强度
|
||||
importance = float(item.get("importance_score", 0.5) or 0.5)
|
||||
|
||||
|
||||
# 获取 activation_value
|
||||
activation_val = item.get("activation_value")
|
||||
|
||||
|
||||
# 只对有激活值的节点应用遗忘曲线
|
||||
if activation_val is not None and isinstance(activation_val, (int, float)):
|
||||
activation_val = float(activation_val)
|
||||
|
||||
|
||||
# 计算记忆强度:importance_score × (1 + activation_value × boost_factor)
|
||||
memory_strength = importance * (1 + activation_val * activation_boost_factor)
|
||||
|
||||
|
||||
# 计算经过的时间(天数)
|
||||
dt = _parse_datetime(item.get("created_at"))
|
||||
if dt is None:
|
||||
time_elapsed_days = 0.0
|
||||
else:
|
||||
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
|
||||
# 获取遗忘权重
|
||||
forgetting_weight = engine.calculate_weight(
|
||||
time_elapsed=time_elapsed_days,
|
||||
memory_strength=memory_strength
|
||||
)
|
||||
|
||||
|
||||
# 应用到基础分数
|
||||
item["forgetting_weight"] = forgetting_weight
|
||||
item["final_score"] = base_score * forgetting_weight
|
||||
@@ -338,7 +339,7 @@ def rerank_with_activation(
|
||||
else:
|
||||
# 不使用遗忘曲线
|
||||
item["final_score"] = base_score
|
||||
|
||||
|
||||
# 步骤 6: 两阶段排序和限制
|
||||
# 第一阶段:按内容相关性(base_score)排序,取 Top-K
|
||||
first_stage_limit = limit * 3 # 可配置,取3倍候选
|
||||
@@ -347,11 +348,11 @@ def rerank_with_activation(
|
||||
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
|
||||
reverse=True
|
||||
)[:first_stage_limit]
|
||||
|
||||
|
||||
# 第二阶段:分离有激活值和无激活值的节点
|
||||
items_with_activation = []
|
||||
items_without_activation = []
|
||||
|
||||
|
||||
for item in first_stage_sorted:
|
||||
activation_score = item.get("activation_score")
|
||||
# 检查是否有有效的激活值(不是 None)
|
||||
@@ -359,14 +360,14 @@ def rerank_with_activation(
|
||||
items_with_activation.append(item)
|
||||
else:
|
||||
items_without_activation.append(item)
|
||||
|
||||
|
||||
# 优先按激活值排序有激活值的节点
|
||||
sorted_with_activation = sorted(
|
||||
items_with_activation,
|
||||
key=lambda x: float(x.get("activation_score", 0) or 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
|
||||
# 如果有激活值的节点不足 limit,用无激活值的节点补充
|
||||
if len(sorted_with_activation) < limit:
|
||||
needed = limit - len(sorted_with_activation)
|
||||
@@ -374,7 +375,7 @@ def rerank_with_activation(
|
||||
sorted_items = sorted_with_activation + items_without_activation[:needed]
|
||||
else:
|
||||
sorted_items = sorted_with_activation[:limit]
|
||||
|
||||
|
||||
# 两阶段排序完成,更新 final_score 以反映实际排序依据
|
||||
# Stage 1: 按 content_score 筛选候选(已完成)
|
||||
# Stage 2: 按 activation_score 排序(已完成)
|
||||
@@ -390,16 +391,29 @@ def rerank_with_activation(
|
||||
else:
|
||||
# 无激活值:使用内容相关性分数
|
||||
item["final_score"] = item.get("base_score", 0)
|
||||
|
||||
# 最终去重确保没有重复项
|
||||
|
||||
if content_score_threshold > 0:
|
||||
before_count = len(sorted_items)
|
||||
sorted_items = [
|
||||
item for item in sorted_items
|
||||
if float(item.get("content_score", 0) or 0) >= content_score_threshold
|
||||
]
|
||||
filtered_count = before_count - len(sorted_items)
|
||||
if filtered_count > 0:
|
||||
logger.info(
|
||||
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
|
||||
f"items below content_score_threshold={content_score_threshold}"
|
||||
)
|
||||
|
||||
sorted_items = _deduplicate_results(sorted_items)
|
||||
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
|
||||
return reranked
|
||||
|
||||
|
||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
|
||||
log_file: str = None):
|
||||
"""Log search query information using the logger.
|
||||
|
||||
Args:
|
||||
@@ -412,7 +426,7 @@ def log_search_query(query_text: str, search_type: str, end_user_id: str | None,
|
||||
"""
|
||||
# Ensure the query text is plain and clean before logging
|
||||
cleaned_query = extract_plain_query(query_text)
|
||||
|
||||
|
||||
# Log using the standard logger
|
||||
logger.info(
|
||||
f"Search query: query='{cleaned_query}', type={search_type}, "
|
||||
@@ -439,8 +453,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
|
||||
|
||||
|
||||
def apply_reranker_placeholder(
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Placeholder for a cross-encoder reranker.
|
||||
@@ -483,7 +497,7 @@ def apply_reranker_placeholder(
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Apply LLM-based reranking to search results.
|
||||
|
||||
|
||||
# Args:
|
||||
# results: Search results organized by category
|
||||
# query_text: Original search query
|
||||
@@ -491,7 +505,7 @@ def apply_reranker_placeholder(
|
||||
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
|
||||
# top_k: Maximum number of items to rerank per category
|
||||
# batch_size: Number of items to process concurrently
|
||||
|
||||
|
||||
# Returns:
|
||||
# Reranked results with final_score and reranker_model fields
|
||||
# """
|
||||
@@ -501,18 +515,18 @@ def apply_reranker_placeholder(
|
||||
# # except Exception as e:
|
||||
# # logger.debug(f"Failed to load reranker config: {e}")
|
||||
# # rc = {}
|
||||
|
||||
|
||||
# # Check if reranking is enabled
|
||||
# enabled = rc.get("enabled", False)
|
||||
# if not enabled:
|
||||
# logger.debug("LLM reranking is disabled in configuration")
|
||||
# return results
|
||||
|
||||
|
||||
# # Load configuration parameters with defaults
|
||||
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
|
||||
# top_k = top_k if top_k is not None else rc.get("top_k", 20)
|
||||
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
|
||||
|
||||
|
||||
# # Initialize reranker client if not provided
|
||||
# if reranker_client is None:
|
||||
# try:
|
||||
@@ -520,10 +534,10 @@ def apply_reranker_placeholder(
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
|
||||
# return results
|
||||
|
||||
|
||||
# # Get model name for metadata
|
||||
# model_name = getattr(reranker_client, 'model_name', 'unknown')
|
||||
|
||||
|
||||
# # Process each category
|
||||
# reranked_results = {}
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
@@ -531,38 +545,38 @@ def apply_reranker_placeholder(
|
||||
# if not items:
|
||||
# reranked_results[category] = []
|
||||
# continue
|
||||
|
||||
|
||||
# # Select top K items by combined_score for reranking
|
||||
# sorted_items = sorted(
|
||||
# items,
|
||||
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
|
||||
# reverse=True
|
||||
# )
|
||||
|
||||
|
||||
# top_items = sorted_items[:top_k]
|
||||
# remaining_items = sorted_items[top_k:]
|
||||
|
||||
|
||||
# # Extract text content from each item
|
||||
# def extract_text(item: Dict[str, Any]) -> str:
|
||||
# """Extract text content from a result item."""
|
||||
# # Try different text fields based on category
|
||||
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
|
||||
# return str(text).strip()
|
||||
|
||||
|
||||
# # Batch items for concurrent processing
|
||||
# batches = []
|
||||
# for i in range(0, len(top_items), batch_size):
|
||||
# batch = top_items[i:i + batch_size]
|
||||
# batches.append(batch)
|
||||
|
||||
|
||||
# # Process batches concurrently
|
||||
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# """Process a batch of items with LLM relevance scoring."""
|
||||
# scored_batch = []
|
||||
|
||||
|
||||
# for item in batch:
|
||||
# item_text = extract_text(item)
|
||||
|
||||
|
||||
# # Skip items with no text
|
||||
# if not item_text:
|
||||
# item_copy = item.copy()
|
||||
@@ -572,7 +586,7 @@ def apply_reranker_placeholder(
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
# continue
|
||||
|
||||
|
||||
# # Create relevance scoring prompt
|
||||
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
|
||||
|
||||
@@ -585,15 +599,15 @@ def apply_reranker_placeholder(
|
||||
# - 1.0 means perfectly relevant
|
||||
|
||||
# Relevance score:"""
|
||||
|
||||
|
||||
# # Send request to LLM
|
||||
# try:
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
# response = await reranker_client.chat(messages)
|
||||
|
||||
|
||||
# # Parse LLM response to extract relevance score
|
||||
# response_text = str(response.content if hasattr(response, 'content') else response).strip()
|
||||
|
||||
|
||||
# # Try to extract a float from the response
|
||||
# try:
|
||||
# # Remove any non-numeric characters except decimal point
|
||||
@@ -608,11 +622,11 @@ def apply_reranker_placeholder(
|
||||
# except (ValueError, AttributeError) as e:
|
||||
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
|
||||
# llm_score = None
|
||||
|
||||
|
||||
# # Calculate final score
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
|
||||
|
||||
# if llm_score is not None:
|
||||
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
|
||||
# item_copy["llm_relevance_score"] = llm_score
|
||||
@@ -620,7 +634,7 @@ def apply_reranker_placeholder(
|
||||
# # Use combined_score as fallback
|
||||
# final_score = combined_score
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
|
||||
|
||||
# item_copy["final_score"] = final_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
@@ -632,14 +646,14 @@ def apply_reranker_placeholder(
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
|
||||
|
||||
# return scored_batch
|
||||
|
||||
|
||||
# # Process all batches concurrently
|
||||
# try:
|
||||
# batch_tasks = [process_batch(batch) for batch in batches]
|
||||
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# # Merge batch results
|
||||
# scored_items = []
|
||||
# for result in batch_results:
|
||||
@@ -647,7 +661,7 @@ def apply_reranker_placeholder(
|
||||
# logger.warning(f"Batch processing failed: {result}")
|
||||
# continue
|
||||
# scored_items.extend(result)
|
||||
|
||||
|
||||
# # Add remaining items (not in top K) with their combined_score as final_score
|
||||
# for item in remaining_items:
|
||||
# item_copy = item.copy()
|
||||
@@ -655,11 +669,11 @@ def apply_reranker_placeholder(
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_items.append(item_copy)
|
||||
|
||||
|
||||
# # Sort all items by final_score in descending order
|
||||
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
|
||||
# reranked_results[category] = scored_items
|
||||
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
|
||||
# # Return original items with combined_score as final_score
|
||||
@@ -668,22 +682,22 @@ def apply_reranker_placeholder(
|
||||
# item["final_score"] = combined_score
|
||||
# item["reranker_model"] = model_name
|
||||
# reranked_results[category] = items
|
||||
|
||||
|
||||
# return reranked_results
|
||||
|
||||
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
end_user_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
end_user_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -699,7 +713,7 @@ async def run_hybrid_search(
|
||||
|
||||
# Clean and normalize the incoming query before use/logging
|
||||
query_text = extract_plain_query(query_text)
|
||||
|
||||
|
||||
# Validate query is not empty after cleaning
|
||||
if not query_text or not query_text.strip():
|
||||
logger.warning("Empty query after cleaning, returning empty results")
|
||||
@@ -716,7 +730,7 @@ async def run_hybrid_search(
|
||||
"error": "Empty query"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Log the search query
|
||||
log_search_query(query_text, search_type, end_user_id, limit, include)
|
||||
|
||||
@@ -732,11 +746,10 @@ async def run_hybrid_search(
|
||||
if search_type in ["keyword", "hybrid"]:
|
||||
# Keyword-based search
|
||||
logger.info("[PERF] Starting keyword search...")
|
||||
keyword_start = time.time()
|
||||
keyword_task = asyncio.create_task(
|
||||
search_graph(
|
||||
connector=connector,
|
||||
q=query_text,
|
||||
query=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include
|
||||
@@ -746,8 +759,7 @@ async def run_hybrid_search(
|
||||
if search_type in ["embedding", "hybrid"]:
|
||||
# Embedding-based search
|
||||
logger.info("[PERF] Starting embedding search...")
|
||||
embedding_start = time.time()
|
||||
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
@@ -758,8 +770,7 @@ async def run_hybrid_search(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
api_key=embedder_config_dict["api_key"],
|
||||
base_url=embedder_config_dict["base_url"],
|
||||
type="llm"
|
||||
base_url=embedder_config_dict["base_url"]
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
|
||||
@@ -769,7 +780,7 @@ async def run_hybrid_search(
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
connector=connector,
|
||||
@@ -789,7 +800,7 @@ async def run_hybrid_search(
|
||||
|
||||
if keyword_task:
|
||||
keyword_results = await keyword_task
|
||||
keyword_latency = time.time() - keyword_start
|
||||
keyword_latency = time.time() - search_start_time
|
||||
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
|
||||
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
|
||||
if search_type == "keyword":
|
||||
@@ -799,7 +810,7 @@ async def run_hybrid_search(
|
||||
|
||||
if embedding_task:
|
||||
embedding_results = await embedding_task
|
||||
embedding_latency = time.time() - embedding_start
|
||||
embedding_latency = time.time() - search_start_time
|
||||
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
|
||||
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
|
||||
if search_type == "embedding":
|
||||
@@ -811,7 +822,8 @@ async def run_hybrid_search(
|
||||
if search_type == "hybrid":
|
||||
results["combined_summary"] = {
|
||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_embedding_results": sum(
|
||||
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
@@ -819,7 +831,7 @@ async def run_hybrid_search(
|
||||
# Apply two-stage reranking with ACTR activation calculation
|
||||
rerank_start = time.time()
|
||||
logger.info("[PERF] Using two-stage reranking with ACTR activation")
|
||||
|
||||
|
||||
# 加载遗忘引擎配置
|
||||
config_start = time.time()
|
||||
try:
|
||||
@@ -830,7 +842,7 @@ async def run_hybrid_search(
|
||||
forgetting_cfg = ForgettingEngineConfig()
|
||||
config_time = time.time() - config_start
|
||||
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
|
||||
|
||||
|
||||
# 统一使用激活度重排序(两阶段:检索 + ACTR计算)
|
||||
rerank_compute_start = time.time()
|
||||
reranked_results = rerank_with_activation(
|
||||
@@ -843,14 +855,14 @@ async def run_hybrid_search(
|
||||
)
|
||||
rerank_compute_time = time.time() - rerank_compute_start
|
||||
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
|
||||
|
||||
|
||||
rerank_latency = time.time() - rerank_start
|
||||
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
|
||||
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
|
||||
|
||||
|
||||
# Optional: apply reranker placeholder if enabled via config
|
||||
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
|
||||
|
||||
|
||||
# Apply LLM reranking if enabled
|
||||
llm_rerank_applied = False
|
||||
# if use_llm_rerank:
|
||||
@@ -863,11 +875,12 @@ async def run_hybrid_search(
|
||||
# logger.info("LLM reranking applied successfully")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"LLM reranking failed: {e}, using previous scores")
|
||||
|
||||
|
||||
results["reranked_results"] = reranked_results
|
||||
results["combined_summary"] = {
|
||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_embedding_results": sum(
|
||||
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat(),
|
||||
@@ -880,17 +893,17 @@ async def run_hybrid_search(
|
||||
# Calculate total latency
|
||||
total_latency = time.time() - search_start_time
|
||||
latency_metrics["total_latency"] = round(total_latency, 4)
|
||||
|
||||
|
||||
# Add latency metrics to results
|
||||
if "combined_summary" in results:
|
||||
results["combined_summary"]["latency_metrics"] = latency_metrics
|
||||
else:
|
||||
results["latency_metrics"] = latency_metrics
|
||||
|
||||
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
|
||||
logger.info("[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
|
||||
logger.info(f"[PERF] =========================================")
|
||||
logger.info("[PERF] =========================================")
|
||||
|
||||
# Sanitize results: drop large/unused fields
|
||||
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs
|
||||
@@ -909,8 +922,10 @@ async def run_hybrid_search(
|
||||
# Log search completion with result count
|
||||
if search_type == "hybrid":
|
||||
result_counts = {
|
||||
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()},
|
||||
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()}
|
||||
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
|
||||
keyword_results.items()},
|
||||
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
|
||||
embedding_results.items()}
|
||||
}
|
||||
else:
|
||||
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
|
||||
@@ -928,12 +943,12 @@ async def run_hybrid_search(
|
||||
|
||||
|
||||
async def search_by_temporal(
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -969,13 +984,13 @@ async def search_by_temporal(
|
||||
|
||||
|
||||
async def search_by_keyword_temporal(
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
@@ -1012,9 +1027,9 @@ async def search_by_keyword_temporal(
|
||||
|
||||
|
||||
async def search_chunk_by_chunk_id(
|
||||
chunk_id: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
chunk_id: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Search for Chunks by chunk_id.
|
||||
@@ -1027,4 +1042,3 @@ async def search_chunk_by_chunk_id(
|
||||
limit=limit
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import asyncio
|
||||
import difflib # 提供字符串相似度计算工具
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
@@ -16,6 +17,8 @@ from app.core.memory.models.graph_models import (
|
||||
)
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 模块级类型统一工具函数
|
||||
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
|
||||
@@ -79,51 +82,38 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
canonical.connect_strength = next(iter(pair))
|
||||
|
||||
# 别名合并(去重保序,使用标准化工具)
|
||||
# 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,去重合并时不修改
|
||||
try:
|
||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||
incoming_name = (getattr(ent, "name", "") or "").strip()
|
||||
|
||||
# 收集所有需要合并的别名
|
||||
all_aliases = []
|
||||
|
||||
# 1. 添加canonical现有的别名
|
||||
existing = getattr(canonical, "aliases", []) or []
|
||||
all_aliases.extend(existing)
|
||||
|
||||
# 2. 添加incoming实体的名称(如果不同于canonical的名称)
|
||||
if incoming_name and incoming_name != canonical_name:
|
||||
all_aliases.append(incoming_name)
|
||||
|
||||
# 3. 添加incoming实体的所有别名
|
||||
incoming = getattr(ent, "aliases", []) or []
|
||||
all_aliases.extend(incoming)
|
||||
|
||||
# 4. 标准化并去重(优先使用alias_utils工具函数)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
# 如果导入失败,使用增强的去重逻辑
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
if canonical_name.lower() not in _USER_PLACEHOLDER_NAMES:
|
||||
incoming_name = (getattr(ent, "name", "") or "").strip()
|
||||
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
|
||||
# 标准化:转小写用于去重判断
|
||||
alias_normalized = alias_stripped.lower()
|
||||
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
# 收集所有需要合并的别名,过滤掉用户占位名避免污染非用户实体
|
||||
all_aliases = list(getattr(canonical, "aliases", []) or [])
|
||||
if incoming_name and incoming_name != canonical_name and incoming_name.lower() not in _USER_PLACEHOLDER_NAMES:
|
||||
all_aliases.append(incoming_name)
|
||||
all_aliases.extend(
|
||||
a for a in (getattr(ent, "aliases", []) or [])
|
||||
if a and a.strip().lower() not in _USER_PLACEHOLDER_NAMES
|
||||
)
|
||||
|
||||
# 排序并赋值
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
alias_normalized = alias_stripped.lower()
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -198,6 +188,161 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 用户和AI助手的占位名称集合(用于名称标准化)
|
||||
_USER_PLACEHOLDER_NAMES = {"用户", "我", "user", "i"}
|
||||
_ASSISTANT_PLACEHOLDER_NAMES = {"ai助手", "助手", "人工智能助手", "智能助手", "智能体", "ai assistant", "assistant"}
|
||||
|
||||
# 标准化后的规范名称和类型
|
||||
_CANONICAL_USER_NAME = "用户"
|
||||
_CANONICAL_USER_TYPE = "用户"
|
||||
_CANONICAL_ASSISTANT_NAME = "AI助手"
|
||||
_CANONICAL_ASSISTANT_TYPE = "Agent"
|
||||
|
||||
# 用户和AI助手的所有可能名称(用于判断实体是否为特殊角色实体)
|
||||
_ALL_USER_NAMES = _USER_PLACEHOLDER_NAMES
|
||||
_ALL_ASSISTANT_NAMES = _ASSISTANT_PLACEHOLDER_NAMES
|
||||
|
||||
|
||||
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为用户实体(name 或 entity_type 匹配)"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _ALL_USER_NAMES or etype == _CANONICAL_USER_TYPE
|
||||
|
||||
|
||||
def _is_assistant_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为AI助手实体(name 或 entity_type 匹配)"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _ALL_ASSISTANT_NAMES or etype == _CANONICAL_ASSISTANT_TYPE
|
||||
|
||||
|
||||
def _would_merge_cross_role(a: ExtractedEntityNode, b: ExtractedEntityNode) -> bool:
|
||||
"""判断两个实体的合并是否会跨越用户/AI助手角色边界。
|
||||
|
||||
用户实体和AI助手实体永远不应该被合并在一起。
|
||||
如果一方是用户实体、另一方是AI助手实体,返回 True(阻止合并)。
|
||||
"""
|
||||
return (
|
||||
(_is_user_entity(a) and _is_assistant_entity(b))
|
||||
or (_is_assistant_entity(a) and _is_user_entity(b))
|
||||
)
|
||||
|
||||
|
||||
def _normalize_special_entity_names(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
) -> None:
|
||||
"""标准化用户和AI助手实体的名称和类型。
|
||||
|
||||
多轮对话中,LLM 对同一角色可能使用不同的名称变体(如"用户"/"我"/"User",
|
||||
"AI助手"/"助手"/"Assistant"),导致精确匹配无法合并。
|
||||
此函数在去重前将这些变体统一为规范名称,并强制绑定 entity_type,确保:
|
||||
- name="用户" 的实体 entity_type 一定为 "用户"
|
||||
- name="AI助手" 的实体 entity_type 一定为 "Agent"
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表(原地修改)
|
||||
"""
|
||||
for ent in entity_nodes:
|
||||
name = (getattr(ent, "name", "") or "").strip()
|
||||
name_lower = name.lower()
|
||||
|
||||
if name_lower in _USER_PLACEHOLDER_NAMES:
|
||||
ent.name = _CANONICAL_USER_NAME
|
||||
ent.entity_type = _CANONICAL_USER_TYPE
|
||||
elif name_lower in _ASSISTANT_PLACEHOLDER_NAMES:
|
||||
ent.name = _CANONICAL_ASSISTANT_NAME
|
||||
ent.entity_type = _CANONICAL_ASSISTANT_TYPE
|
||||
|
||||
# 第二步:清洗用户/AI助手之间的别名交叉污染(复用 clean_cross_role_aliases)
|
||||
clean_cross_role_aliases(entity_nodes)
|
||||
|
||||
|
||||
async def fetch_neo4j_assistant_aliases(neo4j_connector, end_user_id: str) -> set:
|
||||
"""从 Neo4j 查询 AI 助手实体的所有别名(小写归一化)。
|
||||
|
||||
这是助手别名查询的唯一入口,供 write_tools 和 extraction_orchestrator 共用,
|
||||
避免多处维护相同的 Cypher 和名称列表。
|
||||
|
||||
Args:
|
||||
neo4j_connector: Neo4j 连接器实例(需提供 execute_query 方法)
|
||||
end_user_id: 终端用户 ID
|
||||
|
||||
Returns:
|
||||
小写归一化后的助手别名集合
|
||||
"""
|
||||
# 查询名称列表:规范名称 + 常见变体(与 _normalize_special_entity_names 标准化后一致)
|
||||
query_names = [_CANONICAL_ASSISTANT_NAME, *_ASSISTANT_PLACEHOLDER_NAMES]
|
||||
# 去重保序
|
||||
query_names = list(dict.fromkeys(query_names))
|
||||
|
||||
cypher = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND e.name IN $names
|
||||
RETURN e.aliases AS aliases
|
||||
"""
|
||||
try:
|
||||
result = await neo4j_connector.execute_query(
|
||||
cypher, end_user_id=end_user_id, names=query_names
|
||||
)
|
||||
assistant_aliases: set = set()
|
||||
for record in (result or []):
|
||||
for alias in (record.get("aliases") or []):
|
||||
assistant_aliases.add(alias.strip().lower())
|
||||
if assistant_aliases:
|
||||
logger.debug(f"Neo4j 中 AI 助手别名: {assistant_aliases}")
|
||||
return assistant_aliases
|
||||
except Exception as e:
|
||||
logger.warning(f"查询 Neo4j AI 助手别名失败: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
def clean_cross_role_aliases(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
external_assistant_aliases: set = None,
|
||||
) -> None:
|
||||
"""清洗用户实体和AI助手实体之间的别名交叉污染。
|
||||
|
||||
在 Neo4j 写入前调用,确保:
|
||||
- 用户实体的 aliases 不包含 AI 助手的别名
|
||||
- AI 助手实体的 aliases 不包含用户的别名
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表(原地修改)
|
||||
external_assistant_aliases: 外部传入的 AI 助手别名集合(如从 Neo4j 查询),
|
||||
与本轮实体中的 AI 助手别名合并使用
|
||||
"""
|
||||
# 收集本轮 AI 助手实体的所有别名
|
||||
assistant_aliases = set(external_assistant_aliases or set())
|
||||
user_aliases = set()
|
||||
|
||||
for ent in entity_nodes:
|
||||
if _is_assistant_entity(ent):
|
||||
for alias in (getattr(ent, "aliases", []) or []):
|
||||
assistant_aliases.add(alias.strip().lower())
|
||||
elif _is_user_entity(ent):
|
||||
for alias in (getattr(ent, "aliases", []) or []):
|
||||
user_aliases.add(alias.strip().lower())
|
||||
|
||||
# 从用户实体的 aliases 中移除 AI 助手别名
|
||||
if assistant_aliases:
|
||||
for ent in entity_nodes:
|
||||
if _is_user_entity(ent):
|
||||
original = getattr(ent, "aliases", []) or []
|
||||
cleaned = [a for a in original if a.strip().lower() not in assistant_aliases]
|
||||
if len(cleaned) < len(original):
|
||||
ent.aliases = cleaned
|
||||
|
||||
# 从 AI 助手实体的 aliases 中移除用户别名
|
||||
if user_aliases:
|
||||
for ent in entity_nodes:
|
||||
if _is_assistant_entity(ent):
|
||||
original = getattr(ent, "aliases", []) or []
|
||||
cleaned = [a for a in original if a.strip().lower() not in user_aliases]
|
||||
if len(cleaned) < len(original):
|
||||
ent.aliases = cleaned
|
||||
|
||||
|
||||
def accurate_match(
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||
@@ -261,6 +406,10 @@ def accurate_match(
|
||||
canonical = alias_index.get((ent_uid, ent_name))
|
||||
# 确保不是自身
|
||||
if canonical is not None and canonical.id != ent.id:
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(canonical, ent):
|
||||
i += 1
|
||||
continue
|
||||
_merge_attribute(canonical, ent)
|
||||
id_redirect[ent.id] = canonical.id
|
||||
for k, v in list(id_redirect.items()):
|
||||
@@ -571,66 +720,37 @@ def fuzzy_match(
|
||||
|
||||
|
||||
def _merge_entities_with_aliases(canonical: ExtractedEntityNode, losing: ExtractedEntityNode):
|
||||
""" 模糊匹配中的实体合并。
|
||||
"""模糊匹配中的实体合并(别名部分)。
|
||||
|
||||
合并策略:
|
||||
1. 保留canonical的主名称不变
|
||||
2. 将losing的主名称添加为alias(如果不同)
|
||||
3. 合并两个实体的所有aliases
|
||||
4. 自动去重(case-insensitive)并排序
|
||||
|
||||
Args:
|
||||
canonical: 规范实体(保留)
|
||||
losing: 被合并实体(删除)
|
||||
|
||||
Note:
|
||||
使用alias_utils.normalize_aliases进行标准化去重
|
||||
用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,跳过合并。
|
||||
"""
|
||||
# 获取规范实体的名称
|
||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||
if canonical_name.lower() in _USER_PLACEHOLDER_NAMES:
|
||||
return
|
||||
|
||||
losing_name = (getattr(losing, "name", "") or "").strip()
|
||||
|
||||
# 收集所有需要合并的别名
|
||||
all_aliases = []
|
||||
|
||||
# 1. 添加canonical现有的别名
|
||||
current_aliases = getattr(canonical, "aliases", []) or []
|
||||
all_aliases.extend(current_aliases)
|
||||
|
||||
# 2. 添加losing实体的名称(如果不同于canonical的名称)
|
||||
all_aliases = list(getattr(canonical, "aliases", []) or [])
|
||||
if losing_name and losing_name != canonical_name:
|
||||
all_aliases.append(losing_name)
|
||||
all_aliases.extend(getattr(losing, "aliases", []) or [])
|
||||
|
||||
# 3. 添加losing实体的所有别名
|
||||
losing_aliases = getattr(losing, "aliases", []) or []
|
||||
all_aliases.extend(losing_aliases)
|
||||
|
||||
# 4. 标准化并去重(使用标准化后的字符串进行去重)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
# 如果导入失败,使用增强的去重逻辑
|
||||
# 使用标准化后的字符串作为key进行去重
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
|
||||
# 标准化:转小写用于去重判断
|
||||
alias_normalized = alias_stripped.lower()
|
||||
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
|
||||
# 排序并赋值
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
|
||||
# ========== 主循环:遍历所有实体对进行模糊匹配 ==========
|
||||
@@ -704,6 +824,11 @@ def fuzzy_match(
|
||||
# 条件A(快速通道):alias_match_merge = True
|
||||
# 条件B(标准通道):s_name ≥ tn AND s_type ≥ type_threshold AND overall ≥ tover
|
||||
if alias_match_merge or (s_name >= tn and s_type >= type_threshold and overall >= tover):
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(a, b):
|
||||
j += 1
|
||||
continue
|
||||
|
||||
# ========== 第六步:执行实体合并 ==========
|
||||
|
||||
# 6.1 合并别名
|
||||
@@ -813,6 +938,12 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
b = entity_by_id.get(losing_id)
|
||||
if not a or not b: # 若不存在 a 或 b,可能已在精确或模糊阶段合并,在之前阶段合并之后,不会再处理但是处于审计的目的会记录
|
||||
continue
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(a, b):
|
||||
llm_records.append(
|
||||
f"[LLM阻断] 跨角色合并被阻止: {a.id} ({a.name}) 与 {b.id} ({b.name})"
|
||||
)
|
||||
continue
|
||||
_merge_attribute(a, b)
|
||||
# ID 重定向
|
||||
try:
|
||||
@@ -934,6 +1065,9 @@ async def deduplicate_entities_and_edges(
|
||||
返回:去重后的实体、语句→实体边、实体↔实体边。
|
||||
"""
|
||||
local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化,保留为了之后对于LLM决策追溯
|
||||
# 0) 标准化用户和AI助手实体名称(确保多轮对话中的变体名称统一)
|
||||
_normalize_special_entity_names(entity_nodes)
|
||||
|
||||
# 1) 精确匹配
|
||||
deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
deduplicate_entities_and_edges,
|
||||
clean_cross_role_aliases,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
|
||||
second_layer_dedup_and_merge_with_neo4j,
|
||||
@@ -100,6 +101,10 @@ async def dedup_layers_and_merge_and_return(
|
||||
except Exception as e:
|
||||
print(f"Second-layer dedup failed: {e}")
|
||||
|
||||
# 第二层去重后,清洗用户/AI助手之间的别名交叉污染
|
||||
# 第二层从 Neo4j 合并了旧实体,可能带入历史脏数据
|
||||
clean_cross_role_aliases(fused_entity_nodes)
|
||||
|
||||
return (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
|
||||
@@ -44,6 +44,10 @@ from app.core.memory.models.variate_config import (
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
dedup_layers_and_merge_and_return,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
_USER_PLACEHOLDER_NAMES,
|
||||
fetch_neo4j_assistant_aliases,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation,
|
||||
generate_entity_embeddings_from_triplets,
|
||||
@@ -307,10 +311,53 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
)
|
||||
|
||||
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
|
||||
# 步骤 7: 触发异步元数据和别名提取(仅正式模式)
|
||||
if not is_pilot_run:
|
||||
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
|
||||
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import (
|
||||
MetadataExtractor,
|
||||
)
|
||||
|
||||
metadata_extractor = MetadataExtractor(
|
||||
llm_client=self.llm_client, language=self.language
|
||||
)
|
||||
user_statements = (
|
||||
metadata_extractor.collect_user_related_statements(
|
||||
entity_nodes, statement_nodes, statement_entity_edges
|
||||
)
|
||||
)
|
||||
if user_statements:
|
||||
end_user_id = (
|
||||
dialog_data_list[0].end_user_id
|
||||
if dialog_data_list
|
||||
else None
|
||||
)
|
||||
config_id = (
|
||||
dialog_data_list[0].config_id
|
||||
if dialog_data_list
|
||||
and hasattr(dialog_data_list[0], "config_id")
|
||||
else None
|
||||
)
|
||||
if end_user_id:
|
||||
from app.tasks import extract_user_metadata_task
|
||||
|
||||
extract_user_metadata_task.delay(
|
||||
end_user_id=str(end_user_id),
|
||||
statements=user_statements,
|
||||
config_id=str(config_id) if config_id else None,
|
||||
language=self.language,
|
||||
)
|
||||
logger.info(
|
||||
f"已触发异步元数据提取任务,共 {len(user_statements)} 条用户相关 statement"
|
||||
)
|
||||
else:
|
||||
logger.info("未找到用户相关 statement,跳过元数据提取")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"触发元数据提取任务失败(不影响主流程): {e}", exc_info=True
|
||||
)
|
||||
|
||||
# 别名同步已迁移到 Celery 元数据提取任务中,不再在此处执行
|
||||
|
||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||
return (
|
||||
@@ -1103,6 +1150,7 @@ class ExtractionOrchestrator:
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
content=chunk.content,
|
||||
speaker=getattr(chunk, 'speaker', None),
|
||||
chunk_embedding=chunk.chunk_embedding,
|
||||
sequence_number=chunk_idx, # 添加必需的 sequence_number 字段
|
||||
created_at=dialog_data.created_at,
|
||||
@@ -1338,17 +1386,23 @@ class ExtractionOrchestrator:
|
||||
async def _update_end_user_other_name(
|
||||
self,
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
dialog_data_list: List[DialogData]
|
||||
dialog_data_list: List[DialogData],
|
||||
) -> None:
|
||||
"""
|
||||
从 Neo4j 读取用户实体的最终 aliases,同步到 end_user 和 end_user_info 表
|
||||
将本轮提取的用户别名同步到 end_user 和 end_user_info 表。
|
||||
|
||||
注意:
|
||||
1. other_name 使用本次对话提取的第一个别名(保持时间顺序)
|
||||
2. aliases 从 Neo4j 读取(保持完整性)
|
||||
PgSQL end_user_info.aliases 是用户别名的唯一权威源。
|
||||
此方法仅将本轮 LLM 从对话中新提取的别名增量追加到 PgSQL,
|
||||
不再从 Neo4j 二层去重合并历史别名,避免脏数据反向污染 PgSQL。
|
||||
|
||||
策略:
|
||||
1. 从本轮对话原始发言中提取用户别名(current_aliases)
|
||||
2. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases)
|
||||
3. 合并 db_aliases + current_aliases,去重保序
|
||||
4. 写回 PgSQL
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表
|
||||
entity_nodes: 去重后的实体节点列表(内存中)
|
||||
dialog_data_list: 对话数据列表
|
||||
"""
|
||||
try:
|
||||
@@ -1361,23 +1415,28 @@ class ExtractionOrchestrator:
|
||||
logger.warning("end_user_id 为空,跳过用户别名同步")
|
||||
return
|
||||
|
||||
# 1. 提取本次对话的用户别名(保持 LLM 提取的原始顺序,不排序)
|
||||
current_aliases = self._extract_current_aliases(entity_nodes)
|
||||
# 1. 提取本轮对话的用户别名(保持 LLM 提取的原始顺序,不排序)
|
||||
current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list)
|
||||
|
||||
# 2. 从 Neo4j 获取完整 aliases(权威数据源)
|
||||
neo4j_aliases = await self._fetch_neo4j_user_aliases(end_user_id)
|
||||
# 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源
|
||||
# (防止 LLM 未提取出 AI 助手实体时,AI 别名泄漏到用户别名中)
|
||||
neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id)
|
||||
if neo4j_assistant_aliases:
|
||||
before_count = len(current_aliases)
|
||||
current_aliases = [
|
||||
a for a in current_aliases
|
||||
if a.strip().lower() not in neo4j_assistant_aliases
|
||||
]
|
||||
if len(current_aliases) < before_count:
|
||||
logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名")
|
||||
|
||||
if not neo4j_aliases:
|
||||
# Neo4j 中没有别名,使用本次对话提取的别名
|
||||
neo4j_aliases = current_aliases
|
||||
if not neo4j_aliases:
|
||||
logger.debug(f"aliases 为空,跳过同步: end_user_id={end_user_id}")
|
||||
return
|
||||
if not current_aliases:
|
||||
logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}")
|
||||
return
|
||||
|
||||
logger.info(f"本次对话提取的 aliases: {current_aliases}")
|
||||
logger.info(f"Neo4j 中的完整 aliases: {neo4j_aliases}")
|
||||
logger.info(f"本轮对话提取的 aliases: {current_aliases}")
|
||||
|
||||
# 3. 同步到数据库
|
||||
# 2. 同步到数据库
|
||||
end_user_uuid = uuid.UUID(end_user_id)
|
||||
with get_db_context() as db:
|
||||
# 更新 end_user 表
|
||||
@@ -1386,7 +1445,32 @@ class ExtractionOrchestrator:
|
||||
logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录")
|
||||
return
|
||||
|
||||
new_name = self._resolve_other_name(end_user.other_name, current_aliases, neo4j_aliases)
|
||||
# 3. 从 PgSQL 读取已有 aliases 并与本轮新增合并
|
||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||
db_aliases = (info.aliases if info and info.aliases else [])
|
||||
# 过滤掉占位名称
|
||||
db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
|
||||
|
||||
# 合并:PgSQL 已有 + 本轮新增,去重保序(不再合并 Neo4j 历史别名)
|
||||
merged_aliases = list(db_aliases)
|
||||
seen_lower = {a.strip().lower() for a in merged_aliases}
|
||||
for alias in current_aliases:
|
||||
if alias.strip().lower() not in seen_lower:
|
||||
merged_aliases.append(alias)
|
||||
seen_lower.add(alias.strip().lower())
|
||||
|
||||
# 最终过滤:从合并结果中排除 AI 助手别名(清理历史脏数据)
|
||||
if neo4j_assistant_aliases:
|
||||
merged_aliases = [
|
||||
a for a in merged_aliases
|
||||
if a.strip().lower() not in neo4j_assistant_aliases
|
||||
]
|
||||
|
||||
logger.info(f"PgSQL 已有 aliases: {db_aliases}")
|
||||
logger.info(f"合并后 aliases: {merged_aliases}")
|
||||
|
||||
# 更新 end_user 表 other_name
|
||||
new_name = self._resolve_other_name(end_user.other_name, current_aliases, merged_aliases)
|
||||
if new_name is not None:
|
||||
end_user.other_name = new_name
|
||||
logger.info(f"更新 end_user 表 other_name → {new_name}")
|
||||
@@ -1394,78 +1478,105 @@ class ExtractionOrchestrator:
|
||||
logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}")
|
||||
|
||||
# 更新或创建 end_user_info 记录
|
||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||
if info:
|
||||
new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases)
|
||||
new_name_info = self._resolve_other_name(info.other_name, current_aliases, merged_aliases)
|
||||
if new_name_info is not None:
|
||||
info.other_name = new_name_info
|
||||
logger.info(f"更新 end_user_info 表 other_name → {new_name_info}")
|
||||
if info.aliases != neo4j_aliases:
|
||||
info.aliases = neo4j_aliases
|
||||
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
|
||||
if info.aliases != merged_aliases:
|
||||
info.aliases = merged_aliases
|
||||
logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}")
|
||||
else:
|
||||
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||
# 确保 first_alias 不是占位名称
|
||||
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
|
||||
if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES:
|
||||
db.add(EndUserInfo(
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=first_alias,
|
||||
aliases=neo4j_aliases,
|
||||
meta_data={}
|
||||
aliases=merged_aliases,
|
||||
))
|
||||
logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={neo4j_aliases}")
|
||||
logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={merged_aliases}")
|
||||
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新 end_user other_name 失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
|
||||
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
|
||||
USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'}
|
||||
# 复用 deduped_and_disamb 模块级常量,避免重复维护
|
||||
USER_PLACEHOLDER_NAMES = _USER_PLACEHOLDER_NAMES
|
||||
|
||||
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
||||
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
|
||||
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode], dialog_data_list=None) -> List[str]:
|
||||
"""从用户发言的原始实体中提取本轮新增别名(绕过去重污染)
|
||||
|
||||
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。
|
||||
第一个别名将被用作 other_name。
|
||||
策略:
|
||||
仅从 dialog_data_list 中找到 speaker="user" 的 statement,
|
||||
从这些 statement 的 triplet_extraction_info 中提取用户实体的 aliases。
|
||||
这样拿到的是 LLM 对用户原话的提取结果,不受去重合并的影响。
|
||||
|
||||
注意:不再使用去重后 entity_nodes 作为兜底,因为二层去重会将 Neo4j 历史别名
|
||||
合并进来,导致历史别名被误认为"本轮提取"。历史别名的同步由
|
||||
_extract_deduped_entity_aliases 负责。
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表
|
||||
entity_nodes: 去重后的实体节点列表(未使用,保留参数兼容性)
|
||||
dialog_data_list: 对话数据列表
|
||||
|
||||
Returns:
|
||||
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称)
|
||||
别名列表(保持原始顺序,已过滤)
|
||||
"""
|
||||
if not dialog_data_list:
|
||||
return []
|
||||
|
||||
all_user_aliases = []
|
||||
seen_lower = set()
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
speaker = getattr(chunk, 'speaker', None)
|
||||
for statement in chunk.statements:
|
||||
stmt_speaker = getattr(statement, 'speaker', None) or speaker
|
||||
if stmt_speaker != "user":
|
||||
continue
|
||||
triplet_info = getattr(statement, 'triplet_extraction_info', None)
|
||||
if not triplet_info:
|
||||
continue
|
||||
for entity in (triplet_info.entities or []):
|
||||
ent_name = getattr(entity, 'name', '').strip()
|
||||
if ent_name.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
for alias in (getattr(entity, 'aliases', []) or []):
|
||||
a = alias.strip()
|
||||
if a and a.lower() not in self.USER_PLACEHOLDER_NAMES and a.lower() not in seen_lower:
|
||||
all_user_aliases.append(a)
|
||||
seen_lower.add(a.lower())
|
||||
if all_user_aliases:
|
||||
logger.debug(f"从用户原始发言提取到别名: {all_user_aliases}")
|
||||
return all_user_aliases
|
||||
|
||||
def _extract_deduped_entity_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
||||
"""从去重后的用户实体中提取完整别名列表。
|
||||
|
||||
二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 的用户实体中,
|
||||
因此这里提取到的别名包含了历史积累的所有别名,可用于同步到 PgSQL。
|
||||
|
||||
Args:
|
||||
entity_nodes: 去重后的实体节点列表(含二层去重合并结果)
|
||||
|
||||
Returns:
|
||||
别名列表(已过滤占位名称,去重保序)
|
||||
"""
|
||||
for entity in entity_nodes:
|
||||
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
if getattr(entity, 'name', '').strip().lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
aliases = getattr(entity, 'aliases', []) or []
|
||||
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
|
||||
return filtered
|
||||
filtered = [
|
||||
a for a in aliases
|
||||
if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES
|
||||
]
|
||||
if filtered:
|
||||
return filtered
|
||||
return []
|
||||
|
||||
|
||||
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
|
||||
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
|
||||
cypher = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I']
|
||||
RETURN e.aliases AS aliases
|
||||
LIMIT 1
|
||||
"""
|
||||
result = await Neo4jConnector().execute_query(cypher, end_user_id=end_user_id)
|
||||
if not result:
|
||||
logger.debug(f"Neo4j 中未找到用户实体: end_user_id={end_user_id}")
|
||||
return []
|
||||
aliases = result[0].get('aliases') or []
|
||||
if not aliases:
|
||||
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
|
||||
return []
|
||||
# 过滤掉占位名称,防止历史脏数据传播
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
return filtered
|
||||
async def _fetch_neo4j_assistant_aliases(self, end_user_id: str) -> set:
|
||||
"""从 Neo4j 查询 AI 助手实体的所有别名(用于从用户别名中排除)"""
|
||||
return await fetch_neo4j_assistant_aliases(self.connector, end_user_id)
|
||||
|
||||
def _resolve_other_name(
|
||||
self,
|
||||
@@ -1484,19 +1595,18 @@ class ExtractionOrchestrator:
|
||||
注意:返回值不允许是占位名称("用户"、"我"、"User"、"I")
|
||||
"""
|
||||
# 当前值为空或为占位名称时,需要更新
|
||||
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
if not current or not current.strip() or current.strip().lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
candidate = current_aliases[0].strip() if current_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
if current not in neo4j_aliases:
|
||||
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
async def _run_dedup_and_write_summary(
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Metadata extractor module.
|
||||
|
||||
Collects user-related statements from post-dedup graph data and
|
||||
extracts user metadata via an independent LLM call.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
ExtractedEntityNode,
|
||||
StatementEntityEdge,
|
||||
StatementNode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reuse the same user-entity detection logic from dedup module
|
||||
_USER_NAMES = {"用户", "我", "user", "i"}
|
||||
_CANONICAL_USER_TYPE = "用户"
|
||||
|
||||
|
||||
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为用户实体"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _USER_NAMES or etype == _CANONICAL_USER_TYPE
|
||||
|
||||
|
||||
class MetadataExtractor:
|
||||
"""Extracts user metadata from post-dedup graph data via independent LLM call."""
|
||||
|
||||
def __init__(self, llm_client, language: Optional[str] = None):
|
||||
self.llm_client = llm_client
|
||||
self.language = language
|
||||
|
||||
@staticmethod
|
||||
def detect_language(statements: List[str]) -> str:
|
||||
"""根据 statement 文本内容检测语言。
|
||||
如果文本中包含中文字符则返回 "zh",否则返回 "en"。
|
||||
"""
|
||||
import re
|
||||
|
||||
combined = " ".join(statements)
|
||||
if re.search(r"[\u4e00-\u9fff]", combined):
|
||||
return "zh"
|
||||
return "en"
|
||||
|
||||
def collect_user_related_statements(
|
||||
self,
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
) -> List[str]:
|
||||
"""
|
||||
从去重后的数据中筛选与用户直接相关且由用户发言的 statement 文本。
|
||||
|
||||
筛选逻辑:
|
||||
1. 用户实体 → StatementEntityEdge → statement(直接关联)
|
||||
2. 只保留 speaker="user" 的 statement(过滤 assistant 回复的噪声)
|
||||
|
||||
Returns:
|
||||
用户发言的 statement 文本列表
|
||||
"""
|
||||
# Find user entity IDs
|
||||
user_entity_ids = set()
|
||||
for ent in entity_nodes:
|
||||
if _is_user_entity(ent):
|
||||
user_entity_ids.add(ent.id)
|
||||
|
||||
if not user_entity_ids:
|
||||
logger.debug("未找到用户实体节点,跳过 statement 收集")
|
||||
return []
|
||||
|
||||
# 用户实体 → StatementEntityEdge → statement
|
||||
target_stmt_ids = set()
|
||||
for edge in statement_entity_edges:
|
||||
if edge.target in user_entity_ids:
|
||||
target_stmt_ids.add(edge.source)
|
||||
|
||||
# Collect: only speaker="user" statements, preserving order
|
||||
result = []
|
||||
seen = set()
|
||||
total_associated = 0
|
||||
skipped_non_user = 0
|
||||
for stmt_node in statement_nodes:
|
||||
if stmt_node.id in target_stmt_ids and stmt_node.id not in seen:
|
||||
total_associated += 1
|
||||
speaker = getattr(stmt_node, "speaker", None) or "unknown"
|
||||
if speaker == "user":
|
||||
text = (stmt_node.statement or "").strip()
|
||||
if text:
|
||||
result.append(text)
|
||||
else:
|
||||
skipped_non_user += 1
|
||||
seen.add(stmt_node.id)
|
||||
|
||||
logger.info(
|
||||
f"收集到 {len(result)} 条用户发言 statement "
|
||||
f"(直接关联: {total_associated}, speaker=user: {len(result)}, "
|
||||
f"跳过非user: {skipped_non_user})"
|
||||
)
|
||||
if result:
|
||||
for i, text in enumerate(result):
|
||||
logger.info(f" [user statement {i + 1}] {text}")
|
||||
if total_associated > 0 and len(result) == 0:
|
||||
logger.warning(
|
||||
f"有 {total_associated} 条直接关联 statement 但全部被 speaker 过滤,"
|
||||
f"可能本次写入不包含 user 消息"
|
||||
)
|
||||
return result
|
||||
|
||||
async def extract_metadata(
|
||||
self,
|
||||
statements: List[str],
|
||||
existing_metadata: Optional[dict] = None,
|
||||
existing_aliases: Optional[List[str]] = None,
|
||||
) -> Optional[tuple]:
|
||||
"""
|
||||
对筛选后的 statement 列表调用 LLM 提取元数据增量变更和用户别名。
|
||||
|
||||
Args:
|
||||
statements: 用户发言的 statement 文本列表
|
||||
existing_metadata: 数据库已有的元数据(可选)
|
||||
existing_aliases: 数据库已有的用户别名列表(可选)
|
||||
|
||||
Returns:
|
||||
(List[MetadataFieldChange], List[str], List[str]) tuple:
|
||||
(metadata_changes, aliases_to_add, aliases_to_remove) on success, None on failure
|
||||
"""
|
||||
if not statements:
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.prompt.prompt_utils import prompt_env
|
||||
|
||||
if self.language:
|
||||
detected_language = self.language
|
||||
logger.info(f"元数据提取使用显式指定语言: {detected_language}")
|
||||
else:
|
||||
detected_language = self.detect_language(statements)
|
||||
logger.info(f"元数据提取语言自动检测结果: {detected_language}")
|
||||
|
||||
template = prompt_env.get_template("extract_user_metadata.jinja2")
|
||||
prompt = template.render(
|
||||
statements=statements,
|
||||
language=detected_language,
|
||||
existing_metadata=existing_metadata,
|
||||
existing_aliases=existing_aliases,
|
||||
json_schema="",
|
||||
)
|
||||
|
||||
from app.core.memory.models.metadata_models import (
|
||||
MetadataExtractionResponse,
|
||||
)
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_model=MetadataExtractionResponse,
|
||||
)
|
||||
|
||||
if response:
|
||||
changes = response.metadata_changes if response.metadata_changes else []
|
||||
to_add = response.aliases_to_add if response.aliases_to_add else []
|
||||
to_remove = (
|
||||
response.aliases_to_remove if response.aliases_to_remove else []
|
||||
)
|
||||
return changes, to_add, to_remove
|
||||
|
||||
logger.warning("LLM 返回的响应为空")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"元数据提取 LLM 调用失败: {e}", exc_info=True)
|
||||
return None
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -82,6 +81,7 @@ class StatementExtractor:
|
||||
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
|
||||
return None
|
||||
|
||||
|
||||
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
"""Process a single chunk and return extracted statements
|
||||
|
||||
@@ -94,7 +94,8 @@ class StatementExtractor:
|
||||
List of ExtractedStatement objects extracted from the chunk
|
||||
"""
|
||||
chunk_content = chunk.content
|
||||
|
||||
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||
|
||||
if not chunk_content or len(chunk_content.strip()) < 5:
|
||||
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
|
||||
return []
|
||||
@@ -149,8 +150,6 @@ class StatementExtractor:
|
||||
relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT
|
||||
except (KeyError, ValueError):
|
||||
relevence_info = RelevenceInfo.RELEVANT
|
||||
|
||||
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||
|
||||
chunk_statement = Statement(
|
||||
statement=extracted_stmt.statement,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
@@ -61,6 +60,7 @@ class TripletExtractor:
|
||||
predicate_instructions=PREDICATE_DEFINITIONS,
|
||||
language=self._get_language(),
|
||||
ontology_types=self.ontology_types,
|
||||
speaker=getattr(statement, 'speaker', None),
|
||||
)
|
||||
|
||||
# Create messages for LLM
|
||||
|
||||
@@ -42,22 +42,21 @@ class AccessHistoryManager:
|
||||
- access_count: 访问次数
|
||||
|
||||
特性:
|
||||
- 原子性更新:使用Neo4j事务确保所有字段同时更新或回滚
|
||||
- 并发安全:使用乐观锁机制防止并发冲突
|
||||
- 原子性更新:使用 APOC 原子操作确保并发安全
|
||||
- 批次内合并:同一批次中对同一节点的多次访问合并为一次更新
|
||||
- 一致性保证:提供一致性检查和自动修复功能
|
||||
- 智能修剪:自动修剪过长的访问历史
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
actr_calculator: ACT-R激活值计算器实例
|
||||
max_retries: 并发冲突时的最大重试次数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
actr_calculator: ACTRCalculator,
|
||||
max_retries: int = 3
|
||||
max_retries: int = 5
|
||||
):
|
||||
"""
|
||||
初始化访问历史管理器
|
||||
@@ -65,47 +64,35 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
actr_calculator: ACT-R激活值计算器实例
|
||||
max_retries: 并发冲突时的最大重试次数(默认3次)
|
||||
max_retries: 已废弃,保留参数兼容性(APOC 原子操作无需重试)
|
||||
"""
|
||||
self.connector = connector
|
||||
self.actr_calculator = actr_calculator
|
||||
self.max_retries = max_retries
|
||||
|
||||
|
||||
async def record_access(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
current_time: Optional[datetime] = None,
|
||||
access_times: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
记录节点访问并原子性更新所有相关字段
|
||||
|
||||
这是核心方法,实现了:
|
||||
1. 首次访问:初始化access_history,计算初始激活值
|
||||
2. 后续访问:追加访问历史,重新计算激活值
|
||||
3. 历史修剪:当历史过长时自动修剪
|
||||
4. 原子性:所有字段在单个事务中更新
|
||||
5. 并发安全:使用乐观锁重试机制
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
end_user_id: 组ID(可选,用于过滤)
|
||||
current_time: 当前时间(可选,默认使用系统时间)
|
||||
access_times: 本次访问次数(默认1,批量合并时可能大于1)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新后的节点数据,包含:
|
||||
- id: 节点ID
|
||||
- activation_value: 更新后的激活值
|
||||
- access_history: 更新后的访问历史
|
||||
- last_access_time: 最后访问时间
|
||||
- access_count: 访问次数
|
||||
- importance_score: 重要性分数
|
||||
Dict[str, Any]: 更新后的节点数据
|
||||
|
||||
Raises:
|
||||
ValueError: 如果节点不存在或节点标签无效
|
||||
RuntimeError: 如果重试次数耗尽仍然失败
|
||||
RuntimeError: 如果更新失败
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
@@ -119,55 +106,48 @@ class AccessHistoryManager:
|
||||
f"Invalid node_label: {node_label}. Must be one of {valid_labels}"
|
||||
)
|
||||
|
||||
# 使用乐观锁重试机制处理并发冲突
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# 步骤1:读取当前节点状态
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
raise ValueError(
|
||||
f"Node not found: {node_label} with id={node_id}"
|
||||
)
|
||||
|
||||
# 步骤2:计算新的访问历史和激活值
|
||||
update_data = await self._calculate_update(
|
||||
node_data=node_data,
|
||||
current_time=current_time,
|
||||
current_time_iso=current_time_iso
|
||||
try:
|
||||
# 步骤1:读取当前节点状态
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
raise ValueError(
|
||||
f"Node not found: {node_label} with id={node_id}"
|
||||
)
|
||||
|
||||
# 步骤3:原子性更新节点(使用事务)
|
||||
updated_node = await self._atomic_update(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"成功记录访问: {node_label}[{node_id}], "
|
||||
f"activation={update_data['activation_value']:.4f}, "
|
||||
f"access_count={update_data['access_count']}"
|
||||
)
|
||||
|
||||
return updated_node
|
||||
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.warning(
|
||||
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], "
|
||||
f"错误: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to record access after {self.max_retries} attempts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# 步骤2:计算新的访问历史和激活值
|
||||
update_data = await self._calculate_update(
|
||||
node_data=node_data,
|
||||
current_time=current_time,
|
||||
current_time_iso=current_time_iso,
|
||||
access_times=access_times
|
||||
)
|
||||
|
||||
# 步骤3:使用 APOC 原子操作更新节点(无需重试)
|
||||
updated_node = await self._atomic_update(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"成功记录访问: {node_label}[{node_id}], "
|
||||
f"activation={update_data['activation_value']:.4f}, "
|
||||
f"access_count={update_data['access_count']}"
|
||||
f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}"
|
||||
)
|
||||
|
||||
return updated_node
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"访问记录失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to record access: {str(e)}"
|
||||
) from e
|
||||
|
||||
async def record_batch_access(
|
||||
self,
|
||||
node_ids: List[str],
|
||||
@@ -178,11 +158,10 @@ class AccessHistoryManager:
|
||||
"""
|
||||
批量记录多个节点的访问
|
||||
|
||||
为提高性能,批量更新多个节点的访问历史。
|
||||
每个节点独立更新,失败的节点不影响其他节点。
|
||||
对同一个节点的多次访问会先在内存中合并,只发起一次更新。
|
||||
|
||||
Args:
|
||||
node_ids: 节点ID列表
|
||||
node_ids: 节点ID列表(可包含重复ID)
|
||||
node_label: 节点标签(所有节点必须是同一类型)
|
||||
end_user_id: 组ID(可选)
|
||||
current_time: 当前时间(可选)
|
||||
@@ -196,25 +175,38 @@ class AccessHistoryManager:
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially
|
||||
tasks = []
|
||||
# 合并同一节点的访问次数,避免对同一节点并发写入
|
||||
access_count_map: Dict[str, int] = {}
|
||||
for node_id in node_ids:
|
||||
access_count_map[node_id] = access_count_map.get(node_id, 0) + 1
|
||||
|
||||
merged_count = len(node_ids) - len(access_count_map)
|
||||
if merged_count > 0:
|
||||
logger.info(
|
||||
f"批量访问合并: 原始={len(node_ids)}, "
|
||||
f"去重后={len(access_count_map)}, 合并={merged_count}"
|
||||
)
|
||||
|
||||
# 对去重后的节点并行发起更新
|
||||
tasks = []
|
||||
for node_id, access_times in access_count_map.items():
|
||||
task = self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
end_user_id=end_user_id,
|
||||
current_time=current_time
|
||||
current_time=current_time,
|
||||
access_times=access_times
|
||||
)
|
||||
tasks.append(task)
|
||||
tasks.append((node_id, task))
|
||||
|
||||
# Execute all tasks in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
task_results = await asyncio.gather(
|
||||
*[t for _, t in tasks], return_exceptions=True
|
||||
)
|
||||
|
||||
# Collect successful results and count failures
|
||||
results = []
|
||||
failed_count = 0
|
||||
|
||||
for node_id, result in zip(node_ids, task_results):
|
||||
for (node_id, _), result in zip(tasks, task_results):
|
||||
if isinstance(result, Exception):
|
||||
failed_count += 1
|
||||
logger.warning(
|
||||
@@ -225,12 +217,12 @@ class AccessHistoryManager:
|
||||
|
||||
batch_duration = time.time() - batch_start
|
||||
logger.info(
|
||||
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
|
||||
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(access_count_map)}, "
|
||||
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def check_consistency(
|
||||
self,
|
||||
node_id: str,
|
||||
@@ -239,22 +231,6 @@ class AccessHistoryManager:
|
||||
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
"""
|
||||
检查节点数据的一致性
|
||||
|
||||
验证以下一致性规则:
|
||||
1. access_history[-1] == last_access_time
|
||||
2. len(access_history) == access_count
|
||||
3. 如果有访问历史,必须有激活值
|
||||
4. 激活值必须在有效范围内 [offset, 1.0]
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
- 一致性检查结果枚举
|
||||
- 错误描述(如果不一致)
|
||||
"""
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
@@ -266,7 +242,6 @@ class AccessHistoryManager:
|
||||
access_count = node_data.get('access_count', 0)
|
||||
activation_value = node_data.get('activation_value')
|
||||
|
||||
# 检查1:access_history[-1] == last_access_time
|
||||
if access_history and last_access_time:
|
||||
if access_history[-1] != last_access_time:
|
||||
return (
|
||||
@@ -275,7 +250,6 @@ class AccessHistoryManager:
|
||||
f"last_access_time={last_access_time}"
|
||||
)
|
||||
|
||||
# 检查2:len(access_history) == access_count
|
||||
if len(access_history) != access_count:
|
||||
return (
|
||||
ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT,
|
||||
@@ -283,14 +257,12 @@ class AccessHistoryManager:
|
||||
f"access_count={access_count}"
|
||||
)
|
||||
|
||||
# 检查3:有访问历史必须有激活值
|
||||
if access_history and activation_value is None:
|
||||
return (
|
||||
ConsistencyCheckResult.MISSING_ACTIVATION,
|
||||
"Node has access_history but activation_value is None"
|
||||
)
|
||||
|
||||
# 检查4:激活值范围
|
||||
if activation_value is not None:
|
||||
offset = self.actr_calculator.offset
|
||||
if not (offset <= activation_value <= 1.0):
|
||||
@@ -301,30 +273,14 @@ class AccessHistoryManager:
|
||||
)
|
||||
|
||||
return ConsistencyCheckResult.CONSISTENT, None
|
||||
|
||||
|
||||
async def check_batch_consistency(
|
||||
self,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
批量检查多个节点的一致性
|
||||
|
||||
Args:
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
limit: 检查的最大节点数
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 一致性检查报告,包含:
|
||||
- total_checked: 检查的节点总数
|
||||
- consistent_count: 一致的节点数
|
||||
- inconsistent_count: 不一致的节点数
|
||||
- inconsistencies: 不一致节点的详细信息列表
|
||||
- consistency_rate: 一致性率(0-1)
|
||||
"""
|
||||
# 查询所有相关节点
|
||||
"""批量检查多个节点的一致性"""
|
||||
query = f"""
|
||||
MATCH (n:{node_label})
|
||||
WHERE n.access_history IS NOT NULL
|
||||
@@ -343,7 +299,6 @@ class AccessHistoryManager:
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
node_ids = [r['id'] for r in results]
|
||||
|
||||
# 检查每个节点
|
||||
inconsistencies = []
|
||||
consistent_count = 0
|
||||
|
||||
@@ -382,32 +337,15 @@ class AccessHistoryManager:
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
async def repair_inconsistency(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
自动修复节点的数据不一致问题
|
||||
|
||||
修复策略:
|
||||
1. 如果access_history[-1] != last_access_time:使用access_history[-1]
|
||||
2. 如果len(access_history) != access_count:使用len(access_history)
|
||||
3. 如果有历史但无激活值:重新计算激活值
|
||||
4. 如果激活值超出范围:重新计算激活值
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
bool: 修复成功返回True,否则返回False
|
||||
"""
|
||||
"""自动修复节点的数据不一致问题"""
|
||||
try:
|
||||
# 检查一致性
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
@@ -418,7 +356,6 @@ class AccessHistoryManager:
|
||||
logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]")
|
||||
return True
|
||||
|
||||
# 获取节点数据
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
if not node_data:
|
||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||
@@ -427,17 +364,13 @@ class AccessHistoryManager:
|
||||
access_history = node_data.get('access_history') or []
|
||||
importance_score = node_data.get('importance_score', 0.5)
|
||||
|
||||
# 准备修复数据
|
||||
repair_data = {}
|
||||
|
||||
# 修复last_access_time
|
||||
if access_history:
|
||||
repair_data['last_access_time'] = access_history[-1]
|
||||
|
||||
# 修复access_count
|
||||
repair_data['access_count'] = len(access_history)
|
||||
|
||||
# 修复activation_value
|
||||
if access_history:
|
||||
current_time = datetime.now()
|
||||
last_access_dt = datetime.fromisoformat(access_history[-1])
|
||||
@@ -453,7 +386,6 @@ class AccessHistoryManager:
|
||||
)
|
||||
repair_data['activation_value'] = activation_value
|
||||
|
||||
# 执行修复
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
@@ -484,26 +416,16 @@ class AccessHistoryManager:
|
||||
f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
# ==================== 私有辅助方法 ====================
|
||||
|
||||
|
||||
async def _fetch_node(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取节点数据
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
||||
"""
|
||||
"""获取节点数据"""
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
@@ -527,12 +449,13 @@ class AccessHistoryManager:
|
||||
if results:
|
||||
return results[0]
|
||||
return None
|
||||
|
||||
|
||||
async def _calculate_update(
|
||||
self,
|
||||
node_data: Dict[str, Any],
|
||||
current_time: datetime,
|
||||
current_time_iso: str
|
||||
current_time_iso: str,
|
||||
access_times: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算更新数据
|
||||
@@ -541,45 +464,40 @@ class AccessHistoryManager:
|
||||
node_data: 当前节点数据
|
||||
current_time: 当前时间(datetime对象)
|
||||
current_time_iso: 当前时间(ISO格式字符串)
|
||||
access_times: 本次访问次数(合并后可能大于1)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新数据,包含所有需要更新的字段
|
||||
Dict[str, Any]: 更新数据
|
||||
"""
|
||||
access_history = node_data.get('access_history') or []
|
||||
# Handle None importance_score - default to 0.5
|
||||
importance_score = node_data.get('importance_score')
|
||||
if importance_score is None:
|
||||
importance_score = 0.5
|
||||
|
||||
# 追加新的访问时间
|
||||
new_access_history = access_history + [current_time_iso]
|
||||
# 本次新增的时间戳
|
||||
new_timestamps = [current_time_iso] * access_times
|
||||
|
||||
# 修剪访问历史(如果过长)
|
||||
access_history_dt = [
|
||||
datetime.fromisoformat(ts) for ts in new_access_history
|
||||
]
|
||||
# 仅用本次新增的访问记录计算激活值
|
||||
new_history_dt = [current_time] * access_times
|
||||
trimmed_history_dt = self.actr_calculator.trim_access_history(
|
||||
access_history=access_history_dt,
|
||||
access_history=new_history_dt,
|
||||
current_time=current_time
|
||||
)
|
||||
trimmed_history = [ts.isoformat() for ts in trimmed_history_dt]
|
||||
|
||||
# 计算新的激活值
|
||||
activation_value = self.actr_calculator.calculate_memory_activation(
|
||||
access_history=trimmed_history_dt,
|
||||
current_time=current_time,
|
||||
last_access_time=current_time, # 最后访问时间就是当前时间
|
||||
last_access_time=current_time,
|
||||
importance_score=importance_score
|
||||
)
|
||||
|
||||
# 返回所有需要更新的字段
|
||||
return {
|
||||
'activation_value': activation_value,
|
||||
'access_history': trimmed_history,
|
||||
'new_timestamps': new_timestamps,
|
||||
'access_count_delta': access_times,
|
||||
'access_count': len(trimmed_history_dt),
|
||||
'last_access_time': current_time_iso,
|
||||
'access_count': len(trimmed_history)
|
||||
}
|
||||
|
||||
|
||||
async def _atomic_update(
|
||||
self,
|
||||
node_id: str,
|
||||
@@ -588,10 +506,10 @@ class AccessHistoryManager:
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
原子性更新节点(使用乐观锁)
|
||||
原子性更新节点(使用 APOC 原子操作)
|
||||
|
||||
使用Neo4j事务和版本号确保所有字段同时更新或回滚。
|
||||
实现乐观锁机制防止并发冲突。
|
||||
使用 apoc.atomic.add 和 apoc.atomic.insert 保证并发安全,
|
||||
无需 version 字段和乐观锁,数据库层面保证原子性。
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
@@ -603,126 +521,68 @@ class AccessHistoryManager:
|
||||
Dict[str, Any]: 更新后的节点数据
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果更新失败或发生版本冲突
|
||||
RuntimeError: 如果更新失败
|
||||
"""
|
||||
# 定义事务函数
|
||||
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
|
||||
# 步骤1:读取当前节点并获取版本号
|
||||
read_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if end_user_id:
|
||||
read_query += " WHERE n.end_user_id = $end_user_id"
|
||||
read_query += """
|
||||
RETURN n.id as id,
|
||||
n.version as version,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score
|
||||
"""
|
||||
content_field_map = {
|
||||
'Statement': 'n.statement as statement',
|
||||
'MemorySummary': 'n.content as content',
|
||||
'ExtractedEntity': 'null as content_placeholder',
|
||||
'Community': 'n.summary as summary'
|
||||
}
|
||||
|
||||
if node_label not in content_field_map:
|
||||
raise ValueError(
|
||||
f"Unsupported node_label: {node_label}. "
|
||||
f"Supported labels are: {list(content_field_map.keys())}"
|
||||
)
|
||||
|
||||
content_field = content_field_map[node_label]
|
||||
|
||||
where_clause = ""
|
||||
if end_user_id:
|
||||
where_clause = " AND n.end_user_id = $end_user_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
WHERE true{where_clause}
|
||||
CALL apoc.atomic.add(n, 'access_count', $access_count_delta, 5) YIELD oldValue AS old_count
|
||||
WITH n
|
||||
CALL (n) {{
|
||||
UNWIND $new_timestamps AS ts
|
||||
CALL apoc.atomic.insert(n, 'access_history', size(n.access_history), ts, 5) YIELD oldValue
|
||||
RETURN count(*) AS inserted
|
||||
}}
|
||||
SET n.activation_value = $activation_value,
|
||||
n.last_access_time = $last_access_time
|
||||
RETURN n.id as id,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score,
|
||||
{content_field}
|
||||
"""
|
||||
|
||||
params = {
|
||||
'node_id': node_id,
|
||||
'access_count_delta': update_data['access_count_delta'],
|
||||
'new_timestamps': update_data['new_timestamps'],
|
||||
'activation_value': update_data['activation_value'],
|
||||
'last_access_time': update_data['last_access_time'],
|
||||
}
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
try:
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
read_params = {'node_id': node_id}
|
||||
if end_user_id:
|
||||
read_params['end_user_id'] = end_user_id
|
||||
|
||||
read_result = await tx.run(read_query, **read_params)
|
||||
current_node = await read_result.single()
|
||||
|
||||
if not current_node:
|
||||
if not results:
|
||||
raise RuntimeError(f"Node not found: {node_label}[{node_id}]")
|
||||
|
||||
# 获取当前版本号(如果不存在则为0)
|
||||
current_version = current_node.get('version', 0) or 0
|
||||
new_version = current_version + 1
|
||||
|
||||
# 步骤2:使用乐观锁更新节点
|
||||
# 根据节点类型构建完整的查询语句
|
||||
content_field_map = {
|
||||
'Statement': 'n.statement as statement',
|
||||
'MemorySummary': 'n.content as content',
|
||||
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤
|
||||
}
|
||||
|
||||
# 显式检查节点类型,不支持的类型抛出错误
|
||||
if node_label not in content_field_map:
|
||||
raise ValueError(
|
||||
f"Unsupported node_label: {node_label}. "
|
||||
f"Supported labels are: {list(content_field_map.keys())}"
|
||||
)
|
||||
|
||||
content_field = content_field_map[node_label]
|
||||
|
||||
# 构建 WHERE 子句
|
||||
where_conditions = []
|
||||
if end_user_id:
|
||||
where_conditions.append("n.end_user_id = $end_user_id")
|
||||
|
||||
# 添加版本检查
|
||||
if current_version > 0:
|
||||
where_conditions.append("n.version = $current_version")
|
||||
else:
|
||||
where_conditions.append("(n.version IS NULL OR n.version = 0)")
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "true"
|
||||
|
||||
# 构建完整的更新查询
|
||||
update_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
WHERE {where_clause}
|
||||
SET n.activation_value = $activation_value,
|
||||
n.access_history = $access_history,
|
||||
n.last_access_time = $last_access_time,
|
||||
n.access_count = $access_count,
|
||||
n.version = $new_version
|
||||
RETURN n.id as id,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score,
|
||||
n.version as version,
|
||||
{content_field}
|
||||
"""
|
||||
|
||||
update_params = {
|
||||
'node_id': node_id,
|
||||
'current_version': current_version,
|
||||
'new_version': new_version,
|
||||
'activation_value': update_data['activation_value'],
|
||||
'access_history': update_data['access_history'],
|
||||
'last_access_time': update_data['last_access_time'],
|
||||
'access_count': update_data['access_count']
|
||||
}
|
||||
if end_user_id:
|
||||
update_params['end_user_id'] = end_user_id
|
||||
|
||||
update_result = await tx.run(update_query, **update_params)
|
||||
updated_node = await update_result.single()
|
||||
|
||||
if not updated_node:
|
||||
raise RuntimeError(
|
||||
f"Version conflict detected for {node_label}[{node_id}]. "
|
||||
f"Expected version {current_version}, but node was modified by another transaction."
|
||||
)
|
||||
|
||||
# 转换为字典并移除占位符字段
|
||||
result_dict = dict(updated_node)
|
||||
result_dict = dict(results[0])
|
||||
result_dict.pop('content_placeholder', None)
|
||||
|
||||
return result_dict
|
||||
|
||||
# 执行事务
|
||||
try:
|
||||
result = await self.connector.execute_write_transaction(
|
||||
update_transaction,
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
|
||||
@@ -4,11 +4,6 @@
|
||||
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.search_strategy import (
|
||||
@@ -29,115 +24,87 @@ __all__ = [
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 向后兼容的函数式API
|
||||
# 向后兼容的函数式API (DEPRECATED - 未被使用)
|
||||
# ============================================================================
|
||||
# 为了兼容旧代码,提供与 src/search.py 相同的函数式接口
|
||||
# 所有调用方均直接使用 app.core.memory.src.search.run_hybrid_search
|
||||
# 保留注释以备参考
|
||||
|
||||
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str = "hybrid",
|
||||
end_user_id: str | None = None,
|
||||
apply_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
include: list[str] | None = None,
|
||||
alpha: float = 0.6,
|
||||
use_forgetting_curve: bool = False,
|
||||
memory_config: "MemoryConfig" = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""运行混合搜索(向后兼容的函数式API)
|
||||
|
||||
这是一个向后兼容的包装函数,将旧的函数式API转换为新的基于类的API。
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
search_type: 搜索类型("hybrid", "keyword", "semantic")
|
||||
end_user_id: 组ID过滤
|
||||
apply_id: 应用ID过滤
|
||||
user_id: 用户ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
alpha: BM25分数权重(0.0-1.0)
|
||||
use_forgetting_curve: 是否使用遗忘曲线
|
||||
memory_config: MemoryConfig object containing embedding_model_id
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
dict: 搜索结果字典,格式与旧API兼容
|
||||
"""
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
if not memory_config:
|
||||
raise ValueError("memory_config is required for search")
|
||||
|
||||
# 初始化客户端
|
||||
connector = Neo4jConnector()
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
try:
|
||||
# 根据搜索类型选择策略
|
||||
if search_type == "keyword":
|
||||
strategy = KeywordSearchStrategy(connector=connector)
|
||||
elif search_type == "semantic":
|
||||
strategy = SemanticSearchStrategy(
|
||||
connector=connector,
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
else: # hybrid
|
||||
strategy = HybridSearchStrategy(
|
||||
connector=connector,
|
||||
embedder_client=embedder_client,
|
||||
alpha=alpha,
|
||||
use_forgetting_curve=use_forgetting_curve
|
||||
)
|
||||
|
||||
# 执行搜索
|
||||
result = await strategy.search(
|
||||
query_text=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
alpha=alpha,
|
||||
use_forgetting_curve=use_forgetting_curve,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 转换为旧格式
|
||||
result_dict = result.to_dict()
|
||||
|
||||
# 保存到文件(如果指定了output_path)
|
||||
output_path = kwargs.get('output_path', 'search_results.json')
|
||||
if output_path:
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
# 确保目录存在
|
||||
out_dir = os.path.dirname(output_path)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# 保存结果
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str)
|
||||
print(f"Search results saved to {output_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving search results: {e}")
|
||||
return result_dict
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
__all__.append("run_hybrid_search")
|
||||
# async def run_hybrid_search(
|
||||
# query_text: str,
|
||||
# search_type: str = "hybrid",
|
||||
# end_user_id: str | None = None,
|
||||
# apply_id: str | None = None,
|
||||
# user_id: str | None = None,
|
||||
# limit: int = 50,
|
||||
# include: list[str] | None = None,
|
||||
# alpha: float = 0.6,
|
||||
# use_forgetting_curve: bool = False,
|
||||
# memory_config: "MemoryConfig" = None,
|
||||
# **kwargs
|
||||
# ) -> dict:
|
||||
# """运行混合搜索(向后兼容的函数式API)"""
|
||||
# from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
# from app.core.models.base import RedBearModelConfig
|
||||
# from app.db import get_db_context
|
||||
# from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
#
|
||||
# if not memory_config:
|
||||
# raise ValueError("memory_config is required for search")
|
||||
#
|
||||
# connector = Neo4jConnector()
|
||||
# with get_db_context() as db:
|
||||
# config_service = MemoryConfigService(db)
|
||||
# embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
# embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
# embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
#
|
||||
# try:
|
||||
# if search_type == "keyword":
|
||||
# strategy = KeywordSearchStrategy(connector=connector)
|
||||
# elif search_type == "semantic":
|
||||
# strategy = SemanticSearchStrategy(
|
||||
# connector=connector,
|
||||
# embedder_client=embedder_client
|
||||
# )
|
||||
# else:
|
||||
# strategy = HybridSearchStrategy(
|
||||
# connector=connector,
|
||||
# embedder_client=embedder_client,
|
||||
# alpha=alpha,
|
||||
# use_forgetting_curve=use_forgetting_curve
|
||||
# )
|
||||
#
|
||||
# result = await strategy.search(
|
||||
# query_text=query_text,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include,
|
||||
# alpha=alpha,
|
||||
# use_forgetting_curve=use_forgetting_curve,
|
||||
# **kwargs
|
||||
# )
|
||||
#
|
||||
# result_dict = result.to_dict()
|
||||
#
|
||||
# output_path = kwargs.get('output_path', 'search_results.json')
|
||||
# if output_path:
|
||||
# import json
|
||||
# import os
|
||||
# from datetime import datetime
|
||||
#
|
||||
# try:
|
||||
# out_dir = os.path.dirname(output_path)
|
||||
# if out_dir:
|
||||
# os.makedirs(out_dir, exist_ok=True)
|
||||
# with open(output_path, "w", encoding="utf-8") as f:
|
||||
# json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str)
|
||||
# print(f"Search results saved to {output_path}")
|
||||
# except Exception as e:
|
||||
# print(f"Error saving search results: {e}")
|
||||
# return result_dict
|
||||
#
|
||||
# finally:
|
||||
# await connector.close()
|
||||
#
|
||||
# __all__.append("run_hybrid_search")
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
使用Neo4j的全文索引进行高效的文本匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
@@ -74,7 +74,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
# 调用底层的关键词搜索函数
|
||||
results_dict = await search_graph(
|
||||
connector=self.connector,
|
||||
q=query_text,
|
||||
query=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
|
||||
@@ -22,7 +22,9 @@ def escape_lucene_query(query: str) -> str:
|
||||
s = s.replace("\r", " ").replace("\n", " ").strip()
|
||||
|
||||
# Lucene reserved tokens/special characters
|
||||
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':']
|
||||
# NOTE: '/' is the regex delimiter in Lucene — must be escaped to prevent
|
||||
# TokenMgrError when the query contains unmatched slashes.
|
||||
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':', '/']
|
||||
# Replace longer tokens first to avoid partial double-escaping
|
||||
for token in sorted(specials, key=len, reverse=True):
|
||||
s = s.replace(token, f"\\{token}")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering
|
||||
|
||||
# Setup Jinja2 environment
|
||||
@@ -205,6 +205,7 @@ async def render_triplet_extraction_prompt(
|
||||
predicate_instructions: dict = None,
|
||||
language: str = "zh",
|
||||
ontology_types: "OntologyTypeList | None" = None,
|
||||
speaker: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
||||
@@ -216,6 +217,7 @@ async def render_triplet_extraction_prompt(
|
||||
predicate_instructions: Optional predicate instructions
|
||||
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
|
||||
ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification
|
||||
speaker: Speaker role ("user" or "assistant") for the current statement
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -223,7 +225,7 @@ async def render_triplet_extraction_prompt(
|
||||
template = prompt_env.get_template("extract_triplet.jinja2")
|
||||
|
||||
# 准备本体类型数据
|
||||
ontology_type_section = ""
|
||||
ontology_type_section = None
|
||||
ontology_type_names = []
|
||||
type_hierarchy_hints = []
|
||||
if ontology_types and ontology_types.types:
|
||||
@@ -240,6 +242,7 @@ async def render_triplet_extraction_prompt(
|
||||
ontology_types=ontology_type_section,
|
||||
ontology_type_names=ontology_type_names,
|
||||
type_hierarchy_hints=type_hierarchy_hints,
|
||||
speaker=speaker,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('triplet extraction', rendered_prompt)
|
||||
|
||||
@@ -43,8 +43,9 @@ Each statement must be labeled as per the criteria mentioned below.
|
||||
|
||||
对话上下文和共指消解:
|
||||
- 将每个陈述句归属于说出它的参与者。
|
||||
- 如果参与者列表为说话者提供了名称(例如,"李雪(用户)"),请在提取的陈述句中使用具体名称("李雪"),而不是通用角色("用户")。
|
||||
- 将所有代词解析为对话上下文中的具体人物或实体。
|
||||
- **对于用户的发言:必须使用"用户"作为主语**,禁止将"用户"或"我"替换为用户的真实姓名或别名。例如,用户说"我叫张三"应提取为"用户叫张三",而不是"张三叫张三"。
|
||||
- 对于 AI 助手的发言:使用"助手"或"AI助手"作为主语。
|
||||
- 将所有代词解析为对话上下文中的具体人物或实体,但"我"必须解析为"用户"。
|
||||
- 识别并将抽象引用解析为其具体名称(如果提到)。
|
||||
- 将缩写和首字母缩略词扩展为其完整形式。
|
||||
{% else %}
|
||||
@@ -68,8 +69,9 @@ Context Resolution Requirements:
|
||||
|
||||
Conversational Context & Co-reference Resolution:
|
||||
- Attribute every statement to the participant who uttered it.
|
||||
- If the participant list provides a name for a speaker (e.g., "李雪 (用户)"), use the specific name ("李雪") in the extracted statement, not the generic role ("用户").
|
||||
- Resolve all pronouns to the specific person or entity from the conversation's context.
|
||||
- **For user's statements: always use "用户" (User) as the subject**. Do NOT replace "用户" or "I" with the user's real name or alias. For example, if the user says "I'm John", extract as "用户 is John", not "John is John".
|
||||
- For AI assistant's statements: use "助手" or "AI助手" as the subject.
|
||||
- Resolve all pronouns to the specific person or entity from the conversation's context, but "I"/"我" must always resolve to "用户".
|
||||
- Identify and resolve abstract references to their specific names if mentioned.
|
||||
- Expand abbreviations and acronyms to their full form.
|
||||
{% endif %}
|
||||
@@ -139,13 +141,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合
|
||||
示例输出: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "Sarah Chen 最近一直在尝试水彩画。",
|
||||
"statement": "用户最近一直在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 画了一些花朵。",
|
||||
"statement": "用户画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -157,13 +159,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 认为她的水彩画中的色彩组合可以改进。",
|
||||
"statement": "用户认为她的水彩画中的色彩组合可以改进。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 真的很喜欢玫瑰和百合。",
|
||||
"statement": "用户真的很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -186,13 +188,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
|
||||
示例输出: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "张曼婷最近在尝试水彩画。",
|
||||
"statement": "用户最近在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷画了一些花朵。",
|
||||
"statement": "用户画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -204,13 +206,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement": "用户觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷很喜欢玫瑰和百合。",
|
||||
"statement": "用户很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -233,13 +235,13 @@ User: "I think the color combinations could use some improvement, but I really l
|
||||
Example Output: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "Sarah Chen has been trying watercolor painting recently.",
|
||||
"statement": "用户 has been trying watercolor painting recently.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen painted some flowers.",
|
||||
"statement": "用户 painted some flowers.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -251,13 +253,13 @@ Example Output: {
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen thinks the color combinations in her watercolor paintings could use some improvement.",
|
||||
"statement": "用户 thinks the color combinations in her watercolor paintings could use some improvement.",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen really likes roses and lilies.",
|
||||
"statement": "用户 really likes roses and lilies.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -280,13 +282,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
|
||||
Example Output: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "张曼婷最近在尝试水彩画。",
|
||||
"statement": "用户最近在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷画了一些花朵。",
|
||||
"statement": "用户画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -298,13 +300,13 @@ Example Output: {
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement": "用户觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷很喜欢玫瑰和百合。",
|
||||
"statement": "用户很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
|
||||
@@ -23,6 +23,16 @@ Extract entities and knowledge triplets from the given statement.
|
||||
===Inputs===
|
||||
**Chunk Content:** "{{ chunk_content }}"
|
||||
**Statement:** "{{ statement }}"
|
||||
{% if speaker %}
|
||||
**Speaker:** {{ speaker }}
|
||||
{% if speaker == "assistant" %}
|
||||
{% if language == "zh" %}
|
||||
⚠️ 当前陈述句来自 **AI助手的回复**。AI助手在回复中用来称呼用户的名字是**用户的别名**,不是 AI 助手的别名。但只能提取原文中逐字出现的名字,严禁推测或创造原文中不存在的别名变体。
|
||||
{% else %}
|
||||
⚠️ This statement is from the **AI assistant's reply**. Names the AI uses to address the user are **user's aliases**, NOT the AI assistant's aliases. But only extract names that appear VERBATIM in the text — never infer or fabricate alias variants.
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if ontology_types %}
|
||||
===Ontology Type Guidance===
|
||||
@@ -87,7 +97,17 @@ Extract entities and knowledge triplets from the given statement.
|
||||
* "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name)
|
||||
* "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name)
|
||||
- 空值:如果没有别名,使用 `[]`
|
||||
- 重要:只提取本次对话中明确提到的别名,不要推测或添加未提及的名字
|
||||
- **🚨🚨🚨 严禁幻觉:只提取对话原文中逐字出现的别名,绝对不能推测、衍生或创造任何未在原文中出现的名字。例如,看到"陈思远"不能自行添加"思远大人""远哥""小远"等变体。如果原文没有这些字,就不能出现在 aliases 中。**
|
||||
- **🚨 归属区分:必须严格区分名称的归属对象。默认情况下,用户提到的名字归属用户实体。只有出现明确的第二人称命名表达(如"叫你""给你取名")时,才将名字归属 AI/助手实体。**
|
||||
- **🚨 说话人视角:当 speaker 为 assistant 时,AI 助手用来称呼用户的名字是用户的别名,必须归入用户实体的 aliases,绝对不能归入 AI 助手实体。但同样只能提取原文中逐字出现的称呼,不能推测。**
|
||||
* "我叫陈思远,我给AI取名为远仔" → 用户 aliases=["陈思远"],AI助手 aliases=["远仔"]
|
||||
* "我叫vv" → 用户 aliases=["vv"](没有给AI取名的表达,名字归用户)
|
||||
* [speaker=assistant] "好的,VV" → 用户 aliases=["VV"](AI 在称呼用户,原文中出现了"VV")
|
||||
* [speaker=assistant] "我叫陈仔" → AI助手 aliases=["陈仔"](AI 在自我介绍,这是 AI 的别名)
|
||||
* ❌ 错误:将"远仔"放入用户的 aliases("远仔"是给AI取的名字,不是用户的名字)
|
||||
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases
|
||||
* ❌ 错误:AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
|
||||
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
|
||||
{% else %}
|
||||
- Include: nicknames, full names, abbreviations, alternative names
|
||||
- Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST**
|
||||
@@ -96,7 +116,17 @@ Extract entities and knowledge triplets from the given statement.
|
||||
* "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name)
|
||||
* "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
|
||||
- Empty: If no aliases, use `[]`
|
||||
- Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names
|
||||
- **🚨🚨🚨 NO HALLUCINATION: Only extract aliases that appear VERBATIM in the original text. NEVER infer, derive, or fabricate names not present in the text. For example, seeing "John Smith" does NOT allow adding "Johnny", "Smithy", "Mr. Smith" unless those exact strings appear in the conversation.**
|
||||
- **🚨 Ownership distinction: By default, all names mentioned by the user belong to the user entity. Only assign a name to the AI/assistant entity when an explicit second-person naming expression (e.g., "I'll call you", "your name is") is present.**
|
||||
- **🚨 Speaker perspective: When speaker is "assistant", names the AI uses to address the user are the USER's aliases and MUST go into the user entity's aliases, NEVER into the AI assistant entity's aliases. But only extract names that appear verbatim in the text, never infer.**
|
||||
* "I'm Alex, I'll call you Buddy" → User aliases=["Alex"], AI assistant aliases=["Buddy"]
|
||||
* "I'm vv" → User aliases=["vv"] (no AI-naming expression, name belongs to user)
|
||||
* [speaker=assistant] "Sure thing, VV" → User aliases=["VV"] (AI addressing the user, "VV" appears in text)
|
||||
* [speaker=assistant] "I'm Jarvis" → AI assistant aliases=["Jarvis"] (AI self-introduction, this is AI's alias)
|
||||
* ❌ Wrong: putting "Buddy" in user's aliases ("Buddy" is a name for the AI, not the user)
|
||||
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases
|
||||
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
|
||||
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
|
||||
{% endif %}
|
||||
|
||||
|
||||
@@ -122,7 +152,60 @@ Extract entities and knowledge triplets from the given statement.
|
||||
|
||||
|
||||
|
||||
4. **ALIASES ORDER:**
|
||||
4. **AI/ASSISTANT ENTITY SPECIAL HANDLING:**
|
||||
{% if language == "zh" %}
|
||||
- **🚨 默认规则:如果对话中没有出现明确指向 AI/助手的命名表达,则所有名字都归属于用户实体。不要猜测或推断某个名字是给 AI 取的。**
|
||||
- 只有当用户**明确**对 AI/助手进行命名时,才创建 AI/助手实体并将对应名字放入其 aliases
|
||||
- AI/助手实体的 name 字段:使用 "AI助手"
|
||||
- 用户给 AI 取的名字:放入 AI/助手实体的 aliases
|
||||
- **🚨 禁止将用户给 AI 取的名字放入用户实体的 aliases 中**
|
||||
- **必须出现以下明确的命名表达才能判定为给 AI 取名:**「给你取名」「叫你」「称呼你为」「给AI取名」「你的名字是」「以后叫你」「你就叫」「你不叫X了」「你现在叫」等**第二人称(你)或明确指向 AI 的命名句式**
|
||||
- **🚨 "你不叫X了"/"你不叫X,你叫Y" 句式:X 和 Y 都是 AI 的名字(旧名和新名),绝对不是用户的名字。因为句子主语是"你"(AI)。**
|
||||
- **以下情况名字归属用户,不是给 AI 取名:**「我叫」「我的名字是」「叫我」「我是」「大家叫我」「我的英文名是」「我的昵称是」等**第一人称(我)的自我介绍句式**
|
||||
- **🚨 speaker=assistant 时的特殊规则:**
|
||||
* AI 用来称呼用户的名字 → 归入**用户**实体的 aliases(但必须是原文中逐字出现的称呼,不能推测)
|
||||
* AI 自称的名字(如"我叫陈仔""我是你的助手")→ 归入**AI助手**实体的 aliases
|
||||
* 判断依据:AI 说"你叫X"或用 X 称呼用户 → X 是用户别名;AI 说"我叫X"或"我是X" → X 是 AI 别名
|
||||
- 示例:
|
||||
* "我叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
|
||||
* "我的英文名叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
|
||||
* "我叫陈思远,我给AI取名为远仔" → 用户实体: name="用户", aliases=["陈思远"];AI实体: name="AI助手", aliases=["远仔"]
|
||||
* "叫你小助,我自己叫老王" → 用户实体: name="用户", aliases=["老王"];AI实体: name="AI助手", aliases=["小助"]
|
||||
* "你不叫远仔了,你现在叫陈仔" → AI实体: name="AI助手", aliases=["陈仔"]("远仔"是AI旧名,"陈仔"是AI新名,都归AI。不要把"远仔"或"陈仔"放入用户的aliases)
|
||||
* [speaker=assistant] "好的VV,今天想干点啥?" → 用户实体: name="用户", aliases=["VV"](AI 在称呼用户,原文中出现了"VV")
|
||||
* [speaker=assistant] "你叫陈思远,我叫陈仔" → 用户实体: name="用户", aliases=["陈思远"];AI实体: name="AI助手", aliases=["陈仔"]
|
||||
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases(没有任何给 AI 取名的表达)
|
||||
* ❌ 错误:AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
|
||||
* ❌ 错误:aliases=["陈思远", "远仔"]("远仔"是给AI取的名字,不是用户的名字)
|
||||
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
|
||||
{% else %}
|
||||
- **🚨 Default rule: If there is NO explicit AI/assistant naming expression in the conversation, ALL names belong to the user entity. Do NOT guess or infer that a name is for the AI.**
|
||||
- Only create an AI/assistant entity when the user **explicitly** names the AI/assistant
|
||||
- AI/assistant entity name field: use "AI Assistant"
|
||||
- Names the user gives to the AI: put in the AI/assistant entity's aliases
|
||||
- **🚨 NEVER put names given to the AI into the user entity's aliases**
|
||||
- **An AI-naming expression MUST be present to assign a name to the AI:** "I'll call you", "your name is", "I name you", "let me call you", "you'll be called", "you're not called X anymore", "your new name is", etc. — **second-person ("you") or explicit AI-directed naming patterns**
|
||||
- **🚨 "You're not called X anymore" / "You're not X, you're Y" pattern: BOTH X and Y are AI's names (old and new). They are NOT user's names. The subject is "you" (the AI).**
|
||||
- **These patterns mean the name belongs to the USER, NOT the AI:** "I'm", "my name is", "call me", "I am", "people call me", "my English name is", "my nickname is", etc. — **first-person ("I"/"me") self-introduction patterns**
|
||||
- **🚨 Special rules when speaker=assistant:**
|
||||
* Names the AI uses to address the user → belong to the **user** entity's aliases (but only extract names that appear verbatim in the text, never infer)
|
||||
* Names the AI uses for itself (e.g., "I'm Jarvis", "I am your assistant") → belong to the **AI assistant** entity's aliases
|
||||
* Rule: AI says "you are X" or calls user X → X is user's alias; AI says "I'm X" or "I am X" → X is AI's alias
|
||||
- Examples:
|
||||
* "I'm vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
|
||||
* "My English name is vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
|
||||
* "I'm Alex, I'll call you Buddy" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Buddy"]
|
||||
* "Call yourself Jarvis, my name is Tony" → User entity: name="User", aliases=["Tony"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
|
||||
* "You're not called Jarvis anymore, your new name is Friday" → AI entity: name="AI Assistant", aliases=["Friday"] (both "Jarvis" and "Friday" are AI names, NOT user names)
|
||||
* [speaker=assistant] "Sure thing, VV" → User entity: name="User", aliases=["VV"] (AI addressing the user, "VV" appears in text)
|
||||
* [speaker=assistant] "You're Alex, and I'm Jarvis" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
|
||||
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases (no AI-naming expression exists)
|
||||
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
|
||||
* ❌ Wrong: aliases=["Alex", "Buddy"] ("Buddy" is a name for the AI, not the user)
|
||||
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
|
||||
{% endif %}
|
||||
|
||||
5. **ALIASES ORDER:**
|
||||
{% if language == "zh" %}
|
||||
- 顺序优先级:按出现顺序,先出现的在前
|
||||
{% else %}
|
||||
@@ -202,8 +285,19 @@ Output:
|
||||
{"entity_idx": 0, "name": "Tripod", "type": "Equipment", "description": "Photography equipment accessory", "example": "", "aliases": ["Camera Tripod"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 4 (User vs AI alias distinction - English output):** "I'm Alex, and I'll call you Buddy"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "User", "subject_id": 0, "predicate": "NAMED", "object_name": "AI Assistant", "object_id": 1, "value": "Buddy"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "User", "type": "Person", "description": "The user", "example": "", "aliases": ["Alex"], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI Assistant", "type": "Person", "description": "The user's AI assistant", "example": "", "aliases": ["Buddy"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
{% else %}
|
||||
**Example 1 (English input → Chinese output):** "I plan to travel to Paris next week and visit the Louvre."
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
@@ -258,6 +352,39 @@ Output:
|
||||
]
|
||||
}
|
||||
|
||||
**Example 6 (用户与AI别名区分 - Chinese):** "我称呼自己为陈思远,我给AI取名为远仔"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "远仔"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远"], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["远仔"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 7 (纯用户自我介绍,无AI命名 - Chinese):** "我叫vv"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["vv"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 8 (给AI改名 - Chinese):** "你不叫远仔了,你现在叫陈仔"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "陈仔"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": [], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["陈仔"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
{% endif %}
|
||||
===End of Examples===
|
||||
@@ -279,4 +406,12 @@ Output:
|
||||
- **⚠️ ALIASES ORDER: preserve temporal order of appearance**
|
||||
- **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []**
|
||||
|
||||
**Output JSON structure:**
|
||||
```json
|
||||
{
|
||||
"triplets": [...],
|
||||
"entities": [...]
|
||||
}
|
||||
```
|
||||
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
===Task===
|
||||
Extract user metadata changes from the following conversation statements spoken by the user.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**"三度原则"判断标准:**
|
||||
- 复用度:该信息是否会被多个功能模块使用?
|
||||
- 约束度:该信息是否会影响系统行为?
|
||||
- 时效性:该信息是长期稳定的还是临时的?仅提取长期稳定信息。
|
||||
|
||||
**提取规则:**
|
||||
- **只提取关于"用户本人"的画像信息**,忽略用户提到的第三方人物(如朋友、同事、家人)的信息
|
||||
- 仅提取文本中明确提到的信息,不要推测
|
||||
- **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值)
|
||||
|
||||
**增量模式(重要):**
|
||||
你只需要输出**本次对话引起的变更操作**,不要输出完整的元数据。每个变更是一个对象,包含:
|
||||
- `field_path`:字段路径,用点号分隔(如 `profile.role`、`profile.expertise`)
|
||||
- `action`:操作类型
|
||||
* `set`:新增或修改一个字段的值
|
||||
* `remove`:移除一个字段的值
|
||||
- `value`:字段的新值(`action="set"` 时必填,`action="remove"` 时填要移除的元素值)
|
||||
* 所有字段均为列表类型,每个元素一条变更记录
|
||||
|
||||
**判断规则:**
|
||||
- 用户提到新信息 → `action="set"`,填入新值
|
||||
- 用户明确否定已有信息(如"我不再做老师了"、"我已经不学Python了")→ `action="remove"`,`value` 填要移除的元素值
|
||||
- 如果本次对话没有任何可提取的变更,返回空的 `metadata_changes` 数组 `[]`
|
||||
- **不要为未被提及的字段生成任何变更操作**
|
||||
|
||||
{% if existing_metadata %}
|
||||
**已有元数据(仅供参考,用于判断是否需要变更):**
|
||||
请对比已有数据和用户最新发言,只输出差异部分的变更操作。
|
||||
- 如果用户说的信息和已有数据一致,不需要输出变更
|
||||
- 如果用户否定了已有数据中的某个值,输出 `remove` 操作
|
||||
- 如果用户提到了新信息,输出 `set` 操作
|
||||
{% endif %}
|
||||
|
||||
**字段说明:**
|
||||
- profile.role:用户的职业或角色(列表),如 教师、医生、后端工程师,一个人可以有多个角色
|
||||
- profile.domain:用户所在领域(列表),如 教育、医疗、软件开发,一个人可以涉及多个领域
|
||||
- profile.expertise:用户擅长的技能或工具(列表),如 Python、心理咨询、高中物理
|
||||
- profile.interests:用户主动表达兴趣的话题或领域标签(列表)
|
||||
|
||||
**用户别名变更(增量模式):**
|
||||
- **aliases_to_add**:本次新发现的用户别名,包括:
|
||||
* 用户主动自我介绍:如"我叫张三"、"我的名字是XX"、"我的网名是XX"
|
||||
* 他人对用户的称呼:如"同事叫我陈哥"、"大家叫我小张"、"领导叫我老陈"
|
||||
* 只提取原文中逐字出现的名字,严禁推测或创造
|
||||
* 禁止提取:用户给 AI 取的名字、第三方人物自身的名字、"用户"/"我" 等占位词
|
||||
* 如果没有新别名,返回空数组 `[]`
|
||||
- **aliases_to_remove**:用户明确否认的别名,包括:
|
||||
* 用户说"我不叫XX了"、"别叫我XX"、"我改名了,不叫XX" → 将 XX 放入此数组
|
||||
* **严格限制**:只将用户原文中**逐字提到**的被否认名字放入,不要推断关联的其他别名
|
||||
* 如果没有要移除的别名,返回空数组 `[]`
|
||||
{% if existing_aliases %}
|
||||
- 已有别名:{{ existing_aliases | tojson }}(仅供参考,不需要在输出中重复)
|
||||
{% endif %}
|
||||
{% else %}
|
||||
**"Three-Degree Principle" criteria:**
|
||||
- Reusability: Will this information be used by multiple functional modules?
|
||||
- Constraint: Will this information affect system behavior?
|
||||
- Timeliness: Is this information long-term stable or temporary? Only extract long-term stable information.
|
||||
|
||||
**Extraction rules:**
|
||||
- **Only extract profile information about the user themselves**, ignore information about third parties (friends, colleagues, family) mentioned by the user
|
||||
- Only extract information explicitly mentioned in the text, do not speculate
|
||||
- **Output language must match the input text language**
|
||||
|
||||
**Incremental mode (important):**
|
||||
You should only output **the change operations caused by this conversation**, not the complete metadata. Each change is an object containing:
|
||||
- `field_path`: Field path separated by dots (e.g. `profile.role`, `profile.expertise`)
|
||||
- `action`: Operation type
|
||||
* `set`: Add or update a field value
|
||||
* `remove`: Remove a field value
|
||||
- `value`: The new value for the field (required when `action="set"`, for `action="remove"` fill in the element value to remove)
|
||||
* All fields are list types, one change record per element
|
||||
|
||||
**Decision rules:**
|
||||
- User mentions new information → `action="set"`, fill in the new value
|
||||
- User explicitly negates existing info (e.g. "I'm no longer a teacher", "I stopped learning Python") → `action="remove"`, `value` is the element to remove
|
||||
- If this conversation has no extractable changes, return an empty `metadata_changes` array `[]`
|
||||
- **Do NOT generate any change operations for fields not mentioned in the conversation**
|
||||
|
||||
{% if existing_metadata %}
|
||||
**Existing metadata (for reference only, to determine if changes are needed):**
|
||||
Compare existing data with the user's latest statements, and only output change operations for the differences.
|
||||
- If the user's statement matches existing data, no change is needed
|
||||
- If the user negates a value in existing data, output a `remove` operation
|
||||
- If the user mentions new information, output a `set` operation
|
||||
{% endif %}
|
||||
|
||||
**Field descriptions:**
|
||||
- profile.role: User's occupation or role (list), e.g. teacher, doctor, software engineer. A person can have multiple roles
|
||||
- profile.domain: User's domain (list), e.g. education, healthcare, software development. A person can span multiple domains
|
||||
- profile.expertise: User's skills or tools (list), e.g. Python, counseling, physics
|
||||
- profile.interests: Topics or domain tags the user actively expressed interest in (list)
|
||||
|
||||
**User alias changes (incremental mode):**
|
||||
- **aliases_to_add**: Newly discovered user aliases from this conversation, including:
|
||||
* User self-introductions: e.g. "I'm John", "My name is XX", "My username is XX"
|
||||
* How others address the user: e.g. "My colleagues call me Johnny", "People call me Mike"
|
||||
* Only extract names that appear VERBATIM in the text — never infer or fabricate
|
||||
* Do NOT extract: names the user gives to the AI, third-party people's own names, placeholder words like "User"/"I"
|
||||
* If no new aliases, return empty array `[]`
|
||||
- **aliases_to_remove**: Aliases the user explicitly denies, including:
|
||||
* User says "Don't call me XX anymore", "I'm not called XX", "I changed my name from XX" → put XX in this array
|
||||
* **Strict rule**: Only include the exact name the user **verbatim mentions** as denied. Do NOT infer or remove related aliases
|
||||
* If no aliases to remove, return empty array `[]`
|
||||
{% if existing_aliases %}
|
||||
- Existing aliases: {{ existing_aliases | tojson }} (for reference only, do not repeat in output)
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
===User Statements===
|
||||
{% for stmt in statements %}
|
||||
- {{ stmt }}
|
||||
{% endfor %}
|
||||
|
||||
{% if existing_metadata %}
|
||||
===Existing User Metadata===
|
||||
```json
|
||||
{{ existing_metadata | tojson }}
|
||||
```
|
||||
{% endif %}
|
||||
|
||||
===Output Format===
|
||||
Return a JSON object with the following structure:
|
||||
```json
|
||||
{
|
||||
"metadata_changes": [
|
||||
{"field_path": "profile.role", "action": "set", "value": "后端工程师"},
|
||||
{"field_path": "profile.expertise", "action": "set", "value": "Python"},
|
||||
{"field_path": "profile.expertise", "action": "remove", "value": "Java"}
|
||||
],
|
||||
"aliases_to_add": [],
|
||||
"aliases_to_remove": []
|
||||
}
|
||||
```
|
||||
|
||||
{{ json_schema }}
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
from typing import Any, Dict, List, Optional, TypeVar
|
||||
|
||||
from langchain_aws import ChatBedrock
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
@@ -9,11 +9,12 @@ from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.core.models.compatible_chat import CompatibleChatOpenAI
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -24,7 +25,11 @@ class RedBearModelConfig(BaseModel):
|
||||
provider: str
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
capability: List[str] = Field(default_factory=list) # 模型能力列表,驱动所有能力开关
|
||||
is_omni: bool = False # 是否为 Omni 模型
|
||||
deep_thinking: bool = False # 是否启用深度思考模式
|
||||
thinking_budget_tokens: Optional[int] = None # 深度思考 token 预算
|
||||
json_output: bool = False # 是否强制 JSON 输出
|
||||
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||
@@ -32,6 +37,23 @@ class RedBearModelConfig(BaseModel):
|
||||
concurrency: int = 5 # 并发限流
|
||||
extra_params: Dict[str, Any] = {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _resolve_capabilities(self) -> "RedBearModelConfig":
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
if self.deep_thinking and "thinking" not in self.capability:
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking"
|
||||
)
|
||||
self.deep_thinking = False
|
||||
self.thinking_budget_tokens = None
|
||||
if self.json_output and "json_output" not in self.capability:
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 不支持 JSON 输出(capability 中无 'json_output'),已自动关闭 json_output"
|
||||
)
|
||||
self.json_output = False
|
||||
return self
|
||||
|
||||
|
||||
class RedBearModelFactory:
|
||||
"""模型工厂类"""
|
||||
@@ -44,7 +66,7 @@ class RedBearModelFactory:
|
||||
# 打印供应商信息用于调试
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}, deep_thinking: {config.deep_thinking}")
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
@@ -58,7 +80,7 @@ class RedBearModelFactory:
|
||||
write=60.0,
|
||||
pool=10.0,
|
||||
)
|
||||
return {
|
||||
params: Dict[str, Any] = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
@@ -66,6 +88,24 @@ class RedBearModelFactory:
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
if is_streaming:
|
||||
params["stream_usage"] = True
|
||||
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||
if "thinking" in config.capability:
|
||||
extra_body = params.setdefault("extra_body", {})
|
||||
if config.deep_thinking:
|
||||
extra_body["enable_thinking"] = False
|
||||
if is_streaming:
|
||||
extra_body["enable_thinking"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
||||
# JSON 输出模式
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||
@@ -78,7 +118,7 @@ class RedBearModelFactory:
|
||||
write=60.0, # 写入超时:60秒
|
||||
pool=10.0, # 连接池超时:10秒
|
||||
)
|
||||
return {
|
||||
params: Dict[str, Any] = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
@@ -86,16 +126,55 @@ class RedBearModelFactory:
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
if is_streaming:
|
||||
params["stream_usage"] = True
|
||||
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||
if "thinking" in config.capability:
|
||||
# VOLCANO 深度思考仅流式支持
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
thinking_config: Dict[str, Any] = {"type": "enabled" if config.deep_thinking else "disabled"}
|
||||
if config.deep_thinking and config.thinking_budget_tokens:
|
||||
thinking_config["budget_tokens"] = config.thinking_budget_tokens
|
||||
params["extra_body"] = {"thinking": thinking_config}
|
||||
else:
|
||||
extra_body = params.setdefault("extra_body", {})
|
||||
if config.deep_thinking:
|
||||
extra_body["enable_thinking"] = False
|
||||
if is_streaming:
|
||||
extra_body["enable_thinking"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
||||
# JSON 输出模式
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
# VOLCANO 模型不支持 response_format,JSON 输出由 system prompt 注入实现
|
||||
if provider != ModelProvider.VOLCANO:
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
# DashScope (通义千问) 使用自己的参数格式
|
||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||
# 只支持: model, dashscope_api_key, max_retries, client
|
||||
return {
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||
if "thinking" in config.capability:
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
if config.deep_thinking:
|
||||
model_kwargs["enable_thinking"] = False
|
||||
if is_streaming:
|
||||
model_kwargs["enable_thinking"] = True
|
||||
model_kwargs["incremental_output"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
# Bedrock 使用 AWS 凭证
|
||||
# api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id
|
||||
@@ -134,6 +213,17 @@ class RedBearModelFactory:
|
||||
elif "region_name" not in params:
|
||||
params["region_name"] = "us-east-1" # 默认区域
|
||||
|
||||
# 深度思考模式:Claude 3.7 Sonnet 等支持思考的模型
|
||||
# 通过 additional_model_request_fields 传递 thinking 块,关闭时不传(Bedrock 无 disabled 选项)
|
||||
if config.deep_thinking:
|
||||
budget = config.thinking_budget_tokens or 10000
|
||||
params["additional_model_request_fields"] = {
|
||||
"thinking": {"type": "enabled", "budget_tokens": budget}
|
||||
}
|
||||
# JSON 输出模式
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
@@ -145,10 +235,15 @@ class RedBearModelFactory:
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
return {
|
||||
"model": config.model_name,
|
||||
# "base_url": config.base_url,
|
||||
"jina_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
return {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@@ -157,16 +252,19 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
# dashscope的omni模型 和 volcano模型使用
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
return ChatOpenAI
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
|
||||
if type == ModelType.LLM:
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
return ChatOpenAI
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
return CompatibleChatOpenAI
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
return CompatibleChatOpenAI
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
return CompatibleChatOpenAI
|
||||
# if type == ModelType.LLM:
|
||||
# return OpenAI
|
||||
# elif type == ModelType.CHAT:
|
||||
# return CompatibleChatOpenAI
|
||||
# else:
|
||||
# raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
@@ -202,6 +300,9 @@ def get_provider_rerank_class(provider: str):
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
return JinaRerank
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank
|
||||
return DashScopeRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
# from langchain_ollama import OllamaEmbeddings
|
||||
# return OllamaEmbeddings
|
||||
|
||||
73
api/app/core/models/compatible_chat.py
Normal file
73
api/app/core/models/compatible_chat.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
火山引擎 ChatOpenAI 扩展
|
||||
|
||||
ChatOpenAI 在解析流式 SSE 时只取 delta.content,会丢弃 delta.reasoning_content。
|
||||
此类仅重写 _convert_chunk_to_generation_chunk,将 reasoning_content 补入 additional_kwargs。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class CompatibleChatOpenAI(ChatOpenAI):
|
||||
"""火山和千问的omni兼容模型,支持深度思考内容(reasoning_content)的流式和非流式透传。
|
||||
|
||||
同时修复 json_output + tools 同时使用时 langchain_openai 强制走 .parse()/.stream()
|
||||
导致 strict 校验报错的问题:有工具时从 payload 中移除 response_format,
|
||||
让父类走普通 .create()/.astream() 路径,JSON 输出由 system prompt 指令保证。
|
||||
"""
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: list[BaseMessage],
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
# 有工具时 langchain_openai 检测到 response_format 会切换到 .parse()/.stream()
|
||||
# 接口,OpenAI SDK 要求此时所有工具必须 strict=True,动态生成的工具不满足。
|
||||
# 移除 response_format,让父类走普通路径,JSON 输出由 system prompt 指令保证。
|
||||
if payload.get("tools") and "response_format" in payload:
|
||||
payload.pop("response_format")
|
||||
return payload
|
||||
|
||||
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
|
||||
result = super()._create_chat_result(response, generation_info)
|
||||
# 将非流式响应中的 reasoning_content 补入 additional_kwargs
|
||||
choices = response.choices if hasattr(response, "choices") else response.get("choices", [])
|
||||
if choices:
|
||||
message = choices[0].message if hasattr(choices[0], "message") else choices[0].get("message", {})
|
||||
reasoning = (
|
||||
getattr(message, "reasoning_content", None)
|
||||
or (message.get("reasoning_content") if isinstance(message, dict) else None)
|
||||
)
|
||||
if reasoning and result.generations:
|
||||
result.generations[0].message.additional_kwargs["reasoning_content"] = reasoning
|
||||
return result
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: type,
|
||||
base_generation_info: Optional[dict],
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
gen_chunk = super()._convert_chunk_to_generation_chunk(
|
||||
chunk, default_chunk_class, base_generation_info
|
||||
)
|
||||
if gen_chunk is None:
|
||||
return None
|
||||
|
||||
# 从原始 chunk 中提取 reasoning_content
|
||||
choices = chunk.get("choices") or chunk.get("chunk", {}).get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta") or {}
|
||||
reasoning: Any = delta.get("reasoning_content")
|
||||
if reasoning:
|
||||
gen_chunk.message.additional_kwargs["reasoning_content"] = reasoning
|
||||
|
||||
return gen_chunk
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
|
||||
@@ -22,11 +22,38 @@ class RedBearEmbeddings(Embeddings):
|
||||
self._model = self._create_model(config)
|
||||
self._client = None
|
||||
|
||||
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
||||
@staticmethod
|
||||
def _create_model(config: RedBearModelConfig) -> Embeddings:
|
||||
"""根据配置创建 LangChain 模型"""
|
||||
embedding_class = get_provider_embedding_class(config.provider)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**model_params)
|
||||
provider = config.provider.lower()
|
||||
# Embedding models only need connection params, never LLM-specific ones
|
||||
# (e.g. enable_thinking, model_kwargs) — build params directly.
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
import httpx
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
|
||||
"max_retries": config.max_retries
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
"max_retries": config.max_retries,
|
||||
}
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
}
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
params = RedBearModelFactory.get_model_params(config)
|
||||
else:
|
||||
params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**params)
|
||||
|
||||
def _create_volcano_client(self, config: RedBearModelConfig):
|
||||
"""创建火山引擎客户端"""
|
||||
|
||||
@@ -76,5 +76,9 @@ class RedBearRerank(BaseDocumentCompressor):
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
model_instance: JinaRerank = self._model
|
||||
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank
|
||||
model_instance: DashScopeRerank = self._model
|
||||
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型提供商: {provider}")
|
||||
|
||||
@@ -6,11 +6,13 @@ models:
|
||||
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: bedrock
|
||||
|
||||
- name: amazon nova
|
||||
type: llm
|
||||
provider: bedrock
|
||||
@@ -19,6 +21,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -27,6 +30,7 @@ models:
|
||||
- stream-tool-call
|
||||
- vision
|
||||
logo: bedrock
|
||||
|
||||
- name: anthropic claude
|
||||
type: llm
|
||||
provider: bedrock
|
||||
@@ -35,6 +39,8 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -44,13 +50,15 @@ models:
|
||||
- stream-tool-call
|
||||
- document
|
||||
logo: bedrock
|
||||
|
||||
- name: cohere
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -58,6 +66,7 @@ models:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: deepseek
|
||||
type: llm
|
||||
provider: bedrock
|
||||
@@ -66,6 +75,8 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -74,39 +85,45 @@ models:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: meta
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: mistral
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: openai
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -114,13 +131,15 @@ models:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: qwen
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -128,6 +147,7 @@ models:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: amazon.rerank-v1:0
|
||||
type: rerank
|
||||
provider: bedrock
|
||||
@@ -139,6 +159,7 @@ models:
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
|
||||
- name: cohere.rerank-v3-5:0
|
||||
type: rerank
|
||||
provider: bedrock
|
||||
@@ -150,6 +171,7 @@ models:
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
|
||||
- name: amazon.nova-2-multimodal-embeddings-v1:0
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
@@ -163,6 +185,7 @@ models:
|
||||
- 文本嵌入模型
|
||||
- vision
|
||||
logo: bedrock
|
||||
|
||||
- name: amazon.titan-embed-text-v1
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
@@ -174,6 +197,7 @@ models:
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
|
||||
- name: amazon.titan-embed-text-v2:0
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
@@ -185,6 +209,7 @@ models:
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
|
||||
- name: cohere.embed-english-v3
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
@@ -196,6 +221,7 @@ models:
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
|
||||
- name: cohere.embed-multilingual-v3
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
|
||||
@@ -6,91 +6,109 @@ models:
|
||||
description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-r1-distill-qwen-32b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-r1
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-v3.1
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-v3.2-exp
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-v3.2
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-v3
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: farui-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -98,13 +116,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: glm-4.7
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -112,6 +132,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qvq-max-latest
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -119,7 +140,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -127,6 +150,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qvq-max
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -134,7 +158,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -142,6 +168,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-coder-turbo-0919
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -155,13 +182,16 @@ models:
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-max-latest
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -169,6 +199,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-max-longcontext
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -183,13 +214,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-max
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -197,6 +230,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-mt-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -210,6 +244,7 @@ models:
|
||||
- 翻译模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-mt-turbo
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -223,6 +258,7 @@ models:
|
||||
- 翻译模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-0112
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -237,6 +273,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-0125
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -251,6 +288,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-0723
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -265,6 +303,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-0806
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -279,6 +318,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-0919
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -293,6 +333,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-1125
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -307,6 +348,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-1127
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -321,6 +363,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-1220
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -335,6 +378,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-max
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -342,8 +386,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -352,6 +397,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-plus-0809
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -359,8 +405,8 @@ models:
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -369,6 +415,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-plus-2025-01-02
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -376,8 +423,8 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -386,6 +433,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-plus-2025-01-25
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -393,8 +441,8 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -403,6 +451,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-plus-latest
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -410,8 +459,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -420,6 +470,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-plus
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -427,8 +478,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -437,13 +489,15 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen2.5-0.5b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -451,13 +505,16 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-14b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -465,13 +522,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-235b-a22b-instruct-2507
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -479,13 +538,16 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-235b-a22b-thinking-2507
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -493,13 +555,16 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-235b-a22b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -507,13 +572,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-30b-a3b-instruct-2507
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -521,13 +588,16 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-30b-a3b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -535,13 +605,16 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-32b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -549,13 +622,16 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-4b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -563,13 +639,16 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-8b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -577,65 +656,78 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-coder-30b-a3b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-coder-480b-a35b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-coder-plus-2025-09-23
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-coder-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-max-2025-09-23
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -644,13 +736,16 @@ models:
|
||||
- stream-tool-call
|
||||
- 联网搜索
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-max-2026-01-23
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -659,13 +754,16 @@ models:
|
||||
- stream-tool-call
|
||||
- 联网搜索
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-max-preview
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -673,13 +771,16 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-max
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -688,13 +789,15 @@ models:
|
||||
- stream-tool-call
|
||||
- 联网搜索
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-next-80b-a3b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -702,13 +805,16 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-next-80b-a3b-thinking
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -716,6 +822,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-omni-flash-2025-12-01
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -723,9 +830,11 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- audio
|
||||
- vision
|
||||
- video
|
||||
- audio
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -735,6 +844,7 @@ models:
|
||||
- video
|
||||
- audio
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-235b-a22b-instruct
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -742,8 +852,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -754,6 +865,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-235b-a22b-thinking
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -761,8 +873,10 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -773,6 +887,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-30b-a3b-instruct
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -780,8 +895,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -792,6 +908,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-30b-a3b-thinking
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -799,8 +916,10 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -811,6 +930,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-flash
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -818,8 +938,10 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -830,6 +952,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-plus-2025-09-23
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -837,8 +960,10 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -847,6 +972,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-plus
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -854,8 +980,10 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -864,45 +992,55 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwq-32b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwq-plus-0305
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwq-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: gte-rerank-v2
|
||||
type: rerank
|
||||
provider: dashscope
|
||||
@@ -914,6 +1052,7 @@ models:
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
|
||||
- name: gte-rerank
|
||||
type: rerank
|
||||
provider: dashscope
|
||||
@@ -925,6 +1064,7 @@ models:
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
|
||||
- name: multimodal-embedding-v1
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
@@ -932,13 +1072,14 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
logo: dashscope
|
||||
|
||||
- name: text-embedding-v1
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
@@ -951,6 +1092,7 @@ models:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
|
||||
- name: text-embedding-v2
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
@@ -963,6 +1105,7 @@ models:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
|
||||
- name: text-embedding-v3
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
@@ -975,6 +1118,7 @@ models:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
|
||||
- name: text-embedding-v4
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
@@ -986,4 +1130,4 @@ models:
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
logo: dashscope
|
||||
|
||||
@@ -10,6 +10,7 @@ models:
|
||||
- vision
|
||||
- audio
|
||||
- video
|
||||
- json_output
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -20,13 +21,15 @@ models:
|
||||
- audio
|
||||
- video
|
||||
logo: openai
|
||||
|
||||
- name: gpt-3.5-turbo-0125
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -34,13 +37,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-3.5-turbo-1106
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -48,13 +53,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-3.5-turbo-16k
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -62,6 +69,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-3.5-turbo-instruct
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -73,13 +81,15 @@ models:
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: openai
|
||||
|
||||
- name: gpt-3.5-turbo
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -87,13 +97,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-4-0125-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -101,13 +113,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-4-1106-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -115,6 +129,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-4-turbo-2024-04-09
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -123,6 +138,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -131,13 +147,15 @@ models:
|
||||
- stream-tool-call
|
||||
- vision
|
||||
logo: openai
|
||||
|
||||
- name: gpt-4-turbo-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -145,6 +163,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-4-turbo
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -153,6 +172,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -161,6 +181,7 @@ models:
|
||||
- stream-tool-call
|
||||
- vision
|
||||
logo: openai
|
||||
|
||||
- name: o1-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -173,6 +194,7 @@ models:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: openai
|
||||
|
||||
- name: o1
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -181,6 +203,8 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -190,6 +214,7 @@ models:
|
||||
- vision
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3-2025-04-16
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -198,6 +223,8 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -207,13 +234,16 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3-mini-2025-01-31
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -222,13 +252,16 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3-mini
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -237,6 +270,7 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3-pro-2025-06-10
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -245,6 +279,8 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -253,6 +289,7 @@ models:
|
||||
- vision
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3-pro
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -261,6 +298,8 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -269,6 +308,7 @@ models:
|
||||
- vision
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -277,6 +317,8 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -286,6 +328,7 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o4-mini-2025-04-16
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -294,6 +337,8 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -303,6 +348,7 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o4-mini
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -311,6 +357,8 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -320,6 +368,7 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: text-embedding-3-large
|
||||
type: embedding
|
||||
provider: openai
|
||||
@@ -331,6 +380,7 @@ models:
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
|
||||
- name: text-embedding-3-small
|
||||
type: embedding
|
||||
provider: openai
|
||||
@@ -342,6 +392,7 @@ models:
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
|
||||
- name: text-embedding-ada-002
|
||||
type: embedding
|
||||
provider: openai
|
||||
|
||||
@@ -10,6 +10,8 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -24,6 +26,8 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -38,6 +42,8 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -52,6 +58,8 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -68,6 +76,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -82,6 +91,8 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -96,6 +107,8 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -110,6 +123,8 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -124,6 +139,8 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -139,6 +156,8 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -166,7 +185,8 @@ models:
|
||||
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -178,7 +198,8 @@ models:
|
||||
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
|
||||
791
api/app/core/quota_manager.py
Normal file
791
api/app/core/quota_manager.py
Normal file
@@ -0,0 +1,791 @@
|
||||
"""
|
||||
统一配额管理器 - 社区版和 SaaS 版共用
|
||||
|
||||
配额来源策略:
|
||||
1. 优先从 premium 模块的 tenant_subscriptions 表读取(SaaS 版)
|
||||
2. 降级到 default_free_plan.py 配置文件(社区版兜底)
|
||||
"""
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import Optional, Callable, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_auth_logger
|
||||
from app.i18n.exceptions import QuotaExceededError, InternalServerError
|
||||
|
||||
logger = get_auth_logger()
|
||||
|
||||
# Redis key 格式常量,与 RateLimiterService.check_qps 保持一致(per api_key 独立计数)
|
||||
API_KEY_QPS_REDIS_KEY = "rate_limit:qps:{api_key_id}"
|
||||
|
||||
|
||||
def _get_user_from_kwargs(kwargs: dict):
|
||||
"""从 kwargs 中获取 user 对象"""
|
||||
for key in ["user", "current_user"]:
|
||||
if key in kwargs:
|
||||
return kwargs[key]
|
||||
return None
|
||||
|
||||
|
||||
def _get_workspace_id_from_kwargs(kwargs: dict):
|
||||
"""从 kwargs 中获取 workspace_id"""
|
||||
# 优先从 kwargs['workspace_id'] 获取
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
|
||||
# 从 api_key_auth.workspace_id 获取(API Key 认证场景)
|
||||
api_key_auth = kwargs.get("api_key_auth")
|
||||
if api_key_auth and hasattr(api_key_auth, 'workspace_id'):
|
||||
return api_key_auth.workspace_id
|
||||
|
||||
# 从 user.current_workspace_id 获取
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if user:
|
||||
ws_id = getattr(user, 'current_workspace_id', None)
|
||||
if ws_id:
|
||||
return ws_id
|
||||
|
||||
logger.warning(f"无法获取 workspace_id, kwargs keys: {list(kwargs.keys())}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_tenant_id_from_kwargs(db: Session, kwargs: dict):
|
||||
"""从 kwargs 中获取 tenant_id"""
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if user and hasattr(user, 'tenant_id'):
|
||||
return user.tenant_id
|
||||
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
if workspace_id:
|
||||
from app.models.workspace_model import Workspace
|
||||
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if workspace:
|
||||
return workspace.tenant_id
|
||||
|
||||
api_key_auth = kwargs.get("api_key_auth")
|
||||
if api_key_auth and hasattr(api_key_auth, 'workspace_id'):
|
||||
from app.models.workspace_model import Workspace
|
||||
workspace = db.query(Workspace).filter(Workspace.id == api_key_auth.workspace_id).first()
|
||||
if workspace:
|
||||
return workspace.tenant_id
|
||||
|
||||
data = kwargs.get("data") or kwargs.get("body") or kwargs.get("payload")
|
||||
if data and hasattr(data, "workspace_id"):
|
||||
from app.models.workspace_model import Workspace
|
||||
workspace = db.query(Workspace).filter(Workspace.id == data.workspace_id).first()
|
||||
if workspace:
|
||||
return workspace.tenant_id
|
||||
|
||||
share_data = kwargs.get("share_data")
|
||||
if share_data and hasattr(share_data, 'share_token'):
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.models.app_model import App
|
||||
share_token = share_data.share_token
|
||||
from app.models.release_share_model import ReleaseShare
|
||||
share_record = db.query(ReleaseShare).filter(ReleaseShare.share_token == share_token).first()
|
||||
if share_record:
|
||||
app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first()
|
||||
if app:
|
||||
workspace = db.query(Workspace).filter(Workspace.id == app.workspace_id).first()
|
||||
if workspace:
|
||||
return workspace.tenant_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取租户的配额配置
|
||||
|
||||
优先级:
|
||||
1. premium 模块的 tenant_subscriptions(SaaS 版)
|
||||
2. default_free_plan.py 配置文件(社区版兜底)
|
||||
"""
|
||||
# 尝试从 premium 模块获取(SaaS 版)
|
||||
try:
|
||||
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||
# premium 模块存在,运行时错误不应被静默降级,直接抛出
|
||||
quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id)
|
||||
if quota_config:
|
||||
logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置")
|
||||
return quota_config
|
||||
# premium 存在但该租户无订阅记录,降级到免费套餐
|
||||
logger.debug(f"租户 {tenant_id} 无 premium 订阅,降级到免费套餐")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# 社区版:premium 包不存在,正常降级
|
||||
logger.debug("premium 模块不存在,使用社区版免费套餐配额")
|
||||
|
||||
# 降级到社区版配置文件
|
||||
try:
|
||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||
logger.debug(f"使用社区版免费套餐配额: tenant={tenant_id}")
|
||||
return DEFAULT_FREE_PLAN.get("quotas")
|
||||
except Exception as e:
|
||||
logger.error(f"无法从配置文件获取配额: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_api_ops_rate_limit(db: Session, tenant_id: UUID) -> Optional[int]:
|
||||
"""
|
||||
获取租户套餐的 API 操作速率限制(QPS 上限)
|
||||
|
||||
该函数兼容社区版和 SaaS 版:
|
||||
- SaaS 版:从 premium 模块的套餐配额读取
|
||||
- 社区版:从 default_free_plan.py 配置文件读取
|
||||
|
||||
Returns:
|
||||
int: api_ops_rate_limit 值,如果未配置则返回 None
|
||||
"""
|
||||
quota_config = _get_quota_config(db, tenant_id)
|
||||
if quota_config:
|
||||
return quota_config.get("api_ops_rate_limit")
|
||||
return None
|
||||
|
||||
|
||||
class QuotaUsageRepository:
|
||||
"""配额使用量数据访问层"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def count_workspaces(self, tenant_id: UUID) -> int:
|
||||
from app.models.workspace_model import Workspace
|
||||
return self.db.query(Workspace).filter(
|
||||
Workspace.tenant_id == tenant_id,
|
||||
Workspace.is_active.is_(True)
|
||||
).count()
|
||||
|
||||
def count_apps(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
|
||||
from app.models.app_model import App
|
||||
from app.models.workspace_model import Workspace
|
||||
query = self.db.query(App).join(
|
||||
Workspace, App.workspace_id == Workspace.id
|
||||
).filter(
|
||||
App.is_active.is_(True)
|
||||
)
|
||||
if workspace_id:
|
||||
query = query.filter(App.workspace_id == workspace_id)
|
||||
else:
|
||||
query = query.filter(Workspace.tenant_id == tenant_id)
|
||||
return query.count()
|
||||
|
||||
def count_skills(self, tenant_id: UUID) -> int:
|
||||
from app.models.skill_model import Skill
|
||||
return self.db.query(Skill).filter(
|
||||
Skill.tenant_id == tenant_id,
|
||||
Skill.is_active.is_(True)
|
||||
).count()
|
||||
|
||||
def sum_knowledge_capacity_gb(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> float:
|
||||
from app.models.document_model import Document
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.models.workspace_model import Workspace
|
||||
query = self.db.query(func.coalesce(func.sum(Document.file_size), 0)).join(
|
||||
Knowledge, Document.kb_id == Knowledge.id
|
||||
).join(
|
||||
Workspace, Knowledge.workspace_id == Workspace.id
|
||||
).filter(
|
||||
Document.status == 1,
|
||||
)
|
||||
if workspace_id:
|
||||
query = query.filter(Knowledge.workspace_id == workspace_id)
|
||||
else:
|
||||
query = query.filter(Workspace.tenant_id == tenant_id)
|
||||
result = query.scalar()
|
||||
return float(result) / (1024 ** 3) if result else 0.0
|
||||
|
||||
def count_memory_engines(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from app.models.workspace_model import Workspace
|
||||
query = self.db.query(MemoryConfig).join(
|
||||
Workspace, MemoryConfig.workspace_id == Workspace.id
|
||||
)
|
||||
if workspace_id:
|
||||
query = query.filter(MemoryConfig.workspace_id == workspace_id)
|
||||
else:
|
||||
query = query.filter(Workspace.tenant_id == tenant_id)
|
||||
return query.count()
|
||||
|
||||
def count_end_users(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.models.user_model import User
|
||||
query = self.db.query(EndUser).join(
|
||||
Workspace, EndUser.workspace_id == Workspace.id
|
||||
)
|
||||
if workspace_id:
|
||||
query = query.filter(EndUser.workspace_id == workspace_id)
|
||||
else:
|
||||
query = query.filter(Workspace.tenant_id == tenant_id)
|
||||
trial_user_ids = [
|
||||
str(u.id) for u in self.db.query(User.id).filter(User.tenant_id == tenant_id).all()
|
||||
]
|
||||
if trial_user_ids:
|
||||
query = query.filter(~EndUser.other_id.in_(trial_user_ids))
|
||||
return query.count()
|
||||
|
||||
def count_models(self, tenant_id: UUID) -> int:
|
||||
from app.models.models_model import ModelConfig
|
||||
return self.db.query(ModelConfig).filter(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_active == True,
|
||||
ModelConfig.is_composite == True
|
||||
).count()
|
||||
|
||||
def count_ontology_projects(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
from app.models.workspace_model import Workspace
|
||||
if workspace_id:
|
||||
return self.db.query(OntologyScene).filter(
|
||||
OntologyScene.workspace_id == workspace_id
|
||||
).count()
|
||||
return self.db.query(OntologyScene).join(
|
||||
Workspace, OntologyScene.workspace_id == Workspace.id
|
||||
).filter(
|
||||
Workspace.tenant_id == tenant_id
|
||||
).count()
|
||||
|
||||
def get_usage_by_quota_type(self, tenant_id: UUID, quota_type: str, workspace_id: Optional[UUID] = None):
|
||||
"""按配额类型分发,返回当前使用量"""
|
||||
dispatch = {
|
||||
"workspace_quota": self.count_workspaces,
|
||||
"app_quota": self.count_apps,
|
||||
"skill_quota": self.count_skills,
|
||||
"knowledge_capacity_quota": self.sum_knowledge_capacity_gb,
|
||||
"memory_engine_quota": self.count_memory_engines,
|
||||
"end_user_quota": self.count_end_users,
|
||||
"model_quota": self.count_models,
|
||||
"ontology_project_quota": self.count_ontology_projects,
|
||||
}
|
||||
fn = dispatch.get(quota_type)
|
||||
if workspace_id:
|
||||
return fn(tenant_id, workspace_id) if fn else 0
|
||||
return fn(tenant_id) if fn else 0
|
||||
|
||||
|
||||
def _check_quota(
|
||||
db: Session,
|
||||
tenant_id: UUID,
|
||||
quota_type: str,
|
||||
resource_name: str,
|
||||
usage_func: Optional[Callable] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
) -> None:
|
||||
"""核心配额检查逻辑:对比使用量和配额限制"""
|
||||
try:
|
||||
quota_config = _get_quota_config(db, tenant_id)
|
||||
if not quota_config:
|
||||
logger.warning(f"租户 {tenant_id} 无有效配额配置,跳过配额检查")
|
||||
return
|
||||
|
||||
quota_limit = quota_config.get(quota_type)
|
||||
if quota_limit is None:
|
||||
logger.warning(f"配额配置未包含 {quota_type},跳过配额检查")
|
||||
return
|
||||
|
||||
if usage_func:
|
||||
current_usage = usage_func(db, tenant_id, workspace_id) if workspace_id else usage_func(db, tenant_id)
|
||||
else:
|
||||
current_usage = QuotaUsageRepository(db).get_usage_by_quota_type(tenant_id, quota_type, workspace_id)
|
||||
|
||||
if current_usage >= quota_limit:
|
||||
logger.warning(
|
||||
f"配额不足: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, "
|
||||
f"usage={current_usage}, limit={quota_limit}"
|
||||
)
|
||||
raise QuotaExceededError(
|
||||
resource=resource_name,
|
||||
current_usage=current_usage,
|
||||
quota_limit=quota_limit,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"配额检查通过: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, "
|
||||
f"usage={current_usage}, limit={quota_limit}"
|
||||
)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"配额检查异常: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, "
|
||||
f"error_type={type(e).__name__}, error={str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# ─── 具名装饰器 ────────────────────────────────────────────────────────────
|
||||
|
||||
def check_workspace_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_skill_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "skill_quota", "skill")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "skill_quota", "skill")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_app_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "app_quota", "app", workspace_id=workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "app_quota", "app", workspace_id=workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_knowledge_capacity_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
if not db:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||
if not tenant_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity", workspace_id=workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity", workspace_id=workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_memory_engine_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
logger.debug(f"check_memory_engine_quota async_wrapper: db={db is not None}, user={user}, kwargs_keys={list(kwargs.keys())}")
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine", workspace_id=workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
logger.debug(f"check_memory_engine_quota sync_wrapper: db={db is not None}, user={user}, kwargs_keys={list(kwargs.keys())}")
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine", workspace_id=workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_end_user_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
if not db:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||
if not tenant_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
if not db:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||
if not tenant_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_ontology_project_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project", workspace_id=workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
workspace_id = _get_workspace_id_from_kwargs(kwargs)
|
||||
if not workspace_id:
|
||||
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project", workspace_id=workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_model_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_model_activation_quota(func: Callable) -> Callable:
|
||||
"""模型激活时的配额检查装饰器"""
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
|
||||
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
|
||||
model_data = kwargs.get("model_data")
|
||||
|
||||
if not model_id or not model_data:
|
||||
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
if model_data.is_active:
|
||||
try:
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
existing_model = ModelConfigService.get_model_by_id(
|
||||
db=db,
|
||||
model_id=model_id,
|
||||
tenant_id=user.tenant_id
|
||||
)
|
||||
|
||||
if not existing_model.is_active:
|
||||
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
|
||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||
except Exception as e:
|
||||
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
|
||||
raise
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
|
||||
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
|
||||
model_data = kwargs.get("model_data")
|
||||
|
||||
if not model_id or not model_data:
|
||||
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if model_data.is_active:
|
||||
try:
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
existing_model = ModelConfigService.get_model_by_id(
|
||||
db=db,
|
||||
model_id=model_id,
|
||||
tenant_id=user.tenant_id
|
||||
)
|
||||
|
||||
if not existing_model.is_active:
|
||||
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
|
||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||
except Exception as e:
|
||||
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
|
||||
raise
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None):
|
||||
"""通用配额检查装饰器,支持自定义使用量获取函数"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||
raise InternalServerError()
|
||||
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# ─── 配额使用统计 ────────────────────────────────────────────────────────────
|
||||
|
||||
async def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
|
||||
"""获取租户所有配额的使用情况
|
||||
|
||||
对于 workspace 级别的配额(app/knowledge_capacity/memory_engine/end_user):
|
||||
- used: 租户汇总(所有空间加总)
|
||||
- limit: quota × 活跃工作区数(有效总限额,使汇总数据自洽)
|
||||
- per_workspace: 各空间明细,包含 workspace_id、workspace_name、used、limit、percentage
|
||||
- 配额检查逻辑不变:仍按单个空间独立检查
|
||||
"""
|
||||
quota_config = _get_quota_config(db, tenant_id)
|
||||
if not quota_config:
|
||||
return {}
|
||||
|
||||
repo = QuotaUsageRepository(db)
|
||||
|
||||
def pct(used, limit):
|
||||
return round(used / limit * 100, 1) if limit else None
|
||||
|
||||
workspace_count = repo.count_workspaces(tenant_id)
|
||||
skill_count = repo.count_skills(tenant_id)
|
||||
app_count = repo.count_apps(tenant_id)
|
||||
knowledge_gb = repo.sum_knowledge_capacity_gb(tenant_id)
|
||||
memory_count = repo.count_memory_engines(tenant_id)
|
||||
end_user_count = repo.count_end_users(tenant_id)
|
||||
model_count = repo.count_models(tenant_id)
|
||||
ontology_count = repo.count_ontology_projects(tenant_id)
|
||||
|
||||
# 获取租户下所有活跃工作区,用于按空间拆分明细
|
||||
from app.models.workspace_model import Workspace
|
||||
active_workspaces = db.query(Workspace).filter(
|
||||
Workspace.tenant_id == tenant_id,
|
||||
Workspace.is_active.is_(True)
|
||||
).all()
|
||||
|
||||
# 构建各空间的 workspace 级配额明细
|
||||
def _build_per_workspace_detail(count_func, per_unit_limit):
|
||||
"""为 workspace 级配额构建 per_workspace 明细列表"""
|
||||
if not per_unit_limit or not active_workspaces:
|
||||
return []
|
||||
details = []
|
||||
for ws in active_workspaces:
|
||||
ws_used = count_func(tenant_id, ws.id)
|
||||
details.append({
|
||||
"workspace_id": str(ws.id),
|
||||
"workspace_name": ws.name,
|
||||
"used": ws_used,
|
||||
"limit": per_unit_limit,
|
||||
"percentage": pct(ws_used, per_unit_limit),
|
||||
})
|
||||
return details
|
||||
|
||||
# workspace 级配额的每空间限额
|
||||
app_quota_per_ws = quota_config.get("app_quota")
|
||||
knowledge_quota_per_ws = quota_config.get("knowledge_capacity_quota")
|
||||
memory_quota_per_ws = quota_config.get("memory_engine_quota")
|
||||
end_user_quota_per_ws = quota_config.get("end_user_quota")
|
||||
ontology_quota_per_ws = quota_config.get("ontology_project_quota")
|
||||
|
||||
# workspace 级配额的有效总限额 = 每空间限额 × 活跃工作区数
|
||||
app_effective_limit = app_quota_per_ws * workspace_count if app_quota_per_ws is not None and workspace_count > 0 else app_quota_per_ws
|
||||
knowledge_effective_limit = knowledge_quota_per_ws * workspace_count if knowledge_quota_per_ws is not None and workspace_count > 0 else knowledge_quota_per_ws
|
||||
memory_effective_limit = memory_quota_per_ws * workspace_count if memory_quota_per_ws is not None and workspace_count > 0 else memory_quota_per_ws
|
||||
end_user_effective_limit = end_user_quota_per_ws * workspace_count if end_user_quota_per_ws is not None and workspace_count > 0 else end_user_quota_per_ws
|
||||
ontology_effective_limit = ontology_quota_per_ws * workspace_count if ontology_quota_per_ws is not None and workspace_count > 0 else ontology_quota_per_ws
|
||||
|
||||
api_ops_current = 0
|
||||
try:
|
||||
from app.aioRedis import aio_redis as _aio_redis
|
||||
from app.models.api_key_model import ApiKey
|
||||
# api_ops_rate_limit 限的是每个 api_key 每秒最高限额
|
||||
# 展示当前最接近触发限流的 key 的 QPS(取最大值)
|
||||
api_key_ids = db.query(ApiKey.id).join(
|
||||
Workspace, ApiKey.workspace_id == Workspace.id
|
||||
).filter(
|
||||
Workspace.tenant_id == tenant_id,
|
||||
ApiKey.is_active.is_(True)
|
||||
).all()
|
||||
for (key_id,) in api_key_ids:
|
||||
_rk = API_KEY_QPS_REDIS_KEY.format(api_key_id=key_id)
|
||||
val = await _aio_redis.get(_rk)
|
||||
count = int(val) if val else 0
|
||||
if count > api_ops_current:
|
||||
api_ops_current = count
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 api_ops_current 失败,返回 0: {type(e).__name__}: {e}")
|
||||
|
||||
return {
|
||||
"workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))},
|
||||
"skill": {"used": skill_count, "limit": quota_config.get("skill_quota"), "percentage": pct(skill_count, quota_config.get("skill_quota"))},
|
||||
"app": {
|
||||
"used": app_count,
|
||||
"limit": app_effective_limit,
|
||||
"percentage": pct(app_count, app_effective_limit),
|
||||
"per_workspace": _build_per_workspace_detail(repo.count_apps, app_quota_per_ws),
|
||||
},
|
||||
"knowledge_capacity": {
|
||||
"used": round(knowledge_gb, 2),
|
||||
"limit": knowledge_effective_limit,
|
||||
"percentage": pct(knowledge_gb, knowledge_effective_limit),
|
||||
"unit": "GB",
|
||||
"per_workspace": _build_per_workspace_detail(repo.sum_knowledge_capacity_gb, knowledge_quota_per_ws),
|
||||
},
|
||||
"memory_engine": {
|
||||
"used": memory_count,
|
||||
"limit": memory_effective_limit,
|
||||
"percentage": pct(memory_count, memory_effective_limit),
|
||||
"per_workspace": _build_per_workspace_detail(repo.count_memory_engines, memory_quota_per_ws),
|
||||
},
|
||||
"end_user": {
|
||||
"used": end_user_count,
|
||||
"limit": end_user_effective_limit,
|
||||
"percentage": pct(end_user_count, end_user_effective_limit),
|
||||
"per_workspace": _build_per_workspace_detail(repo.count_end_users, end_user_quota_per_ws),
|
||||
},
|
||||
"ontology_project": {
|
||||
"used": ontology_count,
|
||||
"limit": ontology_effective_limit,
|
||||
"percentage": pct(ontology_count, ontology_effective_limit),
|
||||
"per_workspace": _build_per_workspace_detail(repo.count_ontology_projects, ontology_quota_per_ws),
|
||||
},
|
||||
"model": {"used": model_count, "limit": quota_config.get("model_quota"), "percentage": pct(model_count, quota_config.get("model_quota"))},
|
||||
"api_ops_rate_limit": {"current": api_ops_current, "limit": quota_config.get("api_ops_rate_limit"), "percentage": None, "unit": "次/秒"},
|
||||
}
|
||||
38
api/app/core/quota_stub.py
Normal file
38
api/app/core/quota_stub.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
配额检查 stub - 社区版和 SaaS 版统一使用 core.quota_manager 实现
|
||||
|
||||
所有配额检查逻辑统一在 core 层实现,两个版本共用:
|
||||
- 社区版:从 default_free_plan.py 读取配额限制
|
||||
- SaaS 版:优先从 tenant_subscriptions 表读取,降级到配置文件
|
||||
"""
|
||||
from app.core.quota_manager import (
|
||||
check_workspace_quota,
|
||||
check_skill_quota,
|
||||
check_app_quota,
|
||||
check_knowledge_capacity_quota,
|
||||
check_memory_engine_quota,
|
||||
check_end_user_quota,
|
||||
check_ontology_project_quota,
|
||||
check_model_quota,
|
||||
check_model_activation_quota,
|
||||
get_quota_usage,
|
||||
_check_quota,
|
||||
QuotaUsageRepository,
|
||||
API_KEY_QPS_REDIS_KEY,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"check_workspace_quota",
|
||||
"check_skill_quota",
|
||||
"check_app_quota",
|
||||
"check_knowledge_capacity_quota",
|
||||
"check_memory_engine_quota",
|
||||
"check_end_user_quota",
|
||||
"check_ontology_project_quota",
|
||||
"check_model_quota",
|
||||
"check_model_activation_quota",
|
||||
"get_quota_usage",
|
||||
"_check_quota",
|
||||
"QuotaUsageRepository",
|
||||
"API_KEY_QPS_REDIS_KEY",
|
||||
]
|
||||
@@ -672,10 +672,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
excel_parser = ExcelParser()
|
||||
if parser_config.get("html4excel") and parser_config.get("html4excel").lower() == "true":
|
||||
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
|
||||
parser_config["chunk_token_num"] = 0
|
||||
else:
|
||||
sections = [(_, "") for _ in excel_parser(binary) if _]
|
||||
parser_config["chunk_token_num"] = 12800
|
||||
callback(0.8, "Finish parsing.")
|
||||
# Excel 每行直接作为一个 chunk,不经过 naive_merge 避免被 delimiter 拆分
|
||||
chunks = [s for s, _ in sections]
|
||||
res.extend(tokenize_chunks(chunks, doc, is_english, None))
|
||||
res.extend(embed_res)
|
||||
res.extend(url_res)
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
@@ -33,18 +33,16 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
effective_timeout = seconds if seconds else 120 # 默认 120 秒超时
|
||||
for a in range(attempts):
|
||||
try:
|
||||
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
||||
result = result_queue.get(timeout=seconds)
|
||||
else:
|
||||
result = result_queue.get()
|
||||
result = result_queue.get(timeout=effective_timeout)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
except queue.Empty:
|
||||
pass
|
||||
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
|
||||
raise TimeoutError(f"Function '{func.__name__}' timed out after {effective_timeout} seconds and {attempts} attempts.")
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs) -> Any:
|
||||
|
||||
@@ -232,14 +232,14 @@ class RAGExcelParser:
|
||||
t = str(ti[i].value) if i < len(ti) else ""
|
||||
t += (":" if t else "") + str(c.value)
|
||||
fields.append(t)
|
||||
line = "; ".join(fields)
|
||||
line = "\n".join(fields)
|
||||
if sheetname.lower().find("sheet") < 0:
|
||||
line += " ——" + sheetname
|
||||
line += "\n——" + sheetname
|
||||
res.append(line)
|
||||
else:
|
||||
# 只有表头的情况
|
||||
if header_fields:
|
||||
line = "; ".join(header_fields)
|
||||
line = "\n".join(header_fields)
|
||||
if sheetname.lower().find("sheet") < 0:
|
||||
line += " ——" + sheetname
|
||||
res.append(line)
|
||||
|
||||
@@ -292,9 +292,10 @@ class MinerUParser(RAGPdfParser):
|
||||
self.page_from = page_from
|
||||
self.page_to = page_to
|
||||
try:
|
||||
with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
|
||||
self.pdf = pdf
|
||||
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])]
|
||||
with sys.modules[LOCK_KEY_pdfplumber]: # ← 加这一行,获取全局锁
|
||||
with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
|
||||
self.pdf = pdf
|
||||
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])]
|
||||
except Exception as e:
|
||||
self.page_images = None
|
||||
self.total_page = 0
|
||||
|
||||
@@ -50,7 +50,9 @@ class OpenAIEmbed(Base):
|
||||
def encode(self, texts: list):
|
||||
# OpenAI requires batch size <=16
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8191) for t in texts]
|
||||
# Use 8000 instead of 8191 to leave safety margin for tokenizer differences
|
||||
# between cl100k_base (used by truncate) and the actual embedding model
|
||||
texts = [truncate(t, 8000) for t in texts]
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
@@ -63,7 +65,7 @@ class OpenAIEmbed(Base):
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
|
||||
|
||||
@@ -79,6 +81,7 @@ class LocalAIEmbed(Base):
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8000) for t in texts]
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
@@ -173,6 +176,7 @@ class XinferenceEmbed(Base):
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8000) for t in texts]
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
@@ -188,7 +192,7 @@ class XinferenceEmbed(Base):
|
||||
def encode_queries(self, text):
|
||||
res = None
|
||||
try:
|
||||
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
@@ -28,6 +28,7 @@ from app.core.rag.common.float_utils import get_float
|
||||
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -112,11 +113,10 @@ def knowledge_retrieval(
|
||||
continue
|
||||
|
||||
# Use the specified reranker for re-ranking
|
||||
if reranker_id:
|
||||
if reranker_id and all_results:
|
||||
try:
|
||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
all_results = rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
except Exception as rerank_error:
|
||||
# If reranker fails, log warning and continue with original results
|
||||
logger.warning(
|
||||
"Reranker failed, falling back to original results",
|
||||
extra={
|
||||
@@ -132,7 +132,10 @@ def knowledge_retrieval(
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
all_results.insert(0, doc)
|
||||
all_results.insert(0, DocumentChunk(
|
||||
page_content=doc.get("page_content", ""),
|
||||
metadata=doc.get("metadata", {})
|
||||
))
|
||||
except Exception as graph_error:
|
||||
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
|
||||
|
||||
@@ -198,16 +201,18 @@ def _retrieve_for_knowledge(
|
||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||
|
||||
if not chat_model:
|
||||
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||
key=llm_key.api_key,
|
||||
model_name=llm_key.model_name,
|
||||
base_url=llm_key.api_base,
|
||||
)
|
||||
if not embedding_model:
|
||||
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base,
|
||||
key=emb_key.api_key,
|
||||
model_name=emb_key.model_name,
|
||||
base_url=emb_key.api_base,
|
||||
)
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
@@ -248,6 +253,29 @@ def _retrieve_for_knowledge(
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = unique_rs
|
||||
if unique_rs:
|
||||
rs = vector_service.rerank(
|
||||
query=kb_config["query"],
|
||||
docs=unique_rs,
|
||||
top_k=kb_config["top_k"]
|
||||
)
|
||||
if kb_config["retrieve_type"] == "graph":
|
||||
try:
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
graph_doc = kg_retriever.retrieval(
|
||||
question=kb_config["query"],
|
||||
workspace_ids=[str(db_knowledge.workspace_id)],
|
||||
kb_ids=[str(db_knowledge.id)],
|
||||
emb_mdl=embedding_model,
|
||||
llm=chat_model,
|
||||
)
|
||||
if graph_doc:
|
||||
rs.insert(0, DocumentChunk(
|
||||
page_content=graph_doc.get("page_content", ""),
|
||||
metadata=graph_doc.get("metadata", {})
|
||||
))
|
||||
except Exception as graph_error:
|
||||
logger.warning(f"Graph retrieval failed for kb {db_knowledge.id}: {graph_error}")
|
||||
|
||||
results.extend(rs)
|
||||
return results, chat_model, embedding_model
|
||||
|
||||
@@ -68,9 +68,9 @@ class ESConnection(DocStoreConnection):
|
||||
client_config = {
|
||||
"hosts": [hosts],
|
||||
"basic_auth": (os.getenv("ELASTICSEARCH_USERNAME", "elastic"), os.getenv("ELASTICSEARCH_PASSWORD", "elastic")),
|
||||
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)),
|
||||
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 30)),
|
||||
"retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true",
|
||||
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)),
|
||||
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 3)),
|
||||
}
|
||||
|
||||
# Only add SSL settings if using HTTPS
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
import threading
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
from elasticsearch import Elasticsearch, helpers
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from packaging.version import parse as parse_version
|
||||
from pydantic import BaseModel, model_validator
|
||||
from abc import ABC
|
||||
# langchain-community
|
||||
# langchain-xinference
|
||||
# from langchain_community.embeddings import XinferenceEmbeddings
|
||||
# from langchain_xinference import XinferenceRerank
|
||||
from langchain_core.documents import Document
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models import RedBearLLM, RedBearRerank
|
||||
from app.core.models import RedBearRerank
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
from app.models.models_model import ModelConfig, ModelApiKey
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.models.models_model import ModelApiKey
|
||||
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.core.rag.vdb.field import Field
|
||||
@@ -29,37 +26,9 @@ from app.core.rag.models.chunk import DocumentChunk
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticSearchConfig(BaseModel):
|
||||
# Regular Elasticsearch config
|
||||
host: str | None = None
|
||||
port: int | None = None
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
|
||||
# Common config
|
||||
ca_certs: str | None = None
|
||||
verify_certs: bool = False
|
||||
request_timeout: int = 100000
|
||||
retry_on_timeout: bool = True
|
||||
max_retries: int = 10000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
# Regular Elasticsearch validation
|
||||
if not values.get("host"):
|
||||
raise ValueError("config HOST is required for regular Elasticsearch")
|
||||
if not values.get("port"):
|
||||
raise ValueError("config PORT is required for regular Elasticsearch")
|
||||
if not values.get("username"):
|
||||
raise ValueError("config USERNAME is required for regular Elasticsearch")
|
||||
if not values.get("password"):
|
||||
raise ValueError("config PASSWORD is required for regular Elasticsearch")
|
||||
return values
|
||||
|
||||
|
||||
class ElasticSearchVector(BaseVector):
|
||||
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
||||
def __init__(self, index_name: str, client: Elasticsearch,
|
||||
embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
||||
super().__init__(index_name.lower())
|
||||
|
||||
# 初始化 Embedding 模型(自动支持火山引擎多模态)
|
||||
@@ -77,58 +46,8 @@ class ElasticSearchVector(BaseVector):
|
||||
api_key=reranker_config.api_key,
|
||||
base_url=reranker_config.api_base
|
||||
))
|
||||
self._client = self._init_client(config)
|
||||
self._version = self._get_version()
|
||||
self._check_version()
|
||||
|
||||
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
||||
"""
|
||||
Initialize Elasticsearch client for regular Elasticsearch.
|
||||
"""
|
||||
try:
|
||||
# Regular Elasticsearch configuration
|
||||
parsed_url = urlparse(config.host or "")
|
||||
if parsed_url.scheme in {"http", "https"}:
|
||||
hosts = f"{config.host}:{config.port}"
|
||||
use_https = parsed_url.scheme == "https"
|
||||
else:
|
||||
hosts = f"https://{config.host}:{config.port}"
|
||||
use_https = False
|
||||
|
||||
client_config = {
|
||||
"hosts": [hosts],
|
||||
"basic_auth": (config.username, config.password),
|
||||
"request_timeout": config.request_timeout,
|
||||
"retry_on_timeout": config.retry_on_timeout,
|
||||
"max_retries": config.max_retries,
|
||||
}
|
||||
|
||||
# Only add SSL settings if using HTTPS
|
||||
if use_https:
|
||||
client_config["verify_certs"] = config.verify_certs
|
||||
if config.ca_certs:
|
||||
client_config["ca_certs"] = config.ca_certs
|
||||
|
||||
client = Elasticsearch(**client_config)
|
||||
|
||||
# Test connection
|
||||
if not client.ping():
|
||||
raise ConnectionError("Failed to connect to Elasticsearch")
|
||||
|
||||
except requests.ConnectionError as e:
|
||||
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
||||
|
||||
return client
|
||||
|
||||
def _get_version(self) -> str:
|
||||
info = self._client.info()
|
||||
return cast(str, info["version"]["number"])
|
||||
|
||||
def _check_version(self):
|
||||
if parse_version(self._version) < parse_version("8.0.0"):
|
||||
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
|
||||
# 使用外部传入的共享客户端
|
||||
self._client = client
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "elasticsearch"
|
||||
@@ -745,29 +664,79 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
|
||||
class ElasticSearchVectorFactory:
|
||||
@staticmethod
|
||||
def init_vector(knowledge: Knowledge) -> ElasticSearchVector:
|
||||
"""ES 向量服务工厂 - 单例共享连接"""
|
||||
|
||||
_client: Elasticsearch | None = None
|
||||
_lock = threading.Lock()
|
||||
_version_checked = False
|
||||
|
||||
@classmethod
|
||||
def _get_shared_client(cls) -> Elasticsearch:
|
||||
"""获取共享的 ES 客户端(线程安全的懒加载单例)"""
|
||||
if cls._client is not None:
|
||||
return cls._client
|
||||
|
||||
with cls._lock:
|
||||
# 双重检查,防止并发时重复创建
|
||||
if cls._client is not None:
|
||||
return cls._client
|
||||
|
||||
try:
|
||||
parsed_url = urlparse(os.getenv("ELASTICSEARCH_HOST", "127.0.0.1") or "")
|
||||
if parsed_url.scheme in {"http", "https"}:
|
||||
hosts = f'{os.getenv("ELASTICSEARCH_HOST")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}'
|
||||
use_https = parsed_url.scheme == "https"
|
||||
else:
|
||||
hosts = f'https://{os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}'
|
||||
use_https = False
|
||||
|
||||
client_config = {
|
||||
"hosts": [hosts],
|
||||
"basic_auth": (
|
||||
os.getenv("ELASTICSEARCH_USERNAME", "elastic"),
|
||||
os.getenv("ELASTICSEARCH_PASSWORD", "elastic"),
|
||||
),
|
||||
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 30)),
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 3)),
|
||||
"connections_per_node": int(os.getenv("ELASTICSEARCH_CONNECTIONS_PER_NODE", 10)),
|
||||
}
|
||||
|
||||
if use_https:
|
||||
client_config["verify_certs"] = os.getenv("ELASTICSEARCH_VERIFY_CERTS", "false") == "true"
|
||||
ca_certs = os.getenv("ELASTICSEARCH_CA_CERTS")
|
||||
if ca_certs:
|
||||
client_config["ca_certs"] = str(ca_certs)
|
||||
|
||||
client = Elasticsearch(**client_config)
|
||||
|
||||
if not client.ping():
|
||||
raise ConnectionError("Failed to connect to Elasticsearch")
|
||||
|
||||
# 版本检查只做一次
|
||||
if not cls._version_checked:
|
||||
info = client.info()
|
||||
version = info["version"]["number"]
|
||||
if parse_version(version) < parse_version("8.0.0"):
|
||||
raise ValueError(f"Elasticsearch version must be >= 8.0.0, got {version}")
|
||||
cls._version_checked = True
|
||||
logger.info(f"Elasticsearch shared client initialized, version: {version}")
|
||||
|
||||
cls._client = client
|
||||
|
||||
except requests.ConnectionError as e:
|
||||
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
||||
|
||||
return cls._client
|
||||
|
||||
@classmethod
|
||||
def init_vector(cls, knowledge: Knowledge) -> ElasticSearchVector:
|
||||
"""创建向量服务实例(共享 ES 连接)"""
|
||||
client = cls._get_shared_client()
|
||||
collection_name = f"Vector_index_{knowledge.id}_Node"
|
||||
|
||||
# Use regular Elasticsearch with config values
|
||||
config_dict = {
|
||||
"host": os.getenv("ELASTICSEARCH_HOST", "127.0.0.1"),
|
||||
"port": os.getenv("ELASTICSEARCH_PORT", 9200),
|
||||
"username": os.getenv("ELASTICSEARCH_USERNAME", "elastic"),
|
||||
"password": os.getenv("ELASTICSEARCH_PASSWORD", "elastic"),
|
||||
}
|
||||
|
||||
# Common configuration
|
||||
config_dict.update(
|
||||
{
|
||||
"ca_certs": str(os.getenv("ELASTICSEARCH_CA_CERTS")) if os.getenv("ELASTICSEARCH_CA_CERTS") else None,
|
||||
"verify_certs": os.getenv("ELASTICSEARCH_VERIFY_CERTS", False) == "true",
|
||||
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)),
|
||||
"retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true",
|
||||
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)),
|
||||
}
|
||||
)
|
||||
|
||||
if knowledge.embedding is None:
|
||||
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
|
||||
if knowledge.reranker is None:
|
||||
@@ -775,9 +744,9 @@ class ElasticSearchVectorFactory:
|
||||
|
||||
return ElasticSearchVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(**config_dict),
|
||||
client=client,
|
||||
embedding_config=knowledge.embedding.api_keys[0],
|
||||
reranker_config=knowledge.reranker.api_keys[0]
|
||||
reranker_config=knowledge.reranker.api_keys[0],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "now", "datetime_to_timestamp"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
@@ -230,7 +230,7 @@ class DateTimeTool(BuiltinTool):
|
||||
@staticmethod
|
||||
def _datetime_to_timestamp(kwargs) -> dict:
|
||||
"""日期时间转时间戳"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_value = kwargs.get("input_value").strip()
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
timezone_str = kwargs.get("from_timezone", "Asia/Shanghai")
|
||||
|
||||
@@ -253,9 +253,9 @@ class DateTimeTool(BuiltinTool):
|
||||
return {
|
||||
"datetime": input_value,
|
||||
"timezone": timezone_str,
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"timestamp": int(dt.timestamp() * 1000),
|
||||
"iso_format": dt.isoformat(),
|
||||
"result_data": int(dt.timestamp())
|
||||
"result_data": int(dt.timestamp() * 1000)
|
||||
}
|
||||
|
||||
def _calculate_datetime(self, kwargs) -> dict:
|
||||
|
||||
300
api/app/core/tools/builtin/openclaw_tool.py
Normal file
300
api/app/core/tools/builtin/openclaw_tool.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""OpenClaw 远程 Agent 内置工具"""
|
||||
import time
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import List, Dict, Any, Optional
|
||||
import aiohttp
|
||||
|
||||
from app.core.tools.builtin.base import BuiltinTool
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class OpenClawTool(BuiltinTool):
|
||||
"""OpenClaw 远程 Agent 工具 — 支持文本和图片多模态输入"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
super().__init__(tool_id, config)
|
||||
params = self.parameters_config
|
||||
|
||||
# 用户配置项(前端表单填写)
|
||||
self._server_url = params.get("server_url", "")
|
||||
self._api_key = params.get("api_key", "")
|
||||
self._agent_id = params.get("agent_id", "main")
|
||||
|
||||
# 内部默认值
|
||||
self._model = "openclaw"
|
||||
self._session_strategy = "by_user"
|
||||
self._timeout = 120
|
||||
|
||||
# 运行时上下文(通过 set_runtime_context 注入)
|
||||
self._user_id = "anonymous"
|
||||
self._conversation_id = None
|
||||
self._uploaded_files = []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "openclaw_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"OpenClaw 远程 Agent:将任务委托给远程 OpenClaw Agent。"
|
||||
"具备 3D 模型生成与打印控制、设备管理、文件处理、浏览器自动化、"
|
||||
"Shell 命令执行、网络搜索等能力。支持文本和图片多模态交互。"
|
||||
)
|
||||
|
||||
def get_required_config_parameters(self) -> List[str]:
|
||||
return ["server_url", "api_key"]
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="operation",
|
||||
type=ParameterType.STRING,
|
||||
description="任务类型",
|
||||
required=True,
|
||||
enum= ["print_task", "device_query", "image_understand", "general"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="message",
|
||||
type=ParameterType.STRING,
|
||||
description="发送给 OpenClaw Agent 的文本请求内容",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_url",
|
||||
type=ParameterType.STRING,
|
||||
description="可选,附带的图片 URL 或 base64 data URI(OpenClaw 支持图片输入)",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
|
||||
# ---------- 运行时上下文注入 ----------
|
||||
def set_runtime_context(
|
||||
self,
|
||||
user_id: str = "anonymous",
|
||||
conversation_id: Optional[str] = None,
|
||||
uploaded_files: Optional[list] = None
|
||||
):
|
||||
"""注入运行时上下文(由 chat service 调用)"""
|
||||
self._user_id = user_id
|
||||
self._conversation_id = conversation_id
|
||||
self._uploaded_files = uploaded_files or []
|
||||
|
||||
# ---------- 连接测试 ----------
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试 OpenClaw Gateway 连接"""
|
||||
if not self._server_url:
|
||||
return {"success": False, "message": "未配置 server_url"}
|
||||
if not self._api_key:
|
||||
return {"success": False, "message": "未配置 api_key"}
|
||||
|
||||
url = f"{self._server_url.rstrip('/')}/v1/responses"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"x-openclaw-agent-id": self._agent_id
|
||||
}
|
||||
body = {
|
||||
"model": self._model,
|
||||
"user": "connection-test",
|
||||
"input": "hi",
|
||||
"stream": False
|
||||
}
|
||||
try:
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=30)
|
||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
||||
async with session.post(url, json=body, headers=headers) as resp:
|
||||
if resp.status < 400:
|
||||
return {"success": True, "message": "OpenClaw 连接成功"}
|
||||
error_text = await resp.text()
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"OpenClaw HTTP {resp.status}: {error_text[:200]}"
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"OpenClaw 连接失败: {str(e)}"}
|
||||
|
||||
# ---------- 执行 ----------
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行 OpenClaw 调用"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
message = kwargs.get("message", "")
|
||||
if not message:
|
||||
return ToolResult.error_result(
|
||||
error="message 参数不能为空",
|
||||
error_code="OPENCLAW_INVALID_INPUT",
|
||||
execution_time=time.time() - start_time
|
||||
)
|
||||
|
||||
# 提取图片:优先从用户上传文件中获取,LLM 传的 image_url 作为兜底
|
||||
image_url = self._extract_image_from_uploads()
|
||||
if not image_url:
|
||||
image_url = kwargs.get("image_url")
|
||||
if image_url and not image_url.startswith("data:"):
|
||||
image_url = await self._download_and_encode_image(image_url)
|
||||
|
||||
# 构建请求
|
||||
url = f"{self._server_url.rstrip('/')}/v1/responses"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"x-openclaw-agent-id": self._agent_id
|
||||
}
|
||||
user_field = (
|
||||
f"conv-{self._conversation_id}"
|
||||
if self._session_strategy == "by_conversation" and self._conversation_id
|
||||
else f"user-{self._user_id}"
|
||||
)
|
||||
input_field = self._build_input(message, image_url)
|
||||
body = {
|
||||
"model": self._model,
|
||||
"user": user_field,
|
||||
"input": input_field,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=self._timeout)
|
||||
# 打印请求日志(截断 base64 避免日志过大)
|
||||
log_body = {**body}
|
||||
if isinstance(log_body.get("input"), list):
|
||||
log_body["input"] = "[multimodal input, truncated]"
|
||||
elif isinstance(log_body.get("input"), str) and len(log_body["input"]) > 500:
|
||||
log_body["input"] = log_body["input"][:500] + "..."
|
||||
logger.info(
|
||||
f"OpenClaw 请求: url={url}, agent_id={self._agent_id}, "
|
||||
f"has_image={bool(image_url)}, body={log_body}"
|
||||
)
|
||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
||||
async with session.post(url, json=body, headers=headers) as resp:
|
||||
execution_time = time.time() - start_time
|
||||
if resp.status >= 400:
|
||||
error_text = await resp.text()
|
||||
return ToolResult.error_result(
|
||||
error=f"OpenClaw HTTP {resp.status}: {error_text[:500]}",
|
||||
error_code="OPENCLAW_HTTP_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
data = await resp.json()
|
||||
text = self._extract_response(data)
|
||||
display_text = self._format_result(text)
|
||||
return ToolResult.success_result(
|
||||
data=display_text,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return ToolResult.error_result(
|
||||
error=f"OpenClaw 网络连接失败: {str(e)}",
|
||||
error_code="OPENCLAW_NETWORK_ERROR",
|
||||
execution_time=time.time() - start_time
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult.error_result(
|
||||
error=f"OpenClaw 调用失败: {str(e)}",
|
||||
error_code="OPENCLAW_EXECUTION_ERROR",
|
||||
execution_time=time.time() - start_time
|
||||
)
|
||||
|
||||
# ---------- 私有方法 ----------
|
||||
def _extract_image_from_uploads(self) -> Optional[str]:
|
||||
"""从用户上传文件中提取图片 URL"""
|
||||
for f in self._uploaded_files:
|
||||
f_type = f.get("type", "")
|
||||
if f_type == "image":
|
||||
source = f.get("source", {})
|
||||
if source.get("type") == "base64":
|
||||
media_type = source.get("media_type", "image/jpeg")
|
||||
data = source.get("data", "")
|
||||
return f"data:{media_type};base64,{data}"
|
||||
elif f.get("image"):
|
||||
return f.get("image")
|
||||
elif f.get("url"):
|
||||
return f.get("url")
|
||||
elif f_type == "image_url":
|
||||
return f.get("image_url", {}).get("url", "")
|
||||
return None
|
||||
|
||||
async def _download_and_encode_image(self, image_url: str) -> str:
|
||||
"""下载图片并转为 base64 data URI"""
|
||||
try:
|
||||
from PIL import Image
|
||||
MAX_RAW_SIZE = 4 * 1024 * 1024
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
image_url, allow_redirects=True,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
return image_url
|
||||
content_type = resp.headers.get("Content-Type", "image/jpeg")
|
||||
if not content_type.startswith("image/"):
|
||||
return image_url
|
||||
img_bytes = await resp.read()
|
||||
|
||||
if len(img_bytes) > MAX_RAW_SIZE:
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
if img.mode in ("RGBA", "P", "LA"):
|
||||
img = img.convert("RGB")
|
||||
if max(img.size) > 2048:
|
||||
img.thumbnail((2048, 2048), Image.LANCZOS)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format="JPEG", quality=75, optimize=True)
|
||||
img_bytes = buf.getvalue()
|
||||
content_type = "image/jpeg"
|
||||
|
||||
b64 = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return f"data:{content_type};base64,{b64}"
|
||||
except Exception as e:
|
||||
logger.warning(f"OpenClaw 下载图片失败,使用原始 URL: {e}")
|
||||
return image_url
|
||||
|
||||
def _build_input(self, message: str, image_url: Optional[str] = None):
|
||||
"""构造请求 input 字段:有图片则构造多模态结构,否则纯文本"""
|
||||
if not image_url:
|
||||
return message
|
||||
|
||||
content_parts = [{"type": "input_text", "text": message}]
|
||||
if image_url.startswith("data:"):
|
||||
try:
|
||||
header, data = image_url.split(",", 1)
|
||||
media_type = header.split(":")[1].split(";")[0]
|
||||
content_parts.append({
|
||||
"type": "input_image",
|
||||
"source": {"type": "base64", "media_type": media_type, "data": data}
|
||||
})
|
||||
except (ValueError, IndexError):
|
||||
return message
|
||||
else:
|
||||
content_parts.append({
|
||||
"type": "input_image",
|
||||
"source": {"type": "url", "url": image_url}
|
||||
})
|
||||
|
||||
return [{"type": "message", "role": "user", "content": content_parts}]
|
||||
|
||||
def _extract_response(self, response_data: Dict[str, Any]) -> str:
|
||||
"""从 OpenClaw 响应中提取文本内容
|
||||
|
||||
OpenClaw /v1/responses 只返回 output_text 类型的内容。
|
||||
图片信息(如有)由 OpenClaw Skill 以 Markdown 链接形式嵌入文本中返回。
|
||||
"""
|
||||
output = response_data.get("output", [])
|
||||
texts = []
|
||||
for item in output:
|
||||
if item.get("type") == "message":
|
||||
for content in item.get("content", []):
|
||||
if content.get("type") == "output_text" and content.get("text"):
|
||||
texts.append(content["text"])
|
||||
return "\n".join(texts) if texts else str(response_data)
|
||||
|
||||
@staticmethod
|
||||
def _format_result(text: str) -> str:
|
||||
"""格式化结果为 LLM 可读字符串"""
|
||||
return text or "(OpenClaw 返回了空内容)"
|
||||
@@ -11,6 +11,11 @@ class OperationTool(BaseTool):
|
||||
self.base_tool = base_tool
|
||||
self.operation = operation
|
||||
super().__init__(base_tool.tool_id, base_tool.config)
|
||||
|
||||
def set_runtime_context(self, **kwargs):
|
||||
"""转发运行时上下文到 base_tool"""
|
||||
if hasattr(self.base_tool, 'set_runtime_context'):
|
||||
self.base_tool.set_runtime_context(**kwargs)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -32,6 +37,8 @@ class OperationTool(BaseTool):
|
||||
return self._get_datetime_params()
|
||||
elif self.base_tool.name == 'json_tool':
|
||||
return self._get_json_params()
|
||||
elif self.base_tool.name == 'openclaw_tool':
|
||||
return self._get_openclaw_params()
|
||||
else:
|
||||
# 默认返回除operation外的所有参数
|
||||
return [p for p in self.base_tool.parameters if p.name != "operation"]
|
||||
@@ -138,6 +145,29 @@ class OperationTool(BaseTool):
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
elif self.operation == "datetime_to_timestamp":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串,如:2026-04-07 10:30:25)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="from_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="源时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
@@ -209,6 +239,64 @@ class OperationTool(BaseTool):
|
||||
else:
|
||||
return base_params
|
||||
|
||||
def _get_openclaw_params(self) -> List[ToolParameter]:
|
||||
"""获取 openclaw_tool 特定操作的参数"""
|
||||
if self.operation == "print_task":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="message",
|
||||
type=ParameterType.STRING,
|
||||
description="发送给 OpenClaw 的打印任务描述,将用户的原始消息原封不动地传递给 OpenClaw,禁止改写、补充或润色用户的原文",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_url",
|
||||
type=ParameterType.STRING,
|
||||
description="可选,附带的设计图片或参考图,OpenClaw 可据此生成 3D 模型",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
elif self.operation == "device_query":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="message",
|
||||
type=ParameterType.STRING,
|
||||
description="发送给 OpenClaw 的设备查询指令",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "image_understand":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="message",
|
||||
type=ParameterType.STRING,
|
||||
description="发送给 OpenClaw 的图片理解任务,应描述需要对图片做什么(如描述内容、提取文字、分析信息)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_url",
|
||||
type=ParameterType.STRING,
|
||||
description="要分析的图片 URL 或 base64 data URI",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
else:
|
||||
# general 及其他
|
||||
return [
|
||||
ToolParameter(
|
||||
name="message",
|
||||
type=ParameterType.STRING,
|
||||
description="发送给 OpenClaw Agent 的任务描述,应包含完整的任务需求",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_url",
|
||||
type=ParameterType.STRING,
|
||||
description="可选,附带的图片 URL 或 base64 data URI",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行特定操作"""
|
||||
# 添加operation参数
|
||||
|
||||
15
api/app/core/tools/configs/builtin/openclaw_tool.json
Normal file
15
api/app/core/tools/configs/builtin/openclaw_tool.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"name": "openclaw_tool",
|
||||
"description": "调用OpenClaw Agent远程服务",
|
||||
"tool_class": "OpenClawTool",
|
||||
"category": "agent",
|
||||
"requires_config": true,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"server_url": "",
|
||||
"api_key": "",
|
||||
"agent_id": "main"
|
||||
},
|
||||
"tags": ["agent", "openclaw", "multimodal", "3d-printing", "builtin"]
|
||||
}
|
||||
@@ -30,5 +30,18 @@
|
||||
"parameters": {
|
||||
"api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true}
|
||||
}
|
||||
},
|
||||
"openclaw": {
|
||||
"name": "OpenClaw远程Agent",
|
||||
"description": "OpenClaw Agent远程服务",
|
||||
"tool_class": "OpenClawTool",
|
||||
"category": "agent",
|
||||
"requires_config": true,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"server_url": {"type": "string", "description": "OpenClaw Gateway 地址", "required": true},
|
||||
"api_key": {"type": "string", "description": "OpenClaw API Key", "sensitive": true, "required": true}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -30,7 +30,7 @@ class CustomTool(BaseTool):
|
||||
self.auth_config = config.get("auth_config", {})
|
||||
self.base_url = config.get("base_url", "")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
|
||||
|
||||
# 解析schema
|
||||
self._parsed_operations = self._parse_openapi_schema()
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ class LangchainAdapter:
|
||||
def _tool_supports_operations(tool: BaseTool) -> bool:
|
||||
"""检查工具是否支持多操作"""
|
||||
# 内置工具中支持操作的工具
|
||||
builtin_operation_tools = ['datetime_tool', 'json_tool']
|
||||
builtin_operation_tools = ['datetime_tool', 'json_tool', 'openclaw_tool']
|
||||
|
||||
# 检查内置工具
|
||||
if tool.tool_type.value == "builtin" and tool.name in builtin_operation_tools:
|
||||
|
||||
@@ -99,7 +99,7 @@ class SimpleMCPClient:
|
||||
# 建立 SSE 连接
|
||||
response = await self._session.get(self.server_url)
|
||||
|
||||
if response.status not in (200, 202):
|
||||
if not (200 <= response.status < 300):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||
|
||||
@@ -190,9 +190,7 @@ class SimpleMCPClient:
|
||||
|
||||
try:
|
||||
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||
# MCP SSE 协议:POST 请求返回 200 或 202 均为正常
|
||||
# 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回
|
||||
if response.status not in (200, 202):
|
||||
if not (200 <= response.status < 300):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||
|
||||
@@ -207,7 +205,7 @@ class SimpleMCPClient:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||
if response.status not in (200, 202):
|
||||
if not (200 <= response.status < 300):
|
||||
logger.warning(f"通知发送失败: {response.status}")
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
@@ -225,7 +223,7 @@ class SimpleMCPClient:
|
||||
|
||||
try:
|
||||
async with self._session.post(self.server_url, json=init_request) as response:
|
||||
if response.status != 200:
|
||||
if not (200 <= response.status < 300):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ class WorkflowParserResult(BaseModel):
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
features: dict[str, Any] = Field(default_factory=dict)
|
||||
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
|
||||
@@ -51,6 +52,7 @@ class WorkflowImportResult(BaseModel):
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
features: dict[str, Any] = Field(default_factory=dict)
|
||||
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from app.core.workflow.adapters.errors import (
|
||||
ExceptionType
|
||||
)
|
||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
|
||||
from app.core.workflow.nodes.base_config import VariableDefinition as NodeVariableDefinition, BaseNodeConfig
|
||||
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
||||
from app.core.workflow.nodes.configs import (
|
||||
StartNodeConfig,
|
||||
@@ -32,13 +32,17 @@ from app.core.workflow.nodes.configs import (
|
||||
NoteNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
VariableAggregatorNodeConfig
|
||||
VariableAggregatorNodeConfig,
|
||||
ListOperatorNodeConfig,
|
||||
DocExtractorNodeConfig,
|
||||
)
|
||||
from app.schemas.workflow_schema import VariableDefinition as SchemaVariableDefinition
|
||||
from app.core.workflow.nodes.cycle_graph.config import (
|
||||
ConditionDetail as LoopConditionDetail,
|
||||
ConditionsConfig,
|
||||
CycleVariable
|
||||
)
|
||||
from app.core.workflow.nodes.list_operator.config import FilterCondition
|
||||
from app.core.workflow.nodes.enums import (
|
||||
ValueInputType,
|
||||
ComparisonOperator,
|
||||
@@ -90,9 +94,12 @@ class DifyConverter(BaseConverter):
|
||||
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
|
||||
NodeType.TOOL: self.convert_tool_node_config,
|
||||
NodeType.NOTES: self.convert_notes_config,
|
||||
NodeType.LIST_OPERATOR: self.convert_list_operator_node_config,
|
||||
NodeType.DOCUMENT_EXTRACTOR: self.convert_document_extractor_node_config,
|
||||
NodeType.CYCLE_START: lambda x: {},
|
||||
NodeType.BREAK: lambda x: {},
|
||||
}
|
||||
self._file_vars_to_conv: list[SchemaVariableDefinition] = []
|
||||
|
||||
def get_node_convert(self, node_type):
|
||||
func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {})
|
||||
@@ -126,7 +133,7 @@ class DifyConverter(BaseConverter):
|
||||
selector = var_selector.split('.')
|
||||
if len(selector) not in [2, 3] and var_selector != "context":
|
||||
raise Exception(f"invalid variable selector: {var_selector}")
|
||||
if len(selector) == 3:
|
||||
if len(selector) == 3 and selector[0] in ("conversation", "sys"):
|
||||
selector = selector[1:]
|
||||
if selector[0] == "conversation":
|
||||
selector[0] = "conv"
|
||||
@@ -213,7 +220,9 @@ class DifyConverter(BaseConverter):
|
||||
"end with": ComparisonOperator.END_WITH,
|
||||
"not contains": ComparisonOperator.NOT_CONTAINS,
|
||||
"exists": ComparisonOperator.NOT_EMPTY,
|
||||
"not exists": ComparisonOperator.EMPTY
|
||||
"not exists": ComparisonOperator.EMPTY,
|
||||
"in": ComparisonOperator.IN,
|
||||
"not in": ComparisonOperator.NOT_IN,
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@@ -279,19 +288,25 @@ class DifyConverter(BaseConverter):
|
||||
)
|
||||
continue
|
||||
|
||||
if var_type in ["file", "array[file]"]:
|
||||
self.errors.append(
|
||||
ExceptionDefinition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
name=var["variable"],
|
||||
detail=f"Unsupported Variable type for start node: {var_type}"
|
||||
)
|
||||
)
|
||||
if var_type in [VariableType.FILE, VariableType.ARRAY_FILE]:
|
||||
# 开始节点不支持文件变量,转为会话变量
|
||||
self._file_vars_to_conv.append(SchemaVariableDefinition(
|
||||
name=var["variable"],
|
||||
type=var_type.value,
|
||||
required=var.get("required", False),
|
||||
default=None,
|
||||
description=var.get("label", ""),
|
||||
))
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
name=var["variable"],
|
||||
detail=f"File variable '{var['variable']}' is not supported in start node, moved to conversation variables"
|
||||
))
|
||||
continue
|
||||
|
||||
var_def = VariableDefinition(
|
||||
var_def = NodeVariableDefinition(
|
||||
name=var["variable"],
|
||||
type=var_type,
|
||||
required=var["required"],
|
||||
@@ -476,11 +491,11 @@ class DifyConverter(BaseConverter):
|
||||
node_data = node["data"]
|
||||
result = IterationNodeConfig.model_construct(
|
||||
input=self._process_list_variable_literal(node_data["iterator_selector"]),
|
||||
parallel=node_data["is_parallel"],
|
||||
parallel_count=node_data["parallel_nums"],
|
||||
parallel=node_data.get("is_parallel", False),
|
||||
parallel_count=node_data.get("parallel_nums", 4),
|
||||
output=self._process_list_variable_literal(node_data["output_selector"]),
|
||||
output_type=self.variable_type_map(node_data.get("output_type")),
|
||||
flatten=node_data["flatten_output"],
|
||||
flatten=node_data.get("flatten_output", False),
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result)
|
||||
@@ -489,7 +504,23 @@ class DifyConverter(BaseConverter):
|
||||
def convert_assigner_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
assignments = []
|
||||
for assignment in node_data["items"]:
|
||||
|
||||
# Support both formats:
|
||||
# 1. New format: node_data["items"] list
|
||||
# 2. Flat format: assigned_variable_selector + input_variable_selector + write_mode
|
||||
if "items" in node_data:
|
||||
raw_items = node_data["items"]
|
||||
elif "assigned_variable_selector" in node_data and "input_variable_selector" in node_data:
|
||||
raw_items = [{
|
||||
"variable_selector": node_data["assigned_variable_selector"],
|
||||
"value": node_data["input_variable_selector"],
|
||||
"input_type": ValueInputType.VARIABLE,
|
||||
"operation": node_data.get("write_mode", "over-write"),
|
||||
}]
|
||||
else:
|
||||
raw_items = []
|
||||
|
||||
for assignment in raw_items:
|
||||
if assignment.get("operation") is None or assignment.get("value") is None:
|
||||
continue
|
||||
assignments.append(
|
||||
@@ -771,3 +802,119 @@ class DifyConverter(BaseConverter):
|
||||
show_author=node_data.get("showAuthor", True)
|
||||
).model_dump()
|
||||
return result
|
||||
|
||||
def convert_list_operator_node_config(self, node: dict) -> dict:
|
||||
"""Dify list-operator — convert variable path array to {{ }} selector format."""
|
||||
node_data = node["data"]
|
||||
variable_path = node_data.get("variable", [])
|
||||
input_list = self._process_list_variable_literal(variable_path) or ""
|
||||
filter_by = node_data.get("filter_by", {"enabled": False, "conditions": []})
|
||||
# Convert each condition's comparison_operator from Dify format to native
|
||||
if filter_by.get("conditions"):
|
||||
converted_conditions = []
|
||||
for cond in filter_by["conditions"]:
|
||||
converted_conditions.append({
|
||||
**cond,
|
||||
"comparison_operator": self.convert_compare_operator(
|
||||
cond.get("comparison_operator", "")
|
||||
)
|
||||
})
|
||||
filter_by = {**filter_by, "conditions": converted_conditions}
|
||||
result = {
|
||||
"input_list": input_list,
|
||||
"filter_by": filter_by,
|
||||
"order_by": node_data.get("order_by", {"enabled": False, "key": "", "value": "asc"}),
|
||||
"limit": node_data.get("limit", {"enabled": False, "size": -1}),
|
||||
"extract_by": node_data.get("extract_by", {"enabled": False, "serial": "1"}),
|
||||
}
|
||||
self.config_validate(node["id"], node["data"]["title"], ListOperatorNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_document_extractor_node_config(self, node: dict) -> dict:
|
||||
"""Convert Dify document-extractor node to MemoryBear DocExtractorNodeConfig.
|
||||
|
||||
Dify document-extractor data fields:
|
||||
variable_selector: list[str] - file variable path
|
||||
"""
|
||||
node_data = node["data"]
|
||||
file_selector = self._process_list_variable_literal(
|
||||
node_data.get("variable_selector", [])
|
||||
) or ""
|
||||
result = DocExtractorNodeConfig.model_construct(
|
||||
file_selector=file_selector,
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], DocExtractorNodeConfig, result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def convert_features(features: dict) -> dict:
|
||||
"""Convert Dify features to MemoryBear FeaturesConfigForm format."""
|
||||
if not features:
|
||||
return {}
|
||||
|
||||
result: dict = {}
|
||||
|
||||
# opening_statement
|
||||
opening = features.get("opening_statement", "")
|
||||
suggested = features.get("suggested_questions", [])
|
||||
result["opening_statement"] = {
|
||||
"enabled": bool(opening),
|
||||
"statement": opening or None,
|
||||
"suggested_questions": suggested,
|
||||
}
|
||||
|
||||
# citation (对应 Dify retriever_resource)
|
||||
retriever = features.get("retriever_resource", {})
|
||||
result["citation"] = {
|
||||
"enabled": retriever.get("enabled", False) if isinstance(retriever, dict) else False,
|
||||
}
|
||||
|
||||
# file_upload: Dify allowed_file_types 数组 -> 前端扁平字段
|
||||
file_upload = features.get("file_upload", {})
|
||||
allowed_types = file_upload.get("allowed_file_types", []) if file_upload else []
|
||||
allowed_methods = file_upload.get("allowed_file_upload_methods", ["local_file", "remote_url"])
|
||||
if isinstance(allowed_methods, list):
|
||||
if len(allowed_methods) >= 2:
|
||||
transfer_method = "both"
|
||||
elif allowed_methods:
|
||||
transfer_method = allowed_methods[0]
|
||||
else:
|
||||
transfer_method = "both"
|
||||
else:
|
||||
transfer_method = allowed_methods or "both"
|
||||
|
||||
file_config = file_upload.get("fileUploadConfig", {})
|
||||
result["file_upload"] = {
|
||||
"enabled": file_upload.get("enabled", False) if file_upload else False,
|
||||
"image_enabled": "image" in allowed_types,
|
||||
"image_max_size_mb": file_config.get("image_file_size_limit", 10) if file_config else 10,
|
||||
"image_allowed_extensions": ["png", "jpg", "jpeg"],
|
||||
"audio_enabled": "audio" in allowed_types,
|
||||
"audio_max_size_mb": file_config.get("audio_file_size_limit", 50) if file_config else 50,
|
||||
"audio_allowed_extensions": ["mp3", "wav", "m4a"],
|
||||
"document_enabled": "document" in allowed_types,
|
||||
"document_max_size_mb": file_config.get("file_size_limit", 100) if file_config else 100,
|
||||
"document_allowed_extensions": ["pdf", "docx", "doc", "xlsx", "xls", "txt", "csv", "json", "md"],
|
||||
"video_enabled": "video" in allowed_types,
|
||||
"video_max_size_mb": file_config.get("video_file_size_limit", 100) if file_config else 100,
|
||||
"video_allowed_extensions": ["mp4", "mov"],
|
||||
"max_file_count": file_upload.get("number_limits", 1) if file_upload else 1,
|
||||
"allowed_transfer_methods": transfer_method,
|
||||
}
|
||||
|
||||
# text_to_speech
|
||||
tts = features.get("text_to_speech", {})
|
||||
result["text_to_speech"] = {
|
||||
"enabled": tts.get("enabled", False) if isinstance(tts, dict) else False,
|
||||
"voice": tts.get("voice") if isinstance(tts, dict) else None,
|
||||
"language": tts.get("language") if isinstance(tts, dict) else None,
|
||||
"autoplay": False,
|
||||
}
|
||||
|
||||
# suggested_questions_after_answer
|
||||
sqa = features.get("suggested_questions_after_answer", {})
|
||||
result["suggested_questions_after_answer"] = {
|
||||
"enabled": sqa.get("enabled", False) if isinstance(sqa, dict) else False,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -45,6 +45,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"variable-aggregator": NodeType.VAR_AGGREGATOR,
|
||||
"tool": NodeType.TOOL,
|
||||
"list-operator": NodeType.LIST_OPERATOR,
|
||||
"document-extractor": NodeType.DOCUMENT_EXTRACTOR,
|
||||
"": NodeType.NOTES
|
||||
}
|
||||
|
||||
@@ -117,9 +119,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
if variable:
|
||||
self.conv_variables.append(con_var)
|
||||
|
||||
# for variables in config.get("workflow").get("environment_variables"):
|
||||
# variable = self._convert_variable(variables)
|
||||
# conv_variables.append(variable)
|
||||
# 开始节点的文件变量合并到会话变量
|
||||
self.conv_variables.extend(self._file_vars_to_conv)
|
||||
|
||||
features = self.convert_features(
|
||||
self.config.get("workflow", {}).get("features", {})
|
||||
)
|
||||
|
||||
trigger = self._convert_trigger({})
|
||||
execution_config = self._convert_execution({})
|
||||
@@ -133,6 +138,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
edges=self.edges,
|
||||
nodes=self.nodes,
|
||||
variables=self.conv_variables,
|
||||
features=features,
|
||||
warnings=self.warnings,
|
||||
errors=self.errors
|
||||
)
|
||||
|
||||
@@ -22,6 +22,8 @@ from app.core.workflow.nodes.configs import (
|
||||
MemoryReadNodeConfig,
|
||||
MemoryWriteNodeConfig,
|
||||
NoteNodeConfig,
|
||||
ListOperatorNodeConfig,
|
||||
DocExtractorNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
@@ -51,6 +53,8 @@ class MemoryBearConverter(BaseConverter):
|
||||
NodeType.MEMORY_READ: MemoryReadNodeConfig,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
|
||||
NodeType.NOTES: NoteNodeConfig,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNodeConfig,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNodeConfig,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -31,9 +31,9 @@ logger = logging.getLogger(__name__)
|
||||
# Example:
|
||||
# "Hello {{user.name}}!" ->
|
||||
# ["Hello ", "{{user.name}}", "!"]
|
||||
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
|
||||
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{]+|{')
|
||||
# Strict variable format: {{ node_id.field_name }}
|
||||
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
|
||||
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)?\s*}}')
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
|
||||
@@ -59,6 +59,9 @@ class WorkflowResultBuilder:
|
||||
conversation_vars = variable_pool.get_all_conversation_vars()
|
||||
sys_vars = variable_pool.get_all_system_vars()
|
||||
|
||||
# 汇总所有 knowledge 节点的 citations
|
||||
citations = self.aggregate_citations(node_outputs)
|
||||
|
||||
return {
|
||||
"status": "completed" if success else "failed",
|
||||
"output": final_output,
|
||||
@@ -71,9 +74,25 @@ class WorkflowResultBuilder:
|
||||
"conversation_id": execution_context.conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"citations": citations,
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def aggregate_citations(node_outputs: dict) -> list:
|
||||
"""从所有 knowledge 节点的输出中汇总 citations,去重"""
|
||||
seen = set()
|
||||
citations = []
|
||||
for node_output in node_outputs.values():
|
||||
if not isinstance(node_output, dict):
|
||||
continue
|
||||
for c in node_output.get("citations", []):
|
||||
key = c.get("document_id")
|
||||
if key and key not in seen:
|
||||
seen.add(key)
|
||||
citations.append(c)
|
||||
return citations
|
||||
|
||||
@staticmethod
|
||||
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
|
||||
"""
|
||||
|
||||
@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
def merge_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
merged = dict(x)
|
||||
for k, v in y.items():
|
||||
merged[k] = merged.get(k, False) or v
|
||||
return merged
|
||||
|
||||
|
||||
def merge_looping_state(x, y):
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
logger = get_logger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}"
|
||||
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)?\s*}}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,54 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
|
||||
|
||||
|
||||
class LazyVariableDict:
|
||||
def __init__(self, source, literal):
|
||||
self._source: dict[str, VariableStruct[Any]] = source
|
||||
self._literal: bool = literal
|
||||
self._cache = {}
|
||||
|
||||
def keys(self):
|
||||
return self._source.keys()
|
||||
|
||||
def _resolve(self, key):
|
||||
if key in self._cache:
|
||||
return self._cache[key]
|
||||
var_struct = self._source.get(key)
|
||||
if var_struct is None:
|
||||
return None
|
||||
raw = var_struct.instance.get_value()
|
||||
# literal 模式下 dict/list 保留结构,让 Jinja2 能继续访问子字段(如 .type)
|
||||
value = raw if (not self._literal or isinstance(raw, (dict, list))) else var_struct.instance.to_literal()
|
||||
self._cache[key] = value
|
||||
return value
|
||||
|
||||
def get(self, key, default=None):
|
||||
value = self._resolve(key)
|
||||
return default if value is None else value
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = self._resolve(key)
|
||||
if value is None:
|
||||
raise KeyError(key)
|
||||
return value
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key.startswith('_'):
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
|
||||
return self._resolve(key)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._source
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._source)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._source)
|
||||
|
||||
|
||||
class VariableSelector:
|
||||
"""变量选择器
|
||||
@@ -117,10 +165,9 @@ class VariablePool:
|
||||
|
||||
@staticmethod
|
||||
def transform_selector(selector):
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
||||
variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
|
||||
selector = VariableSelector.from_string(variable_literal).path
|
||||
if len(selector) != 2:
|
||||
if len(selector) not in (2, 3):
|
||||
raise ValueError(f"Selector not valid - {selector}")
|
||||
return selector
|
||||
|
||||
@@ -152,6 +199,19 @@ class VariablePool:
|
||||
return None
|
||||
return var_instance
|
||||
|
||||
@staticmethod
|
||||
def _extract_field(struct: "VariableStruct", field: str | None) -> Any:
|
||||
"""If field is given, drill into a dict/object/array[file] variable's value."""
|
||||
if field is None:
|
||||
return struct.instance.get_value()
|
||||
value = struct.instance.get_value()
|
||||
# array[file]: extract the field from every element, return a list
|
||||
if isinstance(value, list):
|
||||
return [item.get(field) if isinstance(item, dict) else getattr(item, field, None) for item in value]
|
||||
if not isinstance(value, dict):
|
||||
raise KeyError(f"Variable is not an object or array, cannot access field '{field}'")
|
||||
return value.get(field)
|
||||
|
||||
def get_instance(
|
||||
self,
|
||||
selector: str,
|
||||
@@ -206,12 +266,14 @@ class VariablePool:
|
||||
Raises:
|
||||
KeyError: If strict is True and the variable does not exist.
|
||||
"""
|
||||
path = self.transform_selector(selector)
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
if strict:
|
||||
raise KeyError(f"{selector} not exist")
|
||||
return default
|
||||
|
||||
if len(path) == 3:
|
||||
return self._extract_field(variable_struct, path[2])
|
||||
return variable_struct.instance.get_value()
|
||||
|
||||
def get_literal(
|
||||
@@ -238,12 +300,15 @@ class VariablePool:
|
||||
Raises:
|
||||
KeyError: If strict is True and the variable does not exist.
|
||||
"""
|
||||
path = self.transform_selector(selector)
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
if strict:
|
||||
raise KeyError(f"{selector} not exist")
|
||||
return default
|
||||
|
||||
if len(path) == 3:
|
||||
value = self._extract_field(variable_struct, path[2])
|
||||
return str(value) if value is not None else ""
|
||||
return variable_struct.instance.to_literal()
|
||||
|
||||
async def set(
|
||||
@@ -274,7 +339,7 @@ class VariablePool:
|
||||
namespace: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
var_type: VariableType,
|
||||
var_type: VariableType | None,
|
||||
mut: bool
|
||||
):
|
||||
if self.has(f"{namespace}.{key}"):
|
||||
@@ -301,7 +366,24 @@ class VariablePool:
|
||||
Returns:
|
||||
变量是否存在
|
||||
"""
|
||||
return self._get_variable_struct(selector) is not None
|
||||
path = self.transform_selector(selector)
|
||||
struct = self._get_variable_struct(selector)
|
||||
if struct is None:
|
||||
return False
|
||||
if len(path) == 3:
|
||||
value = struct.instance.get_value()
|
||||
return isinstance(value, dict) and path[2] in value
|
||||
return True
|
||||
|
||||
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
|
||||
return LazyVariableDict(self.variables.get(namespace, {}), literal)
|
||||
|
||||
def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
|
||||
return {
|
||||
ns: LazyVariableDict(vars_dict, literal)
|
||||
for ns, vars_dict in self.variables.items()
|
||||
if ns not in ("sys", "conv")
|
||||
}
|
||||
|
||||
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
@@ -439,6 +521,23 @@ class VariablePoolInitializer:
|
||||
var_value = var_default
|
||||
else:
|
||||
var_value = DEFAULT_VALUE(var_type)
|
||||
# Convert FileInput-format dicts to full FileObject dicts
|
||||
if var_type == VariableType.FILE:
|
||||
if not var_value:
|
||||
continue
|
||||
var_value = await self._resolve_file_default(var_value)
|
||||
if not var_value:
|
||||
continue
|
||||
elif var_type == VariableType.ARRAY_FILE:
|
||||
if not var_value:
|
||||
var_value = []
|
||||
else:
|
||||
resolved = []
|
||||
for item in var_value:
|
||||
f = await self._resolve_file_default(item)
|
||||
if f:
|
||||
resolved.append(f)
|
||||
var_value = resolved
|
||||
await variable_pool.new(
|
||||
namespace="conv",
|
||||
key=var_name,
|
||||
@@ -447,6 +546,17 @@ class VariablePoolInitializer:
|
||||
mut=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_file_default(file_def: dict) -> dict | None:
|
||||
"""Accept only already-resolved FileObject dicts (is_file=True).
|
||||
FileInput-format dicts are converted at save time by WorkflowService._resolve_variables_file_defaults.
|
||||
"""
|
||||
if not isinstance(file_def, dict):
|
||||
return None
|
||||
if file_def.get("is_file"):
|
||||
return file_def
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _init_system_vars(
|
||||
variable_pool: VariablePool,
|
||||
@@ -479,5 +589,3 @@ class VariablePoolInitializer:
|
||||
var_type=var_type,
|
||||
mut=False
|
||||
)
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user