Compare commits
1324 Commits
v0.2.7
...
refactor/w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d96b9fab20 | ||
|
|
c27ca5a380 | ||
|
|
aa9eb66668 | ||
|
|
e3ab19dd4f | ||
|
|
d255f33f1f | ||
|
|
6419dcd932 | ||
|
|
9dc9b7aee7 | ||
|
|
cf389bb978 | ||
|
|
d66d601e41 | ||
|
|
4af9b02815 | ||
|
|
1f0c88a5f0 | ||
|
|
7747ed7ac1 | ||
|
|
2355536b44 | ||
|
|
b0ddd12cc6 | ||
|
|
a98011fc8a | ||
|
|
41535c34e6 | ||
|
|
feae2f2e1e | ||
|
|
415234d4c8 | ||
|
|
e38a60e107 | ||
|
|
86eb08c73f | ||
|
|
53f1b0e586 | ||
|
|
49cc47a79a | ||
|
|
1817f52edf | ||
|
|
40633d72c3 | ||
|
|
6f10296969 | ||
|
|
89228825cf | ||
|
|
cab4deb2ff | ||
|
|
4048a10858 | ||
|
|
d6ef0f4923 | ||
|
|
75fbe44839 | ||
|
|
06597c567b | ||
|
|
28694fefb0 | ||
|
|
7a0f08148e | ||
|
|
d3058ce379 | ||
|
|
8d88df391d | ||
|
|
7621321d1b | ||
|
|
0e29b0b2a5 | ||
|
|
2fa4d29548 | ||
|
|
7bb181c1c7 | ||
|
|
a9c87b03ff | ||
|
|
720af8d261 | ||
|
|
09d32ed446 | ||
|
|
9a5ce7f7c6 | ||
|
|
531d785629 | ||
|
|
6d80d74f4a | ||
|
|
3d9882643e | ||
|
|
b4e4be1133 | ||
|
|
16926d9db5 | ||
|
|
f369a63c8d | ||
|
|
1861b0fbc9 | ||
|
|
750d4ca841 | ||
|
|
ce4a3daec7 | ||
|
|
c12d06bb07 | ||
|
|
98d8d7b261 | ||
|
|
12a08a487d | ||
|
|
f7fa33c0c4 | ||
|
|
faf8d1a51a | ||
|
|
adb7f873b5 | ||
|
|
b64bcc2c50 | ||
|
|
8baa466b31 | ||
|
|
d9de96cffa | ||
|
|
dd7f9f6cee | ||
|
|
546bfb9627 | ||
|
|
d5d81f0c4f | ||
|
|
9301eaf8df | ||
|
|
a268d0f7f1 | ||
|
|
610ae27cf9 | ||
|
|
6aef8227b1 | ||
|
|
675c7faf32 | ||
|
|
cd34d5f5ce | ||
|
|
1403b38648 | ||
|
|
b6e27da7b0 | ||
|
|
2c14344d3f | ||
|
|
141fd94513 | ||
|
|
a9413f57d1 | ||
|
|
0fc463036e | ||
|
|
ed5f98a746 | ||
|
|
422af69904 | ||
|
|
6cb48664b7 | ||
|
|
f48bb3cbee | ||
|
|
8dee2eae6a | ||
|
|
f63bcd6321 | ||
|
|
0228e6ad64 | ||
|
|
84ccb1e528 | ||
|
|
caef0fe44e | ||
|
|
21eb500680 | ||
|
|
c70f536acc | ||
|
|
5f96a6380e | ||
|
|
2c864f6337 | ||
|
|
32dfee803a | ||
|
|
4d9cfb70f7 | ||
|
|
4b0afe867a | ||
|
|
676c9a226c | ||
|
|
8f31236303 | ||
|
|
f2aedd29bc | ||
|
|
cf8db47389 | ||
|
|
62af9cd241 | ||
|
|
74be09340c | ||
|
|
cedf47b3bc | ||
|
|
0a51ab619d | ||
|
|
c7c1570d40 | ||
|
|
c556995f3a | ||
|
|
dc0a0ebcae | ||
|
|
2c2551e15c | ||
|
|
be10bab763 | ||
|
|
89f2f9a045 | ||
|
|
f4c168d904 | ||
|
|
1191f0f54e | ||
|
|
58710bc800 | ||
|
|
b33f5951d8 | ||
|
|
279353e1ce | ||
|
|
2d120a64b1 | ||
|
|
0f7a7263eb | ||
|
|
767eb5e6f2 | ||
|
|
5c89acced6 | ||
|
|
9fdb952396 | ||
|
|
fb23c34475 | ||
|
|
4619b40d03 | ||
|
|
5f39d9a208 | ||
|
|
f6cf53f81c | ||
|
|
08a455f6b3 | ||
|
|
5960b5add8 | ||
|
|
7ac0eff0b8 | ||
|
|
c818855bab | ||
|
|
fe2c975d61 | ||
|
|
8deb69b595 | ||
|
|
404ce9f9ba | ||
|
|
aac89b172f | ||
|
|
bf9a3503de | ||
|
|
5c836c90c9 | ||
|
|
fc7d9df3cb | ||
|
|
17905196c9 | ||
|
|
b8009074d5 | ||
|
|
09393b2326 | ||
|
|
eaa66ba71a | ||
|
|
c59a97afba | ||
|
|
9480a61229 | ||
|
|
7ffd250b08 | ||
|
|
52bccfaede | ||
|
|
27f6d18a05 | ||
|
|
2a514a9e04 | ||
|
|
9233e74f36 | ||
|
|
46dfd92a9f | ||
|
|
5f33cec8ad | ||
|
|
334502f06b | ||
|
|
b0bb5e883c | ||
|
|
b9cfc47e1e | ||
|
|
4a4391a19c | ||
|
|
7ccc1068ff | ||
|
|
f650406869 | ||
|
|
7193eed9e3 | ||
|
|
ec6b08cde2 | ||
|
|
f93ec8d609 | ||
|
|
fedb02caf7 | ||
|
|
ae770fb131 | ||
|
|
f8ef32c1dd | ||
|
|
c5ae82c3c2 | ||
|
|
2a03f70287 | ||
|
|
124e8d0639 | ||
|
|
6f323f2435 | ||
|
|
881d74d29d | ||
|
|
903b4f2a6e | ||
|
|
7cd76444f1 | ||
|
|
7dc35bb3fb | ||
|
|
b488590537 | ||
|
|
aa56ad15f9 | ||
|
|
cda20ac3f1 | ||
|
|
d6af459ca8 | ||
|
|
2f7fd85ab1 | ||
|
|
398aebd0c5 | ||
|
|
eaa4058c56 | ||
|
|
21b25bfef7 | ||
|
|
a61acbef93 | ||
|
|
a90757745d | ||
|
|
749083bdbe | ||
|
|
b882863907 | ||
|
|
9159d5cbb0 | ||
|
|
7552a5c8fa | ||
|
|
537f6a1812 | ||
|
|
1ea0f308ba | ||
|
|
f37e9b444b | ||
|
|
5304117ae2 | ||
|
|
77c023102e | ||
|
|
ad24119b2d | ||
|
|
ea6fa154e0 | ||
|
|
158507cf8e | ||
|
|
5e0d30dde8 | ||
|
|
363d775270 | ||
|
|
ad4121b0d8 | ||
|
|
71f62bb591 | ||
|
|
46504fda30 | ||
|
|
1cfad37c64 | ||
|
|
129c9cbb3c | ||
|
|
acafceafb0 | ||
|
|
aff94a766a | ||
|
|
42ebba9090 | ||
|
|
1e95cb6604 | ||
|
|
8b3e3c8044 | ||
|
|
671df83bcd | ||
|
|
8bb5a66401 | ||
|
|
4c9f327833 | ||
|
|
866a5552d4 | ||
|
|
93d4607b14 | ||
|
|
9533a9a693 | ||
|
|
6bd528eace | ||
|
|
2b5bece9b6 | ||
|
|
ea0e65f1ec | ||
|
|
cb2a7aa60a | ||
|
|
402c8aef5d | ||
|
|
eb98a69a84 | ||
|
|
152a84aff3 | ||
|
|
a106f4e3cd | ||
|
|
9c20301a52 | ||
|
|
c5c8be89ed | ||
|
|
30aed72b74 | ||
|
|
35c2d9d0d3 | ||
|
|
27275eee43 | ||
|
|
cde02026d3 | ||
|
|
1a826c0026 | ||
|
|
8cab49c2b1 | ||
|
|
7eb21f677f | ||
|
|
6de5d413c4 | ||
|
|
a2df14f658 | ||
|
|
aecb0f6497 | ||
|
|
83b7c6870d | ||
|
|
74157adb12 | ||
|
|
8011610acc | ||
|
|
f1dc507b5c | ||
|
|
f3ac7e084d | ||
|
|
ba3743f9f1 | ||
|
|
20ddc76a4d | ||
|
|
84ca98555d | ||
|
|
7e6d17e4e3 | ||
|
|
7f3c48ce2a | ||
|
|
e5c16a2a24 | ||
|
|
8887600f7d | ||
|
|
df6eb74b28 | ||
|
|
b4b9974064 | ||
|
|
ff65dee754 | ||
|
|
2c2ed0ebf3 | ||
|
|
d60f838fb8 | ||
|
|
817aa78d03 | ||
|
|
4c73887a48 | ||
|
|
94d2d975ee | ||
|
|
d59990d326 | ||
|
|
3227c25b07 | ||
|
|
dc3207b1d3 | ||
|
|
08b5c7bc8a | ||
|
|
688503a1ca | ||
|
|
475e573891 | ||
|
|
b03300c804 | ||
|
|
a5d07ee66d | ||
|
|
10a655772f | ||
|
|
aeeb18581d | ||
|
|
fb1160e833 | ||
|
|
c448cf0660 | ||
|
|
c50969dea4 | ||
|
|
3a1d222c42 | ||
|
|
10a91ec5cb | ||
|
|
b4812cdac1 | ||
|
|
1744b045fb | ||
|
|
5289b3a2cb | ||
|
|
48f3d9b105 | ||
|
|
559b4bef6b | ||
|
|
4a39fd5f46 | ||
|
|
b22c15cccc | ||
|
|
a2f85b3d98 | ||
|
|
7f1cf13b23 | ||
|
|
d4129edcf5 | ||
|
|
ab2a58d68e | ||
|
|
a28b62763e | ||
|
|
86540a81d1 | ||
|
|
dcd874fecd | ||
|
|
bbd85733b8 | ||
|
|
22c5f12657 | ||
|
|
7b5d7696cb | ||
|
|
cb33724673 | ||
|
|
48b56a3d88 | ||
|
|
83d0fb9387 | ||
|
|
bb964c1ed8 | ||
|
|
81d58b001f | ||
|
|
99bc84a9f2 | ||
|
|
37dbe0f95b | ||
|
|
d4a1904b19 | ||
|
|
ecdad19f54 | ||
|
|
fb93c509f4 | ||
|
|
f597139913 | ||
|
|
113ae59f84 | ||
|
|
62c721bdf6 | ||
|
|
4cbb0cee2f | ||
|
|
8c586935a8 | ||
|
|
d5272af76f | ||
|
|
cf8912e929 | ||
|
|
327c1904b1 | ||
|
|
58c13aaeb4 | ||
|
|
377ddd2b9b | ||
|
|
52f7ea7456 | ||
|
|
b02baedd2c | ||
|
|
f3c3b6255e | ||
|
|
b659e2a6e1 | ||
|
|
e15e32cc7b | ||
|
|
04d20dc094 | ||
|
|
b8123fc84c | ||
|
|
5a17b7fd0d | ||
|
|
e3d0602850 | ||
|
|
696b2d2417 | ||
|
|
a5613314b8 | ||
|
|
e87529876c | ||
|
|
7bb3e65fb7 | ||
|
|
5ada7e77fc | ||
|
|
79b7da44e2 | ||
|
|
26a3d8a41b | ||
|
|
2380cd55ef | ||
|
|
a105df33ab | ||
|
|
749cf79581 | ||
|
|
0dd8cc5d43 | ||
|
|
fd90a4c2ad | ||
|
|
b302a94620 | ||
|
|
c96dc53534 | ||
|
|
f883c1469d | ||
|
|
ddfd81259a | ||
|
|
e015455fb8 | ||
|
|
915cb54f21 | ||
|
|
cada860a16 | ||
|
|
e1f8ad871b | ||
|
|
e205aaa6e6 | ||
|
|
62edafcebe | ||
|
|
ccdf7ae81d | ||
|
|
643f69bb90 | ||
|
|
73fbc19747 | ||
|
|
7ba0726473 | ||
|
|
8c6b65db12 | ||
|
|
5ce0bdb0f5 | ||
|
|
a01525e239 | ||
|
|
b59e2b5bcd | ||
|
|
5a2fe738dc | ||
|
|
f04412c455 | ||
|
|
db6fc5d2db | ||
|
|
b6aca0b1e7 | ||
|
|
4fd7395464 | ||
|
|
78ba313262 | ||
|
|
d35bc3a2cf | ||
|
|
d5c8d16e64 | ||
|
|
09496bd7b9 | ||
|
|
171f25a350 | ||
|
|
c7230659e3 | ||
|
|
502d87e88d | ||
|
|
1faa258e23 | ||
|
|
bef6a50deb | ||
|
|
cc12ec3fa8 | ||
|
|
466864afe3 | ||
|
|
643a3fbe09 | ||
|
|
e0d7a5a91f | ||
|
|
5ac2d5602e | ||
|
|
f4c3974956 | ||
|
|
71e5b6586a | ||
|
|
bfb723a468 | ||
|
|
61f2e44bd5 | ||
|
|
ed765b7c26 | ||
|
|
3018d186f7 | ||
|
|
2e1470cb52 | ||
|
|
737858731b | ||
|
|
d072eb1af7 | ||
|
|
2716a55c7f | ||
|
|
daaee63bd5 | ||
|
|
e3c643b659 | ||
|
|
017efdc320 | ||
|
|
29aef4527c | ||
|
|
d9cb2b511b | ||
|
|
18be1a9f89 | ||
|
|
49e0801d15 | ||
|
|
dde7ea9039 | ||
|
|
3e48d620b2 | ||
|
|
5262aedab9 | ||
|
|
441b21774d | ||
|
|
d6dd038167 | ||
|
|
47c242e513 | ||
|
|
811193dd75 | ||
|
|
797780824c | ||
|
|
75e95bab01 | ||
|
|
e7a400bb96 | ||
|
|
28ca4d1734 | ||
|
|
5e6490213d | ||
|
|
3b359df02f | ||
|
|
fcf3071cb0 | ||
|
|
1294aabbcc | ||
|
|
3c2a78a449 | ||
|
|
4f0e5d0866 | ||
|
|
7a84ee33c6 | ||
|
|
e3265e4ba3 | ||
|
|
3e7a004599 | ||
|
|
fa1e5ee43c | ||
|
|
c72a6fd724 | ||
|
|
0965008210 | ||
|
|
bcadd2a6f1 | ||
|
|
e4f306dabb | ||
|
|
b5ec5c2cea | ||
|
|
e539b3eeb7 | ||
|
|
7f8765b815 | ||
|
|
72b39c6fa3 | ||
|
|
9032f50a19 | ||
|
|
aa683efaa0 | ||
|
|
2d9986f902 | ||
|
|
06075ffef5 | ||
|
|
a7336b0829 | ||
|
|
0d16e168e7 | ||
|
|
a882e5e5c4 | ||
|
|
c614bb5be7 | ||
|
|
1ff0f3ebfd | ||
|
|
bafcb5c545 | ||
|
|
f8d27fada6 | ||
|
|
90365cd026 | ||
|
|
d96c7b88f0 | ||
|
|
99559621c5 | ||
|
|
926f65a1ff | ||
|
|
b20971dc95 | ||
|
|
1ff0274027 | ||
|
|
8495aa5dde | ||
|
|
d8ef7a8e02 | ||
|
|
7a4a02b2bb | ||
|
|
8f623a66c8 | ||
|
|
77ed9faea1 | ||
|
|
1ff3748935 | ||
|
|
f023c43f80 | ||
|
|
60124e3232 | ||
|
|
70d4e79de1 | ||
|
|
59b5a1bcf2 | ||
|
|
62f345b3de | ||
|
|
a3f0415cd3 | ||
|
|
2450fe3afe | ||
|
|
52e726eabc | ||
|
|
7ca80b5d01 | ||
|
|
9470dd2f1e | ||
|
|
10f1089198 | ||
|
|
095f4e3001 | ||
|
|
ef8c7093b5 | ||
|
|
05ea372776 | ||
|
|
2b067ce08a | ||
|
|
b63cff2993 | ||
|
|
5bb9ce9018 | ||
|
|
aa581a9083 | ||
|
|
ac51ccaf1f | ||
|
|
dca3173ed9 | ||
|
|
5eaedaad77 | ||
|
|
bd955569b3 | ||
|
|
7a2a941ac4 | ||
|
|
19fa8314e4 | ||
|
|
cba24e58db | ||
|
|
62355186ef | ||
|
|
82faedc972 | ||
|
|
11ea486f82 | ||
|
|
efdee32f85 | ||
|
|
988d101e93 | ||
|
|
418f9f4dba | ||
|
|
520ee7c132 | ||
|
|
72be9f75f9 | ||
|
|
2b52b32b96 | ||
|
|
a96f20ee05 | ||
|
|
b8acc0a32f | ||
|
|
e1cf3bb3d2 | ||
|
|
6f66c9727f | ||
|
|
3beca641e1 | ||
|
|
b8507a1df6 | ||
|
|
0f28d54c43 | ||
|
|
0afc38e7ef | ||
|
|
07fd85c342 | ||
|
|
4c2a1e6d1d | ||
|
|
7cfb6ace22 | ||
|
|
91cc20d589 | ||
|
|
f01ca51896 | ||
|
|
f4a63f7d55 | ||
|
|
0019f3acfd | ||
|
|
3fe90a5e13 | ||
|
|
bc14c94407 | ||
|
|
a21dad70ed | ||
|
|
807a4e715d | ||
|
|
58d18b476c | ||
|
|
5e5927a0b9 | ||
|
|
7869121382 | ||
|
|
7c0fb624d9 | ||
|
|
af83980f99 | ||
|
|
cf0d11208c | ||
|
|
87d1630230 | ||
|
|
50392384e7 | ||
|
|
9a926a8398 | ||
|
|
e5e6699168 | ||
|
|
068e2bfb7e | ||
|
|
4ce6fede67 | ||
|
|
8497c955f9 | ||
|
|
72fe3962cf | ||
|
|
c253968aa8 | ||
|
|
d517bceda2 | ||
|
|
412183c359 | ||
|
|
90e8e90528 | ||
|
|
fd05c000f6 | ||
|
|
627d6a0381 | ||
|
|
807dee8460 | ||
|
|
ac7d39524e | ||
|
|
cd018814fe | ||
|
|
e0b7e95af6 | ||
|
|
3a62d50048 | ||
|
|
0e60da6d8a | ||
|
|
39e94eb3ea | ||
|
|
3e0f59adc6 | ||
|
|
660cd2fadb | ||
|
|
6f1bb43eab | ||
|
|
61b5627505 | ||
|
|
af6392fb09 | ||
|
|
84b1a95313 | ||
|
|
8b21dab255 | ||
|
|
fc5ce63e44 | ||
|
|
15a863b41a | ||
|
|
5226c5b79d | ||
|
|
27e9f9968d | ||
|
|
d38612a10d | ||
|
|
32c71dcd89 | ||
|
|
428e7ebaa5 | ||
|
|
57833689d9 | ||
|
|
384a67482c | ||
|
|
7842435321 | ||
|
|
33c4c5d31b | ||
|
|
ca4f7aa65d | ||
|
|
b875626f18 | ||
|
|
130684cac0 | ||
|
|
5adff38bda | ||
|
|
62e0b2730b | ||
|
|
55b2e05ba8 | ||
|
|
562ca6c1f1 | ||
|
|
e298b38de9 | ||
|
|
a7b8ba0c66 | ||
|
|
460c86cd94 | ||
|
|
33a1c178ff | ||
|
|
c81612e6d3 | ||
|
|
9f9ac69f97 | ||
|
|
0516822d42 | ||
|
|
b598171a3d | ||
|
|
a4ea7f0385 | ||
|
|
32ae60fc65 | ||
|
|
6b272c5b44 | ||
|
|
2782d0661f | ||
|
|
ea2f5e61c9 | ||
|
|
5975d70bf9 | ||
|
|
e0546e01ef | ||
|
|
70aab94fc3 | ||
|
|
0f50537d7d | ||
|
|
b7c1ce261b | ||
|
|
edac6a164e | ||
|
|
1503b242ea | ||
|
|
3ff44f0108 | ||
|
|
18fd48505d | ||
|
|
807ddce5cd | ||
|
|
62fb6c79a0 | ||
|
|
cc373b2864 | ||
|
|
f2d7479229 | ||
|
|
ae1909b7e9 | ||
|
|
8e397b83b6 | ||
|
|
b0aaa12340 | ||
|
|
5eb65e7ad8 | ||
|
|
cb5610e8b1 | ||
|
|
6bb01119d0 | ||
|
|
c16e832081 | ||
|
|
e3d50c5c55 | ||
|
|
e64aadce95 | ||
|
|
bad6087c25 | ||
|
|
b04c05f4a4 | ||
|
|
5e372627f7 | ||
|
|
29611738ce | ||
|
|
de846c05ab | ||
|
|
6475387af8 | ||
|
|
b330bdba29 | ||
|
|
bed279c604 | ||
|
|
9eaf779e67 | ||
|
|
1fbccd98a7 | ||
|
|
931b800bb6 | ||
|
|
d76bb36b9f | ||
|
|
3c93409f7f | ||
|
|
9451a08e7f | ||
|
|
bc49bd2a43 | ||
|
|
4bfc9ca991 | ||
|
|
1ba60401af | ||
|
|
236e8973ac | ||
|
|
ca6cc8ae63 | ||
|
|
dd2cc89c62 | ||
|
|
a87bba93c2 | ||
|
|
a153fdb7cb | ||
|
|
4eed393db5 | ||
|
|
dc8e432719 | ||
|
|
16a0099e27 | ||
|
|
c4cf639bbc | ||
|
|
f91431a70d | ||
|
|
0102ad3a30 | ||
|
|
ebe298b71d | ||
|
|
a486ca7857 | ||
|
|
dd40e5df5f | ||
|
|
8e1ec1bae6 | ||
|
|
1f9c4919be | ||
|
|
065182ad5c | ||
|
|
90ec8db0d8 | ||
|
|
78baf1c60a | ||
|
|
072c94cccb | ||
|
|
a66030d1b3 | ||
|
|
a90ceaf5a2 | ||
|
|
725f2f5146 | ||
|
|
a2c3357e80 | ||
|
|
acbc954e6f | ||
|
|
77032583ab | ||
|
|
ef533d27ac | ||
|
|
ca1a2c7b9e | ||
|
|
145fa398dd | ||
|
|
274430b2c9 | ||
|
|
e9972834fe | ||
|
|
1ecc04fee7 | ||
|
|
78cd1f69a3 | ||
|
|
aabd9a1b57 | ||
|
|
b9439b337a | ||
|
|
eb9f4f39f1 | ||
|
|
baa4b56426 | ||
|
|
49bcc6131b | ||
|
|
0d3f6f1e14 | ||
|
|
84536925c6 | ||
|
|
b22a5a9f12 | ||
|
|
b8825a83dd | ||
|
|
08b4d5c1cf | ||
|
|
a5dfc472d3 | ||
|
|
48d29bcc63 | ||
|
|
856c6f6d78 | ||
|
|
bfc47ad738 | ||
|
|
841b6abb33 | ||
|
|
8a1114a1a7 | ||
|
|
be8c481d6d | ||
|
|
5d439346a1 | ||
|
|
ed753caaf7 | ||
|
|
9a931389ea | ||
|
|
b9d469b6e3 | ||
|
|
e817cfd292 | ||
|
|
af86cb3556 | ||
|
|
e48b146e60 | ||
|
|
07b66a9801 | ||
|
|
c3ee3c4af9 | ||
|
|
cd8229f370 | ||
|
|
9a4a614fc8 | ||
|
|
0b5a030e46 | ||
|
|
675d0fc5ef | ||
|
|
6291c28f0a | ||
|
|
30b512e554 | ||
|
|
33c73c6c6f | ||
|
|
072d118935 | ||
|
|
2e7ebf174b | ||
|
|
3ece83d419 | ||
|
|
9c1c232b2e | ||
|
|
bfc98efc9d | ||
|
|
cfbf83f71e | ||
|
|
a43e8fa594 | ||
|
|
6c8d0d9d64 | ||
|
|
bd2a3bd7ef | ||
|
|
1f72b8aa70 | ||
|
|
9bb32888a2 | ||
|
|
caee5d214e | ||
|
|
38f3455bab | ||
|
|
d60cb423a4 | ||
|
|
b20a65ce29 | ||
|
|
99862db7a0 | ||
|
|
00a8099857 | ||
|
|
117e29fbe3 | ||
|
|
32740e8159 | ||
|
|
bc5ea2d421 | ||
|
|
d34bf4bc89 | ||
|
|
c4ff1a325b | ||
|
|
d1f0258065 | ||
|
|
5db59bc9cf | ||
|
|
a711635694 | ||
|
|
15b3ce3dd5 | ||
|
|
9cc19047b4 | ||
|
|
2e8e63878e | ||
|
|
38955d7d45 | ||
|
|
b6167d4e94 | ||
|
|
7890970a39 | ||
|
|
203732de1d | ||
|
|
4961e7df79 | ||
|
|
fa4be10e51 | ||
|
|
1b52850526 | ||
|
|
1732fc7af5 | ||
|
|
a52e2137b7 | ||
|
|
377f79773d | ||
|
|
cae87de6ef | ||
|
|
63235de42b | ||
|
|
106a32bc3a | ||
|
|
dcb7b496d3 | ||
|
|
2f0bb793d8 | ||
|
|
010eff17cf | ||
|
|
0b47194f12 | ||
|
|
9ff3a3d5f7 | ||
|
|
abbd92b74c | ||
|
|
960ee9f2df | ||
|
|
1c133d3d6c | ||
|
|
d270d25a99 | ||
|
|
8abd59b26e | ||
|
|
bd48b4fdbe | ||
|
|
9535545947 | ||
|
|
aad6955709 | ||
|
|
18703919a8 | ||
|
|
9f2cd6afae | ||
|
|
d1beb9e5d5 | ||
|
|
2c7aaebdd5 | ||
|
|
be38c9e385 | ||
|
|
1aec7115a5 | ||
|
|
9facb513b2 | ||
|
|
9bce14be4e | ||
|
|
59f5c7a8bb | ||
|
|
12f3a3ed77 | ||
|
|
8b9eb81d36 | ||
|
|
4fb3d6992c | ||
|
|
370a668ead | ||
|
|
daaad51357 | ||
|
|
6eca5f6cdf | ||
|
|
f61f86f8fe | ||
|
|
57eb5aa967 | ||
|
|
1305a08c86 | ||
|
|
cf519738f4 | ||
|
|
cdebe014cf | ||
|
|
853ce6f4e1 | ||
|
|
9cbe9d5edc | ||
|
|
767f9ab17c | ||
|
|
7b5b2ab31a | ||
|
|
924d10ac5b | ||
|
|
0470a71d03 | ||
|
|
378b110d91 | ||
|
|
5f7db778b5 | ||
|
|
0d15457299 | ||
|
|
ad4ddea977 | ||
|
|
75bb96d4e7 | ||
|
|
68fdf5d76f | ||
|
|
258c19f9e0 | ||
|
|
386ed2b914 | ||
|
|
264183cec2 | ||
|
|
9561578a2a | ||
|
|
7ce29019f7 | ||
|
|
99ff07ccac | ||
|
|
e77a1a92fd | ||
|
|
d3cd66fc6e | ||
|
|
b95a627424 | ||
|
|
c9ca5df05c | ||
|
|
70c3c7dd74 | ||
|
|
b482822629 | ||
|
|
8f609ba29c | ||
|
|
a1ef5146d7 | ||
|
|
8b997b422a | ||
|
|
6d6338eb06 | ||
|
|
b5c5863b39 | ||
|
|
ab45b7abac | ||
|
|
2dfc3b25d8 | ||
|
|
3ea42ac27f | ||
|
|
fff5e0e8b8 | ||
|
|
fe29141437 | ||
|
|
17d3c81c02 | ||
|
|
ef626951bc | ||
|
|
4533644e13 | ||
|
|
ca255304d9 | ||
|
|
b40f4829cb | ||
|
|
52ae914e17 | ||
|
|
baf02e4faa | ||
|
|
87c2419186 | ||
|
|
2ad25c48d2 | ||
|
|
75e8caf441 | ||
|
|
4d6038c3cc | ||
|
|
d4450658a8 | ||
|
|
02660c7c97 | ||
|
|
3ceb2efeaf | ||
|
|
e134b96333 | ||
|
|
3ea57d1cb0 | ||
|
|
4a71484151 | ||
|
|
db8b3416a6 | ||
|
|
4df41966fe | ||
|
|
2d6cde157e | ||
|
|
abc27c8372 | ||
|
|
dbe387f666 | ||
|
|
5e70d436a8 | ||
|
|
b7198f1abd | ||
|
|
5c87a2beeb | ||
|
|
3419bb137a | ||
|
|
a00684c67d | ||
|
|
6e7c641fd4 | ||
|
|
876c39b1b0 | ||
|
|
0c677701c0 | ||
|
|
4974f9aa98 | ||
|
|
c90b58bbcd | ||
|
|
d6a243f1be | ||
|
|
418114ef72 | ||
|
|
ceed61167f | ||
|
|
83774d7443 | ||
|
|
052c7c19b3 | ||
|
|
d42db0ca33 | ||
|
|
e15af5a2ba | ||
|
|
8b44b2cd61 | ||
|
|
9d91453200 | ||
|
|
ea8db7cd90 | ||
|
|
d60f16df1b | ||
|
|
3cca35a74f | ||
|
|
8dd24533bf | ||
|
|
ed90405439 | ||
|
|
533000030f | ||
|
|
a58ac385b1 | ||
|
|
91b7f2a980 | ||
|
|
891cfc2704 | ||
|
|
f7e89af9d2 | ||
|
|
afbd8c9b4f | ||
|
|
09b3b01d37 | ||
|
|
e3dcbed5f9 | ||
|
|
c7b51e7ad8 | ||
|
|
e9ad13504a | ||
|
|
c0cd2373c0 | ||
|
|
6e757ae9e2 | ||
|
|
64a73c41d6 | ||
|
|
dae7431075 | ||
|
|
643bbbcf5c | ||
|
|
6702e86536 | ||
|
|
13e35ed122 | ||
|
|
ab2bdfa088 | ||
|
|
8285250096 | ||
|
|
e59a215078 | ||
|
|
c89eccf8fe | ||
|
|
5703fc0cb4 | ||
|
|
7acb7045f0 | ||
|
|
3aed5c447a | ||
|
|
13352178ad | ||
|
|
f9f302dd2a | ||
|
|
8f216db353 | ||
|
|
9f6026492d | ||
|
|
b699b746a5 | ||
|
|
6095170169 | ||
|
|
173697e86a | ||
|
|
5c11da6a2e | ||
|
|
96214c433f | ||
|
|
167c915631 | ||
|
|
f485398768 | ||
|
|
289b1989e5 | ||
|
|
8224848ce1 | ||
|
|
c43d258455 | ||
|
|
c3e5c8b8bb | ||
|
|
930cadcaa8 | ||
|
|
57b6b34567 | ||
|
|
f878846364 | ||
|
|
7dce63dc0b | ||
|
|
03bc8ee7f5 | ||
|
|
4aefb01b0b | ||
|
|
4e9b5736b1 | ||
|
|
46fa99a8b8 | ||
|
|
17ea92357d | ||
|
|
bd70a8b812 | ||
|
|
ad5dc3c138 | ||
|
|
e37b1b01ca | ||
|
|
e659ca9fa2 | ||
|
|
758be0087f | ||
|
|
200c13b59f | ||
|
|
32f6886000 | ||
|
|
7fbf3e8873 | ||
|
|
3026702000 | ||
|
|
8677db114b | ||
|
|
2597a1f532 | ||
|
|
4298cd7d06 | ||
|
|
8197f9db35 | ||
|
|
3da6331515 | ||
|
|
539999131c | ||
|
|
d0ca5c8b27 | ||
|
|
ee6b8ffa62 | ||
|
|
14838dc064 | ||
|
|
e017870f44 | ||
|
|
9730c5ce0f | ||
|
|
bca43fcc75 | ||
|
|
f30260939a | ||
|
|
8ba0a74473 | ||
|
|
4f69224cfd | ||
|
|
6f7fee18c9 | ||
|
|
7fd00009a2 | ||
|
|
4534b65d6a | ||
|
|
cc58c7333c | ||
|
|
c936277507 | ||
|
|
701df40270 | ||
|
|
b724dbe53a | ||
|
|
ac7c891ded | ||
|
|
a5bce221bd | ||
|
|
3ed6f49bb0 | ||
|
|
a416a6b2bd | ||
|
|
35be03803f | ||
|
|
6427018ffb | ||
|
|
06b823ff96 | ||
|
|
0fdb489227 | ||
|
|
f6394a791e | ||
|
|
4bfd4944d0 | ||
|
|
7faf291ec3 | ||
|
|
3d291e3c23 | ||
|
|
b35bedc730 | ||
|
|
4d39cdf464 | ||
|
|
a874cc70a4 | ||
|
|
2319432182 | ||
|
|
7556468c6e | ||
|
|
91d38c0648 | ||
|
|
df3d58d388 | ||
|
|
80856e3c92 | ||
|
|
8c6f395818 | ||
|
|
2f4f7219e3 | ||
|
|
4c5183eddc | ||
|
|
dfc0ee9424 | ||
|
|
8dbb067b83 | ||
|
|
1df3fc416a | ||
|
|
6223b80cc4 | ||
|
|
68489f1b28 | ||
|
|
477853b04e | ||
|
|
863be50aaf | ||
|
|
d72d57f966 | ||
|
|
5b940e5f1a | ||
|
|
9ae1d2f0d9 | ||
|
|
318f1be107 | ||
|
|
4cab6317de | ||
|
|
81bfc9af36 | ||
|
|
189013f0f8 | ||
|
|
6f5bcd18a4 | ||
|
|
c7ef97c7a6 | ||
|
|
4d4a780ab7 | ||
|
|
9d2f3aa8f9 | ||
|
|
f2c9902a07 | ||
|
|
2525f8795c | ||
|
|
b7a03a844f | ||
|
|
c13c3846d1 | ||
|
|
30b5db1e98 | ||
|
|
f92eb9f45a | ||
|
|
a136d44e27 | ||
|
|
65b2f9e6e1 | ||
|
|
5275a274c3 | ||
|
|
4f09c4fbb3 | ||
|
|
7a3220aff5 | ||
|
|
14a32778f7 | ||
|
|
2a12cb04bf | ||
|
|
1e986c641f | ||
|
|
38c6c7f053 | ||
|
|
7c0743eb8f | ||
|
|
e981f066a3 | ||
|
|
db14d40fb3 | ||
|
|
e8d575fd0b | ||
|
|
a7285e35ad | ||
|
|
c4461c4917 | ||
|
|
2df615eca0 | ||
|
|
504e5ba61e | ||
|
|
0bae290e0c | ||
|
|
294ee49d59 | ||
|
|
26c36f70e6 | ||
|
|
c4b83b1f9c | ||
|
|
14413fd413 | ||
|
|
caab58dd2f | ||
|
|
0e899bea05 | ||
|
|
1794f8f209 | ||
|
|
85daf576e9 | ||
|
|
56fd5680cf | ||
|
|
0380c13a3b | ||
|
|
9ddc523f91 | ||
|
|
491ef27b8a | ||
|
|
edd115582f | ||
|
|
45eef12842 | ||
|
|
49364802c2 | ||
|
|
8873078006 | ||
|
|
2b9fd33bc8 | ||
|
|
e86d679ae5 | ||
|
|
def7367e33 | ||
|
|
54cff5861a | ||
|
|
dc2a73155b | ||
|
|
1856c55c04 | ||
|
|
522eb569f1 | ||
|
|
9df41456f6 | ||
|
|
04c54081c8 | ||
|
|
1c49e3c167 | ||
|
|
fb6ce839d2 | ||
|
|
c7275dccac | ||
|
|
d62b484d71 | ||
|
|
8ff1c6bd08 | ||
|
|
3dcf901043 | ||
|
|
d6dfc2cb12 | ||
|
|
8a3032ce4a | ||
|
|
391c60c812 | ||
|
|
b739b032d9 | ||
|
|
3dc863cabf | ||
|
|
611b14dfea | ||
|
|
de6e2f54d2 | ||
|
|
89d188fbf3 | ||
|
|
6bba574ca6 | ||
|
|
9cbffd6408 | ||
|
|
4d2ad5757c | ||
|
|
cd0ca9cae4 | ||
|
|
3369b702e4 | ||
|
|
cbec2c1356 | ||
|
|
5987eee0a8 | ||
|
|
6348304b7d | ||
|
|
59f8010519 | ||
|
|
9308c6efae | ||
|
|
2f78b7cf5e | ||
|
|
f86448f4bf | ||
|
|
48e2e613bb | ||
|
|
1060074740 | ||
|
|
95b7df7e38 | ||
|
|
fd1634eec4 | ||
|
|
efeead41b2 | ||
|
|
a3428c2435 | ||
|
|
31b8a3764e | ||
|
|
2ff81ba101 | ||
|
|
93deb286a3 | ||
|
|
7bd97bf6d3 | ||
|
|
2d1a1b4a1f | ||
|
|
503c890d93 | ||
|
|
1f73501786 | ||
|
|
eef13cb717 | ||
|
|
c70ac1339e | ||
|
|
24c13d408e | ||
|
|
338d7f1065 | ||
|
|
27672cfaa0 | ||
|
|
4dbb2bf2e2 | ||
|
|
37bc4beab4 | ||
|
|
6056952936 | ||
|
|
31085ed678 | ||
|
|
dce7206c44 | ||
|
|
c17a2dad2d | ||
|
|
e8ae46b286 | ||
|
|
78316de411 | ||
|
|
c205e7d20e | ||
|
|
81f3b50200 | ||
|
|
e3795fe1ed | ||
|
|
72a2f2a7e8 | ||
|
|
0f092e08f4 | ||
|
|
8e7603bcc4 | ||
|
|
035cc17264 | ||
|
|
a079358028 | ||
|
|
cf26c9f39c | ||
|
|
fa29a39920 | ||
|
|
2146c555d2 | ||
|
|
240f1d431b | ||
|
|
9f947a3395 | ||
|
|
bf5c4628c3 | ||
|
|
911d5e0b34 | ||
|
|
bd31aa5abf | ||
|
|
0775fad5f0 | ||
|
|
726148d7ee | ||
|
|
0f1b1d7d10 | ||
|
|
fabc8936ab | ||
|
|
11aa2e1f9e | ||
|
|
ca654cca74 | ||
|
|
bd1f649bd0 | ||
|
|
06de54ebfd | ||
|
|
ea00747c66 | ||
|
|
3db031891e | ||
|
|
fb6ca3909a | ||
|
|
929afb1770 | ||
|
|
6235584b2e | ||
|
|
0b1ea33b41 | ||
|
|
3929f811b8 | ||
|
|
7c6e48b04e | ||
|
|
b1b53f6b1d | ||
|
|
551a2b59a5 | ||
|
|
9a765ac71e | ||
|
|
83e26732de | ||
|
|
52fdfc7744 | ||
|
|
4e544325a0 | ||
|
|
99a2f396fd | ||
|
|
0157c9d262 | ||
|
|
5ddacab162 | ||
|
|
a51e34852c | ||
|
|
fcc81ac025 | ||
|
|
36f670b2e9 | ||
|
|
cbcbc8822c | ||
|
|
69c001bf84 | ||
|
|
aa2d1e7a35 | ||
|
|
39b2f3ba0e | ||
|
|
43064ab71b | ||
|
|
4144f0b9b5 | ||
|
|
08f0be17ce | ||
|
|
2915e464bf | ||
|
|
152559ae46 | ||
|
|
1f531f1ace | ||
|
|
7ec947189c | ||
|
|
b4615bacdc | ||
|
|
e849fed5c1 | ||
|
|
0f5cae4590 | ||
|
|
1c3029f360 | ||
|
|
e2411e0bdd | ||
|
|
7af88b19cf | ||
|
|
c3f8dbd4bc | ||
|
|
c1e48fde86 | ||
|
|
f644c84fbb | ||
|
|
d0afce27c4 | ||
|
|
b84aba71e7 | ||
|
|
2e481df465 | ||
|
|
a322ec4fd5 | ||
|
|
bdbf9c0609 | ||
|
|
ef7d59e442 | ||
|
|
27b782e12a | ||
|
|
37a22fbfa9 | ||
|
|
d798d101f7 | ||
|
|
825f225f63 | ||
|
|
4d5e2958dc | ||
|
|
6105d46198 | ||
|
|
7aec157859 | ||
|
|
13abb03d87 | ||
|
|
e8947ad0bb | ||
|
|
7056865726 | ||
|
|
9d8c26b999 | ||
|
|
c2c832f8c9 | ||
|
|
6bc4f04293 | ||
|
|
9d150ab353 | ||
|
|
f045b59b2d | ||
|
|
0bb8278a39 | ||
|
|
e43f812c14 | ||
|
|
d584b47280 | ||
|
|
3e995cd971 | ||
|
|
b018e35ada | ||
|
|
4bc030c1ef | ||
|
|
86a0aa1f9f | ||
|
|
d523e4f3c6 | ||
|
|
84c23e7c4e | ||
|
|
186d097e00 | ||
|
|
c5cfe557da | ||
|
|
f786a66a3c | ||
|
|
ebd51928d7 | ||
|
|
2258b5c43c | ||
|
|
2e50e30071 | ||
|
|
8c804a1011 | ||
|
|
1a4c2d7cd0 | ||
|
|
c2fc4ab4ff | ||
|
|
83fcabadae | ||
|
|
d12ad213e0 | ||
|
|
33d522b387 | ||
|
|
5997458aaf | ||
|
|
68f9471caf | ||
|
|
ecbb61db27 | ||
|
|
b42815ee7a | ||
|
|
49d7398e14 | ||
|
|
91589c1497 | ||
|
|
a07727c047 | ||
|
|
25bc506f74 | ||
|
|
18ca83d763 | ||
|
|
4bbc561625 | ||
|
|
d77220a603 | ||
|
|
f52b681133 | ||
|
|
f6efa0d711 | ||
|
|
0fccc91dac | ||
|
|
8d8c6c695a | ||
|
|
57342259ce | ||
|
|
be46ed8865 | ||
|
|
04b2205769 | ||
|
|
76ba357982 | ||
|
|
2c318f6e60 | ||
|
|
3f04153f22 | ||
|
|
3df8af3852 | ||
|
|
8b9ab8a841 | ||
|
|
750dbcc7c3 | ||
|
|
5d6007aaff | ||
|
|
291767031c | ||
|
|
22ffe6ef1d | ||
|
|
02df1a70f3 | ||
|
|
8c5fa9c441 | ||
|
|
e6c558c2a0 | ||
|
|
b52e4d756c | ||
|
|
1089a52ca0 | ||
|
|
c7fb9ab8e3 | ||
|
|
83017d0c80 | ||
|
|
e24217a6ba | ||
|
|
a0f2f738df | ||
|
|
9d9250954b | ||
|
|
f042f44501 | ||
|
|
56c98648f9 | ||
|
|
956efe6a09 | ||
|
|
bb64ad23dd | ||
|
|
a97326df74 | ||
|
|
1503f8781a | ||
|
|
163ddbb6ed | ||
|
|
7bbfd33ca0 | ||
|
|
0ea47ce890 | ||
|
|
38f891235c | ||
|
|
4d83c074d9 | ||
|
|
0e9672df80 | ||
|
|
abc7460539 | ||
|
|
4bb2ccfba7 | ||
|
|
969d428320 | ||
|
|
ff64522c50 | ||
|
|
65dc1a8f48 | ||
|
|
859b7f3c7f | ||
|
|
da3f875555 | ||
|
|
44d63a44da | ||
|
|
7e5e1609b0 | ||
|
|
d94adcb19c | ||
|
|
83894df260 | ||
|
|
7b99a32a1e | ||
|
|
e8c3744f5e | ||
|
|
06d1f54030 | ||
|
|
599ccb6bde | ||
|
|
db9050c302 | ||
|
|
71b3b665b5 | ||
|
|
3b8a806661 | ||
|
|
774719fb50 | ||
|
|
a3ccd41288 | ||
|
|
8ddacb7bc9 | ||
|
|
e74a74c3fb | ||
|
|
262a9ddc48 | ||
|
|
70f84b65ec | ||
|
|
ec5cb42f67 | ||
|
|
0802481fd2 | ||
|
|
548ba0ae36 | ||
|
|
fc2360d40d | ||
|
|
ab67bda5a1 | ||
|
|
376d5ca7d0 | ||
|
|
55438136b0 | ||
|
|
82db3517d7 | ||
|
|
130490c022 | ||
|
|
ede8a11584 | ||
|
|
ba65b06582 | ||
|
|
f4f04036f3 | ||
|
|
43130dcbc8 | ||
|
|
ff6459e439 | ||
|
|
1893de4c75 | ||
|
|
dfcc85a466 | ||
|
|
dacfb360f6 | ||
|
|
8a0d83b340 | ||
|
|
be2ce854a1 | ||
|
|
e492dcd968 | ||
|
|
55bfee856d | ||
|
|
f951075551 | ||
|
|
964086a08a | ||
|
|
67501025b3 | ||
|
|
e1cc5c841a | ||
|
|
6b839bd5a8 | ||
|
|
5df339b56d | ||
|
|
56adca9f22 | ||
|
|
1e63dd8d2d | ||
|
|
fab9272124 | ||
|
|
2f66fd9aae | ||
|
|
5616583fa1 | ||
|
|
3f0e991112 | ||
|
|
477d404727 | ||
|
|
8e6288bca8 | ||
|
|
72bba0662f | ||
|
|
090f46006a | ||
|
|
abe0c7e7d1 | ||
|
|
6516f56ada | ||
|
|
ea391dc44e | ||
|
|
e21f713de0 | ||
|
|
3498e2e884 | ||
|
|
ea8edc5914 | ||
|
|
b62c40dba3 | ||
|
|
0832337839 | ||
|
|
b82f4491fb | ||
|
|
bdf0c256b3 | ||
|
|
3d91a9e926 | ||
|
|
779dbdea26 | ||
|
|
e8e342c206 | ||
|
|
78829d36cc | ||
|
|
f7c2e82dc0 | ||
|
|
88598fb9fb | ||
|
|
19d149c129 | ||
|
|
f09de3a11c | ||
|
|
e13acdc8a9 | ||
|
|
b8e85bed61 | ||
|
|
396493ad2b | ||
|
|
f32d92b9d0 | ||
|
|
6d79db8ba3 | ||
|
|
f9fb480cc3 | ||
|
|
1efa8798bf | ||
|
|
c244e9834f | ||
|
|
b1a7b58f97 | ||
|
|
e81f39b50e | ||
|
|
a0c4515a81 | ||
|
|
4bf418a3d6 | ||
|
|
f033607c8b | ||
|
|
860cd31799 | ||
|
|
d674b48f7d | ||
|
|
07c899f0a9 | ||
|
|
382e4c5377 | ||
|
|
fe6518d052 | ||
|
|
dc513dfbeb | ||
|
|
3d9bc7a986 | ||
|
|
3d79b72d70 | ||
|
|
6eb9b772e7 | ||
|
|
90c8ff35d1 | ||
|
|
ad87fd96db | ||
|
|
c7cc0cd922 | ||
|
|
81a232177e | ||
|
|
73aee97be5 | ||
|
|
aab54ca1a8 | ||
|
|
c354618e20 | ||
|
|
5141a91041 | ||
|
|
668539e737 | ||
|
|
967139cea4 | ||
|
|
6d8b1aede4 | ||
|
|
744ba31ba6 | ||
|
|
db8257b67a | ||
|
|
85770dc037 | ||
|
|
69f976a79a | ||
|
|
fd7e77eff8 | ||
|
|
05c2a093c0 | ||
|
|
01a1e8eab1 | ||
|
|
b71bc1f875 | ||
|
|
6a0ee22d81 | ||
|
|
cbc8714414 | ||
|
|
f6d929ab7a | ||
|
|
a7a2dabc5a | ||
|
|
0694075447 | ||
|
|
d66b9dd8cb | ||
|
|
7267198a8c | ||
|
|
7b8f101824 | ||
|
|
a4c942a21f | ||
|
|
2a66775e45 | ||
|
|
f0c3d5f308 | ||
|
|
d660521c5c | ||
|
|
c612dfbc1f | ||
|
|
fc58ac0408 | ||
|
|
4f5ee24bc5 | ||
|
|
5b431400be | ||
|
|
509d1a2e24 | ||
|
|
153e68e055 | ||
|
|
77b9a6a94e | ||
|
|
d68bbab419 | ||
|
|
6d53d9178c | ||
|
|
06fe3f2f01 | ||
|
|
e2b6c713e7 | ||
|
|
0b3b241436 | ||
|
|
4c18f9e858 | ||
|
|
8fec54c085 | ||
|
|
d8e37a4d2b | ||
|
|
1da2c4fa37 |
164
.github/workflows/release-notify-wechat.yml
vendored
Normal file
164
.github/workflows/release-notify-wechat.yml
vendored
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
name: Release Notify Workflow
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [closed]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
notify:
|
||||||
|
if: >
|
||||||
|
github.event.pull_request.merged == true &&
|
||||||
|
startsWith(github.event.pull_request.base.ref, 'release')
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
# 防止 GitHub HEAD 未同步
|
||||||
|
- run: sleep 3
|
||||||
|
|
||||||
|
# 1️⃣ 获取分支 HEAD
|
||||||
|
- name: Get HEAD
|
||||||
|
id: head
|
||||||
|
run: |
|
||||||
|
HEAD_SHA=$(curl -s \
|
||||||
|
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||||
|
https://api.github.com/repos/${{ github.repository }}/git/ref/heads/${{ github.event.pull_request.base.ref }} \
|
||||||
|
| jq -r '.object.sha')
|
||||||
|
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
# 2️⃣ 判断是否最终PR
|
||||||
|
- name: Check Latest
|
||||||
|
id: check
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event.pull_request.merge_commit_sha }}" = "${{ steps.head.outputs.head_sha }}" ]; then
|
||||||
|
echo "ok=true" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "ok=false" >> $GITHUB_OUTPUT
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 3️⃣ 尝试从 PR body 提取 Sourcery 摘要
|
||||||
|
- name: Extract Sourcery Summary
|
||||||
|
if: steps.check.outputs.ok == 'true'
|
||||||
|
id: sourcery
|
||||||
|
env:
|
||||||
|
PR_BODY: ${{ github.event.pull_request.body }}
|
||||||
|
run: |
|
||||||
|
python3 << 'PYEOF'
|
||||||
|
import os, re
|
||||||
|
|
||||||
|
body = os.environ.get("PR_BODY", "") or ""
|
||||||
|
match = re.search(
|
||||||
|
r"## Summary by Sourcery\s*\n(.*?)(?=\n## |\Z)",
|
||||||
|
body,
|
||||||
|
re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
summary = match.group(1).strip()
|
||||||
|
found = "true"
|
||||||
|
else:
|
||||||
|
summary = ""
|
||||||
|
found = "false"
|
||||||
|
|
||||||
|
with open("sourcery_summary.txt", "w", encoding="utf-8") as f:
|
||||||
|
f.write(summary)
|
||||||
|
|
||||||
|
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||||
|
gh.write(f"found={found}\n")
|
||||||
|
gh.write("summary<<EOF\n")
|
||||||
|
gh.write(summary + "\n")
|
||||||
|
gh.write("EOF\n")
|
||||||
|
PYEOF
|
||||||
|
|
||||||
|
# 4️⃣ Fallback: 获取 commits + 通义千问总结
|
||||||
|
- name: Get Commits
|
||||||
|
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||||
|
run: |
|
||||||
|
curl -s \
|
||||||
|
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||||
|
${{ github.event.pull_request.commits_url }} \
|
||||||
|
| jq -r '.[].commit.message' | head -n 20 > commits.txt
|
||||||
|
|
||||||
|
- name: AI Summary (Qwen Fallback)
|
||||||
|
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||||
|
id: qwen
|
||||||
|
env:
|
||||||
|
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||||
|
run: |
|
||||||
|
python3 << 'PYEOF'
|
||||||
|
import json, os, urllib.request
|
||||||
|
|
||||||
|
with open("commits.txt", "r") as f:
|
||||||
|
commits = f.read().strip()
|
||||||
|
|
||||||
|
prompt = "请用中文总结以下代码提交,输出3-5条要点,面向测试人员。直接输出编号列表,不要输出标题或前言:\n" + commits
|
||||||
|
payload = {"model": "qwen-plus", "input": {"prompt": prompt}}
|
||||||
|
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||||
|
|
||||||
|
req = urllib.request.Request(
|
||||||
|
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
|
||||||
|
data=data,
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer " + os.environ["DASHSCOPE_API_KEY"],
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
resp = urllib.request.urlopen(req)
|
||||||
|
result = json.loads(resp.read().decode())
|
||||||
|
summary = result.get("output", {}).get("text", "AI 摘要生成失败")
|
||||||
|
|
||||||
|
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||||
|
gh.write("summary<<EOF\n")
|
||||||
|
gh.write(summary + "\n")
|
||||||
|
gh.write("EOF\n")
|
||||||
|
PYEOF
|
||||||
|
|
||||||
|
# 5️⃣ 企业微信通知(Markdown)
|
||||||
|
- name: Notify WeChat
|
||||||
|
if: steps.check.outputs.ok == 'true'
|
||||||
|
env:
|
||||||
|
WECHAT_WEBHOOK: ${{ secrets.WECHAT_WEBHOOK }}
|
||||||
|
BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||||
|
AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||||
|
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||||
|
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||||
|
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||||
|
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||||
|
SOURCERY_FOUND: ${{ steps.sourcery.outputs.found }}
|
||||||
|
SOURCERY_SUMMARY: ${{ steps.sourcery.outputs.summary }}
|
||||||
|
QWEN_SUMMARY: ${{ steps.qwen.outputs.summary }}
|
||||||
|
run: |
|
||||||
|
python3 << 'PYEOF'
|
||||||
|
import json, os, urllib.request
|
||||||
|
|
||||||
|
if os.environ.get("SOURCERY_FOUND") == "true":
|
||||||
|
label = "Summary by Sourcery"
|
||||||
|
summary = os.environ.get("SOURCERY_SUMMARY", "")
|
||||||
|
else:
|
||||||
|
label = "AI变更摘要"
|
||||||
|
summary = os.environ.get("QWEN_SUMMARY", "AI 摘要生成失败")
|
||||||
|
|
||||||
|
pr_number = os.environ.get("PR_NUMBER", "")
|
||||||
|
short_sha = os.environ.get("MERGE_SHA", "")[:7]
|
||||||
|
|
||||||
|
content = (
|
||||||
|
"## 🚀 Release 发布通知\n"
|
||||||
|
"> <20> **分支**: " + os.environ["BRANCH"] + "\n"
|
||||||
|
"> 👤 **提交人**: " + os.environ["AUTHOR"] + "\n"
|
||||||
|
"> 📝 **标题**: " + os.environ["PR_TITLE"] + "\n"
|
||||||
|
"> 🔢 **PR编号**: #" + pr_number + "\n"
|
||||||
|
"> 🔖 **Commit**: " + short_sha + "\n\n"
|
||||||
|
"### 🧠 " + label + "\n" +
|
||||||
|
summary + "\n\n"
|
||||||
|
"---\n"
|
||||||
|
"🔗 [查看PR详情](" + os.environ["PR_URL"] + ")"
|
||||||
|
)
|
||||||
|
payload = {"msgtype": "markdown", "markdown": {"content": content}}
|
||||||
|
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||||
|
req = urllib.request.Request(
|
||||||
|
os.environ["WECHAT_WEBHOOK"],
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/json"}
|
||||||
|
)
|
||||||
|
resp = urllib.request.urlopen(req)
|
||||||
|
print(resp.read().decode())
|
||||||
|
PYEOF
|
||||||
33
.github/workflows/sync-to-gitee.yml
vendored
Normal file
33
.github/workflows/sync-to-gitee.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
name: Sync to Gitee
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- '**' # All branchs
|
||||||
|
tags:
|
||||||
|
- '**' # All version tags (v1.0.0, etc.)
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
sync:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout Source Code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Sync to Gitee
|
||||||
|
run: |
|
||||||
|
GITEE_URL="https://${{ secrets.GITEE_USERNAME }}:${{ secrets.GITEE_TOKEN }}@gitee.com/hangzhou-hongxiong-intelligent_1/MemoryBear.git"
|
||||||
|
git remote add gitee "$GITEE_URL"
|
||||||
|
|
||||||
|
# 遍历并推送所有分支
|
||||||
|
for branch in $(git branch -r | grep -v HEAD | sed 's/origin\///'); do
|
||||||
|
echo "Syncing branch: $branch"
|
||||||
|
git push -f gitee "origin/$branch:refs/heads/$branch"
|
||||||
|
done
|
||||||
|
|
||||||
|
# 推送所有标签
|
||||||
|
echo "Syncing tags..."
|
||||||
|
git push gitee --tags --force
|
||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -18,6 +18,7 @@ examples/
|
|||||||
.kiro
|
.kiro
|
||||||
.vscode
|
.vscode
|
||||||
.idea
|
.idea
|
||||||
|
.claude
|
||||||
|
|
||||||
# Temporary outputs
|
# Temporary outputs
|
||||||
.DS_Store
|
.DS_Store
|
||||||
@@ -25,6 +26,9 @@ examples/
|
|||||||
time.log
|
time.log
|
||||||
celerybeat-schedule.db
|
celerybeat-schedule.db
|
||||||
search_results.json
|
search_results.json
|
||||||
|
redbear-mem-metrics/
|
||||||
|
redbear-mem-benchmark/
|
||||||
|
pitch-deck/
|
||||||
|
|
||||||
api/migrations/versions
|
api/migrations/versions
|
||||||
tmp
|
tmp
|
||||||
@@ -39,3 +43,6 @@ cl100k_base.tiktoken
|
|||||||
libssl*.deb
|
libssl*.deb
|
||||||
|
|
||||||
sandbox/lib/seccomp_redbear/target
|
sandbox/lib/seccomp_redbear/target
|
||||||
|
|
||||||
|
# Qoder repowiki generated content
|
||||||
|
.qoder/repowiki/zh/
|
||||||
|
|||||||
@@ -2,6 +2,10 @@
|
|||||||
|
|
||||||
# MemoryBear empowers AI with human-like memory capabilities
|
# MemoryBear empowers AI with human-like memory capabilities
|
||||||
|
|
||||||
|
[](LICENSE)
|
||||||
|
[](https://www.python.org/)
|
||||||
|
[](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
|
||||||
|
|
||||||
[中文](./README_CN.md) | English
|
[中文](./README_CN.md) | English
|
||||||
|
|
||||||
### [Installation Guide](#memorybear-installation-guide)
|
### [Installation Guide](#memorybear-installation-guide)
|
||||||
|
|||||||
@@ -2,6 +2,10 @@
|
|||||||
|
|
||||||
# MemoryBear 让AI拥有如同人类一样的记忆
|
# MemoryBear 让AI拥有如同人类一样的记忆
|
||||||
|
|
||||||
|
[](LICENSE)
|
||||||
|
[](https://www.python.org/)
|
||||||
|
[](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
|
||||||
|
|
||||||
中文 | [English](./README.md)
|
中文 | [English](./README.md)
|
||||||
|
|
||||||
### [安装教程](#memorybear安装教程)
|
### [安装教程](#memorybear安装教程)
|
||||||
|
|||||||
@@ -60,7 +60,12 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
|||||||
# are written from script.py.mako
|
# are written from script.py.mako
|
||||||
# output_encoding = utf-8
|
# output_encoding = utf-8
|
||||||
|
|
||||||
sqlalchemy.url = postgresql://user:password@localhost/dbname
|
# Database connection URL - DO NOT hardcode credentials here!
|
||||||
|
# Connection string is set dynamically from environment variables in migrations/env.py
|
||||||
|
# Required env vars: DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME
|
||||||
|
# Example: postgresql://user:password@localhost:5432/dbname
|
||||||
|
; sqlalchemy.url = postgresql://user:password@host:port/dbname
|
||||||
|
sqlalchemy.url = driver://user:password@host:port/dbname
|
||||||
|
|
||||||
|
|
||||||
[post_write_hooks]
|
[post_write_hooks]
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
from redis.asyncio import ConnectionPool
|
from redis.asyncio import ConnectionPool
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
# 设置日志记录器
|
# 设置日志记录器
|
||||||
@@ -20,6 +23,50 @@ pool = ConnectionPool.from_url(
|
|||||||
)
|
)
|
||||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||||
|
|
||||||
|
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
|
||||||
|
|
||||||
|
# Thread-local storage for connection pools.
|
||||||
|
# Each thread (and each forked process) gets its own pool to avoid
|
||||||
|
# "Future attached to a different loop" errors in Celery --pool=threads
|
||||||
|
# and stale connections after fork in --pool=prefork.
|
||||||
|
_thread_local = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_safe_redis() -> redis.StrictRedis:
|
||||||
|
"""Return a Redis client whose connection pool is bound to the current
|
||||||
|
thread, process **and** event loop.
|
||||||
|
|
||||||
|
The pool is recreated when:
|
||||||
|
- The PID changes (fork, Celery --pool=prefork)
|
||||||
|
- The thread has no pool yet (Celery --pool=threads)
|
||||||
|
- The previously-cached event loop has been closed (Celery tasks call
|
||||||
|
``_shutdown_loop_gracefully`` which closes the loop after each run)
|
||||||
|
"""
|
||||||
|
current_pid = os.getpid()
|
||||||
|
cached_loop = getattr(_thread_local, "loop", None)
|
||||||
|
loop_stale = cached_loop is not None and cached_loop.is_closed()
|
||||||
|
|
||||||
|
if not hasattr(_thread_local, "pool") \
|
||||||
|
or getattr(_thread_local, "pid", None) != current_pid \
|
||||||
|
or loop_stale:
|
||||||
|
_thread_local.pid = current_pid
|
||||||
|
# Python 3.10+: get_event_loop() raises RuntimeError in threads
|
||||||
|
# where no loop has been set yet (e.g. Celery --pool=threads).
|
||||||
|
try:
|
||||||
|
_thread_local.loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
_thread_local.loop = None
|
||||||
|
_thread_local.pool = ConnectionPool.from_url(
|
||||||
|
_REDIS_URL,
|
||||||
|
db=settings.REDIS_DB,
|
||||||
|
password=settings.REDIS_PASSWORD,
|
||||||
|
decode_responses=True,
|
||||||
|
max_connections=5,
|
||||||
|
health_check_interval=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
return redis.StrictRedis(connection_pool=_thread_local.pool)
|
||||||
|
|
||||||
|
|
||||||
async def get_redis_connection():
|
async def get_redis_connection():
|
||||||
"""获取Redis连接"""
|
"""获取Redis连接"""
|
||||||
@@ -43,10 +90,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
|||||||
val = json.dumps(val, ensure_ascii=False)
|
val = json.dumps(val, ensure_ascii=False)
|
||||||
|
|
||||||
if expire is not None:
|
if expire is not None:
|
||||||
# 设置带过期时间的键值
|
|
||||||
await aio_redis.set(key, val, ex=expire)
|
await aio_redis.set(key, val, ex=expire)
|
||||||
else:
|
else:
|
||||||
# 设置永久键值
|
|
||||||
await aio_redis.set(key, val)
|
await aio_redis.set(key, val)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Redis set错误: {str(e)}")
|
logger.error(f"Redis set错误: {str(e)}")
|
||||||
|
|||||||
8
api/app/cache/memory/activity_stats_cache.py
vendored
8
api/app/cache/memory/activity_stats_cache.py
vendored
@@ -10,7 +10,7 @@ import logging
|
|||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.aioRedis import aio_redis
|
from app.aioRedis import get_thread_safe_redis
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -68,7 +68,7 @@ class ActivityStatsCache:
|
|||||||
"cached": True,
|
"cached": True,
|
||||||
}
|
}
|
||||||
value = json.dumps(payload, ensure_ascii=False)
|
value = json.dumps(payload, ensure_ascii=False)
|
||||||
await aio_redis.set(key, value, ex=expire)
|
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -90,7 +90,7 @@ class ActivityStatsCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(workspace_id)
|
key = cls._get_key(workspace_id)
|
||||||
value = await aio_redis.get(key)
|
value = await get_thread_safe_redis().get(key)
|
||||||
if value:
|
if value:
|
||||||
payload = json.loads(value)
|
payload = json.loads(value)
|
||||||
logger.info(f"命中活动统计缓存: {key}")
|
logger.info(f"命中活动统计缓存: {key}")
|
||||||
@@ -116,7 +116,7 @@ class ActivityStatsCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(workspace_id)
|
key = cls._get_key(workspace_id)
|
||||||
result = await aio_redis.delete(key)
|
result = await get_thread_safe_redis().delete(key)
|
||||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||||
return result > 0
|
return result > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
8
api/app/cache/memory/interest_memory.py
vendored
8
api/app/cache/memory/interest_memory.py
vendored
@@ -9,7 +9,7 @@ import logging
|
|||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.aioRedis import aio_redis
|
from app.aioRedis import get_thread_safe_redis
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ class InterestMemoryCache:
|
|||||||
"cached": True,
|
"cached": True,
|
||||||
}
|
}
|
||||||
value = json.dumps(payload, ensure_ascii=False)
|
value = json.dumps(payload, ensure_ascii=False)
|
||||||
await aio_redis.set(key, value, ex=expire)
|
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -86,7 +86,7 @@ class InterestMemoryCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(end_user_id, language)
|
key = cls._get_key(end_user_id, language)
|
||||||
value = await aio_redis.get(key)
|
value = await get_thread_safe_redis().get(key)
|
||||||
if value:
|
if value:
|
||||||
payload = json.loads(value)
|
payload = json.loads(value)
|
||||||
logger.info(f"命中兴趣分布缓存: {key}")
|
logger.info(f"命中兴趣分布缓存: {key}")
|
||||||
@@ -114,7 +114,7 @@ class InterestMemoryCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(end_user_id, language)
|
key = cls._get_key(end_user_id, language)
|
||||||
result = await aio_redis.delete(key)
|
result = await get_thread_safe_redis().delete(key)
|
||||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||||
return result > 0
|
return result > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import re
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
@@ -11,21 +12,25 @@ from app.core.logging_config import get_logger
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_url(url: str) -> str:
|
||||||
|
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
|
||||||
|
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
|
||||||
|
|
||||||
|
|
||||||
# macOS fork() safety - must be set before any Celery initialization
|
# macOS fork() safety - must be set before any Celery initialization
|
||||||
if platform.system() == 'Darwin':
|
if platform.system() == 'Darwin':
|
||||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
|
|
||||||
# 创建 Celery 应用实例
|
# 创建 Celery 应用实例
|
||||||
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定)
|
# broker: 优先使用环境变量 CELERY_BROKER_URL(支持 amqp:// 等任意协议),
|
||||||
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定)
|
# 未配置则回退到 Redis 方案
|
||||||
|
# backend: 结果存储(使用 Redis)
|
||||||
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||||
|
|
||||||
# Build canonical broker/backend URLs and force them into os.environ so that
|
_broker_url = os.getenv("CELERY_BROKER_URL") or \
|
||||||
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
|
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||||
# cannot be overridden by stray env vars.
|
|
||||||
# See: https://github.com/celery/celery/issues/4284
|
|
||||||
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
|
||||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||||
@@ -45,8 +50,8 @@ celery_app = Celery(
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Celery app initialized",
|
"Celery app initialized",
|
||||||
extra={
|
extra={
|
||||||
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
"broker": _mask_url(_broker_url),
|
||||||
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
"backend": _mask_url(_backend_url),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Default queue for unrouted tasks
|
# Default queue for unrouted tasks
|
||||||
@@ -77,6 +82,7 @@ celery_app.conf.update(
|
|||||||
|
|
||||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||||
|
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
|
||||||
|
|
||||||
# 结果过期时间
|
# 结果过期时间
|
||||||
result_expires=3600, # 结果保存1小时
|
result_expires=3600, # 结果保存1小时
|
||||||
@@ -102,11 +108,29 @@ celery_app.conf.update(
|
|||||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题)
|
||||||
|
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Metadata extraction → memory_tasks queue
|
||||||
|
'app.tasks.extract_user_metadata': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Async emotion extraction → memory_tasks queue (IO-bound LLM calls)
|
||||||
|
'app.tasks.extract_emotion_batch': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Post-store dedup + alias merge → memory_tasks queue
|
||||||
|
'app.tasks.post_store_dedup_and_alias_merge': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Async metadata extraction → memory_tasks queue
|
||||||
|
'app.tasks.extract_metadata_batch': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
# Document tasks → document_tasks queue (prefork worker)
|
# Document tasks → document_tasks queue (prefork worker)
|
||||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
|
||||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||||
|
|
||||||
|
# GraphRAG tasks → graphrag_tasks queue (独立队列,避免阻塞文档解析)
|
||||||
|
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'graphrag_tasks'},
|
||||||
|
'app.core.rag.tasks.build_graphrag_for_document': {'queue': 'graphrag_tasks'},
|
||||||
|
|
||||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||||
@@ -115,6 +139,7 @@ celery_app.conf.update(
|
|||||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
||||||
|
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -131,7 +156,7 @@ implicit_emotions_update_schedule = crontab(
|
|||||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||||
)
|
)
|
||||||
|
|
||||||
#构建定时任务配置
|
# 构建定时任务配置
|
||||||
beat_schedule_config = {
|
beat_schedule_config = {
|
||||||
"run-workspace-reflection": {
|
"run-workspace-reflection": {
|
||||||
"task": "app.tasks.workspace_reflection_task",
|
"task": "app.tasks.workspace_reflection_task",
|
||||||
|
|||||||
500
api/app/celery_task_scheduler.py
Normal file
500
api/app/celery_task_scheduler.py
Normal file
@@ -0,0 +1,500 @@
|
|||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import redis
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.logging_config import get_named_logger
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
logger = get_named_logger("task_scheduler")
|
||||||
|
|
||||||
|
# per-user queue scheduler:uq:{user_id}
|
||||||
|
USER_QUEUE_PREFIX = "scheduler:uq:"
|
||||||
|
# User Collection of Pending Messages
|
||||||
|
ACTIVE_USERS = "scheduler:active_users"
|
||||||
|
# Set of users that can dispatch (ready signal)
|
||||||
|
READY_SET = "scheduler:ready_users"
|
||||||
|
# Metadata of tasks that have been dispatched and are pending completion
|
||||||
|
PENDING_HASH = "scheduler:pending_tasks"
|
||||||
|
# Dynamic Sharding: Instance Registry
|
||||||
|
REGISTRY_KEY = "scheduler:instances"
|
||||||
|
|
||||||
|
TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded
|
||||||
|
HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds)
|
||||||
|
INSTANCE_TTL = 30 # Instance timeout (seconds)
|
||||||
|
|
||||||
|
LUA_ATOMIC_LOCK = """
|
||||||
|
local dispatch_lock = KEYS[1]
|
||||||
|
local lock_key = KEYS[2]
|
||||||
|
local instance_id = ARGV[1]
|
||||||
|
local dispatch_ttl = tonumber(ARGV[2])
|
||||||
|
local lock_ttl = tonumber(ARGV[3])
|
||||||
|
|
||||||
|
if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
|
||||||
|
if redis.call('EXISTS', lock_key) == 1 then
|
||||||
|
redis.call('DEL', dispatch_lock)
|
||||||
|
return -1
|
||||||
|
end
|
||||||
|
|
||||||
|
redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl)
|
||||||
|
return 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
LUA_SAFE_DELETE = """
|
||||||
|
if redis.call('GET', KEYS[1]) == ARGV[1] then
|
||||||
|
return redis.call('DEL', KEYS[1])
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def stable_hash(value: str) -> int:
|
||||||
|
return int.from_bytes(
|
||||||
|
hashlib.md5(value.encode("utf-8")).digest(),
|
||||||
|
"big"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def health_check_server(scheduler_ref):
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
health_app = FastAPI()
|
||||||
|
|
||||||
|
@health_app.get("/")
|
||||||
|
def health():
|
||||||
|
return scheduler_ref.health()
|
||||||
|
|
||||||
|
port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001"))
|
||||||
|
threading.Thread(
|
||||||
|
target=uvicorn.run,
|
||||||
|
kwargs={
|
||||||
|
"app": health_app,
|
||||||
|
"host": "0.0.0.0",
|
||||||
|
"port": port,
|
||||||
|
"log_config": None,
|
||||||
|
},
|
||||||
|
daemon=True,
|
||||||
|
).start()
|
||||||
|
logger.info("[Health] Server started at http://0.0.0.0:%s", port)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisTaskScheduler:
|
||||||
|
def __init__(self):
|
||||||
|
self.redis = redis.Redis(
|
||||||
|
host=settings.REDIS_HOST,
|
||||||
|
port=settings.REDIS_PORT,
|
||||||
|
db=settings.REDIS_DB_CELERY_BACKEND,
|
||||||
|
password=settings.REDIS_PASSWORD,
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
self.running = False
|
||||||
|
self.dispatched = 0
|
||||||
|
self.errors = 0
|
||||||
|
|
||||||
|
self.instance_id = f"{socket.gethostname()}-{os.getpid()}"
|
||||||
|
self._shard_index = 0
|
||||||
|
self._shard_count = 1
|
||||||
|
self._last_heartbeat = 0.0
|
||||||
|
|
||||||
|
def push_task(self, task_name, user_id, params):
|
||||||
|
try:
|
||||||
|
msg_id = str(uuid.uuid4())
|
||||||
|
msg = json.dumps({
|
||||||
|
"msg_id": msg_id,
|
||||||
|
"task_name": task_name,
|
||||||
|
"user_id": user_id,
|
||||||
|
"params": json.dumps(params),
|
||||||
|
})
|
||||||
|
|
||||||
|
lock_key = f"{task_name}:{user_id}"
|
||||||
|
queue_key = f"{USER_QUEUE_PREFIX}{user_id}"
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.rpush(queue_key, msg)
|
||||||
|
pipe.sadd(ACTIVE_USERS, user_id)
|
||||||
|
pipe.set(
|
||||||
|
f"task_tracker:{msg_id}",
|
||||||
|
json.dumps({"status": "QUEUED", "task_id": None}),
|
||||||
|
ex=86400,
|
||||||
|
)
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
if not self.redis.exists(lock_key):
|
||||||
|
self.redis.sadd(READY_SET, user_id)
|
||||||
|
|
||||||
|
logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id)
|
||||||
|
return msg_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Push task exception %s", e, exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_task_status(self, msg_id: str) -> dict:
|
||||||
|
raw = self.redis.get(f"task_tracker:{msg_id}")
|
||||||
|
if raw is None:
|
||||||
|
return {"status": "NOT_FOUND"}
|
||||||
|
|
||||||
|
tracker = json.loads(raw)
|
||||||
|
status = tracker["status"]
|
||||||
|
task_id = tracker.get("task_id")
|
||||||
|
result_content = tracker.get("result") or {}
|
||||||
|
|
||||||
|
if status == "DISPATCHED" and task_id:
|
||||||
|
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
|
||||||
|
if result_raw:
|
||||||
|
result_data = json.loads(result_raw)
|
||||||
|
status = result_data.get("status", status)
|
||||||
|
result_content = result_data.get("result")
|
||||||
|
|
||||||
|
return {"status": status, "task_id": task_id, "result": result_content}
|
||||||
|
|
||||||
|
def _cleanup_finished(self):
|
||||||
|
pending = self.redis.hgetall(PENDING_HASH)
|
||||||
|
if not pending:
|
||||||
|
return
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
task_ids = list(pending.keys())
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for task_id in task_ids:
|
||||||
|
pipe.get(f"celery-task-meta-{task_id}")
|
||||||
|
results = pipe.execute()
|
||||||
|
|
||||||
|
cleanup_pipe = self.redis.pipeline()
|
||||||
|
has_cleanup = False
|
||||||
|
ready_user_ids = set()
|
||||||
|
|
||||||
|
for task_id, raw_result in zip(task_ids, results):
|
||||||
|
try:
|
||||||
|
meta = json.loads(pending[task_id])
|
||||||
|
lock_key = meta["lock_key"]
|
||||||
|
dispatched_at = meta.get("dispatched_at", 0)
|
||||||
|
age = now - dispatched_at
|
||||||
|
|
||||||
|
should_cleanup = False
|
||||||
|
result_data = {}
|
||||||
|
|
||||||
|
if raw_result is not None:
|
||||||
|
result_data = json.loads(raw_result)
|
||||||
|
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
|
||||||
|
should_cleanup = True
|
||||||
|
logger.info(
|
||||||
|
"Task finished: %s state=%s", task_id,
|
||||||
|
result_data.get("status"),
|
||||||
|
)
|
||||||
|
elif age > TASK_TIMEOUT:
|
||||||
|
should_cleanup = True
|
||||||
|
logger.warning(
|
||||||
|
"Task expired or lost: %s age=%.0fs, force cleanup",
|
||||||
|
task_id, age,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_cleanup:
|
||||||
|
final_status = (
|
||||||
|
result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id)
|
||||||
|
|
||||||
|
cleanup_pipe.hdel(PENDING_HASH, task_id)
|
||||||
|
|
||||||
|
tracker_msg_id = meta.get("msg_id")
|
||||||
|
if tracker_msg_id:
|
||||||
|
cleanup_pipe.set(
|
||||||
|
f"task_tracker:{tracker_msg_id}",
|
||||||
|
json.dumps({
|
||||||
|
"status": final_status,
|
||||||
|
"task_id": task_id,
|
||||||
|
"result": result_data.get("result") or {},
|
||||||
|
}),
|
||||||
|
ex=86400,
|
||||||
|
)
|
||||||
|
has_cleanup = True
|
||||||
|
|
||||||
|
parts = lock_key.split(":", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
ready_user_ids.add(parts[1])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
if has_cleanup:
|
||||||
|
cleanup_pipe.execute()
|
||||||
|
|
||||||
|
if ready_user_ids:
|
||||||
|
self.redis.sadd(READY_SET, *ready_user_ids)
|
||||||
|
|
||||||
|
def _heartbeat(self):
|
||||||
|
now = time.time()
|
||||||
|
if now - self._last_heartbeat < HEARTBEAT_INTERVAL:
|
||||||
|
return
|
||||||
|
self._last_heartbeat = now
|
||||||
|
|
||||||
|
self.redis.hset(REGISTRY_KEY, self.instance_id, str(now))
|
||||||
|
|
||||||
|
all_instances = self.redis.hgetall(REGISTRY_KEY)
|
||||||
|
|
||||||
|
alive = []
|
||||||
|
dead = []
|
||||||
|
for iid, ts in all_instances.items():
|
||||||
|
if now - float(ts) < INSTANCE_TTL:
|
||||||
|
alive.append(iid)
|
||||||
|
else:
|
||||||
|
dead.append(iid)
|
||||||
|
|
||||||
|
if dead:
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for iid in dead:
|
||||||
|
pipe.hdel(REGISTRY_KEY, iid)
|
||||||
|
pipe.execute()
|
||||||
|
logger.info("Cleaned dead instances: %s", dead)
|
||||||
|
|
||||||
|
alive.sort()
|
||||||
|
self._shard_count = max(len(alive), 1)
|
||||||
|
self._shard_index = (
|
||||||
|
alive.index(self.instance_id) if self.instance_id in alive else 0
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Shard: %s/%s (instance=%s, alive=%d)",
|
||||||
|
self._shard_index, self._shard_count,
|
||||||
|
self.instance_id, len(alive),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_mine(self, user_id: str) -> bool:
|
||||||
|
if self._shard_count <= 1:
|
||||||
|
return True
|
||||||
|
return stable_hash(user_id) % self._shard_count == self._shard_index
|
||||||
|
|
||||||
|
def _dispatch(self, msg_id, msg_data) -> bool:
|
||||||
|
user_id = msg_data["user_id"]
|
||||||
|
task_name = msg_data["task_name"]
|
||||||
|
params = json.loads(msg_data.get("params", "{}"))
|
||||||
|
|
||||||
|
lock_key = f"{task_name}:{user_id}"
|
||||||
|
dispatch_lock = f"dispatch:{msg_id}"
|
||||||
|
|
||||||
|
result = self.redis.eval(
|
||||||
|
LUA_ATOMIC_LOCK, 2,
|
||||||
|
dispatch_lock, lock_key,
|
||||||
|
self.instance_id, str(300), str(3600),
|
||||||
|
)
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
return False
|
||||||
|
if result == -1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
task = celery_app.send_task(task_name, kwargs=params)
|
||||||
|
except Exception as e:
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.delete(dispatch_lock)
|
||||||
|
pipe.delete(lock_key)
|
||||||
|
pipe.execute()
|
||||||
|
self.errors += 1
|
||||||
|
logger.error(
|
||||||
|
"send_task failed for %s:%s msg=%s: %s",
|
||||||
|
task_name, user_id, msg_id, e, exc_info=True,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.set(lock_key, task.id, ex=3600)
|
||||||
|
pipe.hset(PENDING_HASH, task.id, json.dumps({
|
||||||
|
"lock_key": lock_key,
|
||||||
|
"dispatched_at": time.time(),
|
||||||
|
"msg_id": msg_id,
|
||||||
|
}))
|
||||||
|
pipe.delete(dispatch_lock)
|
||||||
|
pipe.set(
|
||||||
|
f"task_tracker:{msg_id}",
|
||||||
|
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
|
||||||
|
ex=86400,
|
||||||
|
)
|
||||||
|
pipe.execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Post-dispatch state update failed for %s: %s",
|
||||||
|
task.id, e, exc_info=True,
|
||||||
|
)
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
self.dispatched += 1
|
||||||
|
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _process_batch(self, user_ids):
|
||||||
|
if not user_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for uid in user_ids:
|
||||||
|
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
|
||||||
|
heads = pipe.execute()
|
||||||
|
|
||||||
|
candidates = [] # (user_id, msg_dict)
|
||||||
|
empty_users = []
|
||||||
|
|
||||||
|
for uid, head in zip(user_ids, heads):
|
||||||
|
if head is None:
|
||||||
|
empty_users.append(uid)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
candidates.append((uid, json.loads(head)))
|
||||||
|
except (json.JSONDecodeError, TypeError) as e:
|
||||||
|
logger.error("Bad message in queue for user %s: %s", uid, e)
|
||||||
|
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
|
||||||
|
|
||||||
|
if empty_users:
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for uid in empty_users:
|
||||||
|
pipe.srem(ACTIVE_USERS, uid)
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return
|
||||||
|
|
||||||
|
for uid, msg in candidates:
|
||||||
|
if self._dispatch(msg["msg_id"], msg):
|
||||||
|
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
|
||||||
|
|
||||||
|
def schedule_loop(self):
|
||||||
|
self._heartbeat()
|
||||||
|
self._cleanup_finished()
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.smembers(READY_SET)
|
||||||
|
pipe.delete(READY_SET)
|
||||||
|
results = pipe.execute()
|
||||||
|
ready_users = results[0] or set()
|
||||||
|
|
||||||
|
my_users = [uid for uid in ready_users if self._is_mine(uid)]
|
||||||
|
|
||||||
|
if not my_users:
|
||||||
|
time.sleep(0.5)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._process_batch(my_users)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
def _full_scan(self):
|
||||||
|
cursor = 0
|
||||||
|
ready_batch = []
|
||||||
|
while True:
|
||||||
|
cursor, user_ids = self.redis.sscan(
|
||||||
|
ACTIVE_USERS, cursor=cursor, count=1000,
|
||||||
|
)
|
||||||
|
if user_ids:
|
||||||
|
my_users = [uid for uid in user_ids if self._is_mine(uid)]
|
||||||
|
if my_users:
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for uid in my_users:
|
||||||
|
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
|
||||||
|
heads = pipe.execute()
|
||||||
|
|
||||||
|
for uid, head in zip(my_users, heads):
|
||||||
|
if head is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
msg = json.loads(head)
|
||||||
|
lock_key = f"{msg['task_name']}:{uid}"
|
||||||
|
ready_batch.append((uid, lock_key))
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not ready_batch:
|
||||||
|
return
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for _, lock_key in ready_batch:
|
||||||
|
pipe.exists(lock_key)
|
||||||
|
lock_exists = pipe.execute()
|
||||||
|
|
||||||
|
ready_uids = [
|
||||||
|
uid for (uid, _), locked in zip(ready_batch, lock_exists)
|
||||||
|
if not locked
|
||||||
|
]
|
||||||
|
|
||||||
|
if ready_uids:
|
||||||
|
self.redis.sadd(READY_SET, *ready_uids)
|
||||||
|
logger.info("Full scan found %d ready users", len(ready_uids))
|
||||||
|
|
||||||
|
def run_server(self):
|
||||||
|
health_check_server(self)
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
last_full_scan = 0.0
|
||||||
|
full_scan_interval = 30.0
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Scheduler started: instance=%s", self.instance_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self.schedule_loop()
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
if now - last_full_scan > full_scan_interval:
|
||||||
|
self._full_scan()
|
||||||
|
last_full_scan = now
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Scheduler exception %s", e, exc_info=True)
|
||||||
|
self.errors += 1
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
def health(self) -> dict:
|
||||||
|
return {
|
||||||
|
"running": self.running,
|
||||||
|
"active_users": self.redis.scard(ACTIVE_USERS),
|
||||||
|
"ready_users": self.redis.scard(READY_SET),
|
||||||
|
"pending_tasks": self.redis.hlen(PENDING_HASH),
|
||||||
|
"dispatched": self.dispatched,
|
||||||
|
"errors": self.errors,
|
||||||
|
"shard": f"{self._shard_index}/{self._shard_count}",
|
||||||
|
"instance": self.instance_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
logger.info("Scheduler shutting down: instance=%s", self.instance_id)
|
||||||
|
self.running = False
|
||||||
|
try:
|
||||||
|
self.redis.hdel(REGISTRY_KEY, self.instance_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Shutdown cleanup error: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
scheduler: RedisTaskScheduler | None = None
|
||||||
|
if scheduler is None:
|
||||||
|
scheduler = RedisTaskScheduler()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def _signal_handler(signum, frame):
|
||||||
|
scheduler.shutdown()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, _signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, _signal_handler)
|
||||||
|
|
||||||
|
scheduler.run_server()
|
||||||
@@ -2,6 +2,9 @@
|
|||||||
Celery Worker 入口点
|
Celery Worker 入口点
|
||||||
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
|
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
|
||||||
"""
|
"""
|
||||||
|
# 必须在导入任何使用 DashScope SDK 的模块之前应用补丁
|
||||||
|
import app.plugins.dashscope_patch # noqa: F401
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.logging_config import LoggingConfig, get_logger
|
from app.core.logging_config import LoggingConfig, get_logger
|
||||||
|
|
||||||
@@ -13,4 +16,39 @@ logger.info("Celery worker logging initialized")
|
|||||||
# 导入任务模块以注册任务
|
# 导入任务模块以注册任务
|
||||||
import app.tasks
|
import app.tasks
|
||||||
|
|
||||||
|
|
||||||
|
@worker_process_init.connect
|
||||||
|
def _reinit_db_pool(**kwargs):
|
||||||
|
"""
|
||||||
|
prefork 子进程启动时重建被 fork 污染的资源。
|
||||||
|
|
||||||
|
fork() 后子进程继承了父进程的:
|
||||||
|
1. SQLAlchemy 连接池 — 多进程共享 TCP socket 导致 DB 连接损坏
|
||||||
|
2. ThreadPoolExecutor — fork 后线程状态不确定,第二个任务会死锁
|
||||||
|
"""
|
||||||
|
# 重建 DB 连接池
|
||||||
|
from app.db import engine
|
||||||
|
engine.dispose()
|
||||||
|
logger.info("DB connection pool disposed for forked worker process")
|
||||||
|
|
||||||
|
# 重建模块级 ThreadPoolExecutor(fork 后线程池不可用)
|
||||||
|
try:
|
||||||
|
from app.core.rag.deepdoc.parser import figure_parser
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
figure_parser.shared_executor = ThreadPoolExecutor(max_workers=10)
|
||||||
|
logger.info("figure_parser.shared_executor recreated")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to recreate figure_parser.shared_executor: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.core.rag.utils import libre_office
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import os
|
||||||
|
max_workers = os.cpu_count() * 2 if os.cpu_count() else 4
|
||||||
|
libre_office.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
|
logger.info("libre_office.executor recreated")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to recreate libre_office.executor: {e}")
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['celery_app']
|
__all__ = ['celery_app']
|
||||||
77
api/app/config/default_free_plan.py
Normal file
77
api/app/config/default_free_plan.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""
|
||||||
|
社区版默认免费套餐配置
|
||||||
|
当无法从 SaaS 版获取 premium 模块时,使用此配置作为兜底
|
||||||
|
|
||||||
|
可通过环境变量覆盖配额配置,格式:QUOTA_<QUOTA_NAME>
|
||||||
|
例如:QUOTA_END_USER_QUOTA=100
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def _get_quota_from_env():
|
||||||
|
"""从环境变量获取配额配置"""
|
||||||
|
quota_keys = [
|
||||||
|
"workspace_quota",
|
||||||
|
"skill_quota",
|
||||||
|
"app_quota",
|
||||||
|
"knowledge_capacity_quota",
|
||||||
|
"memory_engine_quota",
|
||||||
|
"end_user_quota",
|
||||||
|
"ontology_project_quota",
|
||||||
|
"model_quota",
|
||||||
|
"api_ops_rate_limit",
|
||||||
|
]
|
||||||
|
quotas = {}
|
||||||
|
for key in quota_keys:
|
||||||
|
env_key = f"QUOTA_{key.upper()}"
|
||||||
|
env_value = os.getenv(env_key)
|
||||||
|
if env_value is not None:
|
||||||
|
try:
|
||||||
|
quotas[key] = float(env_value) if '.' in env_value else int(env_value)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return quotas
|
||||||
|
|
||||||
|
|
||||||
|
def _build_default_free_plan():
|
||||||
|
"""构建默认免费套餐配置"""
|
||||||
|
base = {
|
||||||
|
"name": "记忆体验版",
|
||||||
|
"name_en": "Memory Experience",
|
||||||
|
"category": "saas_personal",
|
||||||
|
"tier_level": 0,
|
||||||
|
"version": "1.0",
|
||||||
|
"status": True,
|
||||||
|
"price": 0,
|
||||||
|
"billing_cycle": "permanent_free",
|
||||||
|
"core_value": "感受永久记忆",
|
||||||
|
"core_value_en": "Experience Permanent Memory",
|
||||||
|
"tech_support": "社群交流",
|
||||||
|
"tech_support_en": "Community Support",
|
||||||
|
"sla_compliance": "无",
|
||||||
|
"sla_compliance_en": "None",
|
||||||
|
"page_customization": "无",
|
||||||
|
"page_customization_en": "None",
|
||||||
|
"theme_color": "#64748B",
|
||||||
|
"quotas": {
|
||||||
|
"workspace_quota": 1,
|
||||||
|
"skill_quota": 5,
|
||||||
|
"app_quota": 2,
|
||||||
|
"knowledge_capacity_quota": 0.3,
|
||||||
|
"memory_engine_quota": 1,
|
||||||
|
"end_user_quota": 10,
|
||||||
|
"ontology_project_quota": 3,
|
||||||
|
"model_quota": 1,
|
||||||
|
"api_ops_rate_limit": 50,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
env_quotas = _get_quota_from_env()
|
||||||
|
if env_quotas:
|
||||||
|
base["quotas"].update(env_quotas)
|
||||||
|
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_FREE_PLAN = _build_default_free_plan()
|
||||||
@@ -8,6 +8,7 @@ from fastapi import APIRouter
|
|||||||
from . import (
|
from . import (
|
||||||
api_key_controller,
|
api_key_controller,
|
||||||
app_controller,
|
app_controller,
|
||||||
|
app_log_controller,
|
||||||
auth_controller,
|
auth_controller,
|
||||||
chunk_controller,
|
chunk_controller,
|
||||||
document_controller,
|
document_controller,
|
||||||
@@ -16,6 +17,7 @@ from . import (
|
|||||||
file_controller,
|
file_controller,
|
||||||
file_storage_controller,
|
file_storage_controller,
|
||||||
home_page_controller,
|
home_page_controller,
|
||||||
|
i18n_controller,
|
||||||
implicit_memory_controller,
|
implicit_memory_controller,
|
||||||
knowledge_controller,
|
knowledge_controller,
|
||||||
knowledgeshare_controller,
|
knowledgeshare_controller,
|
||||||
@@ -45,7 +47,8 @@ from . import (
|
|||||||
user_memory_controllers,
|
user_memory_controllers,
|
||||||
workspace_controller,
|
workspace_controller,
|
||||||
ontology_controller,
|
ontology_controller,
|
||||||
skill_controller
|
skill_controller,
|
||||||
|
tenant_subscription_controller,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建管理端 API 路由器
|
# 创建管理端 API 路由器
|
||||||
@@ -68,6 +71,7 @@ manager_router.include_router(chunk_controller.router)
|
|||||||
manager_router.include_router(test_controller.router)
|
manager_router.include_router(test_controller.router)
|
||||||
manager_router.include_router(knowledgeshare_controller.router)
|
manager_router.include_router(knowledgeshare_controller.router)
|
||||||
manager_router.include_router(app_controller.router)
|
manager_router.include_router(app_controller.router)
|
||||||
|
manager_router.include_router(app_log_controller.router)
|
||||||
manager_router.include_router(upload_controller.router)
|
manager_router.include_router(upload_controller.router)
|
||||||
manager_router.include_router(memory_agent_controller.router)
|
manager_router.include_router(memory_agent_controller.router)
|
||||||
manager_router.include_router(memory_dashboard_controller.router)
|
manager_router.include_router(memory_dashboard_controller.router)
|
||||||
@@ -94,5 +98,8 @@ manager_router.include_router(memory_working_controller.router)
|
|||||||
manager_router.include_router(file_storage_controller.router)
|
manager_router.include_router(file_storage_controller.router)
|
||||||
manager_router.include_router(ontology_controller.router)
|
manager_router.include_router(ontology_controller.router)
|
||||||
manager_router.include_router(skill_controller.router)
|
manager_router.include_router(skill_controller.router)
|
||||||
|
manager_router.include_router(i18n_controller.router)
|
||||||
|
manager_router.include_router(tenant_subscription_controller.router)
|
||||||
|
manager_router.include_router(tenant_subscription_controller.public_router)
|
||||||
|
|
||||||
__all__ = ["manager_router"]
|
__all__ = ["manager_router"]
|
||||||
|
|||||||
@@ -167,6 +167,8 @@ def update_api_key(
|
|||||||
|
|
||||||
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
||||||
|
|
||||||
|
except BusinessException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"未知错误: {str(e)}", extra={
|
logger.error(f"未知错误: {str(e)}", extra={
|
||||||
"api_key_id": str(api_key_id),
|
"api_key_id": str(api_key_id),
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from app.services.app_statistics_service import AppStatisticsService
|
|||||||
from app.services.workflow_import_service import WorkflowImportService
|
from app.services.workflow_import_service import WorkflowImportService
|
||||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||||
from app.services.app_dsl_service import AppDslService
|
from app.services.app_dsl_service import AppDslService
|
||||||
|
from app.core.quota_stub import check_app_quota
|
||||||
|
|
||||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -35,6 +36,7 @@ logger = get_business_logger()
|
|||||||
|
|
||||||
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
|
@check_app_quota
|
||||||
def create_app(
|
def create_app(
|
||||||
payload: app_schema.AppCreate,
|
payload: app_schema.AppCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -53,6 +55,7 @@ def list_apps(
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
include_shared: bool = True,
|
include_shared: bool = True,
|
||||||
|
shared_only: bool = False,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
pagesize: int = 10,
|
pagesize: int = 10,
|
||||||
ids: Optional[str] = None,
|
ids: Optional[str] = None,
|
||||||
@@ -64,16 +67,42 @@ def list_apps(
|
|||||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||||
|
- search 参数支持:应用名称模糊搜索、API Key 精确搜索
|
||||||
"""
|
"""
|
||||||
|
from sqlalchemy import select as sa_select
|
||||||
|
from app.models.api_key_model import ApiKey
|
||||||
|
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
service = app_service.AppService(db)
|
service = app_service.AppService(db)
|
||||||
|
|
||||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
# 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索
|
||||||
|
if search:
|
||||||
|
search = search.strip()
|
||||||
|
# 尝试作为 API Key 精确匹配(API Key 通常较长)
|
||||||
|
if len(search) >= 10:
|
||||||
|
matched_id = db.execute(
|
||||||
|
sa_select(ApiKey.resource_id).where(
|
||||||
|
ApiKey.workspace_id == workspace_id,
|
||||||
|
ApiKey.api_key == search,
|
||||||
|
ApiKey.resource_id.isnot(None),
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if matched_id:
|
||||||
|
# 找到 API Key,直接返回关联的应用
|
||||||
|
ids = str(matched_id)
|
||||||
|
|
||||||
|
# 当 ids 存在时,根据 ids 获取应用(不分页)
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
if app_ids:
|
||||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||||
return success(data=items)
|
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||||
|
# 返回标准分页格式
|
||||||
|
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
|
||||||
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
# ids 为空时,返回空列表
|
||||||
|
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
|
||||||
|
return success(data=PageData(page=meta, items=[]))
|
||||||
|
|
||||||
# 正常分页查询
|
# 正常分页查询
|
||||||
items_orm, total = app_service.list_apps(
|
items_orm, total = app_service.list_apps(
|
||||||
@@ -84,6 +113,7 @@ def list_apps(
|
|||||||
status=status,
|
status=status,
|
||||||
search=search,
|
search=search,
|
||||||
include_shared=include_shared,
|
include_shared=include_shared,
|
||||||
|
shared_only=shared_only,
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=pagesize,
|
pagesize=pagesize,
|
||||||
)
|
)
|
||||||
@@ -93,6 +123,37 @@ def list_apps(
|
|||||||
return success(data=PageData(page=meta, items=items))
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/my-shared-out", summary="列出本工作空间主动分享出去的记录")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def list_my_shared_out(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""列出本工作空间主动分享给其他工作空间的所有记录(我的共享)"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
shares = service.list_my_shared_out(workspace_id=workspace_id)
|
||||||
|
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||||
|
return success(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/share/{target_workspace_id}", summary="取消对某工作空间的所有应用分享")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def unshare_all_apps_to_workspace(
|
||||||
|
target_workspace_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Cancel all app shares from current workspace to a target workspace."""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
count = service.unshare_all_apps_to_workspace(
|
||||||
|
target_workspace_id=target_workspace_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
return success(msg=f"已取消 {count} 个应用的分享", data={"count": count})
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}", summary="获取应用详情")
|
@router.get("/{app_id}", summary="获取应用详情")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_app(
|
def get_app(
|
||||||
@@ -158,9 +219,11 @@ def delete_app(
|
|||||||
|
|
||||||
@router.post("/{app_id}/copy", summary="复制应用")
|
@router.post("/{app_id}/copy", summary="复制应用")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
|
@check_app_quota
|
||||||
def copy_app(
|
def copy_app(
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
new_name: Optional[str] = None,
|
new_name: Optional[str] = None,
|
||||||
|
payload: app_schema.CopyAppRequest = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user=Depends(get_current_user),
|
current_user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
@@ -172,6 +235,8 @@ def copy_app(
|
|||||||
- 不影响原应用
|
- 不影响原应用
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
# body takes precedence over query param for backward compatibility
|
||||||
|
new_name = (payload.new_name if payload else None) or new_name
|
||||||
logger.info(
|
logger.info(
|
||||||
"用户请求复制应用",
|
"用户请求复制应用",
|
||||||
extra={
|
extra={
|
||||||
@@ -207,6 +272,19 @@ def update_agent_config(
|
|||||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/model/parameters/default", summary="获取 Agent 模型参数默认配置")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_agent_model_parameters(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
service = AppService(db)
|
||||||
|
model_parameters = service.get_default_model_parameters(app_id=app_id)
|
||||||
|
return success(data=model_parameters, msg="获取 Agent 模型参数默认配置")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_agent_config(
|
def get_agent_config(
|
||||||
@@ -221,6 +299,36 @@ def get_agent_config(
|
|||||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_opening(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 根据应用类型获取 features
|
||||||
|
from app.models.app_model import App as AppModel
|
||||||
|
app = db.get(AppModel, app_id)
|
||||||
|
if app and app.type == "workflow":
|
||||||
|
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
|
||||||
|
features = cfg.features or {}
|
||||||
|
else:
|
||||||
|
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||||
|
features = cfg.features or {}
|
||||||
|
if hasattr(features, "model_dump"):
|
||||||
|
features = features.model_dump()
|
||||||
|
|
||||||
|
opening = features.get("opening_statement", {})
|
||||||
|
return success(data=app_schema.OpeningResponse(
|
||||||
|
enabled=opening.get("enabled", False),
|
||||||
|
statement=opening.get("statement"),
|
||||||
|
suggested_questions=opening.get("suggested_questions", []),
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def publish_app(
|
def publish_app(
|
||||||
@@ -302,7 +410,8 @@ def share_app(
|
|||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
target_workspace_ids=payload.target_workspace_ids,
|
target_workspace_ids=payload.target_workspace_ids,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id,
|
||||||
|
permission=payload.permission
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||||
@@ -333,6 +442,32 @@ def unshare_app(
|
|||||||
return success(msg="应用分享已取消")
|
return success(msg="应用分享已取消")
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{app_id}/share/{target_workspace_id}", summary="更新共享权限")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def update_share_permission(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
target_workspace_id: uuid.UUID,
|
||||||
|
payload: app_schema.UpdateSharePermissionRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""更新共享权限(readonly <-> editable)
|
||||||
|
|
||||||
|
- 只能修改自己工作空间应用的共享权限
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
share = service.update_share_permission(
|
||||||
|
app_id=app_id,
|
||||||
|
target_workspace_id=target_workspace_id,
|
||||||
|
permission=payload.permission,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=app_schema.AppShare.model_validate(share))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def list_app_shares(
|
def list_app_shares(
|
||||||
@@ -356,6 +491,46 @@ def list_app_shares(
|
|||||||
return success(data=data)
|
return success(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/shared/{source_workspace_id}", summary="批量移除某来源工作空间的所有共享应用")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def remove_all_shared_apps_from_workspace(
|
||||||
|
source_workspace_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Remove all shared apps from a specific source workspace (recipient operation)."""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
count = service.remove_all_shared_apps_from_workspace(
|
||||||
|
source_workspace_id=source_workspace_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
return success(msg=f"已移除 {count} 个共享应用", data={"count": count})
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{app_id}/shared", summary="移除共享给我的应用")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def remove_shared_app(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""被共享者从自己的工作空间移除共享应用
|
||||||
|
|
||||||
|
- 不会删除源应用,只删除共享记录
|
||||||
|
- 只能移除共享给自己工作空间的应用
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
service.remove_shared_app(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(msg="已移除共享应用")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def draft_run(
|
async def draft_run(
|
||||||
@@ -396,7 +571,7 @@ async def draft_run(
|
|||||||
# 提前验证和准备(在流式响应开始前完成)
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
from app.services.app_service import AppService
|
from app.services.app_service import AppService
|
||||||
from app.services.multi_agent_service import MultiAgentService
|
from app.services.multi_agent_service import MultiAgentService
|
||||||
from app.models import AgentConfig, ModelConfig
|
from app.models import AgentConfig, ModelConfig, AppRelease
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
@@ -413,11 +588,12 @@ async def draft_run(
|
|||||||
service._validate_app_accessible(app, workspace_id)
|
service._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
if payload.user_id is None:
|
if payload.user_id is None:
|
||||||
|
# 先获取 app 的 workspace_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
|
workspace_id=app.workspace_id,
|
||||||
other_id=str(current_user.id),
|
other_id=str(current_user.id),
|
||||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
|
||||||
)
|
)
|
||||||
payload.user_id = str(new_end_user.id)
|
payload.user_id = str(new_end_user.id)
|
||||||
|
|
||||||
@@ -434,18 +610,29 @@ async def draft_run(
|
|||||||
service._check_agent_config(app_id)
|
service._check_agent_config(app_id)
|
||||||
|
|
||||||
# 2. 获取 Agent 配置
|
# 2. 获取 Agent 配置
|
||||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||||
agent_cfg = db.scalars(stmt).first()
|
is_shared = app.workspace_id != workspace_id
|
||||||
if not agent_cfg:
|
if is_shared:
|
||||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
if not app.current_release_id:
|
||||||
|
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
release = db.get(AppRelease, app.current_release_id)
|
||||||
|
if not release:
|
||||||
|
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
agent_cfg = service._agent_config_from_release(release)
|
||||||
|
model_config = db.get(ModelConfig, release.default_model_config_id) if release.default_model_config_id else None
|
||||||
|
else:
|
||||||
|
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||||
|
agent_cfg = db.scalars(stmt).first()
|
||||||
|
if not agent_cfg:
|
||||||
|
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
# 3. 获取模型配置
|
# 3. 获取模型配置
|
||||||
model_config = None
|
model_config = None
|
||||||
if agent_cfg.default_model_config_id:
|
if agent_cfg.default_model_config_id:
|
||||||
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
from app.core.exceptions import ResourceNotFoundException
|
from app.core.exceptions import ResourceNotFoundException
|
||||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||||
|
|
||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
@@ -601,7 +788,17 @@ async def draft_run(
|
|||||||
msg="多 Agent 任务执行成功"
|
msg="多 Agent 任务执行成功"
|
||||||
)
|
)
|
||||||
elif app.type == AppType.WORKFLOW: # 工作流
|
elif app.type == AppType.WORKFLOW: # 工作流
|
||||||
config = workflow_service.check_config(app_id)
|
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||||
|
is_shared = app.workspace_id != workspace_id
|
||||||
|
if is_shared:
|
||||||
|
if not app.current_release_id:
|
||||||
|
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
release = db.get(AppRelease, app.current_release_id)
|
||||||
|
if not release:
|
||||||
|
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
config = service._workflow_config_from_release(release)
|
||||||
|
else:
|
||||||
|
config = workflow_service.check_config(app_id)
|
||||||
# 3. 流式返回
|
# 3. 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -744,6 +941,16 @@ async def draft_run_compare(
|
|||||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
service._validate_app_accessible(app, workspace_id)
|
service._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
|
if payload.user_id is None:
|
||||||
|
# 先获取 app 的 workspace_id
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=app.workspace_id,
|
||||||
|
other_id=str(current_user.id),
|
||||||
|
)
|
||||||
|
payload.user_id = str(new_end_user.id)
|
||||||
|
|
||||||
# 2. 获取 Agent 配置
|
# 2. 获取 Agent 配置
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.models import AgentConfig
|
from app.models import AgentConfig
|
||||||
@@ -789,6 +996,13 @@ async def draft_run_compare(
|
|||||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# 从 features 中读取功能开关(与 draft_run 保持一致)
|
||||||
|
features_config: dict = agent_cfg.features or {}
|
||||||
|
if hasattr(features_config, 'model_dump'):
|
||||||
|
features_config = features_config.model_dump()
|
||||||
|
web_search_feature = features_config.get("web_search", {})
|
||||||
|
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
|
||||||
|
|
||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
@@ -800,11 +1014,11 @@ async def draft_run_compare(
|
|||||||
message=payload.message,
|
message=payload.message,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
conversation_id=payload.conversation_id,
|
conversation_id=payload.conversation_id,
|
||||||
user_id=payload.user_id or str(current_user.id),
|
user_id=payload.user_id,
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
web_search=True,
|
web_search=web_search,
|
||||||
memory=True,
|
memory=True,
|
||||||
parallel=payload.parallel,
|
parallel=payload.parallel,
|
||||||
timeout=payload.timeout or 60,
|
timeout=payload.timeout or 60,
|
||||||
@@ -831,11 +1045,11 @@ async def draft_run_compare(
|
|||||||
message=payload.message,
|
message=payload.message,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
conversation_id=payload.conversation_id,
|
conversation_id=payload.conversation_id,
|
||||||
user_id=payload.user_id or str(current_user.id),
|
user_id=payload.user_id,
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
web_search=True,
|
web_search=web_search,
|
||||||
memory=True,
|
memory=True,
|
||||||
parallel=payload.parallel,
|
parallel=payload.parallel,
|
||||||
timeout=payload.timeout or 60,
|
timeout=payload.timeout or 60,
|
||||||
@@ -881,6 +1095,14 @@ async def update_workflow_config(
|
|||||||
current_user: Annotated[User, Depends(get_current_user)]
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
):
|
):
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
if payload.variables:
|
||||||
|
from app.services.workflow_service import WorkflowService
|
||||||
|
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
|
||||||
|
[v.model_dump() for v in payload.variables]
|
||||||
|
)
|
||||||
|
# Patch default values back into VariableDefinition objects
|
||||||
|
for var_def, resolved_def in zip(payload.variables, resolved):
|
||||||
|
var_def.default = resolved_def.get("default", var_def.default)
|
||||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||||
|
|
||||||
@@ -923,6 +1145,7 @@ async def import_workflow_config(
|
|||||||
|
|
||||||
@router.post("/workflow/import/save")
|
@router.post("/workflow/import/save")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
|
@check_app_quota
|
||||||
async def save_workflow_import(
|
async def save_workflow_import(
|
||||||
data: WorkflowImportSave,
|
data: WorkflowImportSave,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -1044,9 +1267,11 @@ async def export_app(
|
|||||||
async def import_app(
|
async def import_app(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
app_id: Optional[str] = Form(None),
|
||||||
):
|
):
|
||||||
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
||||||
|
传入 app_id 时覆盖该应用的配置(类型必须一致),否则创建新应用。
|
||||||
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
||||||
"""
|
"""
|
||||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||||
@@ -1057,13 +1282,62 @@ async def import_app(
|
|||||||
if not dsl or "app" not in dsl:
|
if not dsl or "app" not in dsl:
|
||||||
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
new_app, warnings = AppDslService(db).import_dsl(
|
target_app_id = uuid.UUID(app_id) if app_id else None
|
||||||
|
# 仅新建应用时检查配额,覆盖已有应用时跳过
|
||||||
|
if target_app_id is None:
|
||||||
|
from app.core.quota_manager import _check_quota
|
||||||
|
_check_quota(db, current_user.tenant_id, "app_quota", "app", workspace_id=current_user.current_workspace_id)
|
||||||
|
result_app, warnings = AppDslService(db).import_dsl(
|
||||||
dsl=dsl,
|
dsl=dsl,
|
||||||
workspace_id=current_user.current_workspace_id,
|
workspace_id=current_user.current_workspace_id,
|
||||||
tenant_id=current_user.tenant_id,
|
tenant_id=current_user.tenant_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
app_id=target_app_id,
|
||||||
)
|
)
|
||||||
return success(
|
return success(
|
||||||
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
|
data={"app": app_schema.App.model_validate(result_app), "warnings": warnings},
|
||||||
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/citations/{document_id}/download", summary="下载引用文档原始文件")
|
||||||
|
async def download_citation_file(
|
||||||
|
document_id: uuid.UUID = Path(..., description="引用文档ID"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
下载引用文档的原始文件。
|
||||||
|
仅当应用功能特性 citation.allow_download=true 时,前端才会展示此下载链接。
|
||||||
|
路由本身不做权限校验,由业务层通过 allow_download 开关控制入口。
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from fastapi import HTTPException, status as http_status
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.models.document_model import Document
|
||||||
|
from app.models.file_model import File as FileModel
|
||||||
|
|
||||||
|
doc = db.query(Document).filter(Document.id == document_id).first()
|
||||||
|
if not doc:
|
||||||
|
raise HTTPException(status_code=http_status.HTTP_404_NOT_FOUND, detail="文档不存在")
|
||||||
|
|
||||||
|
file_record = db.query(FileModel).filter(FileModel.id == doc.file_id).first()
|
||||||
|
if not file_record:
|
||||||
|
raise HTTPException(status_code=http_status.HTTP_404_NOT_FOUND, detail="原始文件不存在")
|
||||||
|
|
||||||
|
file_path = os.path.join(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(file_record.kb_id),
|
||||||
|
str(file_record.parent_id),
|
||||||
|
f"{file_record.id}{file_record.file_ext}"
|
||||||
|
)
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise HTTPException(status_code=http_status.HTTP_404_NOT_FOUND, detail="文件未找到")
|
||||||
|
|
||||||
|
encoded_name = quote(doc.file_name)
|
||||||
|
return FileResponse(
|
||||||
|
path=file_path,
|
||||||
|
filename=doc.file_name,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"}
|
||||||
|
)
|
||||||
|
|||||||
110
api/app/controllers/app_log_controller.py
Normal file
110
api/app/controllers/app_log_controller.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""应用日志(消息记录)接口"""
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||||
|
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
|
||||||
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
|
from app.services.app_service import AppService
|
||||||
|
from app.services.app_log_service import AppLogService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/apps", tags=["App Logs"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/logs", summary="应用日志 - 会话列表")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def list_app_logs(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
pagesize: int = Query(20, ge=1, le=100),
|
||||||
|
is_draft: Optional[bool] = Query(None, description="是否草稿会话(不传则返回全部)"),
|
||||||
|
keyword: Optional[str] = Query(None, description="搜索关键词(匹配消息内容)"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""查看应用下所有会话记录(分页)
|
||||||
|
|
||||||
|
- is_draft 不传则返回所有会话(草稿 + 正式)
|
||||||
|
- is_draft=True 只返回草稿会话
|
||||||
|
- is_draft=False 只返回发布会话
|
||||||
|
- 支持按 keyword 搜索(匹配消息内容)
|
||||||
|
- 按最新更新时间倒序排列
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 验证应用访问权限
|
||||||
|
app_service = AppService(db)
|
||||||
|
app = app_service.get_app(app_id, workspace_id)
|
||||||
|
|
||||||
|
# 使用 Service 层查询
|
||||||
|
log_service = AppLogService(db)
|
||||||
|
conversations, total = log_service.list_conversations(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
is_draft=is_draft,
|
||||||
|
keyword=keyword,
|
||||||
|
app_type=app.type,
|
||||||
|
)
|
||||||
|
|
||||||
|
items = [AppLogConversation.model_validate(c) for c in conversations]
|
||||||
|
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||||
|
|
||||||
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/logs/{conversation_id}", summary="应用日志 - 会话消息详情")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_app_log_detail(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""查看某会话的完整消息记录
|
||||||
|
|
||||||
|
- 返回会话基本信息 + 所有消息(按时间正序)
|
||||||
|
- 消息 meta_data 包含模型名、token 用量等信息
|
||||||
|
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 验证应用访问权限
|
||||||
|
app_service = AppService(db)
|
||||||
|
app = app_service.get_app(app_id, workspace_id)
|
||||||
|
|
||||||
|
# 使用 Service 层查询
|
||||||
|
log_service = AppLogService(db)
|
||||||
|
conversation, messages, node_executions_map = log_service.get_conversation_detail(
|
||||||
|
app_id=app_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
app_type=app.type
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建基础会话信息(不经过 ORM relationship)
|
||||||
|
base = AppLogConversation.model_validate(conversation)
|
||||||
|
|
||||||
|
# 单独处理 messages,避免触发 SQLAlchemy relationship 校验
|
||||||
|
if messages and isinstance(messages[0], AppLogMessage):
|
||||||
|
# 工作流:已经是 AppLogMessage 实例
|
||||||
|
msg_list = messages
|
||||||
|
else:
|
||||||
|
# Agent:ORM Message 对象逐个转换
|
||||||
|
msg_list = [AppLogMessage.model_validate(m) for m in messages]
|
||||||
|
|
||||||
|
detail = AppLogConversationDetail(
|
||||||
|
**base.model_dump(),
|
||||||
|
messages=msg_list,
|
||||||
|
node_executions_map=node_executions_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=detail)
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Callable
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ from app.core.exceptions import BusinessException
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.dependencies import get_current_user, oauth2_scheme
|
from app.dependencies import get_current_user, oauth2_scheme
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
|
from app.i18n.dependencies import get_translator
|
||||||
|
|
||||||
# 获取专用日志器
|
# 获取专用日志器
|
||||||
auth_logger = get_auth_logger()
|
auth_logger = get_auth_logger()
|
||||||
@@ -26,7 +28,8 @@ router = APIRouter(tags=["Authentication"])
|
|||||||
@router.post("/token", response_model=ApiResponse)
|
@router.post("/token", response_model=ApiResponse)
|
||||||
async def login_for_access_token(
|
async def login_for_access_token(
|
||||||
form_data: TokenRequest,
|
form_data: TokenRequest,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""用户登录获取token"""
|
"""用户登录获取token"""
|
||||||
auth_logger.info(f"用户登录请求: {form_data.email}")
|
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||||||
@@ -40,36 +43,38 @@ async def login_for_access_token(
|
|||||||
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||||||
|
|
||||||
if not invite_info.is_valid:
|
if not invite_info.is_valid:
|
||||||
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
|
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
if invite_info.email != form_data.email:
|
if invite_info.email != form_data.email:
|
||||||
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
|
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
|
||||||
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||||||
try:
|
try:
|
||||||
# 尝试认证用户
|
# 尝试认证用户
|
||||||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||||
if form_data.invite:
|
if form_data.invite:
|
||||||
auth_service.bind_workspace_with_invite(db=db,
|
auth_service.bind_workspace_with_invite(
|
||||||
user=user,
|
db=db,
|
||||||
invite_token=form_data.invite,
|
user=user,
|
||||||
workspace_id=invite_info.workspace_id)
|
invite_token=form_data.invite,
|
||||||
|
workspace_id=invite_info.workspace_id
|
||||||
|
)
|
||||||
except BusinessException as e:
|
except BusinessException as e:
|
||||||
# 用户不存在且有邀请码,尝试注册
|
# 用户不存在且有邀请码,尝试注册
|
||||||
if e.code == BizCode.USER_NOT_FOUND:
|
if e.code == BizCode.USER_NOT_FOUND:
|
||||||
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
|
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
|
||||||
user = auth_service.register_user_with_invite(
|
user = auth_service.register_user_with_invite(
|
||||||
db=db,
|
db=db,
|
||||||
email=form_data.email,
|
email=form_data.email,
|
||||||
username=form_data.username,
|
username=form_data.username,
|
||||||
password=form_data.password,
|
password=form_data.password,
|
||||||
invite_token=form_data.invite,
|
invite_token=form_data.invite,
|
||||||
workspace_id=invite_info.workspace_id
|
workspace_id=invite_info.workspace_id
|
||||||
)
|
)
|
||||||
elif e.code == BizCode.PASSWORD_ERROR:
|
elif e.code == BizCode.PASSWORD_ERROR:
|
||||||
# 用户存在但密码错误
|
# 用户存在但密码错误
|
||||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||||
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
|
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
|
||||||
else:
|
else:
|
||||||
# 其他认证失败情况,直接抛出
|
# 其他认证失败情况,直接抛出
|
||||||
raise
|
raise
|
||||||
@@ -82,7 +87,7 @@ async def login_for_access_token(
|
|||||||
except BusinessException as e:
|
except BusinessException as e:
|
||||||
|
|
||||||
# 其他认证失败情况,直接抛出
|
# 其他认证失败情况,直接抛出
|
||||||
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
|
raise BusinessException(e.message, BizCode.LOGIN_FAILED)
|
||||||
|
|
||||||
# 创建 tokens
|
# 创建 tokens
|
||||||
access_token, access_token_id = security.create_access_token(subject=user.id)
|
access_token, access_token_id = security.create_access_token(subject=user.id)
|
||||||
@@ -110,14 +115,15 @@ async def login_for_access_token(
|
|||||||
expires_at=access_expires_at,
|
expires_at=access_expires_at,
|
||||||
refresh_expires_at=refresh_expires_at
|
refresh_expires_at=refresh_expires_at
|
||||||
),
|
),
|
||||||
msg="登录成功"
|
msg=t("auth.login.success")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=ApiResponse)
|
@router.post("/refresh", response_model=ApiResponse)
|
||||||
async def refresh_token(
|
async def refresh_token(
|
||||||
refresh_request: RefreshTokenRequest,
|
refresh_request: RefreshTokenRequest,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""刷新token"""
|
"""刷新token"""
|
||||||
auth_logger.info("收到token刷新请求")
|
auth_logger.info("收到token刷新请求")
|
||||||
@@ -125,18 +131,18 @@ async def refresh_token(
|
|||||||
# 验证 refresh token
|
# 验证 refresh token
|
||||||
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||||||
if not userId:
|
if not userId:
|
||||||
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
|
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
|
||||||
|
|
||||||
# 检查用户是否存在
|
# 检查用户是否存在
|
||||||
user = auth_service.get_user_by_id(db, userId)
|
user = auth_service.get_user_by_id(db, userId)
|
||||||
if not user:
|
if not user:
|
||||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NO_ACCESS)
|
||||||
|
|
||||||
# 检查 refresh token 黑名单
|
# 检查 refresh token 黑名单
|
||||||
if settings.ENABLE_SINGLE_SESSION:
|
if settings.ENABLE_SINGLE_SESSION:
|
||||||
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||||
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||||||
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
|
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
|
||||||
|
|
||||||
# 生成新 tokens
|
# 生成新 tokens
|
||||||
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||||||
@@ -167,7 +173,7 @@ async def refresh_token(
|
|||||||
expires_at=access_expires_at,
|
expires_at=access_expires_at,
|
||||||
refresh_expires_at=refresh_expires_at
|
refresh_expires_at=refresh_expires_at
|
||||||
),
|
),
|
||||||
msg="token刷新成功"
|
msg=t("auth.token.refresh_success")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -175,14 +181,15 @@ async def refresh_token(
|
|||||||
async def logout(
|
async def logout(
|
||||||
token: str = Depends(oauth2_scheme),
|
token: str = Depends(oauth2_scheme),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""登出当前用户:加入token黑名单并清理会话"""
|
"""登出当前用户:加入token黑名单并清理会话"""
|
||||||
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||||||
|
|
||||||
token_id = security.get_token_id(token)
|
token_id = security.get_token_id(token)
|
||||||
if not token_id:
|
if not token_id:
|
||||||
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
|
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
|
||||||
|
|
||||||
# 加入黑名单
|
# 加入黑名单
|
||||||
await SessionService.blacklist_token(token_id)
|
await SessionService.blacklist_token(token_id)
|
||||||
@@ -192,5 +199,5 @@ async def logout(
|
|||||||
await SessionService.clear_user_session(current_user.username)
|
await SessionService.clear_user_session(current_user.username)
|
||||||
|
|
||||||
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||||||
return success(msg="登出成功")
|
return success(msg=t("auth.logout.success"))
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from app.models.user_model import User
|
|||||||
from app.schemas import chunk_schema
|
from app.schemas import chunk_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||||
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -442,10 +443,10 @@ async def retrieve_chunks(
|
|||||||
match retrieve_data.retrieve_type:
|
match retrieve_data.retrieve_type:
|
||||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||||
return success(data=rs, msg="retrieval successful")
|
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||||
case chunk_schema.RetrieveType.SEMANTIC:
|
case chunk_schema.RetrieveType.SEMANTIC:
|
||||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||||
return success(data=rs, msg="retrieval successful")
|
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||||
case _:
|
case _:
|
||||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||||
@@ -456,22 +457,24 @@ async def retrieve_chunks(
|
|||||||
if doc.metadata["doc_id"] not in seen_ids:
|
if doc.metadata["doc_id"] not in seen_ids:
|
||||||
seen_ids.add(doc.metadata["doc_id"])
|
seen_ids.add(doc.metadata["doc_id"])
|
||||||
unique_rs.append(doc)
|
unique_rs.append(doc)
|
||||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
|
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) if unique_rs else []
|
||||||
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
||||||
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
||||||
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
||||||
|
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
|
||||||
|
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
|
||||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||||
chat_model = Base(
|
chat_model = Base(
|
||||||
key=db_knowledge.llm.api_keys[0].api_key,
|
key=llm_key.api_key,
|
||||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
model_name=llm_key.model_name,
|
||||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
base_url=llm_key.api_base
|
||||||
)
|
)
|
||||||
embedding_model = OpenAIEmbed(
|
embedding_model = OpenAIEmbed(
|
||||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
key=emb_key.api_key,
|
||||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
model_name=emb_key.model_name,
|
||||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
base_url=emb_key.api_base
|
||||||
)
|
)
|
||||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||||
if doc:
|
if doc:
|
||||||
rs.insert(0, doc)
|
rs.insert(0, doc)
|
||||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||||
@@ -314,8 +314,10 @@ async def parse_documents(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 4. Check if the file exists
|
# 4. Check if the file exists
|
||||||
|
api_logger.debug(f"Constructed file path: {file_path}")
|
||||||
|
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
|
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="File not found (possibly deleted)"
|
detail="File not found (possibly deleted)"
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from app.models.user_model import User
|
|||||||
from app.schemas import file_schema, document_schema
|
from app.schemas import file_schema, document_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import file_service, document_service
|
from app.services import file_service, document_service
|
||||||
|
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||||
|
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
@@ -131,6 +132,7 @@ async def create_folder(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/file", response_model=ApiResponse)
|
@router.post("/file", response_model=ApiResponse)
|
||||||
|
@check_knowledge_capacity_quota
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
|
|||||||
@@ -14,8 +14,11 @@ Routes:
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
import httpx
|
||||||
|
import mimetypes
|
||||||
|
from urllib.parse import urlparse, unquote
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||||
from fastapi.responses import FileResponse, RedirectResponse
|
from fastapi.responses import FileResponse, RedirectResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -47,6 +50,19 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _match_scheme(request: Request, url: str) -> str:
|
||||||
|
"""
|
||||||
|
将 presigned URL 的协议替换为与当前请求一致的协议(http/https)。
|
||||||
|
解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。
|
||||||
|
"""
|
||||||
|
incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme
|
||||||
|
if url.startswith("http://") and incoming_scheme == "https":
|
||||||
|
return "https://" + url[7:]
|
||||||
|
if url.startswith("https://") and incoming_scheme == "http":
|
||||||
|
return "http://" + url[8:]
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
@router.post("/files", response_model=ApiResponse)
|
@router.post("/files", response_model=ApiResponse)
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
@@ -78,7 +94,7 @@ async def upload_file(
|
|||||||
|
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
||||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -159,7 +175,6 @@ async def upload_file_with_share_token(
|
|||||||
|
|
||||||
# Get share and release info from share_token
|
# Get share and release info from share_token
|
||||||
service = ReleaseShareService(db)
|
service = ReleaseShareService(db)
|
||||||
share_info = service.get_shared_release_info(share_token=share_data.share_token)
|
|
||||||
|
|
||||||
# Get share object to access app_id
|
# Get share object to access app_id
|
||||||
share = service.repo.get_by_share_token(share_data.share_token)
|
share = service.repo.get_by_share_token(share_data.share_token)
|
||||||
@@ -278,8 +293,104 @@ async def upload_file_with_share_token(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/info-by-url", response_model=ApiResponse)
|
||||||
|
async def get_file_info_by_url(
|
||||||
|
url: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get file information by network URL (no authentication required).
|
||||||
|
|
||||||
|
Fetches file metadata from a remote URL via HTTP HEAD request.
|
||||||
|
Falls back to GET request if HEAD is not supported.
|
||||||
|
Returns file type, name, and size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The network URL of the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse with file information.
|
||||||
|
"""
|
||||||
|
api_logger.info(f"File info by URL request: url={url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
# Try HEAD request first
|
||||||
|
response = await client.head(url, follow_redirects=True)
|
||||||
|
|
||||||
|
# If HEAD fails, try GET request (some servers don't support HEAD)
|
||||||
|
if response.status_code != 200:
|
||||||
|
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
|
||||||
|
response = await client.get(url, follow_redirects=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unable to access file: HTTP {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get file size from Content-Length header or actual content
|
||||||
|
file_size = response.headers.get("Content-Length")
|
||||||
|
if file_size:
|
||||||
|
file_size = int(file_size)
|
||||||
|
elif hasattr(response, 'content'):
|
||||||
|
file_size = len(response.content)
|
||||||
|
else:
|
||||||
|
file_size = None
|
||||||
|
|
||||||
|
# Get content type from Content-Type header
|
||||||
|
content_type = response.headers.get("Content-Type", "application/octet-stream")
|
||||||
|
# Remove charset and other parameters from content type
|
||||||
|
content_type = content_type.split(';')[0].strip()
|
||||||
|
|
||||||
|
# Extract filename from Content-Disposition or URL
|
||||||
|
file_name = None
|
||||||
|
content_disposition = response.headers.get("Content-Disposition")
|
||||||
|
if content_disposition and "filename=" in content_disposition:
|
||||||
|
parts = content_disposition.split("filename=")
|
||||||
|
if len(parts) > 1:
|
||||||
|
file_name = parts[1].strip('"').strip("'")
|
||||||
|
|
||||||
|
if not file_name:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
|
||||||
|
|
||||||
|
# Extract file extension from filename
|
||||||
|
_, file_ext = os.path.splitext(file_name)
|
||||||
|
|
||||||
|
# If no extension found, infer from content type
|
||||||
|
if not file_ext:
|
||||||
|
ext = mimetypes.guess_extension(content_type)
|
||||||
|
if ext:
|
||||||
|
file_ext = ext
|
||||||
|
file_name = f"{file_name}{file_ext}"
|
||||||
|
|
||||||
|
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"url": url,
|
||||||
|
"file_name": file_name,
|
||||||
|
"file_ext": file_ext.lower() if file_ext else "",
|
||||||
|
"file_size": file_size,
|
||||||
|
"content_type": content_type,
|
||||||
|
},
|
||||||
|
msg="File information retrieved successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Unexpected error: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to retrieve file information: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/files/{file_id}", response_model=Any)
|
@router.get("/files/{file_id}", response_model=Any)
|
||||||
async def download_file(
|
async def download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -327,6 +438,7 @@ async def download_file(
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@@ -400,6 +512,7 @@ async def delete_file(
|
|||||||
|
|
||||||
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
||||||
async def get_file_url(
|
async def get_file_url(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
expires: int = None,
|
expires: int = None,
|
||||||
permanent: bool = False,
|
permanent: bool = False,
|
||||||
@@ -461,8 +574,13 @@ async def get_file_url(
|
|||||||
# For local storage, generate signed URL with expiration
|
# For local storage, generate signed URL with expiration
|
||||||
url = generate_signed_url(str(file_id), expires)
|
url = generate_signed_url(str(file_id), expires)
|
||||||
else:
|
else:
|
||||||
# For remote storage (OSS/S3), get presigned URL
|
# For remote storage (OSS/S3), get presigned URL with forced download
|
||||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
url = await storage_service.get_file_url(
|
||||||
|
file_key,
|
||||||
|
expires=expires,
|
||||||
|
file_name=file_metadata.file_name,
|
||||||
|
)
|
||||||
|
url = _match_scheme(request, url)
|
||||||
|
|
||||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||||
return success(
|
return success(
|
||||||
@@ -482,8 +600,54 @@ async def get_file_url(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/{file_id}/public-url", response_model=ApiResponse)
|
||||||
|
async def get_permanent_file_url(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取文件的永久公开 URL(无过期时间)。
|
||||||
|
|
||||||
|
- 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置)
|
||||||
|
- 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限)
|
||||||
|
"""
|
||||||
|
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||||
|
if not file_metadata:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist")
|
||||||
|
|
||||||
|
if file_metadata.status != "completed":
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"File upload not completed, status: {file_metadata.status}")
|
||||||
|
|
||||||
|
file_key = file_metadata.file_key
|
||||||
|
storage = storage_service.storage
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(storage, LocalStorage):
|
||||||
|
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
|
||||||
|
else:
|
||||||
|
url = await storage.get_permanent_url(file_key)
|
||||||
|
if not url:
|
||||||
|
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="Permanent URL not supported for current storage backend")
|
||||||
|
|
||||||
|
api_logger.info(f"Generated permanent URL: file_id={file_id}")
|
||||||
|
return success(
|
||||||
|
data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name},
|
||||||
|
msg="Permanent file URL generated successfully"
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to generate permanent URL: {e}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to generate permanent URL: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/public/{file_id}", response_model=Any)
|
@router.get("/public/{file_id}", response_model=Any)
|
||||||
async def public_download_file(
|
async def public_download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
expires: int = 0,
|
expires: int = 0,
|
||||||
signature: str = "",
|
signature: str = "",
|
||||||
@@ -555,6 +719,7 @@ async def public_download_file(
|
|||||||
# For remote storage, redirect to presigned URL
|
# For remote storage, redirect to presigned URL
|
||||||
try:
|
try:
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||||
@@ -566,6 +731,7 @@ async def public_download_file(
|
|||||||
|
|
||||||
@router.get("/permanent/{file_id}", response_model=Any)
|
@router.get("/permanent/{file_id}", response_model=Any)
|
||||||
async def permanent_download_file(
|
async def permanent_download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
@@ -624,7 +790,8 @@ async def permanent_download_file(
|
|||||||
# For remote storage, redirect to presigned URL with long expiration
|
# For remote storage, redirect to presigned URL with long expiration
|
||||||
try:
|
try:
|
||||||
# Use a very long expiration (7 days max for most cloud providers)
|
# Use a very long expiration (7 days max for most cloud providers)
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||||
@@ -632,3 +799,44 @@ async def permanent_download_file(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Failed to retrieve file: {str(e)}"
|
detail=f"Failed to retrieve file: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/{file_id}/status", response_model=ApiResponse)
|
||||||
|
async def get_file_status(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get file upload/processing status (no authentication required).
|
||||||
|
|
||||||
|
This endpoint is used to check if a file (e.g., TTS audio) is ready.
|
||||||
|
Returns status: pending, completed, or failed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The UUID of the file.
|
||||||
|
db: Database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse with file status and metadata.
|
||||||
|
"""
|
||||||
|
api_logger.info(f"File status request: file_id={file_id}")
|
||||||
|
|
||||||
|
# Query file metadata from database
|
||||||
|
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||||
|
if not file_metadata:
|
||||||
|
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"file_id": str(file_id),
|
||||||
|
"status": file_metadata.status,
|
||||||
|
"file_name": file_metadata.file_name,
|
||||||
|
"file_size": file_metadata.file_size,
|
||||||
|
"content_type": file_metadata.content_type,
|
||||||
|
},
|
||||||
|
msg="File status retrieved successfully"
|
||||||
|
)
|
||||||
|
|||||||
@@ -3,9 +3,10 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db, SessionLocal
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
|
from app.repositories.home_page_repository import HomePageRepository
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.home_page_service import HomePageService
|
from app.services.home_page_service import HomePageService
|
||||||
|
|
||||||
@@ -31,9 +32,32 @@ def get_workspace_list(
|
|||||||
|
|
||||||
@router.get("/version", response_model=ApiResponse)
|
@router.get("/version", response_model=ApiResponse)
|
||||||
def get_system_version():
|
def get_system_version():
|
||||||
"""获取系统版本号+说明"""
|
"""获取系统版本号 + 说明"""
|
||||||
current_version = settings.SYSTEM_VERSION
|
current_version = None
|
||||||
version_info = HomePageService.load_version_introduction(current_version)
|
version_info = None
|
||||||
|
|
||||||
|
# 1️⃣ 优先从数据库获取最新已发布的版本
|
||||||
|
try:
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
current_version, version_info = HomePageRepository.get_latest_version_introduction(db)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2️⃣ 降级:使用环境变量中的版本号
|
||||||
|
if not current_version:
|
||||||
|
current_version = settings.SYSTEM_VERSION
|
||||||
|
version_info = HomePageService.load_version_introduction(current_version)
|
||||||
|
|
||||||
|
# 3️⃣ 如果数据库和 JSON 都没有,返回基本信息
|
||||||
|
if not version_info:
|
||||||
|
version_info = {
|
||||||
|
"introduction": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []},
|
||||||
|
"introduction_en": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []}
|
||||||
|
}
|
||||||
|
|
||||||
return success(
|
return success(
|
||||||
data={
|
data={
|
||||||
"version": current_version,
|
"version": current_version,
|
||||||
|
|||||||
833
api/app/controllers/i18n_controller.py
Normal file
833
api/app/controllers/i18n_controller.py
Normal file
@@ -0,0 +1,833 @@
|
|||||||
|
"""
|
||||||
|
I18n Management API Controller
|
||||||
|
|
||||||
|
This module provides management APIs for:
|
||||||
|
- Language management (list, get, add, update languages)
|
||||||
|
- Translation management (get, update, reload translations)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user, get_current_superuser
|
||||||
|
from app.i18n.dependencies import get_translator
|
||||||
|
from app.i18n.service import get_translation_service
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.schemas.i18n_schema import (
|
||||||
|
LanguageInfo,
|
||||||
|
LanguageListResponse,
|
||||||
|
LanguageCreateRequest,
|
||||||
|
LanguageUpdateRequest,
|
||||||
|
TranslationResponse,
|
||||||
|
TranslationUpdateRequest,
|
||||||
|
MissingTranslationsResponse,
|
||||||
|
ReloadResponse
|
||||||
|
)
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/i18n",
|
||||||
|
tags=["I18n Management"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Language Management APIs
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@router.get("/languages", response_model=ApiResponse)
|
||||||
|
def get_languages(
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get list of all supported languages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of language information including code, name, and status
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Get languages request from user: {current_user.username}")
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Get available locales from translation service
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
|
||||||
|
# Build language info list
|
||||||
|
languages = []
|
||||||
|
for locale in available_locales:
|
||||||
|
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||||
|
|
||||||
|
# Get native names
|
||||||
|
native_names = {
|
||||||
|
"zh": "中文(简体)",
|
||||||
|
"en": "English",
|
||||||
|
"ja": "日本語",
|
||||||
|
"ko": "한국어",
|
||||||
|
"fr": "Français",
|
||||||
|
"de": "Deutsch",
|
||||||
|
"es": "Español"
|
||||||
|
}
|
||||||
|
|
||||||
|
language_info = LanguageInfo(
|
||||||
|
code=locale,
|
||||||
|
name=f"{locale.upper()}",
|
||||||
|
native_name=native_names.get(locale, locale),
|
||||||
|
is_enabled=is_enabled,
|
||||||
|
is_default=is_default
|
||||||
|
)
|
||||||
|
languages.append(language_info)
|
||||||
|
|
||||||
|
response = LanguageListResponse(languages=languages)
|
||||||
|
|
||||||
|
api_logger.info(f"Returning {len(languages)} languages")
|
||||||
|
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/languages/{locale}", response_model=ApiResponse)
|
||||||
|
def get_language(
|
||||||
|
locale: str,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get information about a specific language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Language code (e.g., 'zh', 'en')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Language information
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if locale exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
api_logger.warning(f"Language not found: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build language info
|
||||||
|
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||||
|
|
||||||
|
native_names = {
|
||||||
|
"zh": "中文(简体)",
|
||||||
|
"en": "English",
|
||||||
|
"ja": "日本語",
|
||||||
|
"ko": "한국어",
|
||||||
|
"fr": "Français",
|
||||||
|
"de": "Deutsch",
|
||||||
|
"es": "Español"
|
||||||
|
}
|
||||||
|
|
||||||
|
language_info = LanguageInfo(
|
||||||
|
code=locale,
|
||||||
|
name=f"{locale.upper()}",
|
||||||
|
native_name=native_names.get(locale, locale),
|
||||||
|
is_enabled=is_enabled,
|
||||||
|
is_default=is_default
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Returning language info for: {locale}")
|
||||||
|
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/languages", response_model=ApiResponse)
|
||||||
|
def add_language(
|
||||||
|
request: LanguageCreateRequest,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a new language (admin only).
|
||||||
|
|
||||||
|
Note: This endpoint validates the request but actual language addition
|
||||||
|
requires creating translation files in the locales directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Language creation request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Add language request: code={request.code}, admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if language already exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if request.code in available_locales:
|
||||||
|
api_logger.warning(f"Language already exists: {request.code}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=t("i18n.language.already_exists", locale=request.code)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: Actual language addition requires creating translation files
|
||||||
|
# This endpoint serves as a validation and documentation point
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Language addition validated: {request.code}. "
|
||||||
|
"Translation files need to be created manually."
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
msg=t(
|
||||||
|
"i18n.language.add_instructions",
|
||||||
|
locale=request.code,
|
||||||
|
dir=settings.I18N_CORE_LOCALES_DIR
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/languages/{locale}", response_model=ApiResponse)
|
||||||
|
def update_language(
|
||||||
|
locale: str,
|
||||||
|
request: LanguageUpdateRequest,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update language configuration (admin only).
|
||||||
|
|
||||||
|
Note: This endpoint validates the request but actual configuration
|
||||||
|
changes require updating environment variables or config files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Language code
|
||||||
|
request: Language update request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Update language request: locale={locale}, admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if language exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
api_logger.warning(f"Language not found: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: Actual configuration changes require updating settings
|
||||||
|
# This endpoint serves as a validation and documentation point
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Language update validated: {locale}. "
|
||||||
|
"Configuration changes require environment variable updates."
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(msg=t("i18n.language.update_instructions", locale=locale))
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Translation Management APIs
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@router.get("/translations", response_model=ApiResponse)
|
||||||
|
def get_all_translations(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get all translations for all or specific locale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All translations organized by locale and namespace
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Get all translations request: locale={locale}, user={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
if locale:
|
||||||
|
# Get translations for specific locale
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
translations = {
|
||||||
|
locale: translation_service._cache.get(locale, {})
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Get all translations
|
||||||
|
translations = translation_service._cache
|
||||||
|
|
||||||
|
response = TranslationResponse(translations=translations)
|
||||||
|
|
||||||
|
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
|
||||||
|
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/translations/{locale}", response_model=ApiResponse)
|
||||||
|
def get_locale_translations(
|
||||||
|
locale: str,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get all translations for a specific locale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Language code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All translations for the locale organized by namespace
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Get locale translations request: locale={locale}, user={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if locale exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
api_logger.warning(f"Language not found: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
translations = translation_service._cache.get(locale, {})
|
||||||
|
|
||||||
|
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
|
||||||
|
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
|
||||||
|
def get_namespace_translations(
|
||||||
|
locale: str,
|
||||||
|
namespace: str,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get translations for a specific namespace in a locale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Language code
|
||||||
|
namespace: Translation namespace (e.g., 'common', 'auth')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translations for the specified namespace
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Get namespace translations request: locale={locale}, "
|
||||||
|
f"namespace={namespace}, user={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if locale exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
api_logger.warning(f"Language not found: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get namespace translations
|
||||||
|
locale_translations = translation_service._cache.get(locale, {})
|
||||||
|
namespace_translations = locale_translations.get(namespace, {})
|
||||||
|
|
||||||
|
if not namespace_translations:
|
||||||
|
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Returning translations for namespace: {namespace} in locale: {locale}"
|
||||||
|
)
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"locale": locale,
|
||||||
|
"namespace": namespace,
|
||||||
|
"translations": namespace_translations
|
||||||
|
},
|
||||||
|
msg=t("common.success.retrieved")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
|
||||||
|
def update_translation(
|
||||||
|
locale: str,
|
||||||
|
key: str,
|
||||||
|
request: TranslationUpdateRequest,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update a single translation (admin only).
|
||||||
|
|
||||||
|
Note: This endpoint validates the request but actual translation updates
|
||||||
|
require modifying translation files in the locales directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Language code
|
||||||
|
key: Translation key (format: "namespace.key.subkey")
|
||||||
|
request: Translation update request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Update translation request: locale={locale}, key={key}, "
|
||||||
|
f"admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if locale exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
api_logger.warning(f"Language not found: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate key format
|
||||||
|
if "." not in key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=t("i18n.translation.invalid_key_format", key=key)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: Actual translation updates require modifying JSON files
|
||||||
|
# This endpoint serves as a validation and documentation point
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Translation update validated: {locale}/{key}. "
|
||||||
|
"Translation files need to be updated manually."
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/translations/missing", response_model=ApiResponse)
|
||||||
|
def get_missing_translations(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get list of missing translations.
|
||||||
|
|
||||||
|
Compares translations across locales to find missing keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale to check (defaults to checking all non-default locales)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of missing translation keys
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Get missing translations request: locale={locale}, user={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
default_locale = settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
|
||||||
|
# Get default locale translations as reference
|
||||||
|
default_translations = translation_service._cache.get(default_locale, {})
|
||||||
|
|
||||||
|
# Collect all keys from default locale
|
||||||
|
def collect_keys(data, prefix=""):
|
||||||
|
keys = []
|
||||||
|
for key, value in data.items():
|
||||||
|
full_key = f"{prefix}.{key}" if prefix else key
|
||||||
|
if isinstance(value, dict):
|
||||||
|
keys.extend(collect_keys(value, full_key))
|
||||||
|
else:
|
||||||
|
keys.append(full_key)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
default_keys = set()
|
||||||
|
for namespace, translations in default_translations.items():
|
||||||
|
namespace_keys = collect_keys(translations, namespace)
|
||||||
|
default_keys.update(namespace_keys)
|
||||||
|
|
||||||
|
# Find missing keys in target locale(s)
|
||||||
|
missing_by_locale = {}
|
||||||
|
|
||||||
|
target_locales = [locale] if locale else [
|
||||||
|
loc for loc in available_locales if loc != default_locale
|
||||||
|
]
|
||||||
|
|
||||||
|
for target_locale in target_locales:
|
||||||
|
if target_locale not in available_locales:
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_translations = translation_service._cache.get(target_locale, {})
|
||||||
|
target_keys = set()
|
||||||
|
|
||||||
|
for namespace, translations in target_translations.items():
|
||||||
|
namespace_keys = collect_keys(translations, namespace)
|
||||||
|
target_keys.update(namespace_keys)
|
||||||
|
|
||||||
|
missing_keys = default_keys - target_keys
|
||||||
|
if missing_keys:
|
||||||
|
missing_by_locale[target_locale] = sorted(list(missing_keys))
|
||||||
|
|
||||||
|
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
|
||||||
|
|
||||||
|
total_missing = sum(len(keys) for keys in missing_by_locale.values())
|
||||||
|
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
|
||||||
|
|
||||||
|
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/reload", response_model=ApiResponse)
|
||||||
|
def reload_translations(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Trigger hot reload of translation files (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale to reload (defaults to reloading all locales)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reload status and statistics
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Reload translations request: locale={locale or 'all'}, "
|
||||||
|
f"admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
if not settings.I18N_ENABLE_HOT_RELOAD:
|
||||||
|
api_logger.warning("Hot reload is disabled in configuration")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=t("i18n.reload.disabled")
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Reload translations
|
||||||
|
translation_service.reload(locale)
|
||||||
|
|
||||||
|
# Get statistics
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
reloaded_locales = [locale] if locale else available_locales
|
||||||
|
|
||||||
|
response = ReloadResponse(
|
||||||
|
success=True,
|
||||||
|
reloaded_locales=reloaded_locales,
|
||||||
|
total_locales=len(available_locales)
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=response.dict(), msg=t("i18n.reload.success"))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to reload translations: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=t("i18n.reload.failed", error=str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Performance Monitoring APIs
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@router.get("/metrics", response_model=ApiResponse)
|
||||||
|
def get_metrics(
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get i18n performance metrics (admin only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Performance metrics including:
|
||||||
|
- Request counts
|
||||||
|
- Missing translations
|
||||||
|
- Timing statistics
|
||||||
|
- Locale usage
|
||||||
|
- Error counts
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Get metrics request: admin={current_user.username}")
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
metrics = translation_service.get_metrics_summary()
|
||||||
|
|
||||||
|
api_logger.info("Returning i18n metrics")
|
||||||
|
return success(data=metrics, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/metrics/cache", response_model=ApiResponse)
|
||||||
|
def get_cache_stats(
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get cache statistics (admin only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache statistics including:
|
||||||
|
- Hit/miss rates
|
||||||
|
- LRU cache performance
|
||||||
|
- Loaded locales
|
||||||
|
- Memory usage
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Get cache stats request: admin={current_user.username}")
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
cache_stats = translation_service.get_cache_stats()
|
||||||
|
memory_usage = translation_service.get_memory_usage()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"cache": cache_stats,
|
||||||
|
"memory": memory_usage
|
||||||
|
}
|
||||||
|
|
||||||
|
api_logger.info("Returning cache statistics")
|
||||||
|
return success(data=data, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/metrics/prometheus")
|
||||||
|
def get_prometheus_metrics(
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get metrics in Prometheus format (admin only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prometheus-formatted metrics as plain text
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
|
||||||
|
|
||||||
|
from app.i18n.metrics import get_metrics
|
||||||
|
metrics = get_metrics()
|
||||||
|
prometheus_output = metrics.export_prometheus()
|
||||||
|
|
||||||
|
from fastapi.responses import PlainTextResponse
|
||||||
|
return PlainTextResponse(content=prometheus_output)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/metrics/reset", response_model=ApiResponse)
|
||||||
|
def reset_metrics(
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Reset all metrics (admin only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Reset metrics request: admin={current_user.username}")
|
||||||
|
|
||||||
|
from app.i18n.metrics import get_metrics
|
||||||
|
metrics = get_metrics()
|
||||||
|
metrics.reset()
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
translation_service.cache.reset_stats()
|
||||||
|
|
||||||
|
api_logger.info("Metrics reset completed")
|
||||||
|
return success(msg=t("i18n.metrics.reset_success"))
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Missing Translation Logging and Reporting APIs
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@router.get("/logs/missing", response_model=ApiResponse)
|
||||||
|
def get_missing_translation_logs(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
limit: Optional[int] = 100,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get missing translation logs (admin only).
|
||||||
|
|
||||||
|
Returns logged missing translations with context information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale filter
|
||||||
|
limit: Maximum number of entries to return (default: 100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Missing translation logs with context
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Get missing translation logs request: locale={locale}, "
|
||||||
|
f"limit={limit}, admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
translation_logger = translation_service.translation_logger
|
||||||
|
|
||||||
|
# Get missing translations
|
||||||
|
missing_translations = translation_logger.get_missing_translations(locale)
|
||||||
|
|
||||||
|
# Get missing with context
|
||||||
|
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
|
||||||
|
|
||||||
|
# Get statistics
|
||||||
|
statistics = translation_logger.get_statistics()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"missing_translations": missing_translations,
|
||||||
|
"recent_context": missing_with_context,
|
||||||
|
"statistics": statistics
|
||||||
|
}
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Returning {statistics['total_missing']} missing translations"
|
||||||
|
)
|
||||||
|
return success(data=data, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/logs/missing/report", response_model=ApiResponse)
|
||||||
|
def generate_missing_translation_report(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a comprehensive missing translation report (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Comprehensive report with missing translations and statistics
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Generate missing translation report request: locale={locale}, "
|
||||||
|
f"admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
translation_logger = translation_service.translation_logger
|
||||||
|
|
||||||
|
# Generate report
|
||||||
|
report = translation_logger.generate_report(locale)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Generated report with {report['total_missing']} missing translations"
|
||||||
|
)
|
||||||
|
return success(data=report, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logs/missing/export", response_model=ApiResponse)
|
||||||
|
def export_missing_translations(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Export missing translations to JSON file (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Export status and file path
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Export missing translations request: locale={locale}, "
|
||||||
|
f"admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
translation_logger = translation_service.translation_logger
|
||||||
|
|
||||||
|
# Generate filename with timestamp
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
locale_suffix = f"_{locale}" if locale else "_all"
|
||||||
|
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
|
||||||
|
|
||||||
|
# Export to file
|
||||||
|
translation_logger.export_to_json(output_file)
|
||||||
|
|
||||||
|
api_logger.info(f"Missing translations exported to: {output_file}")
|
||||||
|
return success(
|
||||||
|
data={"file_path": output_file},
|
||||||
|
msg=t("i18n.logs.export_success", file=output_file)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/logs/missing", response_model=ApiResponse)
|
||||||
|
def clear_missing_translation_logs(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Clear missing translation logs (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale to clear (clears all if not specified)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Clear missing translation logs request: locale={locale or 'all'}, "
|
||||||
|
f"admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
translation_logger = translation_service.translation_logger
|
||||||
|
|
||||||
|
# Clear logs
|
||||||
|
translation_logger.clear(locale)
|
||||||
|
|
||||||
|
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
|
||||||
|
return success(msg=t("i18n.logs.clear_success"))
|
||||||
@@ -27,6 +27,7 @@ from app.schemas import knowledge_schema
|
|||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import knowledge_service, document_service
|
from app.services import knowledge_service, document_service
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -179,6 +180,7 @@ async def get_knowledges(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/knowledge", response_model=ApiResponse)
|
@router.post("/knowledge", response_model=ApiResponse)
|
||||||
|
@check_knowledge_capacity_quota
|
||||||
async def create_knowledge(
|
async def create_knowledge(
|
||||||
create_data: knowledge_schema.KnowledgeCreate,
|
create_data: knowledge_schema.KnowledgeCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -352,6 +354,7 @@ async def delete_knowledge(
|
|||||||
# 2. Soft-delete knowledge base
|
# 2. Soft-delete knowledge base
|
||||||
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
|
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
|
||||||
db_knowledge.status = 2
|
db_knowledge.status = 2
|
||||||
|
db_knowledge.updated_at = datetime.datetime.now()
|
||||||
db.commit()
|
db.commit()
|
||||||
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
|
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
|
||||||
return success(msg="The knowledge base has been successfully deleted")
|
return success(msg="The knowledge base has been successfully deleted")
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from app.models import mcp_market_config_model
|
|||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.schemas import mcp_market_config_schema
|
from app.schemas import mcp_market_config_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import mcp_market_config_service
|
from app.services import mcp_market_config_service, mcp_market_service
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -91,9 +91,11 @@ async def get_mcp_servers(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
cookies = api.get_cookies(token)
|
cookies = api.get_cookies(token)
|
||||||
|
headers=api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {token}'
|
||||||
r = api.session.put(
|
r = api.session.put(
|
||||||
url=api.mcp_base_url,
|
url=api.mcp_base_url,
|
||||||
headers=api.builder_headers(api.headers),
|
headers=headers,
|
||||||
json=body,
|
json=body,
|
||||||
cookies=cookies)
|
cookies=cookies)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
@@ -123,6 +125,17 @@ async def get_mcp_servers(
|
|||||||
"has_next": True if page * pagesize < total else False
|
"has_next": True if page * pagesize < total else False
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
# 5. Update mck_market.mcp_count
|
||||||
|
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=db_mcp_market_config.mcp_market_id, current_user=current_user)
|
||||||
|
if not db_mcp_market:
|
||||||
|
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={db_mcp_market_config.mcp_market_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The mcp market does not exist or access is denied"
|
||||||
|
)
|
||||||
|
db_mcp_market.mcp_count = total
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_mcp_market)
|
||||||
return success(data=result, msg="Query of mcp servers list successful")
|
return success(data=result, msg="Query of mcp servers list successful")
|
||||||
|
|
||||||
|
|
||||||
@@ -162,6 +175,7 @@ async def get_operational_mcp_servers(
|
|||||||
|
|
||||||
url = f'{api.mcp_base_url}/operational'
|
url = f'{api.mcp_base_url}/operational'
|
||||||
headers = api.builder_headers(api.headers)
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {token}'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
||||||
@@ -249,7 +263,9 @@ async def create_mcp_market_config(
|
|||||||
api.login(create_data.token)
|
api.login(create_data.token)
|
||||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||||
cookies = api.get_cookies(create_data.token)
|
cookies = api.get_cookies(create_data.token)
|
||||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {create_data.token}'
|
||||||
|
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||||
@@ -265,6 +281,32 @@ async def create_mcp_market_config(
|
|||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
||||||
)
|
)
|
||||||
|
# 2. verify token
|
||||||
|
create_data.status = 1
|
||||||
|
try:
|
||||||
|
api = MCPApi()
|
||||||
|
token = create_data.token
|
||||||
|
api.login(token)
|
||||||
|
|
||||||
|
body = {
|
||||||
|
'filter': {},
|
||||||
|
'page_number': 1,
|
||||||
|
'page_size': 20,
|
||||||
|
'search': ""
|
||||||
|
}
|
||||||
|
cookies = api.get_cookies(token)
|
||||||
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {token}'
|
||||||
|
r = api.session.put(
|
||||||
|
url=api.mcp_base_url,
|
||||||
|
headers=headers,
|
||||||
|
json=body,
|
||||||
|
cookies=cookies)
|
||||||
|
raise_for_http_status(r)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||||
|
create_data.status = 0
|
||||||
|
# 3. create mcp_market_config
|
||||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
||||||
@@ -358,7 +400,9 @@ async def update_mcp_market_config(
|
|||||||
api.login(update_data.token)
|
api.login(update_data.token)
|
||||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||||
cookies = api.get_cookies(update_data.token)
|
cookies = api.get_cookies(update_data.token)
|
||||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {update_data.token}'
|
||||||
|
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||||
@@ -395,7 +439,7 @@ async def update_mcp_market_config(
|
|||||||
detail=f"The mcp market config update failed: {str(e)}"
|
detail=f"The mcp market config update failed: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Return the updated mcp market config
|
# 5. Return the updated mcp market config
|
||||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||||
msg="The mcp market config information updated successfully")
|
msg="The mcp market config information updated successfully")
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from app.core.language_utils import get_language_from_header
|
|||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
|
from app.core.memory.enums import SearchStrategy, Neo4jNodeType
|
||||||
|
from app.core.memory.memory_service import MemoryService
|
||||||
from app.core.rag.llm.cv_model import QWenCV
|
from app.core.rag.llm.cv_model import QWenCV
|
||||||
from app.core.response_utils import fail, success
|
from app.core.response_utils import fail, success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
@@ -19,10 +21,11 @@ from app.dependencies import cur_workspace_access_guard, get_current_user
|
|||||||
from app.models import ModelApiKey
|
from app.models import ModelApiKey
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
from app.schemas.memory_agent_schema import StorageType, UserInput, Write_UserInput, WriteMemoryRequest
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import task_service, workspace_service
|
from app.services import task_service, workspace_service
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config as get_config
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -118,142 +121,142 @@ async def download_log(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/writer_service", response_model=ApiResponse)
|
# @router.post("/writer_service", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
# @cur_workspace_access_guard()
|
||||||
async def write_server(
|
# async def write_server(
|
||||||
user_input: Write_UserInput,
|
# user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
# db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
# current_user: User = Depends(get_current_user)
|
||||||
):
|
# ):
|
||||||
"""
|
# """
|
||||||
Write service endpoint - processes write operations synchronously
|
# Write service endpoint - processes write operations synchronously
|
||||||
|
#
|
||||||
Args:
|
# Args:
|
||||||
user_input: Write request containing message and end_user_id
|
# user_input: Write request containing message and end_user_id
|
||||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||||
|
#
|
||||||
Returns:
|
# Returns:
|
||||||
Response with write operation status
|
# Response with write operation status
|
||||||
"""
|
# """
|
||||||
# 使用集中化的语言校验
|
# # 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
# language = get_language_from_header(language_type)
|
||||||
|
#
|
||||||
config_id = user_input.config_id
|
# config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
# workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
#
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# # 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
# storage_type = workspace_service.get_workspace_storage_type(
|
||||||
db=db,
|
# db=db,
|
||||||
workspace_id=workspace_id,
|
# workspace_id=workspace_id,
|
||||||
user=current_user
|
# user=current_user
|
||||||
)
|
# )
|
||||||
if storage_type is None: storage_type = 'neo4j'
|
# if storage_type is None: storage_type = 'neo4j'
|
||||||
user_rag_memory_id = ''
|
# user_rag_memory_id = ''
|
||||||
|
#
|
||||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||||
if storage_type == 'rag':
|
# if storage_type == 'rag':
|
||||||
if workspace_id:
|
# if workspace_id:
|
||||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
db=db,
|
# db=db,
|
||||||
name="USER_RAG_MERORY",
|
# name="USER_RAG_MERORY",
|
||||||
workspace_id=workspace_id
|
# workspace_id=workspace_id
|
||||||
)
|
# )
|
||||||
if knowledge:
|
# if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
# user_rag_memory_id = str(knowledge.id)
|
||||||
else:
|
# else:
|
||||||
api_logger.warning(
|
# api_logger.warning(
|
||||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
# storage_type = 'neo4j'
|
||||||
else:
|
# else:
|
||||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
# storage_type = 'neo4j'
|
||||||
|
#
|
||||||
api_logger.info(
|
# api_logger.info(
|
||||||
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||||
try:
|
# try:
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
result = await memory_agent_service.write_memory(
|
# result = await memory_agent_service.write_memory(
|
||||||
user_input.end_user_id,
|
# user_input.end_user_id,
|
||||||
messages_list,
|
# messages_list,
|
||||||
config_id,
|
# config_id,
|
||||||
db,
|
# db,
|
||||||
storage_type,
|
# storage_type,
|
||||||
user_rag_memory_id,
|
# user_rag_memory_id,
|
||||||
language
|
# language
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
return success(data=result, msg="写入成功")
|
# return success(data=result, msg="写入成功")
|
||||||
except BaseException as e:
|
# except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
if hasattr(e, 'exceptions'):
|
# if hasattr(e, 'exceptions'):
|
||||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||||
detailed_error = "; ".join(error_messages)
|
# detailed_error = "; ".join(error_messages)
|
||||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
|
#
|
||||||
|
#
|
||||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
# @router.post("/writer_service_async", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
# @cur_workspace_access_guard()
|
||||||
async def write_server_async(
|
# async def write_server_async(
|
||||||
user_input: Write_UserInput,
|
# user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
# db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
# current_user: User = Depends(get_current_user)
|
||||||
):
|
# ):
|
||||||
"""
|
# """
|
||||||
Async write service endpoint - enqueues write processing to Celery
|
# Async write service endpoint - enqueues write processing to Celery
|
||||||
|
#
|
||||||
Args:
|
# Args:
|
||||||
user_input: Write request containing message and end_user_id
|
# user_input: Write request containing message and end_user_id
|
||||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||||
|
#
|
||||||
Returns:
|
# Returns:
|
||||||
Task ID for tracking async operation
|
# Task ID for tracking async operation
|
||||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
# Use GET /memory/write_result/{task_id} to check task status and get result
|
||||||
"""
|
# """
|
||||||
# 使用集中化的语言校验
|
# # 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
# language = get_language_from_header(language_type)
|
||||||
|
#
|
||||||
config_id = user_input.config_id
|
# config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
# workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(
|
# api_logger.info(
|
||||||
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
#
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# # 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
# storage_type = workspace_service.get_workspace_storage_type(
|
||||||
db=db,
|
# db=db,
|
||||||
workspace_id=workspace_id,
|
# workspace_id=workspace_id,
|
||||||
user=current_user
|
# user=current_user
|
||||||
)
|
# )
|
||||||
if storage_type is None: storage_type = 'neo4j'
|
# if storage_type is None: storage_type = 'neo4j'
|
||||||
user_rag_memory_id = ''
|
# user_rag_memory_id = ''
|
||||||
if workspace_id:
|
# if workspace_id:
|
||||||
|
#
|
||||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
db=db,
|
# db=db,
|
||||||
name="USER_RAG_MERORY",
|
# name="USER_RAG_MERORY",
|
||||||
workspace_id=workspace_id
|
# workspace_id=workspace_id
|
||||||
)
|
# )
|
||||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
# if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
try:
|
# try:
|
||||||
# 获取标准化的消息列表
|
# # 获取标准化的消息列表
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
|
#
|
||||||
task = celery_app.send_task(
|
# task = celery_app.send_task(
|
||||||
"app.core.memory.agent.write_message",
|
# "app.core.memory.agent.write_message",
|
||||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||||
)
|
# )
|
||||||
api_logger.info(f"Write task queued: {task.id}")
|
# api_logger.info(f"Write task queued: {task.id}")
|
||||||
|
#
|
||||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
# return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
# api_logger.error(f"Async write operation failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/read_service", response_model=ApiResponse)
|
@router.post("/read_service", response_model=ApiResponse)
|
||||||
@@ -300,33 +303,90 @@ async def read_server(
|
|||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||||
try:
|
try:
|
||||||
result = await memory_agent_service.read_memory(
|
# result = await memory_agent_service.read_memory(
|
||||||
user_input.end_user_id,
|
# user_input.end_user_id,
|
||||||
user_input.message,
|
# user_input.message,
|
||||||
user_input.history,
|
# user_input.history,
|
||||||
user_input.search_switch,
|
# user_input.search_switch,
|
||||||
config_id,
|
# config_id,
|
||||||
|
# db,
|
||||||
|
# storage_type,
|
||||||
|
# user_rag_memory_id
|
||||||
|
# )
|
||||||
|
# if str(user_input.search_switch) == "2":
|
||||||
|
# retrieve_info = result['answer']
|
||||||
|
# history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
||||||
|
# user_input.end_user_id)
|
||||||
|
# query = user_input.message
|
||||||
|
#
|
||||||
|
# # 调用 memory_agent_service 的方法生成最终答案
|
||||||
|
# result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||||
|
# end_user_id=user_input.end_user_id,
|
||||||
|
# retrieve_info=retrieve_info,
|
||||||
|
# history=history,
|
||||||
|
# query=query,
|
||||||
|
# config_id=config_id,
|
||||||
|
# db=db
|
||||||
|
# )
|
||||||
|
# if "信息不足,无法回答" in result['answer']:
|
||||||
|
# result['answer'] = retrieve_info
|
||||||
|
memory_config = get_config(user_input.end_user_id, db)
|
||||||
|
service = MemoryService(
|
||||||
db,
|
db,
|
||||||
storage_type,
|
memory_config["memory_config_id"],
|
||||||
user_rag_memory_id
|
end_user_id=user_input.end_user_id
|
||||||
)
|
)
|
||||||
if str(user_input.search_switch) == "2":
|
search_result = await service.read(
|
||||||
retrieve_info = result['answer']
|
user_input.message,
|
||||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
SearchStrategy(user_input.search_switch)
|
||||||
user_input.end_user_id)
|
)
|
||||||
query = user_input.message
|
intermediate_outputs = []
|
||||||
|
sub_queries = set()
|
||||||
|
for memory in search_result.memories:
|
||||||
|
sub_queries.add(str(memory.query))
|
||||||
|
if user_input.search_switch in [SearchStrategy.DEEP, SearchStrategy.NORMAL]:
|
||||||
|
intermediate_outputs.append({
|
||||||
|
"type": "problem_split",
|
||||||
|
"title": "问题拆分",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": f"Q{idx+1}",
|
||||||
|
"question": question
|
||||||
|
}
|
||||||
|
for idx, question in enumerate(sub_queries)
|
||||||
|
]
|
||||||
|
})
|
||||||
|
perceptual_data = [
|
||||||
|
memory.data
|
||||||
|
for memory in search_result.memories
|
||||||
|
if memory.source == Neo4jNodeType.PERCEPTUAL
|
||||||
|
]
|
||||||
|
|
||||||
# 调用 memory_agent_service 的方法生成最终答案
|
intermediate_outputs.append({
|
||||||
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
"type": "perceptual_retrieve",
|
||||||
|
"title": "感知记忆检索",
|
||||||
|
"data": perceptual_data,
|
||||||
|
"total": len(perceptual_data),
|
||||||
|
})
|
||||||
|
intermediate_outputs.append({
|
||||||
|
"type": "search_result",
|
||||||
|
"title": f"合并检索结果 (共{len(sub_queries)}个查询,{len(search_result.memories)}条结果)",
|
||||||
|
"result": search_result.content,
|
||||||
|
"raw_result": search_result.memories,
|
||||||
|
"total": len(search_result.memories),
|
||||||
|
})
|
||||||
|
result = {
|
||||||
|
'answer': await memory_agent_service.generate_summary_from_retrieve(
|
||||||
end_user_id=user_input.end_user_id,
|
end_user_id=user_input.end_user_id,
|
||||||
retrieve_info=retrieve_info,
|
retrieve_info=search_result.content,
|
||||||
history=history,
|
history=[],
|
||||||
query=query,
|
query=user_input.message,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
db=db
|
db=db
|
||||||
)
|
),
|
||||||
if "信息不足,无法回答" in result['answer']:
|
"intermediate_outputs": intermediate_outputs
|
||||||
result['answer'] = retrieve_info
|
}
|
||||||
|
|
||||||
return success(data=result, msg="回复对话消息成功")
|
return success(data=result, msg="回复对话消息成功")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
@@ -801,11 +861,8 @@ async def get_end_user_connected_config(
|
|||||||
Returns:
|
Returns:
|
||||||
包含 memory_config_id 和相关信息的响应
|
包含 memory_config_id 和相关信息的响应
|
||||||
"""
|
"""
|
||||||
from app.services.memory_agent_service import (
|
|
||||||
get_end_user_connected_config as get_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
|
api_logger.info(f"Getting connected config for end_user_id: {end_user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = get_config(end_user_id, db)
|
result = get_config(end_user_id, db)
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -47,61 +49,61 @@ def get_workspace_total_end_users(
|
|||||||
|
|
||||||
@router.get("/end_users", response_model=ApiResponse)
|
@router.get("/end_users", response_model=ApiResponse)
|
||||||
async def get_workspace_end_users(
|
async def get_workspace_end_users(
|
||||||
|
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"),
|
||||||
|
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"),
|
||||||
|
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||||
|
pagesize: int = Query(10, ge=1, description="每页数量"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取工作空间的宿主列表(高性能优化版本 v2)
|
获取工作空间的宿主列表(分页查询,支持模糊搜索)
|
||||||
|
|
||||||
优化策略:
|
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
|
||||||
1. 批量查询 end_users(一次查询而非循环)
|
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
|
||||||
2. 并发查询所有用户的记忆数量(Neo4j)
|
|
||||||
3. RAG 模式使用批量查询(一次 SQL)
|
|
||||||
4. 只返回必要字段减少数据传输
|
|
||||||
5. 添加短期缓存减少重复查询
|
|
||||||
6. 并发执行配置查询和记忆数量查询
|
|
||||||
|
|
||||||
返回格式:
|
Args:
|
||||||
{
|
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||||
"end_user": {"id": "uuid", "other_name": "名称"},
|
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
||||||
"memory_num": {"total": 数量},
|
page: 页码(从1开始,默认1)
|
||||||
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
pagesize: 每页数量(默认10)
|
||||||
}
|
db: 数据库会话
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含宿主列表和分页信息
|
||||||
"""
|
"""
|
||||||
import asyncio
|
# 如果未提供 workspace_id,使用当前用户的工作空间
|
||||||
import json
|
if workspace_id is None:
|
||||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
# 尝试从缓存获取(30秒缓存)
|
|
||||||
cache_key = f"end_users:workspace:{workspace_id}"
|
|
||||||
try:
|
|
||||||
cached_data = await aio_redis_get(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
|
||||||
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
|
||||||
|
|
||||||
# 获取当前空间类型
|
# 获取当前空间类型
|
||||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
|
||||||
|
|
||||||
# 获取 end_users(已优化为批量查询)
|
# 获取分页的 end_users
|
||||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
current_user=current_user
|
current_user=current_user,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
keyword=keyword
|
||||||
)
|
)
|
||||||
|
|
||||||
|
end_users = end_users_result.get("items", [])
|
||||||
|
total = end_users_result.get("total", 0)
|
||||||
|
|
||||||
if not end_users:
|
if not end_users:
|
||||||
api_logger.info("工作空间下没有宿主")
|
api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}")
|
||||||
# 缓存空结果,避免重复查询
|
return success(data={
|
||||||
try:
|
"items": [],
|
||||||
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
"page": {
|
||||||
except Exception as e:
|
"page": page,
|
||||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
"pagesize": pagesize,
|
||||||
return success(data=[], msg="宿主列表获取成功")
|
"total": total,
|
||||||
|
"hasnext": (page * pagesize) < total
|
||||||
|
}
|
||||||
|
}, msg="宿主列表获取成功")
|
||||||
|
|
||||||
end_user_ids = [str(user.id) for user in end_users]
|
end_user_ids = [str(user.id) for user in end_users]
|
||||||
|
|
||||||
@@ -132,21 +134,13 @@ async def get_workspace_end_users(
|
|||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
|
|
||||||
elif current_workspace_type == "neo4j":
|
elif current_workspace_type == "neo4j":
|
||||||
# Neo4j 模式:并发查询(带并发限制)
|
# Neo4j 模式:批量查询(简化版本,只返回total)
|
||||||
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
try:
|
||||||
MAX_CONCURRENT_QUERIES = 10
|
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
||||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
return {uid: {"total": count} for uid, count in batch_result.items()}
|
||||||
|
except Exception as e:
|
||||||
async def get_neo4j_memory_num(end_user_id: str):
|
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
|
||||||
async with semaphore:
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
try:
|
|
||||||
return await memory_storage_service.search_all(end_user_id)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
|
||||||
return {"total": 0}
|
|
||||||
|
|
||||||
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
|
||||||
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
|
||||||
|
|
||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
|
|
||||||
@@ -171,12 +165,12 @@ async def get_workspace_end_users(
|
|||||||
get_memory_nums()
|
get_memory_nums()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建结果(优化:使用列表推导式)
|
# 构建结果列表
|
||||||
result = []
|
items = []
|
||||||
for end_user in end_users:
|
for end_user in end_users:
|
||||||
user_id = str(end_user.id)
|
user_id = str(end_user.id)
|
||||||
config_info = memory_configs_map.get(user_id, {})
|
config_info = memory_configs_map.get(user_id, {})
|
||||||
result.append({
|
items.append({
|
||||||
'end_user': {
|
'end_user': {
|
||||||
'id': user_id,
|
'id': user_id,
|
||||||
'other_name': end_user.other_name
|
'other_name': end_user.other_name
|
||||||
@@ -188,13 +182,26 @@ async def get_workspace_end_users(
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
# 写入缓存(30秒过期)
|
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||||
try:
|
try:
|
||||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
from app.tasks import init_community_clustering_for_users
|
||||||
|
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
|
||||||
|
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||||
|
|
||||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
# 构建分页响应
|
||||||
|
result = {
|
||||||
|
"items": items,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"total": total,
|
||||||
|
"hasnext": (page * pagesize) < total
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
||||||
return success(data=result, msg="宿主列表获取成功")
|
return success(data=result, msg="宿主列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
@@ -584,7 +591,7 @@ async def dashboard_data(
|
|||||||
"total_api_call": None
|
"total_api_call": None
|
||||||
}
|
}
|
||||||
|
|
||||||
# 1. 获取记忆总量(total_memory)
|
# 1. 获取记忆总量(total_memory)—— neo4j 独有逻辑:查询 neo4j 存储节点
|
||||||
try:
|
try:
|
||||||
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
|
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
|
||||||
db=db,
|
db=db,
|
||||||
@@ -593,45 +600,32 @@ async def dashboard_data(
|
|||||||
end_user_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||||
# total_app: 统计当前空间下的所有app数量
|
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}")
|
||||||
from app.repositories import app_repository
|
|
||||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
|
||||||
neo4j_data["total_app"] = len(apps_orm)
|
|
||||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||||
|
|
||||||
# 2. 获取知识库类型统计(total_knowledge)
|
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
||||||
try:
|
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
neo4j_data.update(common_stats)
|
||||||
memory_agent_service = MemoryAgentService()
|
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
||||||
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
only_active=True,
|
|
||||||
current_workspace_id=workspace_id,
|
|
||||||
db=db
|
|
||||||
)
|
|
||||||
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
|
|
||||||
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
|
||||||
|
|
||||||
# 3. 获取API调用统计(total_api_call)
|
# 计算昨日对比
|
||||||
try:
|
try:
|
||||||
# 使用 AppStatisticsService 获取真实的API调用统计
|
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
||||||
app_stats_service = AppStatisticsService(db)
|
db=db,
|
||||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
start_date=start_date,
|
storage_type=storage_type,
|
||||||
end_date=end_date
|
today_data=neo4j_data
|
||||||
)
|
)
|
||||||
# 计算总调用次数
|
neo4j_data.update(changes)
|
||||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
|
||||||
neo4j_data["total_api_call"] = total_api_calls
|
|
||||||
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"获取API调用统计失败: {str(e)}")
|
api_logger.warning(f"计算neo4j昨日对比失败: {str(e)}")
|
||||||
neo4j_data["total_api_call"] = 0
|
neo4j_data.update({
|
||||||
|
"total_memory_change": None,
|
||||||
|
"total_app_change": None,
|
||||||
|
"total_knowledge_change": None,
|
||||||
|
"total_api_call_change": None,
|
||||||
|
})
|
||||||
|
|
||||||
result["neo4j_data"] = neo4j_data
|
result["neo4j_data"] = neo4j_data
|
||||||
api_logger.info("成功获取neo4j_data")
|
api_logger.info("成功获取neo4j_data")
|
||||||
@@ -645,40 +639,36 @@ async def dashboard_data(
|
|||||||
"total_api_call": None
|
"total_api_call": None
|
||||||
}
|
}
|
||||||
|
|
||||||
# 获取RAG相关数据
|
# 1. 获取记忆总量(total_memory)—— rag 独有逻辑:查询 document 表的 chunk_num
|
||||||
try:
|
try:
|
||||||
# total_memory: 只统计用户知识库(permission_id='Memory')的chunk数
|
|
||||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||||
rag_data["total_memory"] = total_chunk
|
rag_data["total_memory"] = total_chunk
|
||||||
|
api_logger.info(f"成功获取RAG记忆总量: {total_chunk}")
|
||||||
# total_app: 统计当前空间下的所有app数量
|
|
||||||
from app.repositories import app_repository
|
|
||||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
|
||||||
rag_data["total_app"] = len(apps_orm)
|
|
||||||
|
|
||||||
# total_knowledge: 使用 total_kb(总知识库数)
|
|
||||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
|
||||||
rag_data["total_knowledge"] = total_kb
|
|
||||||
|
|
||||||
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
|
|
||||||
try:
|
|
||||||
app_stats_service = AppStatisticsService(db)
|
|
||||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date
|
|
||||||
)
|
|
||||||
# 计算总调用次数
|
|
||||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
|
||||||
rag_data["total_api_call"] = total_api_calls
|
|
||||||
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
|
||||||
rag_data["total_api_call"] = 0
|
|
||||||
|
|
||||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
api_logger.warning(f"获取RAG记忆总量失败: {str(e)}")
|
||||||
|
|
||||||
|
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
||||||
|
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
||||||
|
rag_data.update(common_stats)
|
||||||
|
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
||||||
|
|
||||||
|
# 计算昨日对比
|
||||||
|
try:
|
||||||
|
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
storage_type=storage_type,
|
||||||
|
today_data=rag_data
|
||||||
|
)
|
||||||
|
rag_data.update(changes)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"计算RAG昨日对比失败: {str(e)}")
|
||||||
|
rag_data.update({
|
||||||
|
"total_memory_change": None,
|
||||||
|
"total_app_change": None,
|
||||||
|
"total_knowledge_change": None,
|
||||||
|
"total_api_call_change": None,
|
||||||
|
})
|
||||||
|
|
||||||
result["rag_data"] = rag_data
|
result["rag_data"] = rag_data
|
||||||
api_logger.info("成功获取rag_data")
|
api_logger.info("成功获取rag_data")
|
||||||
|
|||||||
@@ -4,7 +4,9 @@
|
|||||||
处理显性记忆相关的API接口,包括情景记忆和语义记忆的查询。
|
处理显性记忆相关的API接口,包括情景记忆和语义记忆的查询。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success, fail
|
from app.core.response_utils import success, fail
|
||||||
@@ -69,6 +71,140 @@ async def get_explicit_memory_overview_api(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/episodics", response_model=ApiResponse)
|
||||||
|
async def get_episodic_memory_list_api(
|
||||||
|
end_user_id: str = Query(..., description="end user ID"),
|
||||||
|
page: int = Query(1, gt=0, description="page number, starting from 1"),
|
||||||
|
pagesize: int = Query(10, gt=0, le=100, description="number of items per page, max 100"),
|
||||||
|
start_date: Optional[int] = Query(None, description="start timestamp (ms)"),
|
||||||
|
end_date: Optional[int] = Query(None, description="end timestamp (ms)"),
|
||||||
|
episodic_type: str = Query("all", description="episodic type :all/conversation/project_work/learning/decision/important_event"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
获取情景记忆分页列表
|
||||||
|
|
||||||
|
返回指定用户的情景记忆列表,支持分页、时间范围筛选和情景类型筛选。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID(必填)
|
||||||
|
page: 页码(从1开始,默认1)
|
||||||
|
pagesize: 每页数量(默认10,最大100)
|
||||||
|
start_date: 开始时间戳(可选,毫秒),自动扩展到当天 00:00:00
|
||||||
|
end_date: 结束时间戳(可选,毫秒),自动扩展到当天 23:59:59
|
||||||
|
episodic_type: 情景类型筛选(可选,默认all)
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含情景记忆分页列表
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- 基础分页查询:GET /episodics?end_user_id=xxx&page=1&pagesize=5
|
||||||
|
返回第1页,每页5条数据
|
||||||
|
- 按时间范围筛选:GET /episodics?end_user_id=xxx&page=1&pagesize=5&start_date=1738684800000&end_date=1738771199000
|
||||||
|
返回指定时间范围内的数据
|
||||||
|
- 按情景类型筛选:GET /episodics?end_user_id=xxx&page=1&pagesize=5&episodic_type=important_event
|
||||||
|
返回类型为"重要事件"的数据
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- start_date 和 end_date 必须同时提供或同时不提供
|
||||||
|
- start_date 不能大于 end_date
|
||||||
|
- episodic_type 可选值:all, conversation, project_work, learning, decision, important_event
|
||||||
|
- total 为该用户情景记忆总数(不受筛选条件影响)
|
||||||
|
- page.total 为筛选后的总条数
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆列表但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"情景记忆分页查询: end_user_id={end_user_id}, "
|
||||||
|
f"start_date={start_date}, end_date={end_date}, episodic_type={episodic_type}, "
|
||||||
|
f"page={page}, pagesize={pagesize}, username={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. 参数校验
|
||||||
|
if page < 1 or pagesize < 1:
|
||||||
|
api_logger.warning(f"分页参数错误: page={page}, pagesize={pagesize}")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "分页参数必须大于0")
|
||||||
|
|
||||||
|
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
|
||||||
|
if episodic_type not in valid_episodic_types:
|
||||||
|
api_logger.warning(f"无效的情景类型参数: {episodic_type}")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
|
||||||
|
|
||||||
|
# 时间戳参数校验
|
||||||
|
if (start_date is not None and end_date is None) or (end_date is not None and start_date is None):
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "start_date和end_date必须同时提供")
|
||||||
|
|
||||||
|
if start_date is not None and end_date is not None and start_date > end_date:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "start_date不能大于end_date")
|
||||||
|
|
||||||
|
# 2. 执行查询
|
||||||
|
try:
|
||||||
|
result = await memory_explicit_service.get_episodic_memory_list(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
episodic_type=episodic_type,
|
||||||
|
)
|
||||||
|
api_logger.info(
|
||||||
|
f"情景记忆分页查询成功: end_user_id={end_user_id}, "
|
||||||
|
f"total={result['total']}, 返回={len(result['items'])}条"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"情景记忆分页查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "情景记忆分页查询失败", str(e))
|
||||||
|
|
||||||
|
# 3. 返回结构化响应
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
|
||||||
|
@router.get("/semantics", response_model=ApiResponse)
|
||||||
|
async def get_semantic_memory_list_api(
|
||||||
|
end_user_id: str = Query(..., description="终端用户ID"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
获取语义记忆列表
|
||||||
|
|
||||||
|
返回指定用户的全量语义记忆列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID(必填)
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含语义记忆全量列表
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试查询语义记忆列表但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"语义记忆列表查询: end_user_id={end_user_id}, username={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await memory_explicit_service.get_semantic_memory_list(
|
||||||
|
end_user_id=end_user_id
|
||||||
|
)
|
||||||
|
api_logger.info(
|
||||||
|
f"语义记忆列表查询成功: end_user_id={end_user_id}, total={len(result)}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"语义记忆列表查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "语义记忆列表查询失败", str(e))
|
||||||
|
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/details", response_model=ApiResponse)
|
@router.post("/details", response_model=ApiResponse)
|
||||||
async def get_explicit_memory_details_api(
|
async def get_explicit_memory_details_api(
|
||||||
request: ExplicitMemoryDetailsRequest,
|
request: ExplicitMemoryDetailsRequest,
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
|
|||||||
ForgettingCurveRequest,
|
ForgettingCurveRequest,
|
||||||
ForgettingCurveResponse,
|
ForgettingCurveResponse,
|
||||||
ForgettingCurvePoint,
|
ForgettingCurvePoint,
|
||||||
|
PendingNodesResponse,
|
||||||
)
|
)
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.memory_forget_service import MemoryForgetService
|
from app.services.memory_forget_service import MemoryForgetService
|
||||||
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/pending-nodes", response_model=ApiResponse)
|
||||||
|
async def get_pending_nodes(
|
||||||
|
end_user_id: str,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 10,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取待遗忘节点列表(独立分页接口)
|
||||||
|
|
||||||
|
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||||
|
此接口独立分页,与 /stats 接口分离。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 组ID(即 end_user_id,必填)
|
||||||
|
page: 页码(从1开始,默认1)
|
||||||
|
pagesize: 每页数量(默认10)
|
||||||
|
current_user: 当前用户
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含待遗忘节点列表和分页信息的响应
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
|
||||||
|
- 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- page 从1开始,pagesize 必须大于0
|
||||||
|
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
# 验证 end_user_id 必填
|
||||||
|
if not end_user_id:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
|
||||||
|
|
||||||
|
# 通过 end_user_id 获取关联的 config_id
|
||||||
|
try:
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
|
|
||||||
|
if config_id is None:
|
||||||
|
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||||
|
|
||||||
|
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||||
|
except ValueError as e:
|
||||||
|
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||||
|
|
||||||
|
# 验证分页参数
|
||||||
|
if page < 1:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
|
||||||
|
if pagesize < 1:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
|
||||||
|
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用服务层获取待遗忘节点列表
|
||||||
|
result = await forget_service.get_pending_nodes(
|
||||||
|
db=db,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
config_id=config_id,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
response_data = PendingNodesResponse(**result)
|
||||||
|
|
||||||
|
return success(data=response_data.model_dump(), msg="查询成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||||
async def get_forgetting_curve(
|
async def get_forgetting_curve(
|
||||||
request: ForgettingCurveRequest,
|
request: ForgettingCurveRequest,
|
||||||
|
|||||||
@@ -1,3 +1,19 @@
|
|||||||
|
"""
|
||||||
|
Memory Reflection Controller
|
||||||
|
|
||||||
|
This module provides REST API endpoints for managing memory reflection configurations
|
||||||
|
and operations. It handles reflection engine setup, configuration management, and
|
||||||
|
execution of self-reflection processes across memory systems.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Reflection configuration management (save, retrieve, update)
|
||||||
|
- Workspace-wide reflection execution across multiple applications
|
||||||
|
- Individual configuration-based reflection runs
|
||||||
|
- Multi-language support for reflection outputs
|
||||||
|
- Integration with Neo4j memory storage and LLM models
|
||||||
|
- Comprehensive error handling and logging
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@@ -28,9 +44,13 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
|
# Load environment variables for configuration
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# Initialize API logger for request tracking and debugging
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
# Configure router with prefix and tags for API organization
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/memory",
|
prefix="/memory",
|
||||||
tags=["Memory"],
|
tags=["Memory"],
|
||||||
@@ -43,7 +63,38 @@ async def save_reflection_config(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Save reflection configuration to data_comfig table"""
|
"""
|
||||||
|
Save reflection configuration to memory config table
|
||||||
|
|
||||||
|
Persists reflection engine configuration settings to the data_config table,
|
||||||
|
including reflection parameters, model settings, and evaluation criteria.
|
||||||
|
Validates configuration parameters and ensures data consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Memory reflection configuration data including:
|
||||||
|
- config_id: Configuration identifier to update
|
||||||
|
- reflection_enabled: Whether reflection is enabled
|
||||||
|
- reflection_period_in_hours: Reflection execution interval
|
||||||
|
- reflexion_range: Scope of reflection (partial/all)
|
||||||
|
- baseline: Reflection strategy (time/fact/hybrid)
|
||||||
|
- reflection_model_id: LLM model for reflection operations
|
||||||
|
- memory_verify: Enable memory verification checks
|
||||||
|
- quality_assessment: Enable quality assessment evaluation
|
||||||
|
current_user: Authenticated user saving the configuration
|
||||||
|
db: Database session for data operations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Success response with saved reflection configuration data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 400: If config_id is missing or parameters are invalid
|
||||||
|
HTTPException 500: If configuration save operation fails
|
||||||
|
|
||||||
|
Database Operations:
|
||||||
|
- Updates memory_config table with reflection settings
|
||||||
|
- Commits transaction and refreshes entity
|
||||||
|
- Maintains configuration consistency
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
config_id = request.config_id
|
config_id = request.config_id
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
@@ -54,6 +105,7 @@ async def save_reflection_config(
|
|||||||
)
|
)
|
||||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||||
|
|
||||||
|
# Update reflection configuration in database
|
||||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||||
db,
|
db,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
@@ -66,6 +118,7 @@ async def save_reflection_config(
|
|||||||
quality_assessment=request.quality_assessment
|
quality_assessment=request.quality_assessment
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Commit transaction and refresh entity
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(memory_config)
|
db.refresh(memory_config)
|
||||||
|
|
||||||
@@ -102,13 +155,55 @@ async def start_workspace_reflection(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""启动工作空间中所有匹配应用的反思功能"""
|
"""
|
||||||
|
Start reflection functionality for all matching applications in workspace
|
||||||
|
|
||||||
|
Initiates reflection processes across all applications within the user's current
|
||||||
|
workspace that have valid memory configurations. Processes each application's
|
||||||
|
configurations and associated end users, executing reflection operations
|
||||||
|
with proper error isolation and transaction management.
|
||||||
|
|
||||||
|
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
|
||||||
|
that reflection failures for individual users don't affect other operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_user: Authenticated user initiating workspace reflection
|
||||||
|
db: Database session for configuration queries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Success response with reflection results for all processed applications:
|
||||||
|
- app_id: Application identifier
|
||||||
|
- config_id: Memory configuration identifier
|
||||||
|
- end_user_id: End user identifier
|
||||||
|
- reflection_result: Individual reflection operation result
|
||||||
|
|
||||||
|
Processing Logic:
|
||||||
|
1. Retrieve all applications in the current workspace
|
||||||
|
2. Filter applications with valid memory configurations
|
||||||
|
3. For each configuration, find matching releases
|
||||||
|
4. Execute reflection for each end user with isolated transactions
|
||||||
|
5. Aggregate results with error handling per user
|
||||||
|
|
||||||
|
Error Handling:
|
||||||
|
- Individual user reflection failures are isolated
|
||||||
|
- Failed operations are logged and included in results
|
||||||
|
- Database transactions are isolated per user to prevent cascading failures
|
||||||
|
- Comprehensive error reporting for debugging
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 500: If workspace reflection initialization fails
|
||||||
|
|
||||||
|
Performance Notes:
|
||||||
|
- Uses independent database sessions for each user operation
|
||||||
|
- Prevents transaction failures from affecting other users
|
||||||
|
- Comprehensive logging for operation tracking
|
||||||
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||||
|
|
||||||
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
|
# Use independent database session to get workspace app details, avoiding transaction failures
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
with get_db_context() as query_db:
|
with get_db_context() as query_db:
|
||||||
service = WorkspaceAppService(query_db)
|
service = WorkspaceAppService(query_db)
|
||||||
@@ -116,8 +211,9 @@ async def start_workspace_reflection(
|
|||||||
|
|
||||||
reflection_results = []
|
reflection_results = []
|
||||||
|
|
||||||
|
# Process each application in the workspace
|
||||||
for data in result['apps_detailed_info']:
|
for data in result['apps_detailed_info']:
|
||||||
# 跳过没有配置的应用
|
# Skip applications without configurations
|
||||||
if not data['memory_configs']:
|
if not data['memory_configs']:
|
||||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||||
continue
|
continue
|
||||||
@@ -126,22 +222,22 @@ async def start_workspace_reflection(
|
|||||||
memory_configs = data['memory_configs']
|
memory_configs = data['memory_configs']
|
||||||
end_users = data['end_users']
|
end_users = data['end_users']
|
||||||
|
|
||||||
# 为每个配置和用户组合执行反思
|
# Execute reflection for each configuration and user combination
|
||||||
for config in memory_configs:
|
for config in memory_configs:
|
||||||
config_id_str = str(config['config_id'])
|
config_id_str = str(config['config_id'])
|
||||||
|
|
||||||
# 找到匹配此配置的所有release
|
# Find all releases matching this configuration
|
||||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||||
|
|
||||||
if not matching_releases:
|
if not matching_releases:
|
||||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 为每个用户执行反思 - 使用独立的数据库会话
|
# Execute reflection for each user - using independent database sessions
|
||||||
for user in end_users:
|
for user in end_users:
|
||||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||||
|
|
||||||
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
|
# Create independent database session for each user to avoid transaction failure impact
|
||||||
with get_db_context() as user_db:
|
with get_db_context() as user_db:
|
||||||
try:
|
try:
|
||||||
reflection_service = MemoryReflectionService(user_db)
|
reflection_service = MemoryReflectionService(user_db)
|
||||||
@@ -184,14 +280,51 @@ async def start_reflection_configs(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
"""
|
||||||
|
Query reflection configuration information by config_id
|
||||||
|
|
||||||
|
Retrieves detailed reflection configuration settings from the memory_config
|
||||||
|
table for a specific configuration ID. Provides comprehensive reflection
|
||||||
|
parameters including model settings, evaluation criteria, and operational flags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_id: Configuration identifier (UUID or integer) to query
|
||||||
|
current_user: Authenticated user making the request
|
||||||
|
db: Database session for data operations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Success response with detailed reflection configuration:
|
||||||
|
- config_id: Resolved configuration identifier
|
||||||
|
- reflection_enabled: Whether reflection is enabled for this config
|
||||||
|
- reflection_period_in_hours: Reflection execution interval
|
||||||
|
- reflexion_range: Scope of reflection operations (partial/all)
|
||||||
|
- baseline: Reflection strategy (time/fact/hybrid)
|
||||||
|
- reflection_model_id: LLM model identifier for reflection
|
||||||
|
- memory_verify: Memory verification flag
|
||||||
|
- quality_assessment: Quality assessment flag
|
||||||
|
|
||||||
|
Database Operations:
|
||||||
|
- Queries memory_config table by resolved config_id
|
||||||
|
- Retrieves all reflection-related configuration fields
|
||||||
|
- Resolves configuration ID for consistent formatting
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 404: If configuration with specified ID is not found
|
||||||
|
HTTPException 500: If configuration query operation fails
|
||||||
|
|
||||||
|
ID Resolution:
|
||||||
|
- Supports both UUID and integer config_id formats
|
||||||
|
- Automatically resolves to appropriate internal format
|
||||||
|
- Maintains consistency across different ID representations
|
||||||
|
"""
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
try:
|
try:
|
||||||
config_id=resolve_config_id(config_id,db)
|
config_id=resolve_config_id(config_id,db)
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
memory_config_id = resolve_config_id(result.config_id, db)
|
memory_config_id = resolve_config_id(result.config_id, db)
|
||||||
# 构建返回数据
|
|
||||||
|
# Build response data with comprehensive configuration details
|
||||||
reflection_config = {
|
reflection_config = {
|
||||||
"config_id": memory_config_id,
|
"config_id": memory_config_id,
|
||||||
"reflection_enabled": result.enable_self_reflexion,
|
"reflection_enabled": result.enable_self_reflexion,
|
||||||
@@ -205,9 +338,11 @@ async def start_reflection_configs(
|
|||||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||||
return success(data=reflection_config, msg="反思配置查询成功")
|
return success(data=reflection_config, msg="反思配置查询成功")
|
||||||
|
|
||||||
|
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
|
||||||
|
return success(data=reflection_config, msg="Reflection configuration query successful")
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
# 重新抛出HTTP异常
|
# Re-raise HTTP exceptions without modification
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"查询反思配置失败: {str(e)}")
|
api_logger.error(f"查询反思配置失败: {str(e)}")
|
||||||
@@ -223,13 +358,66 @@ async def reflection_run(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Activate the reflection function for all matching applications in the workspace"""
|
"""
|
||||||
# 使用集中化的语言校验
|
Execute reflection engine with specified configuration
|
||||||
|
|
||||||
|
Runs the reflection engine using configuration parameters from the database.
|
||||||
|
Validates model availability, sets up the reflection engine with proper
|
||||||
|
configuration, and executes the reflection process with multi-language support.
|
||||||
|
|
||||||
|
This endpoint provides a test run capability for reflection configurations,
|
||||||
|
allowing users to validate their reflection settings and see results before
|
||||||
|
deploying to production environments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_id: Configuration identifier (UUID or integer) for reflection settings
|
||||||
|
language_type: Language preference header for output localization (optional)
|
||||||
|
current_user: Authenticated user executing the reflection
|
||||||
|
db: Database session for configuration queries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Success response with reflection execution results including:
|
||||||
|
- baseline: Reflection strategy used
|
||||||
|
- source_data: Input data processed
|
||||||
|
- memory_verifies: Memory verification results (if enabled)
|
||||||
|
- quality_assessments: Quality assessment results (if enabled)
|
||||||
|
- reflexion_data: Generated reflection insights and solutions
|
||||||
|
|
||||||
|
Configuration Validation:
|
||||||
|
- Verifies configuration exists in database
|
||||||
|
- Validates LLM model availability
|
||||||
|
- Falls back to default model if specified model is unavailable
|
||||||
|
- Ensures all required parameters are properly set
|
||||||
|
|
||||||
|
Reflection Engine Setup:
|
||||||
|
- Creates ReflectionConfig with database parameters
|
||||||
|
- Initializes Neo4j connector for memory access
|
||||||
|
- Sets up ReflectionEngine with validated model
|
||||||
|
- Configures language preferences for output
|
||||||
|
|
||||||
|
Error Handling:
|
||||||
|
- Model validation with fallback to default
|
||||||
|
- Configuration validation and error reporting
|
||||||
|
- Comprehensive logging for debugging
|
||||||
|
- Graceful handling of missing configurations
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 404: If configuration is not found
|
||||||
|
HTTPException 500: If reflection execution fails
|
||||||
|
|
||||||
|
Performance Notes:
|
||||||
|
- Direct database query for configuration retrieval
|
||||||
|
- Model validation to prevent runtime failures
|
||||||
|
- Efficient reflection engine initialization
|
||||||
|
- Language-aware output processing
|
||||||
|
"""
|
||||||
|
# Use centralized language validation for consistent localization
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
# 使用MemoryConfigRepository查询反思配置
|
|
||||||
|
# Query reflection configuration using MemoryConfigRepository
|
||||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
if not result:
|
if not result:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -239,7 +427,7 @@ async def reflection_run(
|
|||||||
|
|
||||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||||
|
|
||||||
# 验证模型ID是否存在
|
# Validate model ID existence
|
||||||
model_id = result.reflection_model_id
|
model_id = result.reflection_model_id
|
||||||
if model_id:
|
if model_id:
|
||||||
try:
|
try:
|
||||||
@@ -250,6 +438,7 @@ async def reflection_run(
|
|||||||
# 可以设置为None,让反思引擎使用默认模型
|
# 可以设置为None,让反思引擎使用默认模型
|
||||||
model_id = None
|
model_id = None
|
||||||
|
|
||||||
|
# Create reflection configuration with database parameters
|
||||||
config = ReflectionConfig(
|
config = ReflectionConfig(
|
||||||
enabled=result.enable_self_reflexion,
|
enabled=result.enable_self_reflexion,
|
||||||
iteration_period=result.iteration_period,
|
iteration_period=result.iteration_period,
|
||||||
@@ -262,11 +451,13 @@ async def reflection_run(
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
language_type=language_type
|
language_type=language_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize Neo4j connector and reflection engine
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
engine = ReflectionEngine(
|
engine = ReflectionEngine(
|
||||||
config=config,
|
config=config,
|
||||||
neo4j_connector=connector,
|
neo4j_connector=connector,
|
||||||
llm_client=model_id # 传入验证后的 model_id
|
llm_client=model_id # Pass validated model_id
|
||||||
)
|
)
|
||||||
|
|
||||||
result=await (engine.reflection_run())
|
result=await (engine.reflection_run())
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
"""
|
||||||
|
Memory Short Term Controller
|
||||||
|
|
||||||
|
This module provides REST API endpoints for managing short-term and long-term memory
|
||||||
|
data retrieval and analysis. It handles memory system statistics, data aggregation,
|
||||||
|
and provides comprehensive memory insights for end users.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Short-term memory data retrieval and statistics
|
||||||
|
- Long-term memory data aggregation
|
||||||
|
- Entity count integration
|
||||||
|
- Multi-language response support
|
||||||
|
- Memory system analytics and reporting
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -13,9 +28,13 @@ from app.models.user_model import User
|
|||||||
from app.services.memory_short_service import LongService, ShortService
|
from app.services.memory_short_service import LongService, ShortService
|
||||||
from app.services.memory_storage_service import search_entity
|
from app.services.memory_storage_service import search_entity
|
||||||
|
|
||||||
|
# Load environment variables for configuration
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# Initialize API logger for request tracking and debugging
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
# Configure router with prefix and tags for API organization
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/memory/short",
|
prefix="/memory/short",
|
||||||
tags=["Memory"],
|
tags=["Memory"],
|
||||||
@@ -27,24 +46,73 @@ async def short_term_configs(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
# 使用集中化的语言校验
|
"""
|
||||||
|
Retrieve comprehensive short-term and long-term memory statistics
|
||||||
|
|
||||||
|
Provides a comprehensive overview of memory system data for a specific end user,
|
||||||
|
including short-term memory entries, long-term memory aggregations, entity counts,
|
||||||
|
and retrieval statistics. Supports multi-language responses based on request headers.
|
||||||
|
|
||||||
|
This endpoint serves as a central dashboard for memory system analytics, combining
|
||||||
|
data from multiple memory subsystems to provide a holistic view of user memory state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: Unique identifier for the end user whose memory data to retrieve
|
||||||
|
language_type: Language preference header for response localization (optional)
|
||||||
|
current_user: Authenticated user making the request (injected by dependency)
|
||||||
|
db: Database session for data operations (injected by dependency)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Success response containing comprehensive memory statistics:
|
||||||
|
- short_term: List of short-term memory entries with detailed data
|
||||||
|
- long_term: List of long-term memory aggregations and summaries
|
||||||
|
- entity: Count of entities associated with the end user
|
||||||
|
- retrieval_number: Total count of short-term memory retrievals
|
||||||
|
- long_term_number: Total count of long-term memory entries
|
||||||
|
|
||||||
|
Response Structure:
|
||||||
|
{
|
||||||
|
"code": 200,
|
||||||
|
"msg": "Short-term memory system data retrieved successfully",
|
||||||
|
"data": {
|
||||||
|
"short_term": [...], # Short-term memory entries
|
||||||
|
"long_term": [...], # Long-term memory data
|
||||||
|
"entity": 42, # Entity count
|
||||||
|
"retrieval_number": 156, # Short-term retrieval count
|
||||||
|
"long_term_number": 23 # Long-term memory count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If end_user_id is invalid or data retrieval fails
|
||||||
|
|
||||||
|
Performance Notes:
|
||||||
|
- Combines multiple service calls for comprehensive data
|
||||||
|
- Entity search is performed asynchronously for better performance
|
||||||
|
- Response time depends on memory data volume for the specified user
|
||||||
|
"""
|
||||||
|
# Use centralized language validation for consistent localization
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
# 获取短期记忆数据
|
# Retrieve short-term memory data and statistics
|
||||||
short_term=ShortService(end_user_id, db)
|
short_term = ShortService(end_user_id, db)
|
||||||
short_result=short_term.get_short_databasets()
|
short_result = short_term.get_short_databasets() # Get short-term memory entries
|
||||||
short_count=short_term.get_short_count()
|
short_count = short_term.get_short_count() # Get short-term retrieval count
|
||||||
|
|
||||||
long_term=LongService(end_user_id, db)
|
# Retrieve long-term memory data and aggregations
|
||||||
long_result=long_term.get_long_databasets()
|
long_term = LongService(end_user_id, db)
|
||||||
|
long_result = long_term.get_long_databasets() # Get long-term memory entries
|
||||||
|
|
||||||
|
# Get entity count for the specified end user
|
||||||
entity_result = await search_entity(end_user_id)
|
entity_result = await search_entity(end_user_id)
|
||||||
|
|
||||||
|
# Compile comprehensive memory statistics response
|
||||||
result = {
|
result = {
|
||||||
'short_term': short_result,
|
'short_term': short_result, # Short-term memory entries
|
||||||
'long_term': long_result,
|
'long_term': long_result, # Long-term memory data
|
||||||
'entity': entity_result.get('num', 0),
|
'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
|
||||||
"retrieval_number":short_count,
|
"retrieval_number": short_count, # Short-term retrieval statistics
|
||||||
"long_term_number":len(long_result)
|
"long_term_number": len(long_result) # Long-term memory entry count
|
||||||
}
|
}
|
||||||
|
|
||||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||||
@@ -26,7 +26,7 @@ from app.services.memory_storage_service import (
|
|||||||
analytics_hot_memory_tags,
|
analytics_hot_memory_tags,
|
||||||
analytics_recent_activity_stats,
|
analytics_recent_activity_stats,
|
||||||
kb_type_distribution,
|
kb_type_distribution,
|
||||||
search_all,
|
search_all_batch,
|
||||||
search_chunk,
|
search_chunk,
|
||||||
search_detials,
|
search_detials,
|
||||||
search_dialogue,
|
search_dialogue,
|
||||||
@@ -34,6 +34,7 @@ from app.services.memory_storage_service import (
|
|||||||
search_entity,
|
search_entity,
|
||||||
search_statement,
|
search_statement,
|
||||||
)
|
)
|
||||||
|
from app.core.quota_stub import check_memory_engine_quota
|
||||||
from fastapi import APIRouter, Depends, Header
|
from fastapi import APIRouter, Depends, Header
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -54,8 +55,8 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/info", response_model=ApiResponse)
|
@router.get("/info", response_model=ApiResponse)
|
||||||
async def get_storage_info(
|
async def get_storage_info(
|
||||||
storage_id: str,
|
storage_id: str,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Example wrapper endpoint - retrieves storage information
|
Example wrapper endpoint - retrieves storage information
|
||||||
@@ -75,17 +76,13 @@ async def get_storage_info(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||||
|
@check_memory_engine_quota
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
|
||||||
def create_config(
|
def create_config(
|
||||||
payload: ConfigParamsCreate,
|
payload: ConfigParamsCreate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
@@ -107,9 +104,11 @@ def create_config(
|
|||||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||||
lang = get_language_from_header(x_language_type)
|
lang = get_language_from_header(x_language_type)
|
||||||
if lang == "en":
|
if lang == "en":
|
||||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||||
|
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
else:
|
else:
|
||||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||||
|
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||||
return JSONResponse(status_code=400, content=msg)
|
return JSONResponse(status_code=400, content=msg)
|
||||||
api_logger.error(f"Create config failed: {err_str}")
|
api_logger.error(f"Create config failed: {err_str}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||||
@@ -119,9 +118,11 @@ def create_config(
|
|||||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||||
lang = get_language_from_header(x_language_type)
|
lang = get_language_from_header(x_language_type)
|
||||||
if lang == "en":
|
if lang == "en":
|
||||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||||
|
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
else:
|
else:
|
||||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||||
|
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||||
return JSONResponse(status_code=400, content=msg)
|
return JSONResponse(status_code=400, content=msg)
|
||||||
api_logger.error(f"Create config failed: {str(e)}")
|
api_logger.error(f"Create config failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||||
@@ -129,10 +130,10 @@ def create_config(
|
|||||||
|
|
||||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||||
def delete_config(
|
def delete_config(
|
||||||
config_id: UUID|int,
|
config_id: UUID | int,
|
||||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""删除记忆配置(带终端用户保护)
|
"""删除记忆配置(带终端用户保护)
|
||||||
|
|
||||||
@@ -145,7 +146,7 @@ def delete_config(
|
|||||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
config_id=resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||||
@@ -203,9 +204,9 @@ def delete_config(
|
|||||||
|
|
||||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||||
def update_config(
|
def update_config(
|
||||||
payload: ConfigUpdate,
|
payload: ConfigUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
@@ -217,7 +218,8 @@ def update_config(
|
|||||||
# 校验至少有一个字段需要更新
|
# 校验至少有一个字段需要更新
|
||||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
|
||||||
|
"config_name, config_desc, scene_id 均为空")
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||||
try:
|
try:
|
||||||
@@ -231,9 +233,9 @@ def update_config(
|
|||||||
|
|
||||||
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
||||||
def update_config_extracted(
|
def update_config_extracted(
|
||||||
payload: ConfigUpdateExtracted,
|
payload: ConfigUpdateExtracted,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
@@ -256,11 +258,11 @@ def update_config_extracted(
|
|||||||
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
||||||
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
||||||
|
|
||||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||||
def read_config_extracted(
|
def read_config_extracted(
|
||||||
config_id: UUID | int,
|
config_id: UUID | int,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
@@ -278,10 +280,11 @@ def read_config_extracted(
|
|||||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||||
|
|
||||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
|
||||||
|
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||||
def read_all_config(
|
def read_all_config(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -303,10 +306,10 @@ def read_all_config(
|
|||||||
|
|
||||||
@router.post("/pilot_run", response_model=None)
|
@router.post("/pilot_run", response_model=None)
|
||||||
async def pilot_run(
|
async def pilot_run(
|
||||||
payload: ConfigPilotRun,
|
payload: ConfigPilotRun,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
# 使用集中化的语言校验
|
# 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
@@ -333,9 +336,9 @@ async def pilot_run(
|
|||||||
|
|
||||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||||
async def get_kb_type_distribution(
|
async def get_kb_type_distribution(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await kb_type_distribution(end_user_id)
|
result = await kb_type_distribution(end_user_id)
|
||||||
@@ -347,9 +350,9 @@ async def get_kb_type_distribution(
|
|||||||
|
|
||||||
@router.get("/search/dialogue", response_model=ApiResponse)
|
@router.get("/search/dialogue", response_model=ApiResponse)
|
||||||
async def search_dialogues_num(
|
async def search_dialogues_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_dialogue(end_user_id)
|
result = await search_dialogue(end_user_id)
|
||||||
@@ -361,9 +364,9 @@ async def search_dialogues_num(
|
|||||||
|
|
||||||
@router.get("/search/chunk", response_model=ApiResponse)
|
@router.get("/search/chunk", response_model=ApiResponse)
|
||||||
async def search_chunks_num(
|
async def search_chunks_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_chunk(end_user_id)
|
result = await search_chunk(end_user_id)
|
||||||
@@ -375,9 +378,9 @@ async def search_chunks_num(
|
|||||||
|
|
||||||
@router.get("/search/statement", response_model=ApiResponse)
|
@router.get("/search/statement", response_model=ApiResponse)
|
||||||
async def search_statements_num(
|
async def search_statements_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_statement(end_user_id)
|
result = await search_statement(end_user_id)
|
||||||
@@ -389,9 +392,9 @@ async def search_statements_num(
|
|||||||
|
|
||||||
@router.get("/search/entity", response_model=ApiResponse)
|
@router.get("/search/entity", response_model=ApiResponse)
|
||||||
async def search_entities_num(
|
async def search_entities_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_entity(end_user_id)
|
result = await search_entity(end_user_id)
|
||||||
@@ -403,12 +406,15 @@ async def search_entities_num(
|
|||||||
|
|
||||||
@router.get("/search", response_model=ApiResponse)
|
@router.get("/search", response_model=ApiResponse)
|
||||||
async def search_all_num(
|
async def search_all_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_all(end_user_id)
|
if not end_user_id:
|
||||||
|
return success(data={"total": 0}, msg="查询成功")
|
||||||
|
batch_result = await search_all_batch([end_user_id])
|
||||||
|
result = {"total": batch_result.get(end_user_id, 0)}
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Search all failed: {str(e)}")
|
api_logger.error(f"Search all failed: {str(e)}")
|
||||||
@@ -417,9 +423,9 @@ async def search_all_num(
|
|||||||
|
|
||||||
@router.get("/search/detials", response_model=ApiResponse)
|
@router.get("/search/detials", response_model=ApiResponse)
|
||||||
async def search_entities_detials(
|
async def search_entities_detials(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_detials(end_user_id)
|
result = await search_detials(end_user_id)
|
||||||
@@ -431,9 +437,9 @@ async def search_entities_detials(
|
|||||||
|
|
||||||
@router.get("/search/edges", response_model=ApiResponse)
|
@router.get("/search/edges", response_model=ApiResponse)
|
||||||
async def search_entity_edges(
|
async def search_entity_edges(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_edges(end_user_id)
|
result = await search_edges(end_user_id)
|
||||||
@@ -443,14 +449,12 @@ async def search_entity_edges(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||||
async def get_hot_memory_tags_api(
|
async def get_hot_memory_tags_api(
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取热门记忆标签(带Redis缓存)
|
获取热门记忆标签(带Redis缓存)
|
||||||
|
|
||||||
@@ -505,8 +509,8 @@ async def get_hot_memory_tags_api(
|
|||||||
|
|
||||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||||
async def clear_hot_memory_tags_cache(
|
async def clear_hot_memory_tags_cache(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
清除热门标签缓存
|
清除热门标签缓存
|
||||||
|
|
||||||
@@ -543,7 +547,7 @@ async def clear_hot_memory_tags_cache(
|
|||||||
|
|
||||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||||
async def get_recent_activity_stats_api(
|
async def get_recent_activity_stats_api(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||||
@@ -553,4 +557,3 @@ async def get_recent_activity_stats_api(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from app.core.response_utils import success
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models import User
|
from app.models import User
|
||||||
|
from app.schemas import conversation_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
|
|
||||||
@@ -32,35 +33,47 @@ def get_memory_count(
|
|||||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||||
def get_conversations(
|
def get_conversations(
|
||||||
end_user_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 20,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Retrieve all conversations for the current user in a specific group.
|
Retrieve conversations for the current user in a specific group with pagination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id (UUID): The group identifier.
|
end_user_id (UUID): The group identifier.
|
||||||
|
page (int): Page number (1-based). Defaults to 1.
|
||||||
|
pagesize (int): Number of items per page. Defaults to 20.
|
||||||
current_user (User, optional): The authenticated user.
|
current_user (User, optional): The authenticated user.
|
||||||
db (Session, optional): SQLAlchemy session.
|
db (Session, optional): SQLAlchemy session.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ApiResponse: Contains a list of conversation IDs.
|
ApiResponse: Contains a paginated list of conversations.
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Initializes the ConversationService with the current DB session.
|
|
||||||
- Returns only conversation IDs for lightweight response.
|
|
||||||
- Logs can be added to trace requests in production.
|
|
||||||
"""
|
"""
|
||||||
|
page = max(1, page)
|
||||||
|
page_size = max(1, min(pagesize, 100)) # Limit page size between 1 and 100
|
||||||
conversation_service = ConversationService(db)
|
conversation_service = ConversationService(db)
|
||||||
conversations = conversation_service.get_user_conversations(
|
conversations, total = conversation_service.get_user_conversations(
|
||||||
end_user_id
|
end_user_id,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
)
|
)
|
||||||
return success(data=[
|
return success(data={
|
||||||
{
|
"items": [
|
||||||
"id": conversation.id,
|
{
|
||||||
"title": conversation.title
|
"id": conversation.id,
|
||||||
} for conversation in conversations
|
"title": conversation.title
|
||||||
], msg="get conversations success")
|
} for conversation in conversations
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": page_size,
|
||||||
|
"total": total,
|
||||||
|
"hasnext": (page * page_size) < total
|
||||||
|
},
|
||||||
|
}, msg="get conversations success")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||||
@@ -90,11 +103,7 @@ def get_messages(
|
|||||||
conversation_id,
|
conversation_id,
|
||||||
)
|
)
|
||||||
messages = [
|
messages = [
|
||||||
{
|
conversation_schema.Message.model_validate(message)
|
||||||
"role": message.role,
|
|
||||||
"content": message.content,
|
|
||||||
"created_at": int(message.created_at.timestamp() * 1000),
|
|
||||||
}
|
|
||||||
for message in messages_obj
|
for message in messages_obj
|
||||||
]
|
]
|
||||||
return success(data=messages, msg="get conversation history success")
|
return success(data=messages, msg="get conversation history success")
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from app.core.response_utils import success
|
|||||||
from app.schemas.response_schema import ApiResponse, PageData
|
from app.schemas.response_schema import ApiResponse, PageData
|
||||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.quota_stub import check_model_quota, check_model_activation_quota
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -42,6 +43,7 @@ def get_model_strategies():
|
|||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||||
|
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
|
||||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||||
@@ -74,10 +76,21 @@ def get_model_list(
|
|||||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||||
|
|
||||||
|
capability_list = []
|
||||||
|
if capability is not None:
|
||||||
|
flat_capability = []
|
||||||
|
for item in capability:
|
||||||
|
split_items = [c.strip() for c in item.split(', ') if c.strip()]
|
||||||
|
flat_capability.extend(split_items)
|
||||||
|
|
||||||
|
unique_flat_capability = list(dict.fromkeys(flat_capability))
|
||||||
|
capability_list = unique_flat_capability
|
||||||
|
|
||||||
api_logger.error(f"获取模型type_list: {type_list}")
|
api_logger.error(f"获取模型type_list: {type_list}")
|
||||||
query = model_schema.ModelConfigQuery(
|
query = model_schema.ModelConfigQuery(
|
||||||
type=type_list,
|
type=type_list,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
|
capability=capability_list,
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
is_public=is_public,
|
is_public=is_public,
|
||||||
search=search,
|
search=search,
|
||||||
@@ -291,6 +304,7 @@ async def create_model(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/composite", response_model=ApiResponse)
|
@router.post("/composite", response_model=ApiResponse)
|
||||||
|
@check_model_quota
|
||||||
async def create_composite_model(
|
async def create_composite_model(
|
||||||
model_data: model_schema.CompositeModelCreate,
|
model_data: model_schema.CompositeModelCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -317,6 +331,7 @@ async def create_composite_model(
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||||
|
@check_model_activation_quota
|
||||||
async def update_composite_model(
|
async def update_composite_model(
|
||||||
model_id: uuid.UUID,
|
model_id: uuid.UUID,
|
||||||
model_data: model_schema.CompositeModelCreate,
|
model_data: model_schema.CompositeModelCreate,
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, H
|
|||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.quota_stub import check_ontology_project_quota
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
@@ -163,6 +165,7 @@ def _get_ontology_service(
|
|||||||
api_key=api_key_config.api_key,
|
api_key=api_key_config.api_key,
|
||||||
base_url=api_key_config.api_base,
|
base_url=api_key_config.api_base,
|
||||||
is_omni=api_key_config.is_omni,
|
is_omni=api_key_config.is_omni,
|
||||||
|
capability=api_key_config.capability,
|
||||||
max_retries=3,
|
max_retries=3,
|
||||||
timeout=60.0
|
timeout=60.0
|
||||||
)
|
)
|
||||||
@@ -286,6 +289,7 @@ async def extract_ontology(
|
|||||||
# ==================== 本体场景管理接口 ====================
|
# ==================== 本体场景管理接口 ====================
|
||||||
|
|
||||||
@router.post("/scene", response_model=ApiResponse)
|
@router.post("/scene", response_model=ApiResponse)
|
||||||
|
@check_ontology_project_quota
|
||||||
async def create_scene(
|
async def create_scene(
|
||||||
request: SceneCreateRequest,
|
request: SceneCreateRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
|||||||
@@ -124,10 +124,11 @@ async def get_prompt_opt(
|
|||||||
skill=data.skill
|
skill=data.skill
|
||||||
):
|
):
|
||||||
# chunk 是 prompt 的增量内容
|
# chunk 是 prompt 的增量内容
|
||||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
yield f"event:message\ndata: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield f"event:error\ndata: {json.dumps(
|
yield f"event:error\ndata: {json.dumps(
|
||||||
{"error": str(e)}
|
{"error": str(e)},
|
||||||
|
ensure_ascii=False
|
||||||
)}\n\n"
|
)}\n\n"
|
||||||
yield "event:end\ndata: {}\n\n"
|
yield "event:end\ndata: {}\n\n"
|
||||||
|
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ from sqlalchemy.orm import Session
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.quota_manager import check_end_user_quota
|
||||||
from app.core.response_utils import success, fail
|
from app.core.response_utils import success, fail
|
||||||
from app.db import get_db, get_db_read
|
from app.db import get_db, get_db_read
|
||||||
from app.dependencies import get_share_user_id, ShareTokenData
|
from app.dependencies import get_share_user_id, ShareTokenData
|
||||||
from app.models.app_model import App
|
|
||||||
from app.models.app_model import AppType
|
from app.models.app_model import AppType
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
@@ -22,11 +22,13 @@ from app.schemas import release_share_schema, conversation_schema
|
|||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.services import workspace_service
|
from app.services import workspace_service
|
||||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||||
|
from app.services.app_service import AppService
|
||||||
from app.services.auth_service import create_access_token
|
from app.services.auth_service import create_access_token
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.release_share_service import ReleaseShareService
|
from app.services.release_share_service import ReleaseShareService
|
||||||
from app.services.shared_chat_service import SharedChatService
|
from app.services.shared_chat_service import SharedChatService
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.workflow_service import WorkflowService
|
||||||
|
from app.models.file_metadata_model import FileMetadata
|
||||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||||
|
|
||||||
@@ -215,8 +217,22 @@ def list_conversations(
|
|||||||
service = SharedChatService(db)
|
service = SharedChatService(db)
|
||||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
|
app_service = AppService(db)
|
||||||
|
app = app_service._get_app_or_404(share.app_id)
|
||||||
|
workspace_id = app.workspace_id
|
||||||
|
|
||||||
|
# 仅在新建终端用户时检查配额
|
||||||
|
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||||
|
if existing_end_user is None:
|
||||||
|
from app.core.quota_manager import _check_quota
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||||
|
if ws:
|
||||||
|
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||||
|
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=share.app_id,
|
app_id=share.app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
other_id=other_id
|
other_id=other_id
|
||||||
)
|
)
|
||||||
logger.debug(new_end_user.id)
|
logger.debug(new_end_user.id)
|
||||||
@@ -256,8 +272,41 @@ def get_conversation(
|
|||||||
conv_service = ConversationService(db)
|
conv_service = ConversationService(db)
|
||||||
messages = conv_service.get_messages(conversation_id)
|
messages = conv_service.get_messages(conversation_id)
|
||||||
|
|
||||||
# 构建响应
|
file_ids = []
|
||||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
|
message_file_id_map = {}
|
||||||
|
|
||||||
|
# 第一次遍历:解析 audio_url,收集所有有效的 file_id
|
||||||
|
for idx, m in enumerate(messages):
|
||||||
|
if m.role == "assistant" and m.meta_data:
|
||||||
|
audio_url = m.meta_data.get("audio_url")
|
||||||
|
if not audio_url:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# audio_url 无法解析为 UUID,标记为 unknown
|
||||||
|
m.meta_data["audio_status"] = "unknown"
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_ids.append(file_id)
|
||||||
|
message_file_id_map[idx] = file_id
|
||||||
|
|
||||||
|
# 批量查询所有相关的 FileMetadata
|
||||||
|
file_status_map = {}
|
||||||
|
if file_ids:
|
||||||
|
file_metas = (
|
||||||
|
db.query(FileMetadata)
|
||||||
|
.filter(FileMetadata.id.in_(set(file_ids)))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
file_status_map = {fm.id: fm.status for fm in file_metas}
|
||||||
|
|
||||||
|
# 第二次遍历:将查询结果映射回消息
|
||||||
|
for idx, file_id in message_file_id_map.items():
|
||||||
|
m = messages[idx]
|
||||||
|
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
|
||||||
|
|
||||||
|
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
|
||||||
conv_dict["messages"] = [
|
conv_dict["messages"] = [
|
||||||
conversation_schema.Message.model_validate(m) for m in messages
|
conversation_schema.Message.model_validate(m) for m in messages
|
||||||
]
|
]
|
||||||
@@ -308,25 +357,51 @@ async def chat(
|
|||||||
|
|
||||||
# Store end_user_id in database with original user_id
|
# Store end_user_id in database with original user_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
|
app_service = AppService(db)
|
||||||
|
app = app_service._get_app_or_404(share.app_id)
|
||||||
|
workspace_id = app.workspace_id
|
||||||
|
|
||||||
|
# 仅在新建终端用户时检查配额,已有用户复用不受限制
|
||||||
|
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||||
|
logger.info(f"终端用户配额检查: workspace_id={workspace_id}, other_id={other_id}, existing={existing_end_user is not None}")
|
||||||
|
if existing_end_user is None:
|
||||||
|
from app.core.quota_manager import _check_quota
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||||
|
if ws:
|
||||||
|
logger.info(f"新终端用户,执行配额检查: tenant_id={ws.tenant_id}")
|
||||||
|
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||||
|
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=share.app_id,
|
app_id=share.app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
original_user_id=user_id # Save original user_id to other_id
|
original_user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Only extract and set memory_config_id when the end user doesn't have one yet
|
||||||
|
if not new_end_user.memory_config_id:
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
memory_config_service = MemoryConfigService(db)
|
||||||
|
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
|
||||||
|
if memory_config_id:
|
||||||
|
new_end_user.memory_config_id = memory_config_id
|
||||||
|
db.commit()
|
||||||
|
db.refresh(new_end_user)
|
||||||
end_user_id = str(new_end_user.id)
|
end_user_id = str(new_end_user.id)
|
||||||
|
|
||||||
appid = share.app_id
|
# appid = share.app_id
|
||||||
"""获取存储类型和工作空间的ID"""
|
"""获取存储类型和工作空间的ID"""
|
||||||
|
|
||||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||||
app = db.query(App).filter(
|
# app = db.query(App).filter(
|
||||||
App.id == appid,
|
# App.id == appid,
|
||||||
App.is_active.is_(True)
|
# App.is_active.is_(True)
|
||||||
).first()
|
# ).first()
|
||||||
if not app:
|
# if not app:
|
||||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
# raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||||
|
|
||||||
workspace_id = app.workspace_id
|
# workspace_id = app.workspace_id
|
||||||
|
|
||||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||||
@@ -402,31 +477,10 @@ async def chat(
|
|||||||
# 流式返回
|
# 流式返回
|
||||||
agent_config = agent_config_4_app_release(release)
|
agent_config = agent_config_4_app_release(release)
|
||||||
|
|
||||||
if payload.stream:
|
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
||||||
# async def event_generator():
|
agent_config.model_parameters["deep_thinking"] = False
|
||||||
# async for event in service.chat_stream(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# ):
|
|
||||||
# yield event
|
|
||||||
|
|
||||||
# return StreamingResponse(
|
if payload.stream:
|
||||||
# event_generator(),
|
|
||||||
# media_type="text/event-stream",
|
|
||||||
# headers={
|
|
||||||
# "Cache-Control": "no-cache",
|
|
||||||
# "Connection": "keep-alive",
|
|
||||||
# "X-Accel-Buffering": "no"
|
|
||||||
# }
|
|
||||||
# )
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for event in app_chat_service.agnet_chat_stream(
|
async for event in app_chat_service.agnet_chat_stream(
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -452,20 +506,6 @@ async def chat(
|
|||||||
"X-Accel-Buffering": "no"
|
"X-Accel-Buffering": "no"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# 非流式返回
|
|
||||||
# result = await service.chat(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# )
|
|
||||||
# return success(data=conversation_schema.ChatResponse(**result))
|
|
||||||
result = await app_chat_service.agnet_chat(
|
result = await app_chat_service.agnet_chat(
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
@@ -524,48 +564,6 @@ async def chat(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
# 多 Agent 流式返回
|
|
||||||
# if payload.stream:
|
|
||||||
# async def event_generator():
|
|
||||||
# async for event in service.multi_agent_chat_stream(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# ):
|
|
||||||
# yield event
|
|
||||||
|
|
||||||
# return StreamingResponse(
|
|
||||||
# event_generator(),
|
|
||||||
# media_type="text/event-stream",
|
|
||||||
# headers={
|
|
||||||
# "Cache-Control": "no-cache",
|
|
||||||
# "Connection": "keep-alive",
|
|
||||||
# "X-Accel-Buffering": "no"
|
|
||||||
# }
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # 多 Agent 非流式返回
|
|
||||||
# result = await service.multi_agent_chat(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return success(data=conversation_schema.ChatResponse(**result))
|
|
||||||
elif app_type == AppType.WORKFLOW:
|
elif app_type == AppType.WORKFLOW:
|
||||||
config = workflow_config_4_app_release(release)
|
config = workflow_config_4_app_release(release)
|
||||||
if not config.id:
|
if not config.id:
|
||||||
@@ -610,11 +608,11 @@ async def chat(
|
|||||||
|
|
||||||
# 多 Agent 非流式返回
|
# 多 Agent 非流式返回
|
||||||
result = await app_chat_service.workflow_chat(
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=end_user_id, # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
|
files=payload.files,
|
||||||
config=config,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
@@ -654,17 +652,23 @@ async def config_query(
|
|||||||
workflow_service = WorkflowService(db)
|
workflow_service = WorkflowService(db)
|
||||||
content = {
|
content = {
|
||||||
"app_type": release.app.type,
|
"app_type": release.app.type,
|
||||||
"variables": workflow_service.get_start_node_variables(release.config)
|
"variables": workflow_service.get_start_node_variables(release.config),
|
||||||
|
"memory": workflow_service.is_memory_enable(release.config),
|
||||||
|
"features": release.config.get("features")
|
||||||
}
|
}
|
||||||
elif release.app.type == AppType.AGENT:
|
elif release.app.type == AppType.AGENT:
|
||||||
content = {
|
content = {
|
||||||
"app_type": release.app.type,
|
"app_type": release.app.type,
|
||||||
"variables": release.config.get("variables")
|
"variables": release.config.get("variables"),
|
||||||
|
"memory": release.config.get("memory", {}).get("enabled"),
|
||||||
|
"features": release.config.get("features"),
|
||||||
|
"model_parameters": release.config.get("model_parameters")
|
||||||
}
|
}
|
||||||
elif release.app.type == AppType.MULTI_AGENT:
|
elif release.app.type == AppType.MULTI_AGENT:
|
||||||
content = {
|
content = {
|
||||||
"app_type": release.app.type,
|
"app_type": release.app.type,
|
||||||
"variables": []
|
"variables": [],
|
||||||
|
"features": release.config.get("features")
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|||||||
@@ -4,7 +4,18 @@
|
|||||||
认证方式: API Key
|
认证方式: API Key
|
||||||
"""
|
"""
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
|
|
||||||
|
from . import (
|
||||||
|
app_api_controller,
|
||||||
|
end_user_api_controller,
|
||||||
|
memory_api_controller,
|
||||||
|
memory_config_api_controller,
|
||||||
|
rag_api_chunk_controller,
|
||||||
|
rag_api_document_controller,
|
||||||
|
rag_api_file_controller,
|
||||||
|
rag_api_knowledge_controller,
|
||||||
|
user_memory_api_controller,
|
||||||
|
)
|
||||||
|
|
||||||
# 创建 V1 API 路由器
|
# 创建 V1 API 路由器
|
||||||
service_router = APIRouter()
|
service_router = APIRouter()
|
||||||
@@ -16,5 +27,8 @@ service_router.include_router(rag_api_document_controller.router)
|
|||||||
service_router.include_router(rag_api_file_controller.router)
|
service_router.include_router(rag_api_file_controller.router)
|
||||||
service_router.include_router(rag_api_chunk_controller.router)
|
service_router.include_router(rag_api_chunk_controller.router)
|
||||||
service_router.include_router(memory_api_controller.router)
|
service_router.include_router(memory_api_controller.router)
|
||||||
|
service_router.include_router(end_user_api_controller.router)
|
||||||
|
service_router.include_router(memory_config_api_controller.router)
|
||||||
|
service_router.include_router(user_memory_api_controller.router)
|
||||||
|
|
||||||
__all__ = ["service_router"]
|
__all__ = ["service_router"]
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from app.core.response_utils import success
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.models.app_model import App
|
from app.models.app_model import App
|
||||||
from app.models.app_model import AppType
|
from app.models.app_model import AppType
|
||||||
|
from app.models.app_release_model import AppRelease
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
from app.schemas import AppChatRequest, conversation_schema
|
from app.schemas import AppChatRequest, conversation_schema
|
||||||
@@ -61,18 +62,18 @@ async def list_apps():
|
|||||||
# return success(data={"received": True}, msg="消息已接收")
|
# return success(data={"received": True}, msg="消息已接收")
|
||||||
|
|
||||||
|
|
||||||
def _checkAppConfig(app: App):
|
def _checkAppConfig(release: AppRelease):
|
||||||
if app.type == AppType.AGENT:
|
if release.type == AppType.AGENT:
|
||||||
if not app.current_release.config:
|
if not release.config:
|
||||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
elif app.type == AppType.MULTI_AGENT:
|
elif release.type == AppType.MULTI_AGENT:
|
||||||
if not app.current_release.config:
|
if not release.config:
|
||||||
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
elif app.type == AppType.WORKFLOW:
|
elif release.type == AppType.WORKFLOW:
|
||||||
if not app.current_release.config:
|
if not release.config:
|
||||||
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
else:
|
else:
|
||||||
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("不支持的应用类型", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat")
|
@router.post("/chat")
|
||||||
@@ -86,17 +87,39 @@ async def chat(
|
|||||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||||
message: str = Body(..., description="聊天消息内容"),
|
message: str = Body(..., description="聊天消息内容"),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Agent/Workflow 聊天接口
|
||||||
|
|
||||||
|
- 不传 version:使用当前生效版本(current_release,回滚后为回滚目标版本)
|
||||||
|
- 传 version=release_id:使用指定版本uuid的历史快照,例如 {"version": "{{release_id}}"}
|
||||||
|
"""
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
payload = AppChatRequest(**body)
|
payload = AppChatRequest(**body)
|
||||||
|
|
||||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
# 版本切换:指定 release_id 时查找对应历史快照,否则使用当前激活版本
|
||||||
|
if payload.version is not None:
|
||||||
|
active_release = app_service.get_release_by_id(app.id, payload.version)
|
||||||
|
else:
|
||||||
|
active_release = app.current_release
|
||||||
other_id = payload.user_id
|
other_id = payload.user_id
|
||||||
workspace_id = app.workspace_id
|
workspace_id = api_key_auth.workspace_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
|
|
||||||
|
# 仅在新建终端用户时检查配额,已有用户复用不受限制
|
||||||
|
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||||
|
if existing_end_user is None:
|
||||||
|
from app.core.quota_manager import _check_quota
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||||
|
if ws:
|
||||||
|
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||||
|
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
original_user_id=other_id # Save original user_id to other_id
|
|
||||||
)
|
)
|
||||||
end_user_id = str(new_end_user.id)
|
end_user_id = str(new_end_user.id)
|
||||||
web_search = True
|
web_search = True
|
||||||
@@ -127,7 +150,7 @@ async def chat(
|
|||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
app_type = app.type
|
app_type = app.type
|
||||||
# check app config
|
# check app config
|
||||||
_checkAppConfig(app)
|
_checkAppConfig(active_release)
|
||||||
|
|
||||||
# 获取或创建会话(提前验证)
|
# 获取或创建会话(提前验证)
|
||||||
conversation = conversation_service.create_or_get_conversation(
|
conversation = conversation_service.create_or_get_conversation(
|
||||||
@@ -142,8 +165,13 @@ async def chat(
|
|||||||
|
|
||||||
# print("="*50)
|
# print("="*50)
|
||||||
# print(app.current_release.default_model_config_id)
|
# print(app.current_release.default_model_config_id)
|
||||||
agent_config = agent_config_4_app_release(app.current_release)
|
agent_config = agent_config_4_app_release(active_release)
|
||||||
# print(agent_config.default_model_config_id)
|
# print(agent_config.default_model_config_id)
|
||||||
|
|
||||||
|
# thinking 开关:仅当 agent 配置了 deep_thinking 且请求 thinking=True 时才启用
|
||||||
|
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
||||||
|
agent_config.model_parameters["deep_thinking"] = False
|
||||||
|
|
||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
@@ -189,7 +217,7 @@ async def chat(
|
|||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
# 多 Agent 流式返回
|
# 多 Agent 流式返回
|
||||||
config = multi_agent_config_4_app_release(app.current_release)
|
config = multi_agent_config_4_app_release(active_release)
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for event in app_chat_service.multi_agent_chat_stream(
|
async for event in app_chat_service.multi_agent_chat_stream(
|
||||||
@@ -232,7 +260,7 @@ async def chat(
|
|||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.WORKFLOW:
|
elif app_type == AppType.WORKFLOW:
|
||||||
# 多 Agent 流式返回
|
# 多 Agent 流式返回
|
||||||
config = workflow_config_4_app_release(app.current_release)
|
config = workflow_config_4_app_release(active_release)
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for event in app_chat_service.workflow_chat_stream(
|
async for event in app_chat_service.workflow_chat_stream(
|
||||||
@@ -248,7 +276,7 @@ async def chat(
|
|||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
release_id=app.current_release.id,
|
release_id=active_release.id,
|
||||||
public=True
|
public=True
|
||||||
):
|
):
|
||||||
event_type = event.get("event", "message")
|
event_type = event.get("event", "message")
|
||||||
@@ -268,7 +296,7 @@ async def chat(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 多 Agent 非流式返回
|
# workflow 非流式返回
|
||||||
result = await app_chat_service.workflow_chat(
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -280,9 +308,10 @@ async def chat(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
files=payload.files,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
release_id=app.current_release.id
|
release_id=active_release.id
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"工作流试运行返回结果",
|
"工作流试运行返回结果",
|
||||||
@@ -296,6 +325,4 @@ async def chat(
|
|||||||
msg="工作流任务执行成功"
|
msg="工作流任务执行成功"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from app.core.exceptions import BusinessException
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|||||||
173
api/app/controllers/service/end_user_api_controller.py
Normal file
173
api/app/controllers/service/end_user_api_controller.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""End User 服务接口 - 基于 API Key 认证"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Depends, Request
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.controllers import user_memory_controllers
|
||||||
|
from app.core.api_key_auth import require_api_key
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.quota_stub import check_end_user_quota
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
|
from app.schemas.end_user_info_schema import EndUserInfoUpdate
|
||||||
|
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
|
||||||
|
from app.services import api_key_service
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
||||||
|
"""Build a current_user object from API key auth
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key_auth: Validated API key auth info
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User object with current_workspace_id set
|
||||||
|
"""
|
||||||
|
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||||
|
current_user = api_key.creator
|
||||||
|
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/create")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
@check_end_user_quota
|
||||||
|
async def create_end_user(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create or retrieve an end user for the workspace.
|
||||||
|
|
||||||
|
Creates a new end user and connects it to a memory configuration.
|
||||||
|
If an end user with the same other_id already exists in the workspace,
|
||||||
|
returns the existing one.
|
||||||
|
|
||||||
|
Optionally accepts a memory_config_id to connect the end user to a specific
|
||||||
|
memory configuration. If not provided, falls back to the workspace default config.
|
||||||
|
Optionally accepts an app_id to bind the end user to a specific app.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = CreateEndUserRequest(**body)
|
||||||
|
workspace_id = api_key_auth.workspace_id
|
||||||
|
|
||||||
|
logger.info("Create end user request - other_id: %s, workspace_id: %s", payload.other_id, workspace_id)
|
||||||
|
|
||||||
|
# Resolve memory_config_id: explicit > workspace default
|
||||||
|
memory_config_id = None
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
|
||||||
|
if payload.memory_config_id:
|
||||||
|
try:
|
||||||
|
memory_config_id = uuid.UUID(payload.memory_config_id)
|
||||||
|
except ValueError:
|
||||||
|
raise BusinessException(
|
||||||
|
f"Invalid memory_config_id format: {payload.memory_config_id}",
|
||||||
|
BizCode.INVALID_PARAMETER
|
||||||
|
)
|
||||||
|
config = config_service.get_config_with_fallback(memory_config_id, workspace_id)
|
||||||
|
if not config:
|
||||||
|
raise BusinessException(
|
||||||
|
f"Memory config not found: {payload.memory_config_id}",
|
||||||
|
BizCode.MEMORY_CONFIG_NOT_FOUND
|
||||||
|
)
|
||||||
|
memory_config_id = config.config_id
|
||||||
|
else:
|
||||||
|
default_config = config_service.get_workspace_default_config(workspace_id)
|
||||||
|
if default_config:
|
||||||
|
memory_config_id = default_config.config_id
|
||||||
|
logger.info(f"Using workspace default memory config: {memory_config_id}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"No default memory config found for workspace: {workspace_id}")
|
||||||
|
|
||||||
|
# Resolve app_id: explicit from payload, otherwise None
|
||||||
|
app_id = None
|
||||||
|
if payload.app_id:
|
||||||
|
try:
|
||||||
|
app_id = uuid.UUID(payload.app_id)
|
||||||
|
except ValueError:
|
||||||
|
raise BusinessException(
|
||||||
|
f"Invalid app_id format: {payload.app_id}",
|
||||||
|
BizCode.INVALID_PARAMETER
|
||||||
|
)
|
||||||
|
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
end_user = end_user_repo.get_or_create_end_user_with_config(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
other_id=payload.other_id,
|
||||||
|
memory_config_id=memory_config_id,
|
||||||
|
other_name=payload.other_name,
|
||||||
|
)
|
||||||
|
end_user.other_name = payload.other_name
|
||||||
|
logger.info(f"End user ready: {end_user.id}")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"id": str(end_user.id),
|
||||||
|
"other_id": end_user.other_id or "",
|
||||||
|
"other_name": end_user.other_name or "",
|
||||||
|
"workspace_id": str(end_user.workspace_id),
|
||||||
|
"memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/info")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_end_user_info(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get end user info.
|
||||||
|
|
||||||
|
Retrieves the info record (aliases, meta_data, etc.) for the specified end user.
|
||||||
|
Delegates to the manager-side controller for shared logic.
|
||||||
|
"""
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
return await user_memory_controllers.get_end_user_info(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/info/update")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_end_user_info(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update end user info.
|
||||||
|
|
||||||
|
Updates the info record (other_name, aliases, meta_data) for the specified end user.
|
||||||
|
Delegates to the manager-side controller for shared logic.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = EndUserInfoUpdate(**body)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
return await user_memory_controllers.update_end_user_info(
|
||||||
|
info_update=payload,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
@@ -1,49 +1,84 @@
|
|||||||
"""Memory 服务接口 - 基于 API Key 认证"""
|
"""Memory 服务接口 - 基于 API Key 认证"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Depends, Query, Request
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.celery_task_scheduler import scheduler
|
||||||
from app.core.api_key_auth import require_api_key
|
from app.core.api_key_auth import require_api_key
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.quota_stub import check_end_user_quota
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
from app.schemas.memory_api_schema import (
|
from app.schemas.memory_api_schema import (
|
||||||
MemoryReadRequest,
|
MemoryReadRequest,
|
||||||
MemoryReadResponse,
|
MemoryReadResponse,
|
||||||
|
MemoryReadSyncResponse,
|
||||||
MemoryWriteRequest,
|
MemoryWriteRequest,
|
||||||
MemoryWriteResponse,
|
MemoryWriteResponse,
|
||||||
|
MemoryWriteSyncResponse,
|
||||||
)
|
)
|
||||||
from app.services.memory_api_service import MemoryAPIService
|
from app.services.memory_api_service import MemoryAPIService
|
||||||
from fastapi import APIRouter, Body, Depends, Request
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_task_result(result: dict) -> dict:
|
||||||
|
"""Make Celery task result JSON-serializable.
|
||||||
|
|
||||||
|
Converts UUID and other non-serializable values to strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Raw task result dict from task_service
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-safe dict
|
||||||
|
"""
|
||||||
|
import uuid as _uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
def _convert(obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: _convert(v) for k, v in obj.items()}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_convert(i) for i in obj]
|
||||||
|
if isinstance(obj, _uuid.UUID):
|
||||||
|
return str(obj)
|
||||||
|
if isinstance(obj, datetime):
|
||||||
|
return obj.isoformat()
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return _convert(result)
|
||||||
|
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
async def get_memory_info():
|
async def get_memory_info():
|
||||||
"""获取记忆服务信息(占位)"""
|
"""获取记忆服务信息(占位)"""
|
||||||
return success(data={}, msg="Memory API - Coming Soon")
|
return success(data={}, msg="Memory API - Coming Soon")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/write_api_service")
|
@router.post("/write")
|
||||||
@require_api_key(scopes=["memory"])
|
@require_api_key(scopes=["memory"])
|
||||||
async def write_memory_api_service(
|
async def write_memory(
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
payload: MemoryWriteRequest = Body(..., embed=False),
|
message: str = Body(..., description="Message content"),
|
||||||
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Write memory to storage.
|
Submit a memory write task.
|
||||||
|
|
||||||
Stores memory content for the specified end user using the Memory API Service.
|
Validates the end user, then dispatches the write to a Celery background task
|
||||||
|
with per-user fair locking. Returns a task_id for status polling.
|
||||||
"""
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = MemoryWriteRequest(**body)
|
||||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
result = await memory_api_service.write_memory(
|
result = memory_api_service.write_memory(
|
||||||
workspace_id=api_key_auth.workspace_id,
|
workspace_id=api_key_auth.workspace_id,
|
||||||
end_user_id=payload.end_user_id,
|
end_user_id=payload.end_user_id,
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -52,28 +87,51 @@ async def write_memory_api_service(
|
|||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write successful for end_user: {payload.end_user_id}")
|
logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}")
|
||||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory written successfully")
|
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/read_api_service")
|
@router.get("/write/status")
|
||||||
@require_api_key(scopes=["memory"])
|
@require_api_key(scopes=["memory"])
|
||||||
async def read_memory_api_service(
|
async def get_write_task_status(
|
||||||
|
request: Request,
|
||||||
|
task_id: str = Query(..., description="Celery task ID"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check the status of a memory write task.
|
||||||
|
|
||||||
|
Returns the current status and result (if completed) of a previously submitted write task.
|
||||||
|
"""
|
||||||
|
logger.info(f"Write task status check - task_id: {task_id}")
|
||||||
|
|
||||||
|
result = scheduler.get_task_status(task_id)
|
||||||
|
|
||||||
|
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/read")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_memory(
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
payload: MemoryReadRequest = Body(..., embed=False),
|
message: str = Body(..., description="Query message"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Read memory from storage.
|
Submit a memory read task.
|
||||||
|
|
||||||
Queries and retrieves memories for the specified end user with context-aware responses.
|
Validates the end user, then dispatches the read to a Celery background task.
|
||||||
|
Returns a task_id for status polling.
|
||||||
"""
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = MemoryReadRequest(**body)
|
||||||
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
result = await memory_api_service.read_memory(
|
result = memory_api_service.read_memory(
|
||||||
workspace_id=api_key_auth.workspace_id,
|
workspace_id=api_key_auth.workspace_id,
|
||||||
end_user_id=payload.end_user_id,
|
end_user_id=payload.end_user_id,
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -83,5 +141,94 @@ async def read_memory_api_service(
|
|||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
logger.info(f"Memory read task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
|
||||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read task submitted")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/read/status")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_read_task_status(
|
||||||
|
request: Request,
|
||||||
|
task_id: str = Query(..., description="Celery task ID"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check the status of a memory read task.
|
||||||
|
|
||||||
|
Returns the current status and result (if completed) of a previously submitted read task.
|
||||||
|
"""
|
||||||
|
logger.info(f"Read task status check - task_id: {task_id}")
|
||||||
|
|
||||||
|
from app.services.task_service import get_task_memory_read_result
|
||||||
|
result = get_task_memory_read_result(task_id)
|
||||||
|
|
||||||
|
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/write/sync")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
@check_end_user_quota
|
||||||
|
async def write_memory_sync(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(..., description="Message content"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Write memory synchronously.
|
||||||
|
|
||||||
|
Blocks until the write completes and returns the result directly.
|
||||||
|
For async processing with task polling, use /write instead.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = MemoryWriteRequest(**body)
|
||||||
|
logger.info(f"Memory write (sync) request - end_user_id: {payload.end_user_id}")
|
||||||
|
|
||||||
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
|
result = await memory_api_service.write_memory_sync(
|
||||||
|
workspace_id=api_key_auth.workspace_id,
|
||||||
|
end_user_id=payload.end_user_id,
|
||||||
|
message=payload.message,
|
||||||
|
config_id=payload.config_id,
|
||||||
|
storage_type=payload.storage_type,
|
||||||
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Memory write (sync) successful for end_user: {payload.end_user_id}")
|
||||||
|
return success(data=MemoryWriteSyncResponse(**result).model_dump(), msg="Memory written successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/read/sync")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_memory_sync(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(..., description="Query message"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Read memory synchronously.
|
||||||
|
|
||||||
|
Blocks until the read completes and returns the answer directly.
|
||||||
|
For async processing with task polling, use /read instead.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = MemoryReadRequest(**body)
|
||||||
|
logger.info(f"Memory read (sync) request - end_user_id: {payload.end_user_id}")
|
||||||
|
|
||||||
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
|
result = await memory_api_service.read_memory_sync(
|
||||||
|
workspace_id=api_key_auth.workspace_id,
|
||||||
|
end_user_id=payload.end_user_id,
|
||||||
|
message=payload.message,
|
||||||
|
search_switch=payload.search_switch,
|
||||||
|
config_id=payload.config_id,
|
||||||
|
storage_type=payload.storage_type,
|
||||||
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Memory read (sync) successful for end_user: {payload.end_user_id}")
|
||||||
|
return success(data=MemoryReadSyncResponse(**result).model_dump(), msg="Memory read successfully")
|
||||||
|
|||||||
491
api/app/controllers/service/memory_config_api_controller.py
Normal file
491
api/app/controllers/service/memory_config_api_controller.py
Normal file
@@ -0,0 +1,491 @@
|
|||||||
|
"""Memory Config 服务接口 - 基于 API Key 认证"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Depends, Header, Query, Request
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.controllers import memory_storage_controller
|
||||||
|
from app.controllers import memory_forget_controller
|
||||||
|
from app.controllers import ontology_controller
|
||||||
|
from app.controllers import emotion_config_controller
|
||||||
|
from app.controllers import memory_reflection_controller
|
||||||
|
from app.schemas.memory_storage_schema import ForgettingConfigUpdateRequest
|
||||||
|
from app.controllers.emotion_config_controller import EmotionConfigUpdate
|
||||||
|
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||||
|
from app.core.api_key_auth import require_api_key
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
|
from app.schemas.memory_api_schema import (
|
||||||
|
ConfigUpdateExtractedRequest,
|
||||||
|
ConfigUpdateRequest,
|
||||||
|
ListConfigsResponse,
|
||||||
|
ConfigCreateRequest,
|
||||||
|
ConfigUpdateForgettingRequest,
|
||||||
|
EmotionConfigUpdateRequest,
|
||||||
|
ReflectionConfigUpdateRequest,
|
||||||
|
)
|
||||||
|
from app.schemas.memory_storage_schema import (
|
||||||
|
ConfigUpdate,
|
||||||
|
ConfigUpdateExtracted,
|
||||||
|
ConfigParamsCreate,
|
||||||
|
)
|
||||||
|
from app.services import api_key_service
|
||||||
|
from app.services.memory_api_service import MemoryAPIService
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/memory_config", tags=["V1 - Memory Config API"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
||||||
|
"""Build a current_user object from API key auth
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key_auth: Validated API key auth info
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User object with current_workspace_id set
|
||||||
|
"""
|
||||||
|
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||||
|
current_user = api_key.creator
|
||||||
|
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_config_ownership(config_id:str, workspace_id:uuid.UUID, db:Session):
|
||||||
|
"""Verify that the config belongs to the workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_id: The ID of the config to verify
|
||||||
|
workspace_id: The workspace ID tocheck against
|
||||||
|
db: Database session for querying
|
||||||
|
Raises:
|
||||||
|
BusinessException: If the config does not exist or does not belong to the workspace
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
resolved_id = resolve_config_id(config_id, db)
|
||||||
|
except ValueError as e:
|
||||||
|
raise BusinessException(
|
||||||
|
message=f"Invalid config_id: {e}",
|
||||||
|
code=BizCode.INVALID_PARAMETER,
|
||||||
|
)
|
||||||
|
config = MemoryConfigRepository.get_by_id(db, resolved_id)
|
||||||
|
if not config or config.workspace_id != workspace_id:
|
||||||
|
raise BusinessException(
|
||||||
|
message="Config not found or access denied",
|
||||||
|
code=BizCode.MEMORY_CONFIG_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# @router.get("/configs")
|
||||||
|
# @require_api_key(scopes=["memory"])
|
||||||
|
# async def list_memory_configs(
|
||||||
|
# request: Request,
|
||||||
|
# api_key_auth: ApiKeyAuth = None,
|
||||||
|
# db: Session = Depends(get_db),
|
||||||
|
# ):
|
||||||
|
# """
|
||||||
|
# List all memory configs for the workspace.
|
||||||
|
|
||||||
|
# Returns all available memory configurations associated with the authorized workspace.
|
||||||
|
# """
|
||||||
|
# logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
# memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
|
# result = memory_api_service.list_memory_configs(
|
||||||
|
# workspace_id=api_key_auth.workspace_id,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||||
|
# return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||||
|
|
||||||
|
@router.get("/read_all_config")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_all_config(
|
||||||
|
request:Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List all memory configs with full details (enhanced version).
|
||||||
|
|
||||||
|
Returns complete config fields for the authorized workspace.
|
||||||
|
No config_id ownership check needed — results are filtered by workspace.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 get all configs (full) - workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return memory_storage_controller.read_all_config(
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/scenes/simple")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_ontology_scenes(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get available ontology scenes for the workspace.
|
||||||
|
|
||||||
|
Returns a simple list of scene_id and scene_name for dropdown selection.
|
||||||
|
Used before creating a memory config to choose which ontology scene to associate.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 get scenes - workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return await ontology_controller.get_scenes_simple(
|
||||||
|
db=db,
|
||||||
|
current_user=current_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/read_config_extracted")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_config_extracted(
|
||||||
|
request: Request,
|
||||||
|
config_id: str = Query(..., description="config_id"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get extraction engine config details for a specific config.
|
||||||
|
|
||||||
|
Only configs belonging to the authorized workspace can be queried.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 read extracted config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return memory_storage_controller.read_config_extracted(
|
||||||
|
config_id = config_id,
|
||||||
|
current_user = current_user,
|
||||||
|
db = db,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/read_config_forgetting")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_config_forgetting(
|
||||||
|
request: Request,
|
||||||
|
config_id: str = Query(..., description="config_id"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get forgetting settings for a specific memory config.
|
||||||
|
|
||||||
|
Only configs belonging to the authorized workspace can be queried.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 read forgetting config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
result = await memory_forget_controller.read_forgetting_config(
|
||||||
|
config_id = config_id,
|
||||||
|
current_user = current_user,
|
||||||
|
db = db,
|
||||||
|
)
|
||||||
|
return jsonable_encoder(result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/read_config_emotion")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_config_emotion(
|
||||||
|
request: Request,
|
||||||
|
config_id: str = Query(..., description="config_id"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get emotion engine config details for a specific config.
|
||||||
|
|
||||||
|
Only configs belonging to the authorized workspace can be queried.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 read emotion config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return jsonable_encoder(emotion_config_controller.get_emotion_config(
|
||||||
|
config_id=config_id,
|
||||||
|
db=db,
|
||||||
|
current_user=current_user,
|
||||||
|
))
|
||||||
|
|
||||||
|
@router.get("/read_config_reflection")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_config_reflection(
|
||||||
|
request: Request,
|
||||||
|
config_id: str = Query(..., description="config_id"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get reflection engine config details for a specific config.
|
||||||
|
|
||||||
|
Only configs belonging to the authorized workspace can be queried.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 read reflection config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return jsonable_encoder(await memory_reflection_controller.start_reflection_configs(
|
||||||
|
config_id=config_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/create_config")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def create_memory_config(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new memory config for the workspace.
|
||||||
|
|
||||||
|
The config will be associated with the workspace of the API Key.
|
||||||
|
config_name is required, other fields are optional.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = ConfigCreateRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 create config - workspace: {api_key_auth.workspace_id}, config_name: {payload.config_name}")
|
||||||
|
|
||||||
|
# 构造管理端 Schema,workspace_id 从 API Key 注入
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
mgmt_payload = ConfigParamsCreate(
|
||||||
|
config_name=payload.config_name,
|
||||||
|
config_desc=payload.config_desc or "",
|
||||||
|
scene_id=payload.scene_id,
|
||||||
|
llm_id=payload.llm_id,
|
||||||
|
embedding_id=payload.embedding_id,
|
||||||
|
rerank_id=payload.rerank_id,
|
||||||
|
reflection_model_id=payload.reflection_model_id,
|
||||||
|
emotion_model_id=payload.emotion_model_id,
|
||||||
|
)
|
||||||
|
#将返回数据中UUID序列化处理
|
||||||
|
result =memory_storage_controller.create_config(
|
||||||
|
payload=mgmt_payload,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
x_language_type=x_language_type,
|
||||||
|
)
|
||||||
|
return jsonable_encoder(result)
|
||||||
|
|
||||||
|
@router.put("/update_config")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_memory_config(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update memory config basic info (name, description, scene).
|
||||||
|
|
||||||
|
Requires API Key with 'memory' scope
|
||||||
|
Only configs belonging to the authorized workspace can be updated.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = ConfigUpdateRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 update config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
mgmt_payload = ConfigUpdate(
|
||||||
|
config_id = payload.config_id,
|
||||||
|
config_name = payload.config_name,
|
||||||
|
config_desc = payload.config_desc,
|
||||||
|
scene_id = payload.scene_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return memory_storage_controller.update_config(
|
||||||
|
payload = mgmt_payload,
|
||||||
|
current_user = current_user,
|
||||||
|
db = db,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/update_config_extracted")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_memory_config_extracted(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
update memory config extraction engine config (models, thresholds, chunking, pruning, etc.).
|
||||||
|
|
||||||
|
Requires API Key with 'memory' scope.
|
||||||
|
Only configs belonging to the authorized workspace can be updated.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = ConfigUpdateExtractedRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 update extracted config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
#校验权限
|
||||||
|
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
update_fields = payload.model_dump(exclude_unset=True)
|
||||||
|
mgmt_payload = ConfigUpdateExtracted(**update_fields)
|
||||||
|
|
||||||
|
return memory_storage_controller.update_config_extracted(
|
||||||
|
payload = mgmt_payload,
|
||||||
|
current_user = current_user,
|
||||||
|
db = db,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/update_config_forgetting")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_memory_config_forgetting(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
update memory config forgetting settings (forgetting strategy, parameters, etc.).
|
||||||
|
|
||||||
|
Requires API Key with 'memory' scope.
|
||||||
|
Only configs belonging to the authorized workspace can be updated.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = ConfigUpdateForgettingRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 update forgetting config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
#校验权限
|
||||||
|
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
update_fields = payload.model_dump(exclude_unset=True)
|
||||||
|
mgmt_payload = ForgettingConfigUpdateRequest(**update_fields)
|
||||||
|
|
||||||
|
#将返回数据中UUID序列化处理
|
||||||
|
result = await memory_forget_controller.update_forgetting_config(
|
||||||
|
payload = mgmt_payload,
|
||||||
|
current_user = current_user,
|
||||||
|
db = db,
|
||||||
|
)
|
||||||
|
return jsonable_encoder(result)
|
||||||
|
|
||||||
|
@router.put("/update_config_emotion")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_config_emotion(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update emotion engine config (full update).
|
||||||
|
|
||||||
|
All fields except emotion_model_id are required.
|
||||||
|
Only configs belonging to the authorized workspace can be updated.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = EmotionConfigUpdateRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 update emotion config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
update_fields = payload.model_dump(exclude_unset=True)
|
||||||
|
mgmt_payload = EmotionConfigUpdate(**update_fields)
|
||||||
|
return jsonable_encoder(emotion_config_controller.update_emotion_config(
|
||||||
|
config=mgmt_payload,
|
||||||
|
db=db,
|
||||||
|
current_user=current_user,
|
||||||
|
))
|
||||||
|
|
||||||
|
@router.put("/update_config_reflection")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_config_reflection(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update reflection engine config (full update).
|
||||||
|
|
||||||
|
All fields are required.
|
||||||
|
Only configs belonging to the authorized workspace can be updated.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = ReflectionConfigUpdateRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 update reflection config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
update_fields = payload.model_dump(exclude_unset=True)
|
||||||
|
mgmt_payload = Memory_Reflection(**update_fields)
|
||||||
|
|
||||||
|
return jsonable_encoder(await memory_reflection_controller.save_reflection_config(
|
||||||
|
request=mgmt_payload,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
))
|
||||||
|
|
||||||
|
@router.delete("/delete_config")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def delete_memory_config(
|
||||||
|
config_id: str,
|
||||||
|
request: Request,
|
||||||
|
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete a memory config.
|
||||||
|
|
||||||
|
- Default configs cannot be deleted.
|
||||||
|
- If end users are connected and force=False, returns a warning.
|
||||||
|
- If force=True, clears end user references and deletes the config.
|
||||||
|
|
||||||
|
Only configs belonging to the authorized workspace can be deleted.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 delete config - config_id: {config_id}, force: {force}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return memory_storage_controller.delete_config(
|
||||||
|
config_id=config_id,
|
||||||
|
force=force,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
230
api/app/controllers/service/user_memory_api_controller.py
Normal file
230
api/app/controllers/service/user_memory_api_controller.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""User Memory 服务接口 — 基于 API Key 认证
|
||||||
|
|
||||||
|
包装 user_memory_controllers.py 和 memory_agent_controller.py 中的内部接口,
|
||||||
|
提供基于 API Key 认证的对外服务:
|
||||||
|
1./analytics/graph_data - 知识图谱数据接口
|
||||||
|
2./analytics/community_graph - 社区图谱接口
|
||||||
|
3./analytics/node_statistics - 记忆节点统计接口
|
||||||
|
4./analytics/user_summary - 用户摘要接口
|
||||||
|
5./analytics/memory_insight - 记忆洞察接口
|
||||||
|
6./analytics/interest_distribution - 兴趣分布接口
|
||||||
|
7./analytics/end_user_info - 终端用户信息接口
|
||||||
|
8./analytics/generate_cache - 缓存生成接口
|
||||||
|
|
||||||
|
|
||||||
|
路由前缀: /memory
|
||||||
|
子路径: /analytics/...
|
||||||
|
最终路径: /v1/memory/analytics/...
|
||||||
|
认证方式: API Key (@require_api_key)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, Query, Request, Body
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.api_key_auth import require_api_key
|
||||||
|
from app.core.api_key_utils import get_current_user_from_api_key, validate_end_user_in_workspace
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.db import get_db
|
||||||
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
|
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||||
|
|
||||||
|
# 包装内部服务 controller
|
||||||
|
from app.controllers import user_memory_controllers, memory_agent_controller
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/memory", tags=["V1 - User Memory API"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 知识图谱 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/graph_data")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_graph_data(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
node_types: Optional[str] = Query(None, description="Comma-separated node types filter"),
|
||||||
|
limit: int = Query(100, description="Max nodes to return (auto-capped at 1000 in service layer)"),
|
||||||
|
depth: int = Query(1, description="Graph traversal depth (auto-capped at 3 in service layer)"),
|
||||||
|
center_node_id: Optional[str] = Query(None, description="Center node for subgraph"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get knowledge graph data (nodes + edges) for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_graph_data_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
node_types=node_types,
|
||||||
|
limit=limit,
|
||||||
|
depth=depth,
|
||||||
|
center_node_id=center_node_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/community_graph")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_community_graph(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get community clustering graph for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_community_graph_data_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 节点统计 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/node_statistics")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_node_statistics(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get memory node type statistics for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_node_statistics_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 用户摘要 & 洞察 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/user_summary")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_user_summary(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get cached user summary for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_user_summary_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
language_type=language_type,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/memory_insight")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_memory_insight(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get cached memory insight report for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_memory_insight_report_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 兴趣分布 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/interest_distribution")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_interest_distribution(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
limit: int = Query(5, le=5, description="Max interest tags to return"),
|
||||||
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get interest distribution tags for an end user."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await memory_agent_controller.get_interest_distribution_by_user_api(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
language_type=language_type,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 终端用户信息 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/end_user_info")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_end_user_info(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str = Query(..., description="End user ID"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get end user basic information (name, aliases, metadata)."""
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.get_end_user_info(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 缓存生成 ====================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/analytics/generate_cache")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def generate_cache(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
|
):
|
||||||
|
"""Trigger cache generation (user summary + memory insight) for an end user or all workspace users."""
|
||||||
|
body = await request.json()
|
||||||
|
cache_request = GenerateCacheRequest(**body)
|
||||||
|
|
||||||
|
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||||
|
|
||||||
|
if cache_request.end_user_id:
|
||||||
|
validate_end_user_in_workspace(db, cache_request.end_user_id, api_key_auth.workspace_id)
|
||||||
|
|
||||||
|
return await user_memory_controllers.generate_cache_api(
|
||||||
|
request=cache_request,
|
||||||
|
language_type=language_type,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -11,11 +11,13 @@ from app.schemas import skill_schema
|
|||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.services.skill_service import SkillService
|
from app.services.skill_service import SkillService
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
|
from app.core.quota_stub import check_skill_quota
|
||||||
|
|
||||||
router = APIRouter(prefix="/skills", tags=["Skills"])
|
router = APIRouter(prefix="/skills", tags=["Skills"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("", summary="创建技能")
|
@router.post("", summary="创建技能")
|
||||||
|
@check_skill_quota
|
||||||
def create_skill(
|
def create_skill(
|
||||||
data: skill_schema.SkillCreate,
|
data: skill_schema.SkillCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
|||||||
173
api/app/controllers/tenant_subscription_controller.py
Normal file
173
api/app/controllers/tenant_subscription_controller.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""
|
||||||
|
租户套餐查询接口(普通用户可访问)
|
||||||
|
"""
|
||||||
|
import datetime
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.response_utils import success, fail
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.i18n.dependencies import get_translator
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
|
||||||
|
logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/tenant", tags=["Tenant"])
|
||||||
|
public_router = APIRouter(tags=["Tenant"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/subscription", response_model=ApiResponse, summary="获取当前用户所属租户的套餐信息")
|
||||||
|
async def get_my_tenant_subscription(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取当前登录用户所属租户的有效套餐订阅信息。
|
||||||
|
包含套餐名称、版本、配额、到期时间等。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||||
|
|
||||||
|
if not current_user.tenant:
|
||||||
|
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
||||||
|
|
||||||
|
tenant_id = current_user.tenant.id
|
||||||
|
svc = TenantSubscriptionService(db)
|
||||||
|
sub = svc.get_subscription(tenant_id)
|
||||||
|
|
||||||
|
if not sub:
|
||||||
|
# 无订阅记录时,兜底返回免费套餐信息
|
||||||
|
free_plan = svc.plan_repo.get_free_plan()
|
||||||
|
if not free_plan:
|
||||||
|
return success(data=None, msg="暂无有效套餐")
|
||||||
|
return success(data={
|
||||||
|
"subscription_id": None,
|
||||||
|
"tenant_id": str(tenant_id),
|
||||||
|
"package_plan_id": str(free_plan.id),
|
||||||
|
"package_version": free_plan.version,
|
||||||
|
"package_plan": {
|
||||||
|
"id": str(free_plan.id),
|
||||||
|
"name": free_plan.name,
|
||||||
|
"name_en": free_plan.name_en,
|
||||||
|
"version": free_plan.version,
|
||||||
|
"category": free_plan.category,
|
||||||
|
"tier_level": free_plan.tier_level,
|
||||||
|
"price": float(free_plan.price) if free_plan.price is not None else 0.0,
|
||||||
|
"billing_cycle": free_plan.billing_cycle,
|
||||||
|
"core_value": free_plan.core_value,
|
||||||
|
"core_value_en": free_plan.core_value_en,
|
||||||
|
"tech_support": free_plan.tech_support,
|
||||||
|
"tech_support_en": free_plan.tech_support_en,
|
||||||
|
"sla_compliance": free_plan.sla_compliance,
|
||||||
|
"sla_compliance_en": free_plan.sla_compliance_en,
|
||||||
|
"page_customization": free_plan.page_customization,
|
||||||
|
"page_customization_en": free_plan.page_customization_en,
|
||||||
|
"theme_color": free_plan.theme_color,
|
||||||
|
},
|
||||||
|
"started_at": None,
|
||||||
|
"expired_at": None,
|
||||||
|
"status": "active",
|
||||||
|
"quotas": free_plan.quotas or {},
|
||||||
|
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||||
|
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||||
|
}, msg="免费套餐")
|
||||||
|
|
||||||
|
return success(data=svc.build_response(sub))
|
||||||
|
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# 社区版无 premium 模块,从配置文件读取免费套餐
|
||||||
|
if not current_user.tenant:
|
||||||
|
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
||||||
|
|
||||||
|
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||||
|
|
||||||
|
plan = DEFAULT_FREE_PLAN
|
||||||
|
response_data = {
|
||||||
|
"subscription_id": None,
|
||||||
|
"tenant_id": str(current_user.tenant.id),
|
||||||
|
"package_plan_id": None,
|
||||||
|
"package_version": plan["version"],
|
||||||
|
"package_plan": {
|
||||||
|
"id": None,
|
||||||
|
"name": plan["name"],
|
||||||
|
"name_en": plan.get("name_en"),
|
||||||
|
"version": plan["version"],
|
||||||
|
"category": plan["category"],
|
||||||
|
"tier_level": plan["tier_level"],
|
||||||
|
"price": float(plan["price"]),
|
||||||
|
"billing_cycle": plan["billing_cycle"],
|
||||||
|
"core_value": plan.get("core_value"),
|
||||||
|
"core_value_en": plan.get("core_value_en"),
|
||||||
|
"tech_support": plan.get("tech_support"),
|
||||||
|
"tech_support_en": plan.get("tech_support_en"),
|
||||||
|
"sla_compliance": plan.get("sla_compliance"),
|
||||||
|
"sla_compliance_en": plan.get("sla_compliance_en"),
|
||||||
|
"page_customization": plan.get("page_customization"),
|
||||||
|
"page_customization_en": plan.get("page_customization_en"),
|
||||||
|
"theme_color": plan.get("theme_color"),
|
||||||
|
},
|
||||||
|
"started_at": None,
|
||||||
|
"expired_at": None,
|
||||||
|
"status": "active",
|
||||||
|
"quotas": plan["quotas"],
|
||||||
|
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||||
|
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||||
|
}
|
||||||
|
return success(data=response_data, msg="社区版免费套餐")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取租户套餐信息失败: {e}", exc_info=True)
|
||||||
|
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐信息失败"))
|
||||||
|
|
||||||
|
|
||||||
|
@public_router.get("/package-plans", response_model=ApiResponse, summary="获取套餐列表(公开)")
|
||||||
|
async def list_package_plans_public(
|
||||||
|
category: Optional[str] = None,
|
||||||
|
status: Optional[bool] = None,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
公开接口,无需鉴权。
|
||||||
|
SaaS 版从数据库读取套餐列表;社区版降级返回 default_free_plan.py 中的免费套餐。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from premium.platform_admin.package_plan_service import PackagePlanService
|
||||||
|
from premium.platform_admin.package_plan_schema import PackagePlanResponse
|
||||||
|
svc = PackagePlanService(db)
|
||||||
|
result = svc.get_list(page=1, size=9999, category=category, status=status, search=search)
|
||||||
|
return success(data=[PackagePlanResponse.model_validate(p).model_dump(mode="json") for p in result["items"]])
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||||
|
plan = DEFAULT_FREE_PLAN
|
||||||
|
return success(data=[{
|
||||||
|
"id": None,
|
||||||
|
"name": plan["name"],
|
||||||
|
"name_en": plan.get("name_en"),
|
||||||
|
"version": plan["version"],
|
||||||
|
"category": plan["category"],
|
||||||
|
"tier_level": plan["tier_level"],
|
||||||
|
"price": float(plan["price"]),
|
||||||
|
"billing_cycle": plan["billing_cycle"],
|
||||||
|
"core_value": plan.get("core_value"),
|
||||||
|
"core_value_en": plan.get("core_value_en"),
|
||||||
|
"tech_support": plan.get("tech_support"),
|
||||||
|
"tech_support_en": plan.get("tech_support_en"),
|
||||||
|
"sla_compliance": plan.get("sla_compliance"),
|
||||||
|
"sla_compliance_en": plan.get("sla_compliance_en"),
|
||||||
|
"page_customization": plan.get("page_customization"),
|
||||||
|
"page_customization_en": plan.get("page_customization_en"),
|
||||||
|
"theme_color": plan.get("theme_color"),
|
||||||
|
"status": plan.get("status", True),
|
||||||
|
"quotas": plan["quotas"],
|
||||||
|
}])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取套餐列表失败: {e}", exc_info=True)
|
||||||
|
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐列表失败"))
|
||||||
@@ -3,8 +3,11 @@ from typing import Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
from app.schemas.tool_schema import (
|
from app.schemas.tool_schema import (
|
||||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
|
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
|
||||||
|
CustomToolTestRequest, ToolActiveUpdate
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
@@ -73,6 +76,8 @@ async def get_tool_methods(
|
|||||||
if methods is None:
|
if methods is None:
|
||||||
raise HTTPException(status_code=404, detail="工具不存在")
|
raise HTTPException(status_code=404, detail="工具不存在")
|
||||||
return success(data=methods, msg="获取工具方法成功")
|
return success(data=methods, msg="获取工具方法成功")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -118,6 +123,8 @@ async def create_tool(
|
|||||||
raise HTTPException(status_code=400, detail=e.message)
|
raise HTTPException(status_code=400, detail=e.message)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -146,6 +153,8 @@ async def update_tool(
|
|||||||
return success(msg="工具更新成功")
|
return success(msg="工具更新成功")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -156,7 +165,7 @@ async def delete_tool(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
service: ToolService = Depends(get_tool_service)
|
service: ToolService = Depends(get_tool_service)
|
||||||
):
|
):
|
||||||
"""删除工具"""
|
"""删除工具(逻辑删除,is_active=False)"""
|
||||||
try:
|
try:
|
||||||
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
||||||
if not success_flag:
|
if not success_flag:
|
||||||
@@ -164,6 +173,34 @@ async def delete_tool(
|
|||||||
return success(msg="工具删除成功")
|
return success(msg="工具删除成功")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{tool_id}/active", response_model=ApiResponse)
|
||||||
|
async def set_tool_active(
|
||||||
|
tool_id: str,
|
||||||
|
request: ToolActiveUpdate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
service: ToolService = Depends(get_tool_service)
|
||||||
|
):
|
||||||
|
"""设置工具可用状态(启用/禁用)
|
||||||
|
|
||||||
|
- is_active=true: 启用工具
|
||||||
|
- is_active=false: 禁用工具(等同于删除,但可恢复)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
|
||||||
|
if not success_flag:
|
||||||
|
raise HTTPException(status_code=404, detail="工具不存在")
|
||||||
|
action = "启用" if request.is_active else "禁用"
|
||||||
|
return success(msg=f"工具已{action}")
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -196,6 +233,8 @@ async def execute_tool(
|
|||||||
},
|
},
|
||||||
msg="工具执行完成"
|
msg="工具执行完成"
|
||||||
)
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -212,6 +251,8 @@ async def parse_openapi_schema(
|
|||||||
if result["success"] is False:
|
if result["success"] is False:
|
||||||
raise HTTPException(status_code=400, detail=result["message"])
|
raise HTTPException(status_code=400, detail=result["message"])
|
||||||
return success(data=result, msg="Schema解析完成")
|
return success(data=result, msg="Schema解析完成")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -225,8 +266,10 @@ async def sync_mcp_tools(
|
|||||||
try:
|
try:
|
||||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||||
if not result.get("success", False):
|
if not result.get("success", False):
|
||||||
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
|
||||||
return success(data=result, msg="MCP工具列表同步完成")
|
return success(data=result, msg="MCP工具列表同步完成")
|
||||||
|
except BusinessException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -249,8 +292,10 @@ async def test_tool_connection(
|
|||||||
# 普通连接测试
|
# 普通连接测试
|
||||||
result = await service.test_connection(tool_id, current_user.tenant_id)
|
result = await service.test_connection(tool_id, current_user.tenant_id)
|
||||||
if result["success"] is False:
|
if result["success"] is False:
|
||||||
raise HTTPException(status_code=400, detail=result["message"])
|
raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
|
||||||
return success(data=result, msg="连接测试完成")
|
return success(data=result, msg="连接测试完成")
|
||||||
|
except BusinessException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
@@ -19,6 +20,7 @@ from app.services import user_service
|
|||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.core.security import verify_password
|
from app.core.security import verify_password
|
||||||
|
from app.i18n.dependencies import get_translator
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -33,7 +35,8 @@ router = APIRouter(
|
|||||||
def create_superuser(
|
def create_superuser(
|
||||||
user: user_schema.UserCreate,
|
user: user_schema.UserCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_superuser: User = Depends(get_current_superuser)
|
current_superuser: User = Depends(get_current_superuser),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""创建超级管理员(仅超级管理员可访问)"""
|
"""创建超级管理员(仅超级管理员可访问)"""
|
||||||
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
||||||
@@ -42,7 +45,7 @@ def create_superuser(
|
|||||||
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
||||||
|
|
||||||
result_schema = user_schema.User.model_validate(result)
|
result_schema = user_schema.User.model_validate(result)
|
||||||
return success(data=result_schema, msg="超级管理员创建成功")
|
return success(data=result_schema, msg=t("users.create.superuser_success"))
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{user_id}", response_model=ApiResponse)
|
@router.delete("/{user_id}", response_model=ApiResponse)
|
||||||
@@ -50,6 +53,7 @@ def delete_user(
|
|||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""停用用户(软删除)"""
|
"""停用用户(软删除)"""
|
||||||
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
@@ -57,13 +61,14 @@ def delete_user(
|
|||||||
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
||||||
return success(msg="用户停用成功")
|
return success(msg=t("users.delete.deactivate_success"))
|
||||||
|
|
||||||
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
||||||
def activate_user(
|
def activate_user(
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""激活用户"""
|
"""激活用户"""
|
||||||
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
@@ -74,13 +79,14 @@ def activate_user(
|
|||||||
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
||||||
|
|
||||||
result_schema = user_schema.User.model_validate(result)
|
result_schema = user_schema.User.model_validate(result)
|
||||||
return success(data=result_schema, msg="用户激活成功")
|
return success(data=result_schema, msg=t("users.activate.success"))
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
def get_current_user_info(
|
def get_current_user_info(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取当前用户信息"""
|
"""获取当前用户信息"""
|
||||||
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
||||||
@@ -105,7 +111,22 @@ def get_current_user_info(
|
|||||||
break
|
break
|
||||||
|
|
||||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||||
return success(data=result_schema, msg="用户信息获取成功")
|
|
||||||
|
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
||||||
|
if current_user.external_source:
|
||||||
|
try:
|
||||||
|
from premium.sso.models import SSOSource
|
||||||
|
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||||
|
if source and source.permissions:
|
||||||
|
result_schema.permissions = source.permissions
|
||||||
|
else:
|
||||||
|
result_schema.permissions = []
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
result_schema.permissions = []
|
||||||
|
else:
|
||||||
|
result_schema.permissions = ["all"]
|
||||||
|
|
||||||
|
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/superusers", response_model=ApiResponse)
|
@router.get("/superusers", response_model=ApiResponse)
|
||||||
@@ -113,6 +134,7 @@ def get_tenant_superusers(
|
|||||||
include_inactive: bool = False,
|
include_inactive: bool = False,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
||||||
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
||||||
@@ -125,8 +147,7 @@ def get_tenant_superusers(
|
|||||||
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
||||||
|
|
||||||
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
||||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=ApiResponse)
|
@router.get("/{user_id}", response_model=ApiResponse)
|
||||||
@@ -134,6 +155,7 @@ def get_user_info_by_id(
|
|||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""根据用户ID获取用户信息"""
|
"""根据用户ID获取用户信息"""
|
||||||
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
@@ -144,7 +166,7 @@ def get_user_info_by_id(
|
|||||||
api_logger.info(f"用户信息获取成功: {result.username}")
|
api_logger.info(f"用户信息获取成功: {result.username}")
|
||||||
|
|
||||||
result_schema = user_schema.User.model_validate(result)
|
result_schema = user_schema.User.model_validate(result)
|
||||||
return success(data=result_schema, msg="用户信息获取成功")
|
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||||
|
|
||||||
|
|
||||||
@router.put("/change-password", response_model=ApiResponse)
|
@router.put("/change-password", response_model=ApiResponse)
|
||||||
@@ -152,6 +174,7 @@ async def change_password(
|
|||||||
request: ChangePasswordRequest,
|
request: ChangePasswordRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""修改当前用户密码"""
|
"""修改当前用户密码"""
|
||||||
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
||||||
@@ -164,7 +187,7 @@ async def change_password(
|
|||||||
current_user=current_user
|
current_user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
||||||
return success(msg="密码修改成功")
|
return success(msg=t("auth.password.change_success"))
|
||||||
|
|
||||||
|
|
||||||
@router.put("/admin/change-password", response_model=ApiResponse)
|
@router.put("/admin/change-password", response_model=ApiResponse)
|
||||||
@@ -172,6 +195,7 @@ async def admin_change_password(
|
|||||||
request: AdminChangePasswordRequest,
|
request: AdminChangePasswordRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""超级管理员修改指定用户的密码"""
|
"""超级管理员修改指定用户的密码"""
|
||||||
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
||||||
@@ -186,16 +210,17 @@ async def admin_change_password(
|
|||||||
# 根据是否生成了随机密码来构造响应
|
# 根据是否生成了随机密码来构造响应
|
||||||
if request.new_password:
|
if request.new_password:
|
||||||
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
||||||
return success(msg="密码修改成功")
|
return success(msg=t("auth.password.change_success"))
|
||||||
else:
|
else:
|
||||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||||
return success(data=generated_password, msg="密码重置成功")
|
return success(data=generated_password, msg=t("auth.password.reset_success"))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||||
def verify_pwd(
|
def verify_pwd(
|
||||||
request: VerifyPasswordRequest,
|
request: VerifyPasswordRequest,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""验证当前用户密码"""
|
"""验证当前用户密码"""
|
||||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||||
@@ -203,8 +228,8 @@ def verify_pwd(
|
|||||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
|
raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
|
||||||
return success(data={"valid": is_valid}, msg="验证完成")
|
return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/send-email-code", response_model=ApiResponse)
|
@router.post("/send-email-code", response_model=ApiResponse)
|
||||||
@@ -212,6 +237,7 @@ async def send_email_code(
|
|||||||
request: SendEmailCodeRequest,
|
request: SendEmailCodeRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""发送邮箱验证码"""
|
"""发送邮箱验证码"""
|
||||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||||
@@ -219,7 +245,7 @@ async def send_email_code(
|
|||||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||||
|
|
||||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||||
return success(msg="验证码已发送到您的邮箱,请查收")
|
return success(msg=t("users.email.code_sent"))
|
||||||
|
|
||||||
|
|
||||||
@router.put("/change-email", response_model=ApiResponse)
|
@router.put("/change-email", response_model=ApiResponse)
|
||||||
@@ -227,6 +253,7 @@ async def change_email(
|
|||||||
request: VerifyEmailCodeRequest,
|
request: VerifyEmailCodeRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""验证验证码并修改邮箱"""
|
"""验证验证码并修改邮箱"""
|
||||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||||
@@ -239,4 +266,51 @@ async def change_email(
|
|||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||||
return success(msg="邮箱修改成功")
|
return success(msg=t("users.email.change_success"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me/language", response_model=ApiResponse)
|
||||||
|
def get_current_user_language(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
|
):
|
||||||
|
"""获取当前用户的语言偏好"""
|
||||||
|
api_logger.info(f"获取用户语言偏好: {current_user.username}")
|
||||||
|
|
||||||
|
language = user_service.get_user_language_preference(
|
||||||
|
db=db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
|
||||||
|
return success(
|
||||||
|
data=user_schema.LanguagePreferenceResponse(language=language),
|
||||||
|
msg=t("users.language.get_success")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/language", response_model=ApiResponse)
|
||||||
|
def update_current_user_language(
|
||||||
|
request: user_schema.LanguagePreferenceRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
|
):
|
||||||
|
"""设置当前用户的语言偏好"""
|
||||||
|
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
|
||||||
|
|
||||||
|
updated_user = user_service.update_user_language_preference(
|
||||||
|
db=db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
language=request.language,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
|
||||||
|
return success(
|
||||||
|
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
|
||||||
|
msg=t("users.language.update_success")
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import datetime
|
import datetime
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi import APIRouter, Depends,Header
|
from fastapi import APIRouter, Depends, Header
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
@@ -17,14 +17,17 @@ from app.services.user_memory_service import (
|
|||||||
UserMemoryService,
|
UserMemoryService,
|
||||||
analytics_memory_types,
|
analytics_memory_types,
|
||||||
analytics_graph_data,
|
analytics_graph_data,
|
||||||
|
analytics_community_graph_data,
|
||||||
)
|
)
|
||||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||||
from app.repositories.workspace_repository import WorkspaceRepository
|
from app.repositories.workspace_repository import WorkspaceRepository
|
||||||
from app.schemas.end_user_schema import (
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
EndUserProfileResponse,
|
from app.schemas.end_user_info_schema import (
|
||||||
EndUserProfileUpdate,
|
EndUserInfoResponse,
|
||||||
|
EndUserInfoCreate,
|
||||||
|
EndUserInfoUpdate,
|
||||||
)
|
)
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
@@ -44,9 +47,9 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||||
async def get_memory_insight_report_api(
|
async def get_memory_insight_report_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取缓存的记忆洞察报告
|
获取缓存的记忆洞察报告
|
||||||
@@ -72,10 +75,10 @@ async def get_memory_insight_report_api(
|
|||||||
|
|
||||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||||
async def get_user_summary_api(
|
async def get_user_summary_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取缓存的用户摘要
|
获取缓存的用户摘要
|
||||||
@@ -101,7 +104,7 @@ async def get_user_summary_api(
|
|||||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||||
try:
|
try:
|
||||||
# 调用服务层获取缓存数据
|
# 调用服务层获取缓存数据
|
||||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
|
||||||
|
|
||||||
if result["is_cached"]:
|
if result["is_cached"]:
|
||||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||||
@@ -116,10 +119,10 @@ async def get_user_summary_api(
|
|||||||
|
|
||||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||||
async def generate_cache_api(
|
async def generate_cache_api(
|
||||||
request: GenerateCacheRequest,
|
request: GenerateCacheRequest,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
手动触发缓存生成
|
手动触发缓存生成
|
||||||
@@ -154,10 +157,12 @@ async def generate_cache_api(
|
|||||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||||
|
|
||||||
# 生成记忆洞察
|
# 生成记忆洞察
|
||||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
|
||||||
|
language=language)
|
||||||
|
|
||||||
# 生成用户摘要
|
# 生成用户摘要
|
||||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
|
||||||
|
language=language)
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
result = {
|
result = {
|
||||||
@@ -208,9 +213,9 @@ async def generate_cache_api(
|
|||||||
|
|
||||||
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
||||||
async def get_node_statistics_api(
|
async def get_node_statistics_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -219,7 +224,8 @@ async def get_node_statistics_api(
|
|||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
api_logger.info(
|
||||||
|
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用新的记忆类型统计函数
|
# 调用新的记忆类型统计函数
|
||||||
@@ -227,21 +233,23 @@ async def get_node_statistics_api(
|
|||||||
|
|
||||||
# 计算总数用于日志
|
# 计算总数用于日志
|
||||||
total_count = sum(item["count"] for item in result)
|
total_count = sum(item["count"] for item in result)
|
||||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
api_logger.info(
|
||||||
|
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||||
async def get_graph_data_api(
|
async def get_graph_data_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
node_types: Optional[str] = None,
|
node_types: Optional[str] = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
depth: int = 1,
|
depth: int = 1,
|
||||||
center_node_id: Optional[str] = None,
|
center_node_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -295,110 +303,165 @@ async def get_graph_data_api(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||||
async def get_end_user_profile(
|
async def get_community_graph_data_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
workspace_repo = WorkspaceRepository(db)
|
|
||||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
|
||||||
|
|
||||||
if workspace_models:
|
|
||||||
model_id = workspace_models.get("llm", None)
|
|
||||||
else:
|
|
||||||
model_id = None
|
|
||||||
# 检查用户是否已选择工作空间
|
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||||
f"workspace={workspace_id}"
|
f"workspace={workspace_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 查询终端用户
|
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
|
||||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
|
||||||
|
|
||||||
if not end_user:
|
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
return success(data=result, msg=result.get("message", "查询成功"))
|
||||||
# 构建响应数据
|
|
||||||
profile_data = EndUserProfileResponse(
|
api_logger.info(
|
||||||
id=end_user.id,
|
f"成功获取社区图谱: end_user_id={end_user_id}, "
|
||||||
other_name=end_user.other_name,
|
f"nodes={result['statistics']['total_nodes']}, "
|
||||||
position=end_user.position,
|
f"edges={result['statistics']['total_edges']}"
|
||||||
department=end_user.department,
|
|
||||||
contact=end_user.contact,
|
|
||||||
phone=end_user.phone,
|
|
||||||
hire_date=end_user.hire_date,
|
|
||||||
updatetime_profile=end_user.updatetime_profile
|
|
||||||
)
|
)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
|
|
||||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
|
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
|
||||||
|
|
||||||
|
#=======================终端用户信息接口=======================
|
||||||
|
|
||||||
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
@router.get("/end_user_info", response_model=ApiResponse)
|
||||||
async def update_end_user_profile(
|
async def get_end_user_info(
|
||||||
profile_update: EndUserProfileUpdate,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
更新终端用户的基本信息
|
查询终端用户信息记录
|
||||||
|
|
||||||
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
根据 end_user_id 查询单条终端用户信息记录。
|
||||||
所有字段都是可选的,只更新提供的字段。
|
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
end_user_id = profile_update.end_user_id
|
|
||||||
|
|
||||||
# 验证工作空间
|
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试查询终端用户信息但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
|
f"查询终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||||
f"workspace={workspace_id}"
|
f"workspace={workspace_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调用 Service 层处理业务逻辑
|
# 校验 end_user 是否属于当前工作空间
|
||||||
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
|
end_user_repo = EndUserRepository(db)
|
||||||
|
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||||
|
if end_user is None:
|
||||||
|
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||||
|
if str(end_user.workspace_id) != str(workspace_id):
|
||||||
|
api_logger.warning(
|
||||||
|
f"用户 {current_user.username} 尝试查询不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||||
|
)
|
||||||
|
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||||
|
|
||||||
|
result = user_memory_service.get_end_user_info(db, end_user_id)
|
||||||
|
|
||||||
if result["success"]:
|
if result["success"]:
|
||||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
|
api_logger.info(f"成功查询终端用户信息: end_user_id={end_user_id}")
|
||||||
|
return success(data=result["data"], msg="查询成功")
|
||||||
|
else:
|
||||||
|
error_msg = result["error"]
|
||||||
|
api_logger.error(f"查询终端用户信息失败: end_user_id={end_user_id}, error={error_msg}")
|
||||||
|
|
||||||
|
if error_msg == "终端用户信息记录不存在":
|
||||||
|
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
||||||
|
elif error_msg == "无效的终端用户ID格式":
|
||||||
|
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
||||||
|
else:
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询终端用户信息失败", error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/end_user_info/updated", response_model=ApiResponse)
|
||||||
|
async def update_end_user_info(
|
||||||
|
info_update: EndUserInfoUpdate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
更新终端用户信息记录
|
||||||
|
|
||||||
|
根据 end_user_id 更新终端用户信息记录,支持批量更新多个别名。
|
||||||
|
|
||||||
|
示例请求体:
|
||||||
|
{
|
||||||
|
"end_user_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
||||||
|
"other_name": "张三1",
|
||||||
|
"aliases": ["小张", "张工"],
|
||||||
|
"meta_data": {"position": "工程师", "department": "技术部"}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
end_user_id = info_update.end_user_id
|
||||||
|
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试更新终端用户信息但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"更新终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||||
|
f"workspace={workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 校验 end_user 是否属于当前工作空间
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||||
|
if end_user is None:
|
||||||
|
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||||
|
if str(end_user.workspace_id) != str(workspace_id):
|
||||||
|
api_logger.warning(
|
||||||
|
f"用户 {current_user.username} 尝试更新不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||||
|
)
|
||||||
|
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||||
|
|
||||||
|
# 获取更新数据(排除 end_user_id)
|
||||||
|
update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||||
|
|
||||||
|
result = user_memory_service.update_end_user_info(db, end_user_id, update_data)
|
||||||
|
|
||||||
|
if result["success"]:
|
||||||
|
api_logger.info(f"成功更新终端用户信息: end_user_id={end_user_id}")
|
||||||
return success(data=result["data"], msg="更新成功")
|
return success(data=result["data"], msg="更新成功")
|
||||||
else:
|
else:
|
||||||
error_msg = result["error"]
|
error_msg = result["error"]
|
||||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
api_logger.error(f"终端用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||||
|
|
||||||
# 根据错误类型映射到合适的业务错误码
|
if error_msg == "终端用户信息记录不存在":
|
||||||
if error_msg == "终端用户不存在":
|
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
||||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
elif error_msg == "无效的终端用户ID格式":
|
||||||
elif error_msg == "无效的用户ID格式":
|
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
||||||
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
|
|
||||||
else:
|
else:
|
||||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
return fail(BizCode.INTERNAL_ERROR, "终端用户信息更新失败", error_msg)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
|
||||||
|
|
||||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
async def memory_space_timeline_of_shared_memories(
|
||||||
current_user: User = Depends(get_current_user),
|
id: str, label: str,
|
||||||
db: Session = Depends(get_db),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
):
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
# 使用集中化的语言校验
|
# 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
workspace_id=current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
workspace_repo = WorkspaceRepository(db)
|
workspace_repo = WorkspaceRepository(db)
|
||||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||||
|
|
||||||
@@ -410,11 +473,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
|||||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||||
|
|
||||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||||
async def memory_space_relationship_evolution(id: str, label: str,
|
async def memory_space_relationship_evolution(id: str, label: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,12 @@ from app.dependencies import (
|
|||||||
get_current_user,
|
get_current_user,
|
||||||
workspace_access_guard,
|
workspace_access_guard,
|
||||||
)
|
)
|
||||||
|
from app.i18n.dependencies import get_current_language, get_translator
|
||||||
|
from app.i18n.serializers import (
|
||||||
|
WorkspaceSerializer,
|
||||||
|
WorkspaceMemberSerializer,
|
||||||
|
WorkspaceInviteSerializer
|
||||||
|
)
|
||||||
from app.models.tenant_model import Tenants
|
from app.models.tenant_model import Tenants
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.models.workspace_model import InviteStatus
|
from app.models.workspace_model import InviteStatus
|
||||||
@@ -29,6 +35,7 @@ from app.schemas.workspace_schema import (
|
|||||||
WorkspaceUpdate,
|
WorkspaceUpdate,
|
||||||
)
|
)
|
||||||
from app.services import workspace_service
|
from app.services import workspace_service
|
||||||
|
from app.core.quota_stub import check_workspace_quota
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -65,7 +72,9 @@ def get_workspaces(
|
|||||||
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
current_tenant: Tenants = Depends(get_current_tenant)
|
current_tenant: Tenants = Depends(get_current_tenant),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取当前租户下用户参与的所有工作空间
|
"""获取当前租户下用户参与的所有工作空间
|
||||||
|
|
||||||
@@ -88,16 +97,24 @@ def get_workspaces(
|
|||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
||||||
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
|
|
||||||
return success(data=workspaces_schema, msg="工作空间列表获取成功")
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceSerializer()
|
||||||
|
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
|
||||||
|
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
|
||||||
|
|
||||||
|
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ApiResponse)
|
@router.post("", response_model=ApiResponse)
|
||||||
|
@check_workspace_quota
|
||||||
def create_workspace(
|
def create_workspace(
|
||||||
workspace: WorkspaceCreate,
|
workspace: WorkspaceCreate,
|
||||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""创建新的工作空间"""
|
"""创建新的工作空间"""
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
@@ -118,8 +135,13 @@ def create_workspace(
|
|||||||
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
||||||
f"创建者: {current_user.username}, language={language}"
|
f"创建者: {current_user.username}, language={language}"
|
||||||
)
|
)
|
||||||
result_schema = WorkspaceResponse.model_validate(result)
|
|
||||||
return success(data=result_schema, msg="工作空间创建成功")
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceSerializer()
|
||||||
|
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||||
|
result_i18n = serializer.serialize(result_data, language)
|
||||||
|
|
||||||
|
return success(data=result_i18n, msg=t("workspace.created"))
|
||||||
|
|
||||||
@router.put("", response_model=ApiResponse)
|
@router.put("", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
@@ -127,6 +149,8 @@ def update_workspace(
|
|||||||
workspace: WorkspaceUpdate,
|
workspace: WorkspaceUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""更新工作空间"""
|
"""更新工作空间"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -139,14 +163,21 @@ def update_workspace(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
||||||
result_schema = WorkspaceResponse.model_validate(result)
|
|
||||||
return success(data=result_schema, msg="工作空间更新成功")
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceSerializer()
|
||||||
|
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||||
|
result_i18n = serializer.serialize(result_data, language)
|
||||||
|
|
||||||
|
return success(data=result_i18n, msg=t("workspace.updated"))
|
||||||
|
|
||||||
@router.get("/members", response_model=ApiResponse)
|
@router.get("/members", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_cur_workspace_members(
|
def get_cur_workspace_members(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取工作空间成员列表(关系序列化)"""
|
"""获取工作空间成员列表(关系序列化)"""
|
||||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
||||||
@@ -157,8 +188,14 @@ def get_cur_workspace_members(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
||||||
|
|
||||||
|
# 转换为表格项并使用序列化器添加国际化字段
|
||||||
table_items = _convert_members_to_table_items(members)
|
table_items = _convert_members_to_table_items(members)
|
||||||
return success(data=table_items, msg="工作空间成员列表获取成功")
|
serializer = WorkspaceMemberSerializer()
|
||||||
|
members_data = [item.model_dump() for item in table_items]
|
||||||
|
members_i18n = serializer.serialize_list(members_data, language)
|
||||||
|
|
||||||
|
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@router.put("/members", response_model=ApiResponse)
|
@router.put("/members", response_model=ApiResponse)
|
||||||
@@ -168,6 +205,7 @@ def update_workspace_members(
|
|||||||
updates: List[WorkspaceMemberUpdate],
|
updates: List[WorkspaceMemberUpdate],
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
||||||
@@ -178,27 +216,28 @@ def update_workspace_members(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
||||||
return success(msg="成员角色更新成功")
|
return success(msg=t("workspace.members.role_updated"))
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def delete_workspace_member(
|
async def delete_workspace_member(
|
||||||
member_id: uuid.UUID,
|
member_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
|
|
||||||
workspace_service.delete_workspace_member(
|
await workspace_service.delete_workspace_member(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
member_id=member_id,
|
member_id=member_id,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
||||||
return success(msg="成员删除成功")
|
return success(msg=t("workspace.members.deleted"))
|
||||||
|
|
||||||
|
|
||||||
# 创建空间协作邀请
|
# 创建空间协作邀请
|
||||||
@@ -208,6 +247,8 @@ def create_workspace_invite(
|
|||||||
invite_data: WorkspaceInviteCreate,
|
invite_data: WorkspaceInviteCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""创建工作空间邀请"""
|
"""创建工作空间邀请"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -220,7 +261,12 @@ def create_workspace_invite(
|
|||||||
user=current_user
|
user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
||||||
return success(data=result, msg="邀请创建成功")
|
|
||||||
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceInviteSerializer()
|
||||||
|
result_i18n = serializer.serialize(result, language)
|
||||||
|
|
||||||
|
return success(data=result_i18n, msg=t("workspace.invites.created"))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/invites", response_model=ApiResponse)
|
@router.get("/invites", response_model=ApiResponse)
|
||||||
@@ -232,6 +278,8 @@ def get_workspace_invites(
|
|||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取工作空间邀请列表"""
|
"""获取工作空间邀请列表"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -246,18 +294,30 @@ def get_workspace_invites(
|
|||||||
offset=offset
|
offset=offset
|
||||||
)
|
)
|
||||||
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
||||||
return success(data=invites, msg="邀请列表获取成功")
|
|
||||||
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceInviteSerializer()
|
||||||
|
invites_i18n = serializer.serialize_list(invites, language)
|
||||||
|
|
||||||
|
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
||||||
def get_workspace_invite_info(
|
def get_workspace_invite_info(
|
||||||
token: str,
|
token: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取工作空间邀请用户信息(无需认证)"""
|
"""获取工作空间邀请用户信息(无需认证)"""
|
||||||
result = workspace_service.validate_invite_token(db=db, token=token)
|
result = workspace_service.validate_invite_token(db=db, token=token)
|
||||||
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
||||||
return success(data=result, msg="邀请验证成功")
|
|
||||||
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceInviteSerializer()
|
||||||
|
result_i18n = serializer.serialize(result, language)
|
||||||
|
|
||||||
|
return success(data=result_i18n, msg=t("workspace.invites.validated"))
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
||||||
@@ -267,6 +327,8 @@ def revoke_workspace_invite(
|
|||||||
invite_id: uuid.UUID,
|
invite_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""撤销工作空间邀请"""
|
"""撤销工作空间邀请"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -279,7 +341,12 @@ def revoke_workspace_invite(
|
|||||||
user=current_user
|
user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
||||||
return success(data=result, msg="邀请撤销成功")
|
|
||||||
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceInviteSerializer()
|
||||||
|
result_i18n = serializer.serialize(result, language)
|
||||||
|
|
||||||
|
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
|
||||||
|
|
||||||
# ==================== 公开邀请接口(无需认证) ====================
|
# ==================== 公开邀请接口(无需认证) ====================
|
||||||
|
|
||||||
@@ -302,6 +369,7 @@ def switch_workspace(
|
|||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""切换工作空间"""
|
"""切换工作空间"""
|
||||||
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
||||||
@@ -312,7 +380,7 @@ def switch_workspace(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
||||||
return success(msg="工作空间切换成功")
|
return success(msg=t("workspace.switched"))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/storage", response_model=ApiResponse)
|
@router.get("/storage", response_model=ApiResponse)
|
||||||
@@ -320,6 +388,7 @@ def switch_workspace(
|
|||||||
def get_workspace_storage_type(
|
def get_workspace_storage_type(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取当前工作空间的存储类型"""
|
"""获取当前工作空间的存储类型"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -331,7 +400,7 @@ def get_workspace_storage_type(
|
|||||||
user=current_user
|
user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
||||||
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
|
return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/workspace_models", response_model=ApiResponse)
|
@router.get("/workspace_models", response_model=ApiResponse)
|
||||||
@@ -339,6 +408,8 @@ def get_workspace_storage_type(
|
|||||||
def workspace_models_configs(
|
def workspace_models_configs(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -354,14 +425,14 @@ def workspace_models_configs(
|
|||||||
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="工作空间不存在或无权访问"
|
detail=t("workspace.not_found")
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||||
)
|
)
|
||||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
|
return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@router.put("/workspace_models", response_model=ApiResponse)
|
@router.put("/workspace_models", response_model=ApiResponse)
|
||||||
@@ -370,6 +441,7 @@ def update_workspace_models_configs(
|
|||||||
models_update: WorkspaceModelsUpdate,
|
models_update: WorkspaceModelsUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -386,5 +458,5 @@ def update_workspace_models_configs(
|
|||||||
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
||||||
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
||||||
)
|
)
|
||||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")
|
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))
|
||||||
|
|
||||||
|
|||||||
@@ -11,17 +11,14 @@ LangChain Agent 封装
|
|||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
from langchain.agents import create_agent
|
||||||
from app.db import get_db
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langgraph.errors import GraphRecursionError
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
from app.models.models_model import ModelType, ModelProvider
|
from app.models.models_model import ModelType
|
||||||
from app.services.memory_agent_service import (
|
|
||||||
get_end_user_connected_config,
|
|
||||||
)
|
|
||||||
from langchain.agents import create_agent
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import BaseTool
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -41,7 +38,11 @@ class LangChainAgent:
|
|||||||
tools: Optional[Sequence[BaseTool]] = None,
|
tools: Optional[Sequence[BaseTool]] = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
|
||||||
|
deep_thinking: bool = False, # 是否启用深度思考模式
|
||||||
|
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
|
||||||
|
json_output: bool = False, # 是否强制 JSON 输出
|
||||||
|
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
|
||||||
):
|
):
|
||||||
"""初始化 LangChain Agent
|
"""初始化 LangChain Agent
|
||||||
|
|
||||||
@@ -79,6 +80,17 @@ class LangChainAgent:
|
|||||||
|
|
||||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||||
|
|
||||||
|
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format
|
||||||
|
# 在 system prompt 中注入 JSON 要求
|
||||||
|
from app.models.models_model import ModelProvider
|
||||||
|
if json_output and (
|
||||||
|
(provider.lower() == ModelProvider.DASHSCOPE and not is_omni)
|
||||||
|
or provider.lower() == ModelProvider.VOLCANO
|
||||||
|
# 有工具时 response_format 会被移除,所有 provider 都需要 system prompt 注入保证 JSON 输出
|
||||||
|
or bool(tools)
|
||||||
|
):
|
||||||
|
self.system_prompt += "\n请以JSON格式输出。"
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||||
f"tool_count={len(self.tools)}, "
|
f"tool_count={len(self.tools)}, "
|
||||||
@@ -86,21 +98,28 @@ class LangChainAgent:
|
|||||||
f"auto_calculated={max_iterations is None}"
|
f"auto_calculated={max_iterations is None}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 RedBearLLM(支持多提供商)
|
# 创建 RedBearLLM,capability 校验由 RedBearModelConfig 统一处理
|
||||||
model_config = RedBearModelConfig(
|
model_config = RedBearModelConfig(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
is_omni=is_omni,
|
is_omni=is_omni,
|
||||||
|
capability=capability,
|
||||||
|
deep_thinking=deep_thinking,
|
||||||
|
thinking_budget_tokens=thinking_budget_tokens,
|
||||||
|
json_output=json_output,
|
||||||
extra_params={
|
extra_params={
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"streaming": streaming # 使用参数控制流式
|
"streaming": streaming
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
||||||
|
# 从经过校验的 config 读取实际生效的能力开关
|
||||||
|
self.deep_thinking = model_config.deep_thinking
|
||||||
|
self.json_output = model_config.json_output
|
||||||
|
|
||||||
# 获取底层模型用于真正的流式调用
|
# 获取底层模型用于真正的流式调用
|
||||||
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
|
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
|
||||||
@@ -226,10 +245,7 @@ class LangChainAgent:
|
|||||||
Returns:
|
Returns:
|
||||||
List[BaseMessage]: 消息列表
|
List[BaseMessage]: 消息列表
|
||||||
"""
|
"""
|
||||||
messages = []
|
messages: list = []
|
||||||
|
|
||||||
# 添加系统提示词
|
|
||||||
messages.append(SystemMessage(content=self.system_prompt))
|
|
||||||
|
|
||||||
# 添加历史消息
|
# 添加历史消息
|
||||||
if history:
|
if history:
|
||||||
@@ -254,6 +270,33 @@ class LangChainAgent:
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_tokens_from_message(msg) -> int:
|
||||||
|
"""从 AIMessage 或类似对象中提取 total_tokens,兼容多种 provider 格式
|
||||||
|
|
||||||
|
支持的格式:
|
||||||
|
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
|
||||||
|
- response_metadata.usage.total_tokens (部分 provider)
|
||||||
|
- usage_metadata.total_tokens (LangChain 新版)
|
||||||
|
"""
|
||||||
|
total = 0
|
||||||
|
# 1. response_metadata
|
||||||
|
response_meta = getattr(msg, "response_metadata", None)
|
||||||
|
if response_meta and isinstance(response_meta, dict):
|
||||||
|
# 尝试 token_usage 路径
|
||||||
|
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
|
||||||
|
if isinstance(token_usage, dict):
|
||||||
|
total = token_usage.get("total_tokens", 0)
|
||||||
|
# 2. usage_metadata(LangChain 新版 AIMessage 属性)
|
||||||
|
if not total:
|
||||||
|
usage_meta = getattr(msg, "usage_metadata", None)
|
||||||
|
if usage_meta:
|
||||||
|
if isinstance(usage_meta, dict):
|
||||||
|
total = usage_meta.get("total_tokens", 0)
|
||||||
|
else:
|
||||||
|
total = getattr(usage_meta, "total_tokens", 0)
|
||||||
|
return total or 0
|
||||||
|
|
||||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
构建多模态消息内容
|
构建多模态消息内容
|
||||||
@@ -288,17 +331,23 @@ class LangChainAgent:
|
|||||||
|
|
||||||
return content_parts
|
return content_parts
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_reasoning_content(msg) -> str:
|
||||||
|
"""从 AIMessage 中提取深度思考内容(reasoning_content)
|
||||||
|
|
||||||
|
所有 provider 统一通过 additional_kwargs.reasoning_content 传递:
|
||||||
|
- DeepSeek-R1 / QwQ: 原生字段
|
||||||
|
- Volcano (Doubao-thinking): 由 VolcanoChatOpenAI 从 delta.reasoning_content 注入
|
||||||
|
"""
|
||||||
|
additional = getattr(msg, "additional_kwargs", None) or {}
|
||||||
|
return additional.get("reasoning_content") or additional.get("reasoning", "")
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
history: Optional[List[Dict[str, str]]] = None,
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
end_user_id: Optional[str] = None,
|
files: Optional[List[Dict[str, Any]]] = None
|
||||||
config_id: Optional[str] = None, # 添加这个参数
|
|
||||||
storage_type: Optional[str] = None,
|
|
||||||
user_rag_memory_id: Optional[str] = None,
|
|
||||||
memory_flag: Optional[bool] = True,
|
|
||||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""执行对话
|
"""执行对话
|
||||||
|
|
||||||
@@ -306,32 +355,12 @@ class LangChainAgent:
|
|||||||
message: 用户消息
|
message: 用户消息
|
||||||
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||||
context: 上下文信息(如知识库检索结果)
|
context: 上下文信息(如知识库检索结果)
|
||||||
|
files: 多模态文件
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含 content 和元数据的字典
|
Dict: 包含 content 和元数据的字典
|
||||||
"""
|
"""
|
||||||
message_chat = message
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
actual_config_id = config_id
|
|
||||||
# If config_id is None, try to get from end_user's connected config
|
|
||||||
if actual_config_id is None and end_user_id:
|
|
||||||
try:
|
|
||||||
from app.services.memory_agent_service import (
|
|
||||||
get_end_user_connected_config,
|
|
||||||
)
|
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
|
||||||
actual_config_id = connected_config.get("memory_config_id")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get db session: {e}")
|
|
||||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
|
||||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
|
||||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
|
||||||
try:
|
try:
|
||||||
# 准备消息列表(支持多模态)
|
# 准备消息列表(支持多模态)
|
||||||
messages = self._prepare_messages(message, history, context, files)
|
messages = self._prepare_messages(message, history, context, files)
|
||||||
@@ -355,7 +384,7 @@ class LangChainAgent:
|
|||||||
{"messages": messages},
|
{"messages": messages},
|
||||||
config={"recursion_limit": self.max_iterations}
|
config={"recursion_limit": self.max_iterations}
|
||||||
)
|
)
|
||||||
except RecursionError as e:
|
except (RecursionError, GraphRecursionError) as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
||||||
extra={"error": str(e)}
|
extra={"error": str(e)}
|
||||||
@@ -378,6 +407,7 @@ class LangChainAgent:
|
|||||||
|
|
||||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
reasoning_content = ""
|
||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||||
@@ -412,16 +442,13 @@ class LangChainAgent:
|
|||||||
else:
|
else:
|
||||||
content = str(msg.content)
|
content = str(msg.content)
|
||||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
total_tokens = self._extract_tokens_from_message(msg)
|
||||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
reasoning_content = self._extract_reasoning_content(msg) if self.deep_thinking else ""
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
|
||||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
|
||||||
actual_config_id)
|
|
||||||
response = {
|
response = {
|
||||||
"content": content,
|
"content": content,
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
@@ -432,6 +459,8 @@ class LangChainAgent:
|
|||||||
"total_tokens": total_tokens
|
"total_tokens": total_tokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if reasoning_content:
|
||||||
|
response["reasoning_content"] = reasoning_content
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Agent 调用完成",
|
"Agent 调用完成",
|
||||||
@@ -452,22 +481,20 @@ class LangChainAgent:
|
|||||||
message: str,
|
message: str,
|
||||||
history: Optional[List[Dict[str, str]]] = None,
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
end_user_id: Optional[str] = None,
|
files: Optional[List[Dict[str, Any]]] = None
|
||||||
config_id: Optional[str] = None,
|
) -> AsyncGenerator[str | int | dict[str, str], None]:
|
||||||
storage_type: Optional[str] = None,
|
|
||||||
user_rag_memory_id: Optional[str] = None,
|
|
||||||
memory_flag: Optional[bool] = True,
|
|
||||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""执行流式对话
|
"""执行流式对话
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 用户消息
|
message: 用户消息
|
||||||
history: 历史消息列表
|
history: 历史消息列表
|
||||||
context: 上下文信息
|
context: 上下文信息
|
||||||
|
files: 多模态文件
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
str: 消息内容块
|
str: 消息内容块
|
||||||
|
int: token 统计
|
||||||
|
Dict: 深度思考内容 {"type": "reasoning", "content": "..."}
|
||||||
"""
|
"""
|
||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
logger.info(" chat_stream 方法开始执行")
|
logger.info(" chat_stream 方法开始执行")
|
||||||
@@ -475,23 +502,6 @@ class LangChainAgent:
|
|||||||
logger.info(f" Has tools: {bool(self.tools)}")
|
logger.info(f" Has tools: {bool(self.tools)}")
|
||||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
message_chat = message
|
|
||||||
actual_config_id = config_id
|
|
||||||
# If config_id is None, try to get from end_user's connected config
|
|
||||||
if actual_config_id is None and end_user_id:
|
|
||||||
try:
|
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
|
||||||
actual_config_id = connected_config.get("memory_config_id")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get db session: {e}")
|
|
||||||
|
|
||||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
|
||||||
try:
|
try:
|
||||||
# 准备消息列表(支持多模态)
|
# 准备消息列表(支持多模态)
|
||||||
messages = self._prepare_messages(message, history, context, files)
|
messages = self._prepare_messages(message, history, context, files)
|
||||||
@@ -501,17 +511,19 @@ class LangChainAgent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
yielded_content = False
|
|
||||||
|
|
||||||
# 统一使用 agent 的 astream_events 实现流式输出
|
# 统一使用 agent 的 astream_events 实现流式输出
|
||||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||||
full_content = ''
|
full_content = ''
|
||||||
|
full_reasoning = ''
|
||||||
try:
|
try:
|
||||||
|
last_event = {}
|
||||||
async for event in self.agent.astream_events(
|
async for event in self.agent.astream_events(
|
||||||
{"messages": messages},
|
{"messages": messages},
|
||||||
version="v2",
|
version="v2",
|
||||||
config={"recursion_limit": self.max_iterations}
|
config={"recursion_limit": self.max_iterations}
|
||||||
):
|
):
|
||||||
|
last_event = event
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
kind = event.get("event")
|
kind = event.get("event")
|
||||||
|
|
||||||
@@ -520,12 +532,18 @@ class LangChainAgent:
|
|||||||
# LLM 流式输出
|
# LLM 流式输出
|
||||||
chunk = event.get("data", {}).get("chunk")
|
chunk = event.get("data", {}).get("chunk")
|
||||||
if chunk and hasattr(chunk, "content"):
|
if chunk and hasattr(chunk, "content"):
|
||||||
|
# 提取深度思考内容(仅在启用深度思考时)
|
||||||
|
if self.deep_thinking:
|
||||||
|
reasoning_chunk = self._extract_reasoning_content(chunk)
|
||||||
|
if reasoning_chunk:
|
||||||
|
full_reasoning += reasoning_chunk
|
||||||
|
yield {"type": "reasoning", "content": reasoning_chunk}
|
||||||
|
|
||||||
# 处理多模态响应:content 可能是字符串或列表
|
# 处理多模态响应:content 可能是字符串或列表
|
||||||
chunk_content = chunk.content
|
chunk_content = chunk.content
|
||||||
if isinstance(chunk_content, str) and chunk_content:
|
if isinstance(chunk_content, str) and chunk_content:
|
||||||
full_content += chunk_content
|
full_content += chunk_content
|
||||||
yield chunk_content
|
yield chunk_content
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(chunk_content, list):
|
elif isinstance(chunk_content, list):
|
||||||
# 多模态响应:提取文本部分
|
# 多模态响应:提取文本部分
|
||||||
for item in chunk_content:
|
for item in chunk_content:
|
||||||
@@ -536,29 +554,32 @@ class LangChainAgent:
|
|||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
elif item.get("type") == "text":
|
elif item.get("type") == "text":
|
||||||
text = item.get("text", "")
|
text = item.get("text", "")
|
||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
full_content += item
|
full_content += item
|
||||||
yield item
|
yield item
|
||||||
yielded_content = True
|
|
||||||
|
|
||||||
elif kind == "on_llm_stream":
|
elif kind == "on_llm_stream":
|
||||||
# 另一种 LLM 流式事件
|
# 另一种 LLM 流式事件
|
||||||
chunk = event.get("data", {}).get("chunk")
|
chunk = event.get("data", {}).get("chunk")
|
||||||
if chunk:
|
if chunk:
|
||||||
if hasattr(chunk, "content"):
|
if hasattr(chunk, "content"):
|
||||||
|
# 提取深度思考内容(仅在启用深度思考时)
|
||||||
|
if self.deep_thinking:
|
||||||
|
reasoning_chunk = self._extract_reasoning_content(chunk)
|
||||||
|
if reasoning_chunk:
|
||||||
|
full_reasoning += reasoning_chunk
|
||||||
|
yield {"type": "reasoning", "content": reasoning_chunk}
|
||||||
|
|
||||||
chunk_content = chunk.content
|
chunk_content = chunk.content
|
||||||
if isinstance(chunk_content, str) and chunk_content:
|
if isinstance(chunk_content, str) and chunk_content:
|
||||||
full_content += chunk_content
|
full_content += chunk_content
|
||||||
yield chunk_content
|
yield chunk_content
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(chunk_content, list):
|
elif isinstance(chunk_content, list):
|
||||||
# 多模态响应:提取文本部分
|
# 多模态响应:提取文本部分
|
||||||
for item in chunk_content:
|
for item in chunk_content:
|
||||||
@@ -569,22 +590,18 @@ class LangChainAgent:
|
|||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
elif item.get("type") == "text":
|
elif item.get("type") == "text":
|
||||||
text = item.get("text", "")
|
text = item.get("text", "")
|
||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
full_content += item
|
full_content += item
|
||||||
yield item
|
yield item
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
full_content += chunk
|
full_content += chunk
|
||||||
yield chunk
|
yield chunk
|
||||||
yielded_content = True
|
|
||||||
|
|
||||||
# 记录工具调用(可选)
|
# 记录工具调用(可选)
|
||||||
elif kind == "on_tool_start":
|
elif kind == "on_tool_start":
|
||||||
@@ -594,17 +611,20 @@ class LangChainAgent:
|
|||||||
|
|
||||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||||
# 统计token消耗
|
# 统计token消耗
|
||||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
|
||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
stream_total_tokens = self._extract_tokens_from_message(msg)
|
||||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
|
||||||
0) if response_meta else 0
|
yield stream_total_tokens
|
||||||
yield total_tokens
|
|
||||||
break
|
break
|
||||||
if memory_flag:
|
|
||||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
except GraphRecursionError:
|
||||||
actual_config_id)
|
logger.warning(
|
||||||
|
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),模型可能不支持正确的工具调用停止判断"
|
||||||
|
)
|
||||||
|
if not full_content:
|
||||||
|
yield "抱歉,我在处理您的请求时遇到了问题(已达最大处理步骤限制)。请尝试简化问题或更换模型后重试。"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ def require_api_key(
|
|||||||
})
|
})
|
||||||
raise BusinessException("API Key 无效或已过期", BizCode.API_KEY_INVALID)
|
raise BusinessException("API Key 无效或已过期", BizCode.API_KEY_INVALID)
|
||||||
|
|
||||||
|
ApiKeyAuthService.check_app_published(db, api_key_obj)
|
||||||
|
|
||||||
if scopes:
|
if scopes:
|
||||||
missing_scopes = []
|
missing_scopes = []
|
||||||
for scope in scopes:
|
for scope in scopes:
|
||||||
@@ -97,7 +99,7 @@ def require_api_key(
|
|||||||
)
|
)
|
||||||
|
|
||||||
rate_limiter = RateLimiterService()
|
rate_limiter = RateLimiterService()
|
||||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj, db=db)
|
||||||
if not is_allowed:
|
if not is_allowed:
|
||||||
logger.warning("API Key 限流触发", extra={
|
logger.warning("API Key 限流触发", extra={
|
||||||
"api_key_id": str(api_key_obj.id),
|
"api_key_id": str(api_key_obj.id),
|
||||||
@@ -106,10 +108,12 @@ def require_api_key(
|
|||||||
"error_msg": error_msg
|
"error_msg": error_msg
|
||||||
})
|
})
|
||||||
# 根据错误消息判断限流类型
|
# 根据错误消息判断限流类型
|
||||||
if "QPS" in error_msg:
|
if "Daily" in error_msg:
|
||||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
|
||||||
elif "Daily" in error_msg:
|
|
||||||
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
||||||
|
elif "Tenant" in error_msg:
|
||||||
|
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED # 租户套餐速率超限,同属 QPS 类
|
||||||
|
elif "QPS" in error_msg:
|
||||||
|
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||||
else:
|
else:
|
||||||
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,15 @@
|
|||||||
"""API Key 工具函数"""
|
"""API Key 工具函数"""
|
||||||
import secrets
|
import secrets
|
||||||
|
import uuid as _uuid
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session as _Session
|
||||||
|
from app.core.error_codes import BizCode as _BizCode
|
||||||
|
from app.core.exceptions import BusinessException as _BusinessException
|
||||||
|
from app.models.end_user_model import EndUser as _EndUser
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository as _EndUserRepository
|
||||||
|
|
||||||
from app.models.api_key_model import ApiKeyType
|
from app.models.api_key_model import ApiKeyType
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
@@ -65,3 +72,72 @@ def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return int(dt.timestamp() * 1000)
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_from_api_key(db: _Session, api_key_auth):
|
||||||
|
"""通过 API Key 构造 current_user 对象。
|
||||||
|
|
||||||
|
从 API Key 反查创建者(管理员用户),并设置其 workspace 上下文。
|
||||||
|
与内部接口的 Depends(get_current_user) (JWT) 等价。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
api_key_auth: API Key 认证信息(ApiKeyAuth)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User ORM 对象,已设置 current_workspace_id
|
||||||
|
"""
|
||||||
|
from app.services import api_key_service
|
||||||
|
|
||||||
|
api_key = api_key_service.ApiKeyService.get_api_key(
|
||||||
|
db, api_key_auth.api_key_id, api_key_auth.workspace_id
|
||||||
|
)
|
||||||
|
current_user = api_key.creator
|
||||||
|
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def validate_end_user_in_workspace(
|
||||||
|
db: _Session,
|
||||||
|
end_user_id: str,
|
||||||
|
workspace_id,
|
||||||
|
) -> _EndUser:
|
||||||
|
"""校验 end_user 是否存在且属于指定 workspace。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
end_user_id: 终端用户 ID
|
||||||
|
workspace_id: 工作空间 ID(UUID 或字符串均可)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EndUser ORM 对象(校验通过时)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException(INVALID_PARAMETER): end_user_id 格式无效
|
||||||
|
BusinessException(USER_NOT_FOUND): end_user 不存在
|
||||||
|
BusinessException(PERMISSION_DENIED): end_user 不属于该 workspace
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_uuid.UUID(end_user_id)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
raise _BusinessException(
|
||||||
|
f"Invalid end_user_id format: {end_user_id}",
|
||||||
|
_BizCode.INVALID_PARAMETER,
|
||||||
|
)
|
||||||
|
|
||||||
|
end_user_repo = _EndUserRepository(db)
|
||||||
|
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||||
|
|
||||||
|
if end_user is None:
|
||||||
|
raise _BusinessException(
|
||||||
|
"End user not found",
|
||||||
|
_BizCode.USER_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
if str(end_user.workspace_id) != str(workspace_id):
|
||||||
|
raise _BusinessException(
|
||||||
|
"End user does not belong to this workspace",
|
||||||
|
_BizCode.PERMISSION_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
|
return end_user
|
||||||
@@ -97,6 +97,7 @@ class Settings:
|
|||||||
|
|
||||||
# File Upload
|
# File Upload
|
||||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||||
|
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||||
|
|
||||||
@@ -162,6 +163,44 @@ class Settings:
|
|||||||
# This controls the language used for memory summary titles and other generated content
|
# This controls the language used for memory summary titles and other generated content
|
||||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# Internationalization (i18n) Configuration
|
||||||
|
# ========================================================================
|
||||||
|
# Default language for API responses
|
||||||
|
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
|
||||||
|
|
||||||
|
# Supported languages (comma-separated)
|
||||||
|
I18N_SUPPORTED_LANGUAGES: list[str] = [
|
||||||
|
lang.strip()
|
||||||
|
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
|
||||||
|
if lang.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Core locales directory (community edition)
|
||||||
|
# Use absolute path to work from any working directory
|
||||||
|
I18N_CORE_LOCALES_DIR: str = os.getenv(
|
||||||
|
"I18N_CORE_LOCALES_DIR",
|
||||||
|
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Premium locales directory (enterprise edition, optional)
|
||||||
|
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
|
||||||
|
|
||||||
|
# Enable translation cache
|
||||||
|
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
|
||||||
|
|
||||||
|
# LRU cache size for hot translations
|
||||||
|
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
|
||||||
|
|
||||||
|
# Enable hot reload of translation files
|
||||||
|
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
|
||||||
|
|
||||||
|
# Fallback language when translation is missing
|
||||||
|
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
|
||||||
|
|
||||||
|
# Log missing translations
|
||||||
|
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
|
||||||
|
|
||||||
# Logging settings
|
# Logging settings
|
||||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
@@ -192,8 +231,8 @@ class Settings:
|
|||||||
# Celery configuration (internal)
|
# Celery configuration (internal)
|
||||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||||
# 详见 docs/celery-env-bug-report.md
|
# 详见 docs/celery-env-bug-report.md
|
||||||
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
|
# 默认使用 Redis 作为 broker 和 backend,与业务缓存隔离
|
||||||
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
|
# 如需使用 RabbitMQ,在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost
|
||||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
||||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
||||||
|
|
||||||
@@ -203,6 +242,8 @@ class Settings:
|
|||||||
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
||||||
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
||||||
|
|
||||||
|
SANDBOX_URL: str = os.getenv("SANDBOX_URL", "")
|
||||||
|
|
||||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||||
@@ -260,11 +301,11 @@ class Settings:
|
|||||||
# Prompt 中最大类型数量
|
# Prompt 中最大类型数量
|
||||||
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
|
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
|
||||||
|
|
||||||
# 核心通用类型列表(逗号分隔)
|
# 核心通用类型列表(逗号分隔)—— 与 ontology.md Entity Ontology 保持一致的 13 类
|
||||||
CORE_GENERAL_TYPES: str = os.getenv(
|
CORE_GENERAL_TYPES: str = os.getenv(
|
||||||
"CORE_GENERAL_TYPES",
|
"CORE_GENERAL_TYPES",
|
||||||
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
|
"人物,组织,群体,角色职业,地点设施,物品设备,软件平台,识别联系信息,"
|
||||||
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
|
"文档媒体,知识能力,偏好习惯,具体目标,称呼别名"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 实验模式开关(允许通过 API 动态切换本体配置)
|
# 实验模式开关(允许通过 API 动态切换本体配置)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class BizCode(IntEnum):
|
|||||||
TENANT_NOT_FOUND = 3002
|
TENANT_NOT_FOUND = 3002
|
||||||
WORKSPACE_NO_ACCESS = 3003
|
WORKSPACE_NO_ACCESS = 3003
|
||||||
WORKSPACE_INVITE_NOT_FOUND = 3004
|
WORKSPACE_INVITE_NOT_FOUND = 3004
|
||||||
|
WORKSPACE_ACCESS_DENIED = 3005
|
||||||
# API Key 管理(3xxx)
|
# API Key 管理(3xxx)
|
||||||
API_KEY_NOT_FOUND = 3007
|
API_KEY_NOT_FOUND = 3007
|
||||||
API_KEY_DUPLICATE_NAME = 3008
|
API_KEY_DUPLICATE_NAME = 3008
|
||||||
@@ -30,6 +31,9 @@ class BizCode(IntEnum):
|
|||||||
API_KEY_QPS_LIMIT_EXCEEDED = 3014
|
API_KEY_QPS_LIMIT_EXCEEDED = 3014
|
||||||
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
|
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
|
||||||
API_KEY_QUOTA_EXCEEDED = 3016
|
API_KEY_QUOTA_EXCEEDED = 3016
|
||||||
|
API_KEY_RATE_LIMIT_EXCEEDED = 3017
|
||||||
|
QUOTA_EXCEEDED = 3018
|
||||||
|
RATE_LIMIT_EXCEEDED = 3019
|
||||||
# 资源(4xxx)
|
# 资源(4xxx)
|
||||||
NOT_FOUND = 4000
|
NOT_FOUND = 4000
|
||||||
USER_NOT_FOUND = 4001
|
USER_NOT_FOUND = 4001
|
||||||
@@ -40,6 +44,7 @@ class BizCode(IntEnum):
|
|||||||
FILE_NOT_FOUND = 4006
|
FILE_NOT_FOUND = 4006
|
||||||
APP_NOT_FOUND = 4007
|
APP_NOT_FOUND = 4007
|
||||||
RELEASE_NOT_FOUND = 4008
|
RELEASE_NOT_FOUND = 4008
|
||||||
|
USER_NO_ACCESS = 4009
|
||||||
|
|
||||||
# 冲突/状态(5xxx)
|
# 冲突/状态(5xxx)
|
||||||
DUPLICATE_NAME = 5001
|
DUPLICATE_NAME = 5001
|
||||||
@@ -61,6 +66,7 @@ class BizCode(IntEnum):
|
|||||||
PERMISSION_DENIED = 6010
|
PERMISSION_DENIED = 6010
|
||||||
INVALID_CONVERSATION = 6011
|
INVALID_CONVERSATION = 6011
|
||||||
CONFIG_MISSING = 6012
|
CONFIG_MISSING = 6012
|
||||||
|
APP_NOT_PUBLISHED = 6013
|
||||||
|
|
||||||
# 模型(7xxx)
|
# 模型(7xxx)
|
||||||
MODEL_CONFIG_INVALID = 7001
|
MODEL_CONFIG_INVALID = 7001
|
||||||
@@ -113,8 +119,11 @@ HTTP_MAPPING = {
|
|||||||
BizCode.FORBIDDEN: 403,
|
BizCode.FORBIDDEN: 403,
|
||||||
BizCode.TENANT_NOT_FOUND: 400,
|
BizCode.TENANT_NOT_FOUND: 400,
|
||||||
BizCode.WORKSPACE_NO_ACCESS: 403,
|
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||||
|
BizCode.WORKSPACE_INVITE_NOT_FOUND: 400,
|
||||||
|
BizCode.WORKSPACE_ACCESS_DENIED: 403,
|
||||||
BizCode.NOT_FOUND: 400,
|
BizCode.NOT_FOUND: 400,
|
||||||
BizCode.USER_NOT_FOUND: 200,
|
BizCode.USER_NOT_FOUND: 200,
|
||||||
|
BizCode.USER_NO_ACCESS: 401,
|
||||||
BizCode.WORKSPACE_NOT_FOUND: 400,
|
BizCode.WORKSPACE_NOT_FOUND: 400,
|
||||||
BizCode.MODEL_NOT_FOUND: 400,
|
BizCode.MODEL_NOT_FOUND: 400,
|
||||||
BizCode.KNOWLEDGE_NOT_FOUND: 400,
|
BizCode.KNOWLEDGE_NOT_FOUND: 400,
|
||||||
@@ -150,6 +159,7 @@ HTTP_MAPPING = {
|
|||||||
BizCode.API_KEY_QPS_LIMIT_EXCEEDED: 429,
|
BizCode.API_KEY_QPS_LIMIT_EXCEEDED: 429,
|
||||||
BizCode.API_KEY_DAILY_LIMIT_EXCEEDED: 429,
|
BizCode.API_KEY_DAILY_LIMIT_EXCEEDED: 429,
|
||||||
BizCode.API_KEY_QUOTA_EXCEEDED: 429,
|
BizCode.API_KEY_QUOTA_EXCEEDED: 429,
|
||||||
|
BizCode.QUOTA_EXCEEDED: 402,
|
||||||
|
|
||||||
BizCode.MODEL_CONFIG_INVALID: 400,
|
BizCode.MODEL_CONFIG_INVALID: 400,
|
||||||
BizCode.API_KEY_MISSING: 400,
|
BizCode.API_KEY_MISSING: 400,
|
||||||
@@ -179,4 +189,21 @@ HTTP_MAPPING = {
|
|||||||
BizCode.DB_ERROR: 500,
|
BizCode.DB_ERROR: 500,
|
||||||
BizCode.SERVICE_UNAVAILABLE: 503,
|
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||||
BizCode.RATE_LIMITED: 429,
|
BizCode.RATE_LIMITED: 429,
|
||||||
|
BizCode.RATE_LIMIT_EXCEEDED: 429,
|
||||||
|
}
|
||||||
|
|
||||||
|
ERROR_CODE_TO_BIZ_CODE = {
|
||||||
|
"QUOTA_EXCEEDED": BizCode.QUOTA_EXCEEDED,
|
||||||
|
"RATE_LIMIT_EXCEEDED": BizCode.RATE_LIMIT_EXCEEDED,
|
||||||
|
"API_KEY_NOT_FOUND": BizCode.API_KEY_NOT_FOUND,
|
||||||
|
"API_KEY_INVALID": BizCode.API_KEY_INVALID,
|
||||||
|
"API_KEY_EXPIRED": BizCode.API_KEY_EXPIRED,
|
||||||
|
"WORKSPACE_NOT_FOUND": BizCode.WORKSPACE_NOT_FOUND,
|
||||||
|
"WORKSPACE_NO_ACCESS": BizCode.WORKSPACE_NO_ACCESS,
|
||||||
|
"PERMISSION_DENIED": BizCode.PERMISSION_DENIED,
|
||||||
|
"TOKEN_EXPIRED": BizCode.TOKEN_EXPIRED,
|
||||||
|
"TOKEN_INVALID": BizCode.TOKEN_INVALID,
|
||||||
|
"VALIDATION_FAILED": BizCode.VALIDATION_FAILED,
|
||||||
|
"INVALID_PARAMETER": BizCode.INVALID_PARAMETER,
|
||||||
|
"MISSING_PARAMETER": BizCode.MISSING_PARAMETER,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,6 +46,10 @@ def validate_language(language: Optional[str]) -> str:
|
|||||||
if language is None:
|
if language is None:
|
||||||
return DEFAULT_LANGUAGE
|
return DEFAULT_LANGUAGE
|
||||||
|
|
||||||
|
# 处理枚举类型:优先取 .value,避免 str(Language.ZH) → "Language.ZH"
|
||||||
|
if hasattr(language, "value"):
|
||||||
|
language = language.value
|
||||||
|
|
||||||
# 标准化:转小写并去除空白
|
# 标准化:转小写并去除空白
|
||||||
lang = str(language).lower().strip()
|
lang = str(language).lower().strip()
|
||||||
|
|
||||||
|
|||||||
@@ -131,6 +131,10 @@ class LoggingConfig:
|
|||||||
neo4j_logger = logging.getLogger(neo4j_logger_name)
|
neo4j_logger = logging.getLogger(neo4j_logger_name)
|
||||||
neo4j_logger.addFilter(neo4j_filter)
|
neo4j_logger.addFilter(neo4j_filter)
|
||||||
|
|
||||||
|
# 压制 httpx / httpcore 的请求级日志(大量 HTTP Request: POST ... 噪音)
|
||||||
|
for noisy_logger in ["httpx", "httpcore", "httpcore.http11", "httpcore.connection"]:
|
||||||
|
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
|
||||||
|
|
||||||
# 创建格式化器
|
# 创建格式化器
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt=settings.LOG_FORMAT,
|
fmt=settings.LOG_FORMAT,
|
||||||
@@ -529,8 +533,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
|
|||||||
# Fallback to console only if file write fails
|
# Fallback to console only if file write fails
|
||||||
print(f"Warning: Could not write to timing log: {e}")
|
print(f"Warning: Could not write to timing log: {e}")
|
||||||
|
|
||||||
# Always print to console (backward compatible behavior)
|
# Always log at INFO level (avoids Celery treating stdout as WARNING)
|
||||||
print(f"✓ {step_name}: {duration:.2f}s")
|
_timing_logger = logging.getLogger(__name__)
|
||||||
|
_timing_logger.info(f"✓ {step_name}: {duration:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
def get_agent_logger(name: str = "agent_service",
|
def get_agent_logger(name: str = "agent_service",
|
||||||
|
|||||||
@@ -1,16 +1,45 @@
|
|||||||
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||||
|
from app.schemas.memory_agent_schema import AgentMemoryDataset
|
||||||
|
|
||||||
|
|
||||||
def content_input_node(state: ReadState) -> ReadState:
|
def content_input_node(state: ReadState) -> ReadState:
|
||||||
"""开始节点 - 提取内容并保持状态信息"""
|
"""
|
||||||
|
Start node - Extract content and maintain state information
|
||||||
|
|
||||||
|
Extracts the content from the first message in the state and returns it
|
||||||
|
as the data field while preserving all other state information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing messages and other state data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state with extracted content in data field
|
||||||
|
"""
|
||||||
|
|
||||||
content = state['messages'][0].content if state.get('messages') else ''
|
content = state['messages'][0].content if state.get('messages') else ''
|
||||||
# 返回内容并保持所有状态信息
|
# Return content and maintain all state information
|
||||||
|
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||||
|
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||||
|
|
||||||
return {"data": content}
|
return {"data": content}
|
||||||
|
|
||||||
|
|
||||||
def content_input_write(state: WriteState) -> WriteState:
|
def content_input_write(state: WriteState) -> WriteState:
|
||||||
"""开始节点 - 提取内容并保持状态信息"""
|
"""
|
||||||
|
Start node - Extract content and maintain state information for write operations
|
||||||
|
|
||||||
|
Extracts the content from the first message in the state for write operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: WriteState containing messages and other state data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WriteState: Updated state with extracted content in data field
|
||||||
|
"""
|
||||||
|
|
||||||
content = state['messages'][0].content if state.get('messages') else ''
|
content = state['messages'][0].content if state.get('messages') else ''
|
||||||
# 返回内容并保持所有状态信息
|
# Return content and maintain all state information
|
||||||
|
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||||
|
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||||
|
|
||||||
return {"data": content}
|
return {"data": content}
|
||||||
@@ -0,0 +1,408 @@
|
|||||||
|
"""
|
||||||
|
Perceptual Memory Retrieval Node & Service
|
||||||
|
|
||||||
|
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
|
||||||
|
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
|
||||||
|
with BM25+embedding fusion reranking.
|
||||||
|
|
||||||
|
Also provides the perceptual_retrieve_node for use as a LangGraph node.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import math
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||||
|
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||||
|
from app.repositories.neo4j.graph_search import (
|
||||||
|
search_perceptual_by_fulltext,
|
||||||
|
search_perceptual_by_embedding,
|
||||||
|
)
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualSearchService:
|
||||||
|
"""
|
||||||
|
感知记忆检索服务。
|
||||||
|
|
||||||
|
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
|
||||||
|
调用方只需提供 query / keywords、end_user_id、memory_config,即可获得
|
||||||
|
格式化并排序后的感知记忆列表和拼接文本。
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
service = PerceptualSearchService(end_user_id=..., memory_config=...)
|
||||||
|
results = await service.search(query="...", keywords=[...], limit=10)
|
||||||
|
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_ALPHA = 0.6
|
||||||
|
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
memory_config: Any,
|
||||||
|
alpha: float = DEFAULT_ALPHA,
|
||||||
|
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
|
||||||
|
):
|
||||||
|
self.end_user_id = end_user_id
|
||||||
|
self.memory_config = memory_config
|
||||||
|
self.alpha = alpha
|
||||||
|
self.content_score_threshold = content_score_threshold
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
keywords: Optional[List[str]] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
|
||||||
|
|
||||||
|
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
|
||||||
|
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 原始用户查询(用于向量检索和 BM25 补查)
|
||||||
|
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
|
||||||
|
limit: 最大返回数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"memories": [格式化后的记忆 dict, ...],
|
||||||
|
"content": "拼接的纯文本摘要",
|
||||||
|
"keyword_raw": int,
|
||||||
|
"embedding_raw": int,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if keywords is None:
|
||||||
|
keywords = [query] if query else []
|
||||||
|
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
kw_task = self._keyword_search(connector, keywords, limit)
|
||||||
|
emb_task = self._embedding_search(connector, query, limit)
|
||||||
|
|
||||||
|
kw_results, emb_results = await asyncio.gather(
|
||||||
|
kw_task, emb_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
if isinstance(kw_results, Exception):
|
||||||
|
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
|
||||||
|
kw_results = []
|
||||||
|
if isinstance(emb_results, Exception):
|
||||||
|
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
|
||||||
|
emb_results = []
|
||||||
|
|
||||||
|
# 补查 BM25:找出 embedding 命中但 keyword 未命中的 id,
|
||||||
|
# 用原始 query 对这些节点补查全文索引拿 BM25 score
|
||||||
|
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
|
||||||
|
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
|
||||||
|
|
||||||
|
if emb_only_ids and query:
|
||||||
|
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
|
||||||
|
# 把补查到的 BM25 score 注入到 embedding 结果中
|
||||||
|
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
|
||||||
|
for r in emb_results:
|
||||||
|
rid = r.get("id", "")
|
||||||
|
if rid in backfill_map:
|
||||||
|
r["bm25_backfill_score"] = backfill_map[rid]
|
||||||
|
logger.info(
|
||||||
|
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
|
||||||
|
f"{len(backfill_map)} got BM25 scores"
|
||||||
|
)
|
||||||
|
|
||||||
|
reranked = self._rerank(kw_results, emb_results, limit)
|
||||||
|
|
||||||
|
memories = []
|
||||||
|
content_parts = []
|
||||||
|
for record in reranked:
|
||||||
|
fmt = self._format_result(record)
|
||||||
|
fmt["score"] = round(record.get("content_score", 0), 4)
|
||||||
|
memories.append(fmt)
|
||||||
|
content_parts.append(self._build_content_text(fmt))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[PerceptualSearch] {len(memories)} results after rerank "
|
||||||
|
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"memories": memories,
|
||||||
|
"content": "\n\n".join(content_parts),
|
||||||
|
"keyword_raw": len(kw_results),
|
||||||
|
"embedding_raw": len(emb_results),
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
await connector.close()
|
||||||
|
|
||||||
|
async def _bm25_backfill(
|
||||||
|
self,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
query: str,
|
||||||
|
target_ids: set,
|
||||||
|
limit: int,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
对指定 id 集合补查全文索引 BM25 score。
|
||||||
|
|
||||||
|
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
|
||||||
|
"""
|
||||||
|
escaped = escape_lucene_query(query)
|
||||||
|
if not escaped.strip():
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
r = await search_perceptual_by_fulltext(
|
||||||
|
connector=connector, query=escaped,
|
||||||
|
end_user_id=self.end_user_id,
|
||||||
|
limit=limit * 5, # 多查一些以提高命中率
|
||||||
|
)
|
||||||
|
all_hits = r.get("perceptuals", [])
|
||||||
|
return [h for h in all_hits if h.get("id") in target_ids]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _keyword_search(
|
||||||
|
self,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
keywords: List[str],
|
||||||
|
limit: int,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
|
||||||
|
seen_ids: set = set()
|
||||||
|
all_results: List[dict] = []
|
||||||
|
|
||||||
|
async def _one(kw: str):
|
||||||
|
escaped = escape_lucene_query(kw)
|
||||||
|
if not escaped.strip():
|
||||||
|
return []
|
||||||
|
r = await search_perceptual_by_fulltext(
|
||||||
|
connector=connector, query=escaped,
|
||||||
|
end_user_id=self.end_user_id, limit=limit,
|
||||||
|
)
|
||||||
|
return r.get("perceptuals", [])
|
||||||
|
|
||||||
|
tasks = [_one(kw) for kw in keywords[:10]]
|
||||||
|
batch = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for result in batch:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
|
||||||
|
continue
|
||||||
|
for rec in result:
|
||||||
|
rid = rec.get("id", "")
|
||||||
|
if rid and rid not in seen_ids:
|
||||||
|
seen_ids.add(rid)
|
||||||
|
all_results.append(rec)
|
||||||
|
|
||||||
|
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
|
||||||
|
return all_results[:limit]
|
||||||
|
|
||||||
|
async def _embedding_search(
|
||||||
|
self,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
query_text: str,
|
||||||
|
limit: int,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
|
||||||
|
try:
|
||||||
|
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||||
|
from app.core.models.base import RedBearModelConfig
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
with get_db_context() as db:
|
||||||
|
cfg = MemoryConfigService(db).get_embedder_config(
|
||||||
|
str(self.memory_config.embedding_model_id)
|
||||||
|
)
|
||||||
|
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
|
||||||
|
|
||||||
|
r = await search_perceptual_by_embedding(
|
||||||
|
connector=connector, embedder_client=client,
|
||||||
|
query_text=query_text, end_user_id=self.end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
return r.get("perceptuals", [])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _rerank(
|
||||||
|
self,
|
||||||
|
keyword_results: List[dict],
|
||||||
|
embedding_results: List[dict],
|
||||||
|
limit: int,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""BM25 + embedding 融合排序。
|
||||||
|
|
||||||
|
对 embedding 结果中带有 bm25_backfill_score 的条目,
|
||||||
|
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
|
||||||
|
"""
|
||||||
|
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
|
||||||
|
emb_backfill_items = []
|
||||||
|
for item in embedding_results:
|
||||||
|
backfill_score = item.get("bm25_backfill_score")
|
||||||
|
if backfill_score is not None and item.get("id"):
|
||||||
|
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
|
||||||
|
|
||||||
|
# 合并后统一归一化 BM25 scores
|
||||||
|
all_bm25_items = keyword_results + emb_backfill_items
|
||||||
|
all_bm25_items = self._normalize_scores(all_bm25_items)
|
||||||
|
|
||||||
|
# 建立 id -> normalized BM25 score 的映射
|
||||||
|
bm25_norm_map: Dict[str, float] = {}
|
||||||
|
for item in all_bm25_items:
|
||||||
|
item_id = item.get("id", "")
|
||||||
|
if item_id:
|
||||||
|
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
|
||||||
|
|
||||||
|
# 归一化 embedding scores
|
||||||
|
embedding_results = self._normalize_scores(embedding_results)
|
||||||
|
|
||||||
|
# 合并
|
||||||
|
combined: Dict[str, dict] = {}
|
||||||
|
for item in keyword_results:
|
||||||
|
item_id = item.get("id", "")
|
||||||
|
if not item_id:
|
||||||
|
continue
|
||||||
|
combined[item_id] = item.copy()
|
||||||
|
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||||
|
combined[item_id]["embedding_score"] = 0.0
|
||||||
|
|
||||||
|
for item in embedding_results:
|
||||||
|
item_id = item.get("id", "")
|
||||||
|
if not item_id:
|
||||||
|
continue
|
||||||
|
if item_id in combined:
|
||||||
|
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||||
|
else:
|
||||||
|
combined[item_id] = item.copy()
|
||||||
|
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||||
|
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||||
|
|
||||||
|
for item in combined.values():
|
||||||
|
bm25 = float(item.get("bm25_score", 0) or 0)
|
||||||
|
emb = float(item.get("embedding_score", 0) or 0)
|
||||||
|
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
|
||||||
|
|
||||||
|
results = list(combined.values())
|
||||||
|
before = len(results)
|
||||||
|
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
|
||||||
|
results.sort(key=lambda x: x["content_score"], reverse=True)
|
||||||
|
results = results[:limit]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
|
||||||
|
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
|
||||||
|
"""Z-score + sigmoid 归一化。"""
|
||||||
|
if not items:
|
||||||
|
return items
|
||||||
|
scores = [float(it.get(field, 0) or 0) for it in items]
|
||||||
|
if len(scores) <= 1:
|
||||||
|
for it in items:
|
||||||
|
it[f"normalized_{field}"] = 1.0
|
||||||
|
return items
|
||||||
|
mean = sum(scores) / len(scores)
|
||||||
|
var = sum((s - mean) ** 2 for s in scores) / len(scores)
|
||||||
|
std = math.sqrt(var)
|
||||||
|
if std == 0:
|
||||||
|
for it in items:
|
||||||
|
it[f"normalized_{field}"] = 1.0
|
||||||
|
else:
|
||||||
|
for it, s in zip(items, scores):
|
||||||
|
z = (s - mean) / std
|
||||||
|
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
|
||||||
|
return items
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_result(record: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"id": record.get("id", ""),
|
||||||
|
"perceptual_type": record.get("perceptual_type", ""),
|
||||||
|
"file_name": record.get("file_name", ""),
|
||||||
|
"file_path": record.get("file_path", ""),
|
||||||
|
"summary": record.get("summary", ""),
|
||||||
|
"topic": record.get("topic", ""),
|
||||||
|
"domain": record.get("domain", ""),
|
||||||
|
"keywords": record.get("keywords", []),
|
||||||
|
"created_at": str(record.get("created_at", "")),
|
||||||
|
"file_type": record.get("file_type", ""),
|
||||||
|
"score": record.get("score", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_content_text(formatted: dict) -> str:
|
||||||
|
parts = []
|
||||||
|
if formatted["summary"]:
|
||||||
|
parts.append(formatted["summary"])
|
||||||
|
if formatted["topic"]:
|
||||||
|
parts.append(f"[主题: {formatted['topic']}]")
|
||||||
|
if formatted["keywords"]:
|
||||||
|
kw_list = formatted["keywords"]
|
||||||
|
if isinstance(kw_list, list):
|
||||||
|
parts.append(f"[关键词: {', '.join(kw_list)}]")
|
||||||
|
if formatted["file_name"]:
|
||||||
|
parts.append(f"[文件: {formatted['file_name']}]")
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
|
||||||
|
"""Extract search keywords from problem extension results."""
|
||||||
|
keywords = []
|
||||||
|
context = problem_extension.get("context", {})
|
||||||
|
if isinstance(context, dict):
|
||||||
|
for original_q, extended_qs in context.items():
|
||||||
|
keywords.append(original_q)
|
||||||
|
if isinstance(extended_qs, list):
|
||||||
|
keywords.extend(extended_qs)
|
||||||
|
return keywords
|
||||||
|
|
||||||
|
|
||||||
|
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
LangGraph node: perceptual memory retrieval.
|
||||||
|
|
||||||
|
Uses PerceptualSearchService to run keyword + embedding search with
|
||||||
|
BM25 fusion reranking, then writes results to state['perceptual_data'].
|
||||||
|
"""
|
||||||
|
end_user_id = state.get("end_user_id", "")
|
||||||
|
problem_extension = state.get("problem_extension", {})
|
||||||
|
original_query = state.get("data", "")
|
||||||
|
memory_config = state.get("memory_config", None)
|
||||||
|
|
||||||
|
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
|
||||||
|
|
||||||
|
keywords = _extract_keywords_from_problems(problem_extension)
|
||||||
|
if not keywords:
|
||||||
|
keywords = [original_query] if original_query else []
|
||||||
|
|
||||||
|
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
|
||||||
|
|
||||||
|
service = PerceptualSearchService(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
memory_config=memory_config,
|
||||||
|
)
|
||||||
|
search_result = await service.search(
|
||||||
|
query=original_query,
|
||||||
|
keywords=keywords,
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"memories": search_result["memories"],
|
||||||
|
"content": search_result["content"],
|
||||||
|
"_intermediate": {
|
||||||
|
"type": "perceptual_retrieve",
|
||||||
|
"title": "感知记忆检索",
|
||||||
|
"data": search_result["memories"],
|
||||||
|
"query": original_query,
|
||||||
|
"result_count": len(search_result["memories"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return {"perceptual_data": result}
|
||||||
@@ -19,19 +19,39 @@ logger = get_agent_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ProblemNodeService(LLMServiceMixin):
|
class ProblemNodeService(LLMServiceMixin):
|
||||||
"""问题处理节点服务类"""
|
"""
|
||||||
|
Problem processing node service class
|
||||||
|
|
||||||
|
Handles problem decomposition and extension operations using LLM services.
|
||||||
|
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
template_service: Service for rendering Jinja2 templates
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# Create global service instance
|
||||||
problem_service = ProblemNodeService()
|
problem_service = ProblemNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||||
"""问题分解节点"""
|
"""
|
||||||
|
Problem decomposition node
|
||||||
|
|
||||||
|
Breaks down complex user queries into smaller, more manageable sub-problems.
|
||||||
|
Uses LLM to analyze the input and generate structured problem decomposition
|
||||||
|
with question types and reasoning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing user input and configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state with problem decomposition results
|
||||||
|
"""
|
||||||
# 从状态中获取数据
|
# 从状态中获取数据
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
@@ -64,7 +84,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
# 添加更详细的日志记录
|
# 添加更详细的日志记录
|
||||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||||
|
|
||||||
# 验证结构化响应
|
# Validate structured response
|
||||||
if not structured or not hasattr(structured, 'root'):
|
if not structured or not hasattr(structured, 'root'):
|
||||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||||
split_result = json.dumps([], ensure_ascii=False)
|
split_result = json.dumps([], ensure_ascii=False)
|
||||||
@@ -106,7 +126,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 提供更详细的错误信息
|
# Provide more detailed error information
|
||||||
error_details = {
|
error_details = {
|
||||||
"error_type": type(e).__name__,
|
"error_type": type(e).__name__,
|
||||||
"error_message": str(e),
|
"error_message": str(e),
|
||||||
@@ -116,7 +136,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||||
|
|
||||||
# 创建默认的空结果
|
# Create default empty result
|
||||||
result = {
|
result = {
|
||||||
"context": json.dumps([], ensure_ascii=False),
|
"context": json.dumps([], ensure_ascii=False),
|
||||||
"original": content,
|
"original": content,
|
||||||
@@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 返回更新后的状态,包含spit_context字段
|
# Return updated state including spit_context field
|
||||||
return {"spit_data": result}
|
return {"spit_data": result}
|
||||||
|
|
||||||
|
|
||||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||||
"""问题扩展节点"""
|
"""
|
||||||
# 获取原始数据和分解结果
|
Problem extension node
|
||||||
|
|
||||||
|
Extends the decomposed problems from Split_The_Problem node by generating
|
||||||
|
additional related questions and organizing them by original question.
|
||||||
|
Uses LLM to create comprehensive question extensions for better memory retrieval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing decomposed problems and configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state with extended problem results
|
||||||
|
"""
|
||||||
|
# Get original data and decomposition results
|
||||||
start = time.time()
|
start = time.time()
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
data = state.get('spit_data', '')['context']
|
data = state.get('spit_data', '')['context']
|
||||||
@@ -182,7 +214,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||||
|
|
||||||
# 验证结构化响应
|
# Validate structured response
|
||||||
if not response_content or not hasattr(response_content, 'root'):
|
if not response_content or not hasattr(response_content, 'root'):
|
||||||
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
||||||
aggregated_dict = {}
|
aggregated_dict = {}
|
||||||
@@ -216,7 +248,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 提供更详细的错误信息
|
# Provide more detailed error information
|
||||||
error_details = {
|
error_details = {
|
||||||
"error_type": type(e).__name__,
|
"error_type": type(e).__name__,
|
||||||
"error_message": str(e),
|
"error_message": str(e),
|
||||||
@@ -231,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||||
|
|
||||||
# Emit intermediate output for frontend
|
# Emit intermediate output for frontend
|
||||||
print(time.time() - start)
|
|
||||||
result = {
|
result = {
|
||||||
"context": aggregated_dict,
|
"context": aggregated_dict,
|
||||||
"original": data,
|
"original": data,
|
||||||
|
|||||||
@@ -29,6 +29,18 @@ logger = get_agent_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
async def rag_config(state):
|
||||||
|
"""
|
||||||
|
Configure RAG (Retrieval-Augmented Generation) settings
|
||||||
|
|
||||||
|
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||||
|
weights, and reranker settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state containing user_rag_memory_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: RAG configuration dictionary
|
||||||
|
"""
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
kb_config = {
|
kb_config = {
|
||||||
"knowledge_bases": [
|
"knowledge_bases": [
|
||||||
@@ -48,6 +60,19 @@ async def rag_config(state):
|
|||||||
|
|
||||||
|
|
||||||
async def rag_knowledge(state, question):
|
async def rag_knowledge(state, question):
|
||||||
|
"""
|
||||||
|
Retrieve knowledge using RAG approach
|
||||||
|
|
||||||
|
Performs knowledge retrieval from configured knowledge bases using the
|
||||||
|
provided question and returns formatted results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state containing configuration
|
||||||
|
question: Question to search for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||||
|
"""
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
@@ -68,12 +93,24 @@ async def rag_knowledge(state, question):
|
|||||||
|
|
||||||
|
|
||||||
async def llm_infomation(state: ReadState) -> ReadState:
|
async def llm_infomation(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Get LLM configuration information from state
|
||||||
|
|
||||||
|
Retrieves model configuration details including model ID and tenant ID
|
||||||
|
from the memory configuration in the current state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing memory configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Model configuration as Pydantic model
|
||||||
|
"""
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
model_id = memory_config.llm_model_id
|
model_id = memory_config.llm_model_id
|
||||||
tenant_id = memory_config.tenant_id
|
tenant_id = memory_config.tenant_id
|
||||||
|
|
||||||
# 使用现有的 memory_config 而不是重新查询数据库
|
# Use existing memory_config instead of re-querying database
|
||||||
# 或者使用线程安全的数据库访问
|
# or use thread-safe database access
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
||||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||||
@@ -82,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
async def clean_databases(data) -> str:
|
async def clean_databases(data) -> str:
|
||||||
"""
|
"""
|
||||||
简化的数据库搜索结果清理函数
|
Simplified database search result cleaning function
|
||||||
|
|
||||||
|
Processes and cleans search results from various sources including
|
||||||
|
reranked results and time-based search results. Extracts text content
|
||||||
|
from structured data and returns as formatted string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 搜索结果数据
|
data: Search result data (can be string, dict, or other types)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
清理后的内容字符串
|
str: Cleaned content string
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 解析JSON字符串
|
# Parse JSON string
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
try:
|
try:
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
@@ -101,24 +142,24 @@ async def clean_databases(data) -> str:
|
|||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
return str(data)
|
return str(data)
|
||||||
|
|
||||||
# 获取结果数据
|
# Get result data
|
||||||
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
||||||
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
||||||
results = data.get('results', data)
|
results = data.get('results', data)
|
||||||
if not isinstance(results, dict):
|
if not isinstance(results, dict):
|
||||||
return str(results)
|
return str(results)
|
||||||
|
|
||||||
# 收集所有内容
|
# Collect all content
|
||||||
content_list = []
|
content_list = []
|
||||||
|
|
||||||
# 处理重排序结果
|
# Process reranked results
|
||||||
reranked = results.get('reranked_results', {})
|
reranked = results.get('reranked_results', {})
|
||||||
if reranked:
|
if reranked:
|
||||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']:
|
||||||
items = reranked.get(category, [])
|
items = reranked.get(category, [])
|
||||||
if isinstance(items, list):
|
if isinstance(items, list):
|
||||||
content_list.extend(items)
|
content_list.extend(items)
|
||||||
# 处理时间搜索结果
|
# Process time search results
|
||||||
time_search = results.get('time_search', {})
|
time_search = results.get('time_search', {})
|
||||||
if time_search:
|
if time_search:
|
||||||
if isinstance(time_search, dict):
|
if isinstance(time_search, dict):
|
||||||
@@ -128,11 +169,18 @@ async def clean_databases(data) -> str:
|
|||||||
elif isinstance(time_search, list):
|
elif isinstance(time_search, list):
|
||||||
content_list.extend(time_search)
|
content_list.extend(time_search)
|
||||||
|
|
||||||
# 提取文本内容
|
# Extract text content,对 community 按 name 去重(多次 tool 调用会产生重复)
|
||||||
text_parts = []
|
text_parts = []
|
||||||
|
seen_community_names = set()
|
||||||
for item in content_list:
|
for item in content_list:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
text = item.get('statement') or item.get('content', '')
|
# community 节点用 name 去重
|
||||||
|
if 'member_count' in item or 'core_entities' in item:
|
||||||
|
community_name = item.get('name') or item.get('id', '')
|
||||||
|
if community_name in seen_community_names:
|
||||||
|
continue
|
||||||
|
seen_community_names.add(community_name)
|
||||||
|
text = item.get('statement') or item.get('content') or item.get('summary', '')
|
||||||
if text:
|
if text:
|
||||||
text_parts.append(text)
|
text_parts.append(text)
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
@@ -146,10 +194,19 @@ async def clean_databases(data) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||||
'''
|
"""
|
||||||
|
Retrieve information using simplified search approach
|
||||||
|
|
||||||
模型信息
|
Processes extended problems from previous nodes and performs retrieval
|
||||||
'''
|
using either RAG or hybrid search based on storage type. Handles concurrent
|
||||||
|
processing of multiple questions and deduplicates results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing problem extensions and configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state with retrieval results and intermediate outputs
|
||||||
|
"""
|
||||||
|
|
||||||
problem_extension = state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
@@ -163,7 +220,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
problem_list.append(data)
|
problem_list.append(data)
|
||||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
|
|
||||||
# 创建异步任务处理单个问题
|
# Create async task to process individual questions
|
||||||
async def process_question_nodes(idx, question):
|
async def process_question_nodes(idx, question):
|
||||||
try:
|
try:
|
||||||
# Prepare search parameters based on storage type
|
# Prepare search parameters based on storage type
|
||||||
@@ -209,7 +266,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 并发处理所有问题
|
# Process all questions concurrently
|
||||||
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
||||||
databases_anser = await asyncio.gather(*tasks)
|
databases_anser = await asyncio.gather(*tasks)
|
||||||
databases_data = {
|
databases_data = {
|
||||||
@@ -257,7 +314,20 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
|
|
||||||
async def retrieve(state: ReadState) -> ReadState:
|
async def retrieve(state: ReadState) -> ReadState:
|
||||||
# 从state中获取end_user_id
|
"""
|
||||||
|
Advanced retrieve function using LangChain agents and tools
|
||||||
|
|
||||||
|
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
|
||||||
|
to perform sophisticated information retrieval. Supports both RAG and traditional
|
||||||
|
memory storage approaches with concurrent processing and result deduplication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing problem extensions and configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state with retrieval results and intermediate outputs
|
||||||
|
"""
|
||||||
|
# Get end_user_id from state
|
||||||
import time
|
import time
|
||||||
start = time.time()
|
start = time.time()
|
||||||
problem_extension = state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
@@ -291,7 +361,11 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
)
|
)
|
||||||
|
|
||||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||||
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
|
search_params = {
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"return_raw_results": True,
|
||||||
|
"include": ["summaries", "statements", "chunks", "entities", "communities"],
|
||||||
|
}
|
||||||
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
llm,
|
llm,
|
||||||
@@ -299,21 +373,21 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建异步任务处理单个问题
|
# Create async task to process individual questions
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# 在模块级别定义信号量,限制最大并发数
|
# Define semaphore at module level to limit maximum concurrency
|
||||||
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
|
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
|
||||||
|
|
||||||
async def process_question(idx, question):
|
async def process_question(idx, question):
|
||||||
async with SEMAPHORE: # 限制并发
|
async with SEMAPHORE: # Limit concurrency
|
||||||
try:
|
try:
|
||||||
if storage_type == "rag" and user_rag_memory_id:
|
if storage_type == "rag" and user_rag_memory_id:
|
||||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||||
question)
|
question)
|
||||||
else:
|
else:
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
# Use asyncio to run synchronous agent.invoke in thread pool
|
||||||
import asyncio
|
import asyncio
|
||||||
response = await asyncio.get_event_loop().run_in_executor(
|
response = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
@@ -327,8 +401,32 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
raw_results = tool_results['content']
|
raw_results = tool_results['content']
|
||||||
clean_content = await clean_databases(raw_results)
|
clean_content = await clean_databases(raw_results)
|
||||||
|
|
||||||
|
# 社区展开:从 tool 返回结果中提取命中的 community,
|
||||||
|
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
|
||||||
|
_expanded_stmts_to_write = []
|
||||||
|
try:
|
||||||
|
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
|
||||||
|
reranked = results_dict.get('reranked_results', {})
|
||||||
|
community_hits = reranked.get('communities', [])
|
||||||
|
if not community_hits:
|
||||||
|
community_hits = results_dict.get('communities', [])
|
||||||
|
if community_hits:
|
||||||
|
from app.core.memory.agent.services.search_service import expand_communities_to_statements
|
||||||
|
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
|
||||||
|
community_results=community_hits,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
existing_content=clean_content,
|
||||||
|
)
|
||||||
|
if new_texts:
|
||||||
|
clean_content = clean_content + '\n' + '\n'.join(new_texts)
|
||||||
|
except Exception as parse_err:
|
||||||
|
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw_results = raw_results['results']
|
raw_results = raw_results['results']
|
||||||
|
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段)
|
||||||
|
if _expanded_stmts_to_write and isinstance(raw_results, dict):
|
||||||
|
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
|
||||||
except Exception:
|
except Exception:
|
||||||
raw_results = []
|
raw_results = []
|
||||||
|
|
||||||
@@ -362,7 +460,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 并发处理所有问题
|
# Process all questions concurrently
|
||||||
import asyncio
|
import asyncio
|
||||||
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
||||||
databases_anser = await asyncio.gather(*tasks)
|
databases_anser = await asyncio.gather(*tasks)
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger, log_time
|
from app.core.logging_config import get_agent_logger, log_time
|
||||||
|
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||||
|
PerceptualSearchService,
|
||||||
|
)
|
||||||
from app.core.memory.agent.models.summary_models import (
|
from app.core.memory.agent.models.summary_models import (
|
||||||
RetrieveSummaryResponse,
|
RetrieveSummaryResponse,
|
||||||
SummaryResponse,
|
SummaryResponse,
|
||||||
@@ -15,6 +19,7 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
|
from app.core.memory.enums import Neo4jNodeType
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
|
|
||||||
@@ -23,18 +28,39 @@ logger = get_agent_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class SummaryNodeService(LLMServiceMixin):
|
class SummaryNodeService(LLMServiceMixin):
|
||||||
"""总结节点服务类"""
|
"""
|
||||||
|
Summary node service class
|
||||||
|
|
||||||
|
Handles summary generation operations using LLM services. Inherits from
|
||||||
|
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||||
|
generating summaries from retrieved information.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
template_service: Service for rendering Jinja2 templates
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# Create global service instance
|
||||||
summary_service = SummaryNodeService()
|
summary_service = SummaryNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
async def rag_config(state):
|
||||||
|
"""
|
||||||
|
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
|
||||||
|
|
||||||
|
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||||
|
weights, and reranker settings specifically for summary generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state containing user_rag_memory_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: RAG configuration dictionary with knowledge base settings
|
||||||
|
"""
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
kb_config = {
|
kb_config = {
|
||||||
"knowledge_bases": [
|
"knowledge_bases": [
|
||||||
@@ -54,6 +80,23 @@ async def rag_config(state):
|
|||||||
|
|
||||||
|
|
||||||
async def rag_knowledge(state, question):
|
async def rag_knowledge(state, question):
|
||||||
|
"""
|
||||||
|
Retrieve knowledge using RAG approach for summary generation
|
||||||
|
|
||||||
|
Performs knowledge retrieval from configured knowledge bases using the
|
||||||
|
provided question and returns formatted results for summary processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state containing configuration
|
||||||
|
question: Question to search for in knowledge base
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||||
|
- retrieval_knowledge: List of retrieved knowledge chunks
|
||||||
|
- clean_content: Formatted content string
|
||||||
|
- cleaned_query: Processed query string
|
||||||
|
- raw_results: Raw retrieval results
|
||||||
|
"""
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
@@ -74,6 +117,18 @@ async def rag_knowledge(state, question):
|
|||||||
|
|
||||||
|
|
||||||
async def summary_history(state: ReadState) -> ReadState:
|
async def summary_history(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Retrieve conversation history for summary context
|
||||||
|
|
||||||
|
Gets the conversation history for the current user to provide context
|
||||||
|
for summary generation operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing end_user_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Conversation history data
|
||||||
|
"""
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
return history
|
return history
|
||||||
@@ -82,11 +137,26 @@ async def summary_history(state: ReadState) -> ReadState:
|
|||||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||||
search_mode) -> str:
|
search_mode) -> str:
|
||||||
"""
|
"""
|
||||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
Enhanced summary_llm function with better error handling and data validation
|
||||||
|
|
||||||
|
Generates summaries using LLM with structured output. Includes fallback mechanisms
|
||||||
|
for handling LLM failures and provides robust error recovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing current context
|
||||||
|
history: Conversation history for context
|
||||||
|
retrieve_info: Retrieved information to summarize
|
||||||
|
template_name: Jinja2 template name for prompt generation
|
||||||
|
operation_name: Type of operation (summary, input_summary, retrieve_summary)
|
||||||
|
response_model: Pydantic model for structured output
|
||||||
|
search_mode: Search mode flag ("0" for simple, "1" for complex)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Generated summary text or fallback message
|
||||||
"""
|
"""
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
|
|
||||||
# 构建系统提示词
|
# Build system prompt
|
||||||
if str(search_mode) == "0":
|
if str(search_mode) == "0":
|
||||||
system_prompt = await summary_service.template_service.render_template(
|
system_prompt = await summary_service.template_service.render_template(
|
||||||
template_name=template_name,
|
template_name=template_name,
|
||||||
@@ -103,7 +173,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
retrieve_info=retrieve_info
|
retrieve_info=retrieve_info
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务进行结构化输出
|
# Use optimized LLM service for structured output
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
structured = await summary_service.call_llm_structured(
|
structured = await summary_service.call_llm_structured(
|
||||||
state=state,
|
state=state,
|
||||||
@@ -112,23 +182,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
fallback_value=None
|
fallback_value=None
|
||||||
)
|
)
|
||||||
# 验证结构化响应
|
# Validate structured response
|
||||||
if structured is None:
|
if structured is None:
|
||||||
logger.warning("LLM返回None,使用默认回答")
|
logger.warning("LLM返回None,使用默认回答")
|
||||||
return "信息不足,无法回答"
|
return "信息不足,无法回答"
|
||||||
|
|
||||||
# 根据操作类型提取答案
|
# Extract answer based on operation type
|
||||||
if operation_name == "summary":
|
if operation_name == "summary":
|
||||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||||
else:
|
else:
|
||||||
# 处理RetrieveSummaryResponse
|
# Handle RetrieveSummaryResponse
|
||||||
if hasattr(structured, 'data') and structured.data:
|
if hasattr(structured, 'data') and structured.data:
|
||||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||||
else:
|
else:
|
||||||
logger.warning("结构化响应缺少data字段")
|
logger.warning("结构化响应缺少data字段")
|
||||||
aimessages = "信息不足,无法回答"
|
aimessages = "信息不足,无法回答"
|
||||||
|
|
||||||
# 验证答案不为空
|
# Validate answer is not empty
|
||||||
if not aimessages or aimessages.strip() == "":
|
if not aimessages or aimessages.strip() == "":
|
||||||
aimessages = "信息不足,无法回答"
|
aimessages = "信息不足,无法回答"
|
||||||
|
|
||||||
@@ -137,7 +207,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||||
|
|
||||||
# 尝试非结构化输出作为fallback
|
# Try unstructured output as fallback
|
||||||
try:
|
try:
|
||||||
logger.info("尝试非结构化输出作为fallback")
|
logger.info("尝试非结构化输出作为fallback")
|
||||||
response = await summary_service.call_llm_simple(
|
response = await summary_service.call_llm_simple(
|
||||||
@@ -148,9 +218,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response and response.strip():
|
if response and response.strip():
|
||||||
# 简单清理响应
|
# Simple response cleaning
|
||||||
cleaned_response = response.strip()
|
cleaned_response = response.strip()
|
||||||
# 移除可能的JSON标记
|
# Remove possible JSON markers
|
||||||
if cleaned_response.startswith('```'):
|
if cleaned_response.startswith('```'):
|
||||||
lines = cleaned_response.split('\n')
|
lines = cleaned_response.split('\n')
|
||||||
cleaned_response = '\n'.join(lines[1:-1])
|
cleaned_response = '\n'.join(lines[1:-1])
|
||||||
@@ -165,6 +235,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
|
|
||||||
|
|
||||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||||
|
"""
|
||||||
|
Save summary results to Redis session storage
|
||||||
|
|
||||||
|
Stores the generated summary and user query in Redis for session management
|
||||||
|
and conversation history tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing user and query information
|
||||||
|
aimessages: Generated summary message to save
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state after saving to Redis
|
||||||
|
"""
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
await SessionService(store).save_session(
|
await SessionService(store).save_session(
|
||||||
@@ -179,6 +262,20 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
|||||||
|
|
||||||
|
|
||||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||||
|
"""
|
||||||
|
Format summary results for different output types
|
||||||
|
|
||||||
|
Creates structured output formats for both input summary and retrieval summary
|
||||||
|
operations, including metadata and intermediate results for frontend display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing storage and user information
|
||||||
|
aimessages: Generated summary message
|
||||||
|
raw_results: Raw search/retrieval results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (input_summary, retrieve_summary) formatted result dictionaries
|
||||||
|
"""
|
||||||
storage_type = state.get("storage_type", '')
|
storage_type = state.get("storage_type", '')
|
||||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
@@ -217,6 +314,19 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState
|
|||||||
|
|
||||||
|
|
||||||
async def Input_Summary(state: ReadState) -> ReadState:
|
async def Input_Summary(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Generate quick input summary from retrieved information
|
||||||
|
|
||||||
|
Performs fast retrieval and generates a quick summary response for user queries.
|
||||||
|
This function prioritizes speed by only searching summary nodes and provides
|
||||||
|
immediate feedback to users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing user query, storage configuration, and context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Dictionary containing summary results with status and metadata
|
||||||
|
"""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
storage_type = state.get("storage_type", '')
|
storage_type = state.get("storage_type", '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
@@ -229,13 +339,56 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"question": data,
|
"question": data,
|
||||||
"return_raw_results": True,
|
"return_raw_results": True,
|
||||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
"include": [Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY] # MemorySummary 和 Community 同为高维度概括节点
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if storage_type != "rag":
|
if storage_type != "rag":
|
||||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
|
||||||
memory_config=memory_config)
|
async def _perceptual_search():
|
||||||
|
service = PerceptualSearchService(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
memory_config=memory_config,
|
||||||
|
)
|
||||||
|
return await service.search(query=data, limit=5)
|
||||||
|
|
||||||
|
hybrid_task = SearchService().execute_hybrid_search(
|
||||||
|
**search_params,
|
||||||
|
memory_config=memory_config,
|
||||||
|
expand_communities=False,
|
||||||
|
)
|
||||||
|
perceptual_task = _perceptual_search()
|
||||||
|
|
||||||
|
gather_results = await asyncio.gather(
|
||||||
|
hybrid_task, perceptual_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
hybrid_result = gather_results[0]
|
||||||
|
perceptual_results = gather_results[1]
|
||||||
|
|
||||||
|
# 处理 hybrid search 异常
|
||||||
|
if isinstance(hybrid_result, Exception):
|
||||||
|
raise hybrid_result
|
||||||
|
retrieve_info, question, raw_results = hybrid_result
|
||||||
|
|
||||||
|
# 处理感知记忆结果
|
||||||
|
if isinstance(perceptual_results, Exception):
|
||||||
|
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
|
||||||
|
perceptual_results = []
|
||||||
|
|
||||||
|
# 拼接感知记忆内容到 retrieve_info
|
||||||
|
if perceptual_results and isinstance(perceptual_results, dict):
|
||||||
|
perceptual_content = perceptual_results.get("content", "")
|
||||||
|
if perceptual_content:
|
||||||
|
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
|
||||||
|
count = len(perceptual_results.get("memories", []))
|
||||||
|
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
|
||||||
|
|
||||||
|
# 调试:打印 community 检索结果数量
|
||||||
|
if raw_results and isinstance(raw_results, dict):
|
||||||
|
reranked = raw_results.get('reranked_results', {})
|
||||||
|
community_hits = reranked.get('communities', [])
|
||||||
|
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
|
||||||
|
f"summary 命中数: {len(reranked.get('summaries', []))}")
|
||||||
else:
|
else:
|
||||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -257,15 +410,25 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
"error": str(e)
|
"error": str(e)
|
||||||
}
|
}
|
||||||
end = time.time()
|
end = time.time()
|
||||||
try:
|
duration = end - start
|
||||||
duration = end - start
|
|
||||||
except Exception:
|
|
||||||
duration = 0.0
|
|
||||||
log_time('检索', duration)
|
log_time('检索', duration)
|
||||||
return {"summary": summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
|
|
||||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Generate comprehensive summary from retrieved expansion issues
|
||||||
|
|
||||||
|
Processes retrieved expansion issues and generates a detailed summary using LLM.
|
||||||
|
This function handles complex retrieval results and provides comprehensive answers
|
||||||
|
based on expanded query results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing retrieve data with expansion issues
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Dictionary containing comprehensive summary results
|
||||||
|
"""
|
||||||
retrieve = state.get("retrieve", '')
|
retrieve = state.get("retrieve", '')
|
||||||
history = await summary_history(state)
|
history = await summary_history(state)
|
||||||
import json
|
import json
|
||||||
@@ -285,8 +448,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
|||||||
retrieve_info_str = list(set(retrieve_info_str))
|
retrieve_info_str = list(set(retrieve_info_str))
|
||||||
retrieve_info_str = '\n'.join(retrieve_info_str)
|
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||||
|
|
||||||
aimessages = await summary_llm(state, history, retrieve_info_str,
|
# Merge perceptual memory content
|
||||||
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
perceptual_data = state.get("perceptual_data", {})
|
||||||
|
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||||
|
if perceptual_content:
|
||||||
|
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||||
|
|
||||||
|
aimessages = await summary_llm(
|
||||||
|
state,
|
||||||
|
history,
|
||||||
|
retrieve_info_str,
|
||||||
|
'direct_summary_prompt.jinja2',
|
||||||
|
'retrieve_summary', RetrieveSummaryResponse,
|
||||||
|
"1"
|
||||||
|
)
|
||||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||||
await summary_redis_save(state, aimessages)
|
await summary_redis_save(state, aimessages)
|
||||||
if aimessages == '':
|
if aimessages == '':
|
||||||
@@ -299,13 +474,26 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
|||||||
duration = 0.0
|
duration = 0.0
|
||||||
log_time('Retrieval summary', duration)
|
log_time('Retrieval summary', duration)
|
||||||
|
|
||||||
# 修复协程调用 - 先await,然后访问返回值
|
# Fixed coroutine call - await first, then access return value
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary": summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
|
|
||||||
async def Summary(state: ReadState) -> ReadState:
|
async def Summary(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Generate final comprehensive summary from verified data
|
||||||
|
|
||||||
|
Creates the final summary using verified expansion issues and conversation history.
|
||||||
|
This function processes verified data to generate the most comprehensive and
|
||||||
|
accurate response to user queries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing verified data and query information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Dictionary containing final summary results
|
||||||
|
"""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
query = state.get("data", '')
|
query = state.get("data", '')
|
||||||
verify = state.get("verify", '')
|
verify = state.get("verify", '')
|
||||||
@@ -318,6 +506,12 @@ async def Summary(state: ReadState) -> ReadState:
|
|||||||
retrieve_info_str += i + '\n'
|
retrieve_info_str += i + '\n'
|
||||||
history = await summary_history(state)
|
history = await summary_history(state)
|
||||||
|
|
||||||
|
# Merge perceptual memory content
|
||||||
|
perceptual_data = state.get("perceptual_data", {})
|
||||||
|
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||||
|
if perceptual_content:
|
||||||
|
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
@@ -336,13 +530,26 @@ async def Summary(state: ReadState) -> ReadState:
|
|||||||
duration = 0.0
|
duration = 0.0
|
||||||
log_time('Retrieval summary', duration)
|
log_time('Retrieval summary', duration)
|
||||||
|
|
||||||
# 修复协程调用 - 先await,然后访问返回值
|
# Fixed coroutine call - await first, then access return value
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary": summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
|
|
||||||
async def Summary_fails(state: ReadState) -> ReadState:
|
async def Summary_fails(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Generate fallback summary when normal summary process fails
|
||||||
|
|
||||||
|
Provides a fallback summary generation mechanism when the standard summary
|
||||||
|
process encounters errors or fails to produce satisfactory results. Uses
|
||||||
|
a specialized failure template to handle edge cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing verified data and failure context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Dictionary containing fallback summary results
|
||||||
|
"""
|
||||||
storage_type = state.get("storage_type", '')
|
storage_type = state.get("storage_type", '')
|
||||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
history = await summary_history(state)
|
history = await summary_history(state)
|
||||||
@@ -355,6 +562,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
|
|||||||
if key == 'answer_small':
|
if key == 'answer_small':
|
||||||
for i in value:
|
for i in value:
|
||||||
retrieve_info_str += i + '\n'
|
retrieve_info_str += i + '\n'
|
||||||
|
|
||||||
|
# Merge perceptual memory content
|
||||||
|
perceptual_data = state.get("perceptual_data", {})
|
||||||
|
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||||
|
if perceptual_content:
|
||||||
|
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
|
|||||||
@@ -18,24 +18,46 @@ logger = get_agent_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class VerificationNodeService(LLMServiceMixin):
|
class VerificationNodeService(LLMServiceMixin):
|
||||||
"""验证节点服务类"""
|
"""
|
||||||
|
Verification node service class
|
||||||
|
|
||||||
|
Handles data verification operations using LLM services. Inherits from
|
||||||
|
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||||
|
verifying and validating retrieved information.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
template_service: Service for rendering Jinja2 templates
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# Create global service instance
|
||||||
verification_service = VerificationNodeService()
|
verification_service = VerificationNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||||
"""处理验证结果并生成输出格式"""
|
"""
|
||||||
|
Process verification results and generate output format
|
||||||
|
|
||||||
|
Transforms VerificationResult objects into structured output format suitable
|
||||||
|
for frontend consumption. Handles conversion of VerificationItem objects to
|
||||||
|
dictionary format and adds metadata for tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing storage and user configuration
|
||||||
|
messages_deal: VerificationResult containing verification outcomes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Formatted verification result with status and metadata
|
||||||
|
"""
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
data = state.get('data', '')
|
data = state.get('data', '')
|
||||||
|
|
||||||
# 将 VerificationItem 对象转换为字典列表
|
# Convert VerificationItem objects to dictionary list
|
||||||
verified_data = []
|
verified_data = []
|
||||||
if messages_deal.expansion_issue:
|
if messages_deal.expansion_issue:
|
||||||
for item in messages_deal.expansion_issue:
|
for item in messages_deal.expansion_issue:
|
||||||
@@ -89,7 +111,7 @@ async def Verify(state: ReadState):
|
|||||||
|
|
||||||
logger.info("Verify: 开始渲染模板")
|
logger.info("Verify: 开始渲染模板")
|
||||||
|
|
||||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
# Generate JSON schema to guide LLM output format
|
||||||
json_schema = VerificationResult.model_json_schema()
|
json_schema = VerificationResult.model_json_schema()
|
||||||
|
|
||||||
system_prompt = await verification_service.template_service.render_template(
|
system_prompt = await verification_service.template_service.render_template(
|
||||||
@@ -104,8 +126,8 @@ async def Verify(state: ReadState):
|
|||||||
# 使用优化的LLM服务,添加超时保护
|
# 使用优化的LLM服务,添加超时保护
|
||||||
logger.info("Verify: 开始调用 LLM")
|
logger.info("Verify: 开始调用 LLM")
|
||||||
try:
|
try:
|
||||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
|
||||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
|
||||||
|
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
structured = await asyncio.wait_for(
|
structured = await asyncio.wait_for(
|
||||||
@@ -122,7 +144,7 @@ async def Verify(state: ReadState):
|
|||||||
"reason": "验证失败或超时"
|
"reason": "验证失败或超时"
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
timeout=150.0 # 150秒超时
|
timeout=150.0 # 150 second timeout
|
||||||
)
|
)
|
||||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|||||||
@@ -1,67 +0,0 @@
|
|||||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
|
||||||
from app.core.memory.agent.utils.write_tools import write
|
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def write_node(state: WriteState) -> WriteState:
|
|
||||||
"""
|
|
||||||
Write data to the database/file system.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: WriteState containing messages, end_user_id, memory_config, and language
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Contains 'write_result' with status and data fields
|
|
||||||
"""
|
|
||||||
messages = state.get('messages', [])
|
|
||||||
end_user_id = state.get('end_user_id', '')
|
|
||||||
memory_config = state.get('memory_config', '')
|
|
||||||
language = state.get('language', 'zh') # 默认中文
|
|
||||||
|
|
||||||
# Convert LangChain messages to structured format expected by write()
|
|
||||||
structured_messages = []
|
|
||||||
for msg in messages:
|
|
||||||
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
|
||||||
# Map LangChain message types to role names
|
|
||||||
role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type
|
|
||||||
structured_messages.append({
|
|
||||||
"role": role,
|
|
||||||
"content": msg.content # content is now guaranteed to be a string
|
|
||||||
})
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await write(
|
|
||||||
messages=structured_messages,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
memory_config=memory_config,
|
|
||||||
language=language,
|
|
||||||
)
|
|
||||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
|
||||||
|
|
||||||
# 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成
|
|
||||||
for lang in ["zh", "en"]:
|
|
||||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
language=lang,
|
|
||||||
)
|
|
||||||
if deleted:
|
|
||||||
logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
|
||||||
|
|
||||||
write_result = {
|
|
||||||
"status": "success",
|
|
||||||
"data": structured_messages,
|
|
||||||
"config_id": memory_config.config_id,
|
|
||||||
"config_name": memory_config.config_name,
|
|
||||||
}
|
|
||||||
return {"write_result": write_result}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Data_write failed: {e}", exc_info=True)
|
|
||||||
write_result = {
|
|
||||||
"status": "error",
|
|
||||||
"message": str(e),
|
|
||||||
}
|
|
||||||
return {"write_result": write_result}
|
|
||||||
@@ -1,21 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langgraph.constants import START, END
|
from langgraph.constants import START, END
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
from app.db import get_db
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
|
|
||||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
|
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
|
||||||
|
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||||
|
perceptual_retrieve_node,
|
||||||
|
)
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
||||||
Split_The_Problem,
|
Split_The_Problem,
|
||||||
Problem_Extension,
|
Problem_Extension,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
||||||
retrieve,
|
retrieve_nodes,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||||
Input_Summary,
|
Input_Summary,
|
||||||
@@ -29,11 +28,26 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
|||||||
Retrieve_continue,
|
Retrieve_continue,
|
||||||
Verify_continue,
|
Verify_continue,
|
||||||
)
|
)
|
||||||
|
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def make_read_graph():
|
async def make_read_graph():
|
||||||
"""创建并返回 LangGraph 工作流"""
|
"""
|
||||||
|
Create and return a LangGraph workflow for memory reading operations
|
||||||
|
|
||||||
|
Builds a state graph workflow that handles memory retrieval, problem analysis,
|
||||||
|
verification, and summarization. The workflow includes nodes for content input,
|
||||||
|
problem splitting, retrieval, verification, and various summary operations.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StateGraph: Compiled LangGraph workflow for memory reading
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If workflow creation fails
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Build workflow graph
|
# Build workflow graph
|
||||||
workflow = StateGraph(ReadState)
|
workflow = StateGraph(ReadState)
|
||||||
@@ -41,139 +55,34 @@ async def make_read_graph():
|
|||||||
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
||||||
workflow.add_node("Problem_Extension", Problem_Extension)
|
workflow.add_node("Problem_Extension", Problem_Extension)
|
||||||
workflow.add_node("Input_Summary", Input_Summary)
|
workflow.add_node("Input_Summary", Input_Summary)
|
||||||
# workflow.add_node("Retrieve", retrieve_nodes)
|
workflow.add_node("Retrieve", retrieve_nodes)
|
||||||
workflow.add_node("Retrieve", retrieve)
|
# workflow.add_node("Retrieve", retrieve)
|
||||||
|
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
|
||||||
workflow.add_node("Verify", Verify)
|
workflow.add_node("Verify", Verify)
|
||||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||||
workflow.add_node("Summary", Summary)
|
workflow.add_node("Summary", Summary)
|
||||||
workflow.add_node("Summary_fails", Summary_fails)
|
workflow.add_node("Summary_fails", Summary_fails)
|
||||||
|
|
||||||
# 添加边
|
# Add edges to define workflow flow
|
||||||
workflow.add_edge(START, "content_input")
|
workflow.add_edge(START, "content_input")
|
||||||
workflow.add_conditional_edges("content_input", Split_continue)
|
workflow.add_conditional_edges("content_input", Split_continue)
|
||||||
workflow.add_edge("Input_Summary", END)
|
workflow.add_edge("Input_Summary", END)
|
||||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
|
||||||
|
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
|
||||||
|
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
|
||||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||||
workflow.add_edge("Retrieve_Summary", END)
|
workflow.add_edge("Retrieve_Summary", END)
|
||||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||||
workflow.add_edge("Summary_fails", END)
|
workflow.add_edge("Summary_fails", END)
|
||||||
workflow.add_edge("Summary", END)
|
workflow.add_edge("Summary", END)
|
||||||
|
|
||||||
'''-----'''
|
|
||||||
# workflow.add_edge("Retrieve", END)
|
# workflow.add_edge("Retrieve", END)
|
||||||
|
|
||||||
# 编译工作流
|
# Compile workflow
|
||||||
graph = workflow.compile()
|
graph = workflow.compile()
|
||||||
yield graph
|
yield graph
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"创建工作流失败: {e}")
|
logger.error(f"创建工作流失败: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
print("工作流创建完成")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""主函数 - 运行工作流"""
|
|
||||||
message = "昨天有什么好看的电影"
|
|
||||||
end_user_id = '88a459f5_text09' # 组ID
|
|
||||||
storage_type = 'neo4j' # 存储类型
|
|
||||||
search_switch = '1' # 搜索开关
|
|
||||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
|
||||||
|
|
||||||
# 获取数据库会话
|
|
||||||
db_session = next(get_db())
|
|
||||||
config_service = MemoryConfigService(db_session)
|
|
||||||
memory_config = config_service.load_memory_config(
|
|
||||||
config_id=17, # 改为整数
|
|
||||||
service_name="MemoryAgentService"
|
|
||||||
)
|
|
||||||
import time
|
|
||||||
start = time.time()
|
|
||||||
try:
|
|
||||||
async with make_read_graph() as graph:
|
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
|
||||||
# 初始状态 - 包含所有必要字段
|
|
||||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
|
||||||
"end_user_id": end_user_id
|
|
||||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
|
||||||
"memory_config": memory_config}
|
|
||||||
# 获取节点更新信息
|
|
||||||
_intermediate_outputs = []
|
|
||||||
summary = ''
|
|
||||||
|
|
||||||
async for update_event in graph.astream(
|
|
||||||
initial_state,
|
|
||||||
stream_mode="updates",
|
|
||||||
config=config
|
|
||||||
):
|
|
||||||
for node_name, node_data in update_event.items():
|
|
||||||
print(f"处理节点: {node_name}")
|
|
||||||
|
|
||||||
# 处理不同Summary节点的返回结构
|
|
||||||
if 'Summary' in node_name:
|
|
||||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
|
||||||
summary = node_data['InputSummary']['summary_result']
|
|
||||||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
|
||||||
summary = node_data['RetrieveSummary']['summary_result']
|
|
||||||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
|
||||||
summary = node_data['summary']['summary_result']
|
|
||||||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
|
||||||
summary = node_data['SummaryFails']['summary_result']
|
|
||||||
|
|
||||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
|
||||||
if spit_data and spit_data != [] and spit_data != {}:
|
|
||||||
_intermediate_outputs.append(spit_data)
|
|
||||||
|
|
||||||
# Problem_Extension 节点
|
|
||||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
|
||||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
|
||||||
_intermediate_outputs.append(problem_extension)
|
|
||||||
|
|
||||||
# Retrieve 节点
|
|
||||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
|
||||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
|
||||||
_intermediate_outputs.extend(retrieve_node)
|
|
||||||
|
|
||||||
# Verify 节点
|
|
||||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
|
||||||
if verify_n and verify_n != [] and verify_n != {}:
|
|
||||||
_intermediate_outputs.append(verify_n)
|
|
||||||
|
|
||||||
# Summary 节点
|
|
||||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
|
||||||
if summary_n and summary_n != [] and summary_n != {}:
|
|
||||||
_intermediate_outputs.append(summary_n)
|
|
||||||
|
|
||||||
# # 过滤掉空值
|
|
||||||
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
|
||||||
#
|
|
||||||
# # 优化搜索结果
|
|
||||||
# print("=== 开始优化搜索结果 ===")
|
|
||||||
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
|
||||||
# result=reorder_output_results(optimized_outputs)
|
|
||||||
# # 保存优化后的结果到文件
|
|
||||||
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
|
|
||||||
# import json
|
|
||||||
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
|
|
||||||
#
|
|
||||||
print(f"=== 最终摘要 ===")
|
|
||||||
print(summary)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
finally:
|
|
||||||
db_session.close()
|
|
||||||
|
|
||||||
end = time.time()
|
|
||||||
print(100 * 'y')
|
|
||||||
print(f"总耗时: {end - start}s")
|
|
||||||
print(100 * 'y')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||||
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
counter = COUNTState(limit=3)
|
counter = COUNTState(limit=3)
|
||||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
|
||||||
|
|
||||||
|
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||||
"""
|
"""
|
||||||
Determine routing based on search_switch value.
|
Determine routing based on search_switch value.
|
||||||
|
|
||||||
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
|
|||||||
return 'Input_Summary'
|
return 'Input_Summary'
|
||||||
return 'Split_The_Problem' # 默认情况
|
return 'Split_The_Problem' # 默认情况
|
||||||
|
|
||||||
|
|
||||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||||
"""
|
"""
|
||||||
Determine routing based on search_switch value.
|
Determine routing based on search_switch value.
|
||||||
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
|||||||
elif search_switch == '1':
|
elif search_switch == '1':
|
||||||
return 'Retrieve_Summary'
|
return 'Retrieve_Summary'
|
||||||
return 'Retrieve_Summary' # Default based on business logic
|
return 'Retrieve_Summary' # Default based on business logic
|
||||||
|
|
||||||
|
|
||||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||||
status=state.get('verify', '')['status']
|
status = state.get('verify', '')['status']
|
||||||
# loop_count = counter.get_total()
|
# loop_count = counter.get_total()
|
||||||
if "success" in status:
|
if "success" in status:
|
||||||
# counter.reset()
|
# counter.reset()
|
||||||
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
|
|||||||
# if loop_count < 2: # Maximum loop count is 3
|
# if loop_count < 2: # Maximum loop count is 3
|
||||||
# return "content_input"
|
# return "content_input"
|
||||||
# else:
|
# else:
|
||||||
# counter.reset()
|
# counter.reset()
|
||||||
return "Summary_fails"
|
return "Summary_fails"
|
||||||
else:
|
else:
|
||||||
# Add default return value to avoid returning None
|
# Add default return value to avoid returning None
|
||||||
|
|||||||
@@ -1,184 +1,244 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from app.celery_task_scheduler import scheduler
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
|
||||||
|
|
||||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
from app.core.memory.agent.utils.redis_tool import count_store
|
from app.core.memory.agent.utils.redis_tool import count_store
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context, get_db
|
from app.db import get_db_context
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.memory_konwledges_server import write_rag
|
|
||||||
from app.services.task_service import get_task_memory_write_result
|
|
||||||
from app.tasks import write_message_task
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
|
|
||||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
|
||||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
async def write(
|
||||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
storage_type,
|
||||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
end_user_id,
|
||||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
user_message,
|
||||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
ai_message,
|
||||||
actual_config_id, long_term_messages=[]):
|
user_rag_memory_id,
|
||||||
|
actual_end_user_id,
|
||||||
|
actual_config_id,
|
||||||
|
long_term_messages=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
写入记忆(支持结构化消息)
|
Write memory with structured message support
|
||||||
|
|
||||||
|
Handles memory writing operations for different storage types (Neo4j/RAG).
|
||||||
|
Supports both individual message pairs and batch long-term message processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
storage_type: 存储类型 (neo4j/rag)
|
storage_type: Storage type identifier ("neo4j" or "rag")
|
||||||
end_user_id: 终端用户ID
|
end_user_id: Terminal user identifier
|
||||||
user_message: 用户消息内容
|
user_message: User message content
|
||||||
ai_message: AI 回复内容
|
ai_message: AI response content
|
||||||
user_rag_memory_id: RAG 记忆ID
|
user_rag_memory_id: RAG memory identifier
|
||||||
actual_end_user_id: 实际用户ID
|
actual_end_user_id: Actual user identifier for storage
|
||||||
actual_config_id: 配置ID
|
actual_config_id: Configuration identifier
|
||||||
|
long_term_messages: Optional list of structured messages for batch processing
|
||||||
|
|
||||||
逻辑说明:
|
Logic explanation:
|
||||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
|
||||||
- Neo4j 模式:使用结构化消息列表
|
- Neo4j mode: Uses structured message lists
|
||||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
|
||||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
|
||||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
3. Each message is converted to independent Chunk, preserving speaker field
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db = next(get_db())
|
if long_term_messages is None:
|
||||||
try:
|
long_term_messages = []
|
||||||
|
with get_db_context() as db:
|
||||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||||
# Neo4j 模式:使用结构化消息列表
|
# Neo4j mode: Use structured message lists
|
||||||
structured_messages = []
|
structured_messages = []
|
||||||
|
|
||||||
# 始终添加用户消息(如果不为空)
|
# Always add user message (if not empty)
|
||||||
if isinstance(user_message, str) and user_message.strip() != "":
|
if isinstance(user_message, str) and user_message.strip() != "":
|
||||||
structured_messages.append({"role": "user", "content": user_message})
|
structured_messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
# Only add assistant message when AI reply is not empty
|
||||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||||
|
|
||||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
# If long_term_messages provided, use it to replace structured_messages
|
||||||
if long_term_messages and isinstance(long_term_messages, list):
|
if long_term_messages and isinstance(long_term_messages, list):
|
||||||
structured_messages = long_term_messages
|
structured_messages = long_term_messages
|
||||||
elif long_term_messages and isinstance(long_term_messages, str):
|
elif long_term_messages and isinstance(long_term_messages, str):
|
||||||
# 如果是 JSON 字符串,先解析
|
# If it's a JSON string, parse it first
|
||||||
try:
|
try:
|
||||||
structured_messages = json.loads(long_term_messages)
|
structured_messages = json.loads(long_term_messages)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||||
|
|
||||||
# 如果没有消息,直接返回
|
# If no messages, return directly
|
||||||
if not structured_messages:
|
if not structured_messages:
|
||||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
write_id = write_message_task.delay(
|
# write_id = write_message_task.delay(
|
||||||
actual_end_user_id, # end_user_id: 用户ID
|
# actual_end_user_id, # end_user_id: User ID
|
||||||
structured_messages, # message: JSON 字符串格式的消息列表
|
# structured_messages, # message: JSON string format message list
|
||||||
str(actual_config_id), # config_id: 配置ID字符串
|
# str(actual_config_id), # config_id: Configuration ID string
|
||||||
storage_type, # storage_type: "neo4j"
|
# storage_type, # storage_type: "neo4j"
|
||||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
# user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||||
|
# )
|
||||||
|
scheduler.push_task(
|
||||||
|
"app.core.memory.agent.write_message",
|
||||||
|
str(actual_end_user_id),
|
||||||
|
{
|
||||||
|
"end_user_id": str(actual_end_user_id),
|
||||||
|
"message": structured_messages,
|
||||||
|
"config_id": str(actual_config_id),
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id or ""
|
||||||
|
}
|
||||||
)
|
)
|
||||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
|
||||||
write_status = get_task_memory_write_result(str(write_id))
|
|
||||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
# logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
|
# write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
# logger.info(f'[WRITE] Task result - user={actual_end_user_id}')
|
||||||
|
|
||||||
|
|
||||||
|
async def term_memory_save(end_user_id, strategy_type, scope):
|
||||||
|
"""
|
||||||
|
Save long-term memory data to database
|
||||||
|
|
||||||
|
Handles the storage of long-term memory data based on different strategies
|
||||||
|
(chunk-based or aggregate-based) and manages the transition from short-term
|
||||||
|
to long-term memory storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: User identifier for memory association
|
||||||
|
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||||
|
scope: Scope/window size for memory processing
|
||||||
|
"""
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
repo = LongTermMemoryRepository(db_session)
|
repo = LongTermMemoryRepository(db_session)
|
||||||
|
|
||||||
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
result = write_store.get_session_by_userid(end_user_id)
|
result = write_store.get_session_by_userid(end_user_id)
|
||||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
if not result:
|
||||||
|
logger.warning(f"No write data found for user {end_user_id}")
|
||||||
|
return
|
||||||
|
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
|
||||||
data = await format_parsing(result, "dict")
|
data = await format_parsing(result, "dict")
|
||||||
chunk_data = data[:scope]
|
chunk_data = data[:scope]
|
||||||
if len(chunk_data)==scope:
|
if len(chunk_data) == scope:
|
||||||
repo.upsert(end_user_id, chunk_data)
|
repo.upsert(end_user_id, chunk_data)
|
||||||
logger.info(f'---------写入短长期-----------')
|
logger.info('---------写入短长期-----------')
|
||||||
else:
|
else:
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||||
long_messages = await messages_parse(long_time_data)
|
long_messages = await messages_parse(long_time_data)
|
||||||
repo.upsert(end_user_id, long_messages)
|
repo.upsert(end_user_id, long_messages)
|
||||||
logger.info(f'写入短长期:')
|
logger.info('写入短长期:')
|
||||||
|
|
||||||
|
|
||||||
|
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||||
|
"""
|
||||||
|
TODO 考虑作为滑动窗口写入的函数
|
||||||
|
Process dialogue based on window size and write to Neo4j
|
||||||
|
|
||||||
'''根据窗口'''
|
Manages conversation data based on a sliding window approach. When the window
|
||||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
|
||||||
'''
|
|
||||||
根据窗口获取redis数据,写入neo4j:
|
Args:
|
||||||
Args:
|
end_user_id: Terminal user identifier
|
||||||
end_user_id: 终端用户ID
|
memory_config: Memory configuration object containing settings
|
||||||
memory_config: 内存配置对象
|
langchain_messages: Original message data list
|
||||||
langchain_messages:原始数据LIST
|
scope: Window size determining when to trigger long-term storage
|
||||||
scope:窗口大小
|
"""
|
||||||
'''
|
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
|
||||||
scope=scope
|
if is_end_user_has_history:
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
end_user_visit_count, redis_messages = is_end_user_has_history
|
||||||
if is_end_user_id is not False:
|
else:
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
return
|
||||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
end_user_visit_count += 1
|
||||||
is_end_user_id += 1
|
if end_user_visit_count < scope:
|
||||||
langchain_messages += redis_messages
|
redis_messages.extend(langchain_messages)
|
||||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
|
||||||
elif int(is_end_user_id) == int(scope):
|
else:
|
||||||
logger.info('写入长期记忆NEO4J')
|
logger.info('写入长期记忆NEO4J')
|
||||||
formatted_messages = (redis_messages)
|
redis_messages.extend(langchain_messages)
|
||||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||||
if hasattr(memory_config, 'config_id'):
|
if hasattr(memory_config, 'config_id'):
|
||||||
config_id = memory_config.config_id
|
config_id = memory_config.config_id
|
||||||
else:
|
else:
|
||||||
config_id = memory_config
|
config_id = memory_config
|
||||||
|
|
||||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
scheduler.push_task(
|
||||||
config_id, formatted_messages)
|
"app.core.memory.agent.write_message",
|
||||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
str(end_user_id),
|
||||||
else:
|
{
|
||||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
"end_user_id": str(end_user_id),
|
||||||
|
"message": redis_messages,
|
||||||
|
"config_id": str(config_id),
|
||||||
|
"storage_type": AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||||
|
"user_rag_memory_id": ""
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# write_message_task.delay(
|
||||||
|
# end_user_id, # end_user_id: User ID
|
||||||
|
# redis_messages, # message: JSON string format message list
|
||||||
|
# config_id, # config_id: Configuration ID string
|
||||||
|
# AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
||||||
|
# "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||||
|
# )
|
||||||
|
count_store.update_sessions_count(end_user_id, 0, [])
|
||||||
|
|
||||||
|
|
||||||
"""根据时间"""
|
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
|
||||||
'''
|
|
||||||
根据时间获取redis数据,写入neo4j:
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
memory_config: 内存配置对象
|
|
||||||
'''
|
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
|
||||||
format_messages = (long_time_data)
|
|
||||||
messages=[]
|
|
||||||
memory_config=memory_config.config_id
|
|
||||||
for i in format_messages:
|
|
||||||
message=json.loads(i['Query'])
|
|
||||||
messages+= message
|
|
||||||
if format_messages!=[]:
|
|
||||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
|
||||||
memory_config, messages)
|
|
||||||
'''聚合判断'''
|
|
||||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
|
||||||
"""
|
"""
|
||||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
Process memory storage based on time intervals and write to Neo4j
|
||||||
|
|
||||||
|
Retrieves Redis data based on time intervals and writes it to Neo4j for
|
||||||
|
long-term storage. This function handles time-based memory consolidation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 终端用户ID
|
end_user_id: Terminal user identifier
|
||||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
memory_config: Memory configuration object containing settings
|
||||||
memory_config: 内存配置对象
|
time: Time interval for data retrieval
|
||||||
"""
|
"""
|
||||||
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||||
|
format_messages = long_time_data
|
||||||
|
messages = []
|
||||||
|
memory_config = memory_config.config_id
|
||||||
|
for i in format_messages:
|
||||||
|
message = json.loads(i['Query'])
|
||||||
|
messages += message
|
||||||
|
if format_messages:
|
||||||
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
|
memory_config, messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||||
|
"""
|
||||||
|
Aggregation judgment function: determine if input sentence and historical messages describe the same event
|
||||||
|
|
||||||
|
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
|
||||||
|
historical data or stored as separate events. This helps optimize memory storage and retrieval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: Terminal user identifier
|
||||||
|
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||||
|
memory_config: Memory configuration object containing LLM settings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Aggregation judgment result containing is_same_event flag and processed output
|
||||||
|
"""
|
||||||
|
history = None
|
||||||
try:
|
try:
|
||||||
# 1. 获取历史会话数据(使用新方法)
|
# 1. Get historical session data (using new method)
|
||||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||||
history = await format_parsing(result)
|
history = await format_parsing(result)
|
||||||
if not result:
|
if not result:
|
||||||
@@ -225,9 +285,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
|||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"is_same_event": False,
|
"is_same_event": False,
|
||||||
|
|||||||
@@ -2,41 +2,53 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
from app.core.memory.src.search import (
|
from app.core.memory.src.search import (
|
||||||
search_by_temporal,
|
search_by_temporal,
|
||||||
search_by_keyword_temporal,
|
search_by_keyword_temporal,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_tool_message_content(response):
|
def extract_tool_message_content(response):
|
||||||
"""从agent响应中提取ToolMessage内容和工具名称"""
|
"""
|
||||||
|
Extract ToolMessage content and tool names from agent response
|
||||||
|
|
||||||
|
Parses agent response messages to extract tool execution results and metadata.
|
||||||
|
Handles JSON parsing and provides structured access to tool output data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Agent response dictionary containing messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
|
||||||
|
- tool_name: Name of the executed tool
|
||||||
|
- content: Parsed tool execution result (JSON or raw text)
|
||||||
|
"""
|
||||||
messages = response.get('messages', [])
|
messages = response.get('messages', [])
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
||||||
# 这是一个ToolMessage
|
# This is a ToolMessage
|
||||||
tool_content = message.content
|
tool_content = message.content
|
||||||
tool_name = None
|
tool_name = None
|
||||||
|
|
||||||
# 尝试获取工具名称
|
# Try to get tool name
|
||||||
if hasattr(message, 'name'):
|
if hasattr(message, 'name'):
|
||||||
tool_name = message.name
|
tool_name = message.name
|
||||||
elif hasattr(message, 'tool_name'):
|
elif hasattr(message, 'tool_name'):
|
||||||
tool_name = message.tool_name
|
tool_name = message.tool_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析JSON内容
|
# Parse JSON content
|
||||||
parsed_content = json.loads(tool_content)
|
parsed_content = json.loads(tool_content)
|
||||||
return {
|
return {
|
||||||
'tool_name': tool_name,
|
'tool_name': tool_name,
|
||||||
'content': parsed_content
|
'content': parsed_content
|
||||||
}
|
}
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# 如果不是JSON格式,直接返回内容
|
# If not JSON format, return content directly
|
||||||
return {
|
return {
|
||||||
'tool_name': tool_name,
|
'tool_name': tool_name,
|
||||||
'content': tool_content
|
'content': tool_content
|
||||||
@@ -46,26 +58,49 @@ def extract_tool_message_content(response):
|
|||||||
|
|
||||||
|
|
||||||
class TimeRetrievalInput(BaseModel):
|
class TimeRetrievalInput(BaseModel):
|
||||||
"""时间检索工具的输入模式"""
|
"""
|
||||||
|
Input schema for time retrieval tool
|
||||||
|
|
||||||
|
Defines the expected input parameters for time-based retrieval operations.
|
||||||
|
Used for validation and documentation of tool parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
context: User input query content for search
|
||||||
|
end_user_id: Group ID for filtering search results, defaults to test user
|
||||||
|
"""
|
||||||
context: str = Field(description="用户输入的查询内容")
|
context: str = Field(description="用户输入的查询内容")
|
||||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||||
|
|
||||||
|
|
||||||
def create_time_retrieval_tool(end_user_id: str):
|
def create_time_retrieval_tool(end_user_id: str):
|
||||||
"""
|
"""
|
||||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
|
||||||
|
|
||||||
|
Creates a specialized time-based retrieval tool that searches for statements within
|
||||||
|
specified time ranges. Includes field cleaning functionality to remove unnecessary
|
||||||
|
metadata from search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: User identifier for scoping search results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: Configured TimeRetrievalWithGroupId tool function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def clean_temporal_result_fields(data):
|
def clean_temporal_result_fields(data):
|
||||||
"""
|
"""
|
||||||
清理时间搜索结果中不需要的字段,并修改结构
|
Clean unnecessary fields from temporal search results and modify structure
|
||||||
|
|
||||||
|
Removes metadata fields that are not needed for end-user consumption and
|
||||||
|
restructures the response format for better usability.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 要清理的数据
|
data: Data to be cleaned (dict, list, or other types)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
清理后的数据
|
Cleaned data with unnecessary fields removed
|
||||||
"""
|
"""
|
||||||
# 需要过滤的字段列表
|
# List of fields to filter out
|
||||||
fields_to_remove = {
|
fields_to_remove = {
|
||||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||||
'valid_at', 'invalid_at', 'statement_ids'
|
'valid_at', 'invalid_at', 'statement_ids'
|
||||||
@@ -75,9 +110,9 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
cleaned = {}
|
cleaned = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
||||||
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
|
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
|
||||||
cleaned_value = clean_temporal_result_fields(value)
|
cleaned_value = clean_temporal_result_fields(value)
|
||||||
# 进一步将内部的 statements 改为 time_search
|
# Further change internal statements to time_search
|
||||||
if 'statements' in cleaned_value:
|
if 'statements' in cleaned_value:
|
||||||
cleaned['results'] = {
|
cleaned['results'] = {
|
||||||
'time_search': cleaned_value['statements']
|
'time_search': cleaned_value['statements']
|
||||||
@@ -93,24 +128,33 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
|
||||||
|
end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
|
||||||
显式接收参数:
|
|
||||||
- context: 查询上下文内容
|
Performs time-based search operations with automatic metadata filtering. Supports
|
||||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
flexible date range specification and provides clean, user-friendly output.
|
||||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
|
||||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
Explicit parameters:
|
||||||
- clean_output: 是否清理输出中的元数据字段
|
- context: Query context content
|
||||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||||
|
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||||
|
- end_user_id_param: Group ID (optional, overrides default group ID)
|
||||||
|
- clean_output: Whether to clean metadata fields from output
|
||||||
|
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON formatted search results with temporal data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
# 使用传入的参数或默认值
|
# Use passed parameters or default values
|
||||||
actual_end_user_id = end_user_id_param or end_user_id
|
actual_end_user_id = end_user_id_param or end_user_id
|
||||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
# 基本时间搜索
|
# Basic time search
|
||||||
results = await search_by_temporal(
|
results = await search_by_temporal(
|
||||||
end_user_id=actual_end_user_id,
|
end_user_id=actual_end_user_id,
|
||||||
start_date=actual_start_date,
|
start_date=actual_start_date,
|
||||||
@@ -118,7 +162,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
limit=10
|
limit=10
|
||||||
)
|
)
|
||||||
|
|
||||||
# 清理结果中不需要的字段
|
# Clean unnecessary fields from results
|
||||||
if clean_output:
|
if clean_output:
|
||||||
cleaned_results = clean_temporal_result_fields(results)
|
cleaned_results = clean_temporal_result_fields(results)
|
||||||
else:
|
else:
|
||||||
@@ -129,22 +173,32 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
return asyncio.run(_async_search())
|
return asyncio.run(_async_search())
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
|
||||||
|
clean_output: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
|
||||||
显式接收参数:
|
|
||||||
- context: 查询内容
|
Performs combined keyword and temporal search operations with automatic metadata
|
||||||
- days_back: 向前搜索的天数,默认7天
|
filtering. Provides more targeted search results by combining content relevance
|
||||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
with time-based filtering.
|
||||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
|
||||||
- clean_output: 是否清理输出中的元数据字段
|
Explicit parameters:
|
||||||
- end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
- context: Query content for keyword matching
|
||||||
|
- days_back: Number of days to search backwards, default 7 days
|
||||||
|
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||||
|
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||||
|
- clean_output: Whether to clean metadata fields from output
|
||||||
|
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON formatted search results combining keyword and temporal data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
# 关键词时间搜索
|
# Keyword time search
|
||||||
results = await search_by_keyword_temporal(
|
results = await search_by_keyword_temporal(
|
||||||
query_text=context,
|
query_text=context,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
limit=15
|
limit=15
|
||||||
)
|
)
|
||||||
|
|
||||||
# 清理结果中不需要的字段
|
# Clean unnecessary fields from results
|
||||||
if clean_output:
|
if clean_output:
|
||||||
cleaned_results = clean_temporal_result_fields(results)
|
cleaned_results = clean_temporal_result_fields(results)
|
||||||
else:
|
else:
|
||||||
@@ -168,43 +222,53 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
|
|
||||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||||
"""
|
"""
|
||||||
创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段
|
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
|
||||||
|
|
||||||
|
Creates an advanced hybrid search tool that combines multiple search strategies
|
||||||
|
(keyword, vector, hybrid) with automatic result cleaning and formatting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_config: 内存配置对象
|
memory_config: Memory configuration object containing LLM and search settings
|
||||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
**search_params: Search parameters including end_user_id, limit, include, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: Configured HybridSearch tool function with async capabilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def clean_result_fields(data):
|
def clean_result_fields(data):
|
||||||
"""
|
"""
|
||||||
递归清理结果中不需要的字段
|
Recursively clean unnecessary fields from results
|
||||||
|
|
||||||
|
Removes metadata fields that are not needed for end-user consumption,
|
||||||
|
improving readability and reducing response size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 要清理的数据(可能是字典、列表或其他类型)
|
data: Data to be cleaned (can be dict, list, or other types)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
清理后的数据
|
Cleaned data with unnecessary fields removed
|
||||||
"""
|
"""
|
||||||
# 需要过滤的字段列表
|
# List of fields to filter out
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||||
fields_to_remove = {
|
fields_to_remove = {
|
||||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
'created_at', 'chunk_id', 'apply_id',
|
||||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||||
}
|
}
|
||||||
|
# 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements
|
||||||
|
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
# 对字典进行清理
|
# Clean dictionary
|
||||||
cleaned = {}
|
cleaned = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key not in fields_to_remove:
|
if key not in fields_to_remove:
|
||||||
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
|
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
|
||||||
return cleaned
|
return cleaned
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
# 对列表中的每个元素进行清理
|
# Clean each element in list
|
||||||
return [clean_result_fields(item) for item in data]
|
return [clean_result_fields(item) for item in data]
|
||||||
else:
|
else:
|
||||||
# 其他类型直接返回
|
# Return other types directly
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -216,49 +280,55 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
rerank_alpha: float = 0.6,
|
rerank_alpha: float = 0.6,
|
||||||
use_forgetting_rerank: bool = False,
|
use_forgetting_rerank: bool = False,
|
||||||
use_llm_rerank: bool = False,
|
use_llm_rerank: bool = False,
|
||||||
clean_output: bool = True # 新增:是否清理输出字段
|
clean_output: bool = True # New: whether to clean output fields
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
|
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
|
||||||
|
|
||||||
|
Provides comprehensive search capabilities combining multiple search strategies
|
||||||
|
with intelligent result ranking and automatic metadata filtering for clean output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: 查询内容
|
context: Query content for search
|
||||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||||
limit: 结果数量限制
|
limit: Result quantity limit
|
||||||
end_user_id: 组ID,用于过滤搜索结果
|
end_user_id: Group ID for filtering search results
|
||||||
rerank_alpha: 重排序权重参数
|
rerank_alpha: Reranking weight parameter for result scoring
|
||||||
use_forgetting_rerank: 是否使用遗忘重排序
|
use_forgetting_rerank: Whether to use forgetting-based reranking
|
||||||
use_llm_rerank: 是否使用LLM重排序
|
use_llm_rerank: Whether to use LLM-based reranking
|
||||||
clean_output: 是否清理输出中的元数据字段
|
clean_output: Whether to clean metadata fields from output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON formatted comprehensive search results
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 导入run_hybrid_search函数
|
# Import run_hybrid_search function
|
||||||
from app.core.memory.src.search import run_hybrid_search
|
from app.core.memory.src.search import run_hybrid_search
|
||||||
|
|
||||||
# 合并参数,优先使用传入的参数
|
# Merge parameters, prioritize passed parameters
|
||||||
final_params = {
|
final_params = {
|
||||||
"query_text": context,
|
"query_text": context,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||||
"limit": limit or search_params.get("limit", 10),
|
"limit": limit or search_params.get("limit", 10),
|
||||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]),
|
||||||
"output_path": None, # 不保存到文件
|
"output_path": None, # Don't save to file
|
||||||
"memory_config": memory_config,
|
"memory_config": memory_config,
|
||||||
"rerank_alpha": rerank_alpha,
|
"rerank_alpha": rerank_alpha,
|
||||||
"use_forgetting_rerank": use_forgetting_rerank,
|
"use_forgetting_rerank": use_forgetting_rerank,
|
||||||
"use_llm_rerank": use_llm_rerank
|
"use_llm_rerank": use_llm_rerank
|
||||||
}
|
}
|
||||||
|
|
||||||
# 执行混合检索
|
# Execute hybrid retrieval
|
||||||
raw_results = await run_hybrid_search(**final_params)
|
raw_results = await run_hybrid_search(**final_params)
|
||||||
|
|
||||||
# 清理结果中不需要的字段
|
# Clean unnecessary fields from results
|
||||||
if clean_output:
|
if clean_output:
|
||||||
cleaned_results = clean_result_fields(raw_results)
|
cleaned_results = clean_result_fields(raw_results)
|
||||||
else:
|
else:
|
||||||
cleaned_results = raw_results
|
cleaned_results = raw_results
|
||||||
|
|
||||||
# 格式化返回结果
|
# Format return results
|
||||||
formatted_results = {
|
formatted_results = {
|
||||||
"search_query": context,
|
"search_query": context,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
@@ -281,32 +351,46 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
|
|
||||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||||
"""
|
"""
|
||||||
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
|
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
|
||||||
|
|
||||||
|
Creates a synchronous wrapper around the async hybrid search functionality,
|
||||||
|
making it compatible with synchronous tool execution environments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_config: 内存配置对象
|
memory_config: Memory configuration object containing search settings
|
||||||
**search_params: 搜索参数
|
**search_params: Search parameters for configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: Configured HybridSearchSync tool function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def HybridSearchSync(
|
def HybridSearchSync(
|
||||||
context: str,
|
context: str,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
end_user_id: str = None,
|
end_user_id: str = None,
|
||||||
clean_output: bool = True
|
clean_output: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
|
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
|
||||||
|
|
||||||
|
Provides the same hybrid search capabilities as the async version but in a
|
||||||
|
synchronous execution context. Automatically handles async-to-sync conversion.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: 查询内容
|
context: Query content for search
|
||||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||||
limit: 结果数量限制
|
limit: Result quantity limit
|
||||||
end_user_id: 组ID,用于过滤搜索结果
|
end_user_id: Group ID for filtering search results
|
||||||
clean_output: 是否清理输出中的元数据字段
|
clean_output: Whether to clean metadata fields from output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON formatted search results
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
# 创建异步工具并执行
|
# Create async tool and execute
|
||||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||||
return await async_tool.ainvoke({
|
return await async_tool.ainvoke({
|
||||||
"context": context,
|
"context": context,
|
||||||
|
|||||||
@@ -1,20 +1,28 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
async def format_parsing(messages: list,type:str='string'):
|
|
||||||
|
|
||||||
|
async def format_parsing(messages: list, type: str = 'string'):
|
||||||
"""
|
"""
|
||||||
格式化解析消息列表
|
Format and parse message lists into different output types
|
||||||
|
|
||||||
|
Processes message lists from storage and converts them into either string format
|
||||||
|
or dictionary format based on the specified type parameter. Handles JSON parsing
|
||||||
|
and role-based message organization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
messages: List of message objects from storage containing message data
|
||||||
type: 返回类型 ('string' 或 'dict')
|
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
格式化后的消息列表
|
list: Formatted message list in the specified format
|
||||||
|
- 'string': List of formatted text messages with role prefixes
|
||||||
|
- 'dict': List of dictionaries mapping user messages to AI responses
|
||||||
"""
|
"""
|
||||||
result = []
|
result = []
|
||||||
user=[]
|
user = []
|
||||||
ai=[]
|
ai = []
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
hstory_messages = message['messages']
|
hstory_messages = message['messages']
|
||||||
@@ -24,25 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
|
|||||||
role = content['role']
|
role = content['role']
|
||||||
content = content['content']
|
content = content['content']
|
||||||
if type == "string":
|
if type == "string":
|
||||||
if role == 'human' or role=="user":
|
if role == 'human' or role == "user":
|
||||||
content = '用户:' + content
|
content = '用户:' + content
|
||||||
else:
|
else:
|
||||||
content = 'AI:' + content
|
content = 'AI:' + content
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if type == "dict" :
|
if type == "dict":
|
||||||
if role == 'human' or role=="user":
|
if role == 'human' or role == "user":
|
||||||
user.append( content)
|
user.append(content)
|
||||||
else:
|
else:
|
||||||
ai.append(content)
|
ai.append(content)
|
||||||
if type == "dict":
|
if type == "dict":
|
||||||
for key,values in zip(user,ai):
|
for key, values in zip(user, ai):
|
||||||
result.append({key:values})
|
result.append({key: values})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def messages_parse(messages: list | dict):
|
async def messages_parse(messages: list | dict):
|
||||||
user=[]
|
"""
|
||||||
ai=[]
|
Parse messages from storage format into user-AI conversation pairs
|
||||||
database=[]
|
|
||||||
|
Extracts and organizes conversation data from stored message format,
|
||||||
|
separating user and AI messages and pairing them for database storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List or dictionary containing stored message data with Query fields
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of dictionaries containing user-AI message pairs for database storage
|
||||||
|
"""
|
||||||
|
user = []
|
||||||
|
ai = []
|
||||||
|
database = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
Query = message['Query']
|
Query = message['Query']
|
||||||
Query = json.loads(Query)
|
Query = json.loads(Query)
|
||||||
@@ -54,10 +75,23 @@ async def messages_parse(messages: list | dict):
|
|||||||
ai.append(data['content'])
|
ai.append(data['content'])
|
||||||
for key, values in zip(user, ai):
|
for key, values in zip(user, ai):
|
||||||
database.append({key, values})
|
database.append({key, values})
|
||||||
return database
|
return database
|
||||||
|
|
||||||
|
|
||||||
async def agent_chat_messages(user_content,ai_content):
|
async def agent_chat_messages(user_content, ai_content):
|
||||||
|
"""
|
||||||
|
Create structured chat message format for agent conversations
|
||||||
|
|
||||||
|
Formats user and AI content into a standardized message structure suitable
|
||||||
|
for agent processing and storage. Creates role-based message objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_content: User's message content string
|
||||||
|
ai_content: AI's response content string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of structured message dictionaries with role and content fields
|
||||||
|
"""
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|||||||
@@ -1,104 +1,106 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from langgraph.constants import END, START
|
|
||||||
from langgraph.graph import StateGraph
|
|
||||||
|
|
||||||
from app.db import get_db, get_db_context
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
aggregate_judgment
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
from app.db import get_db_context
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
from app.services.memory_konwledges_server import write_rag
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
||||||
|
|
||||||
|
async def long_term_storage(
|
||||||
@asynccontextmanager
|
long_term_type: str,
|
||||||
async def make_write_graph():
|
langchain_messages: list,
|
||||||
|
memory_config_id: str,
|
||||||
|
end_user_id: str,
|
||||||
|
scope: int = 6
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create a write graph workflow for memory operations.
|
Handle long-term memory storage with different strategies
|
||||||
|
|
||||||
|
Supports multiple storage strategies including chunk-based, time-based,
|
||||||
|
and aggregate judgment approaches for long-term memory persistence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User identifier
|
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||||
tools: MCP tools loaded from session
|
langchain_messages: List of messages to store
|
||||||
apply_id: Application identifier
|
memory_config_id: Memory configuration identifier
|
||||||
end_user_id: Group identifier
|
end_user_id: User group identifier
|
||||||
memory_config: MemoryConfig object containing all configuration
|
scope: Scope parameter for chunk-based storage (default: 6)
|
||||||
"""
|
"""
|
||||||
workflow = StateGraph(WriteState)
|
if langchain_messages is None:
|
||||||
workflow.add_node("save_neo4j", write_node)
|
langchain_messages = []
|
||||||
workflow.add_edge(START, "save_neo4j")
|
|
||||||
workflow.add_edge("save_neo4j", END)
|
|
||||||
|
|
||||||
graph = workflow.compile()
|
write_store.save_session_write(end_user_id, langchain_messages)
|
||||||
|
|
||||||
yield graph
|
|
||||||
|
|
||||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
|
||||||
# 获取数据库会话
|
# 获取数据库会话
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
config_service = MemoryConfigService(db_session)
|
config_service = MemoryConfigService(db_session)
|
||||||
|
# 通过 end_user_id 获取 workspace_id,确保日志和 fallback 逻辑完整
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
import uuid as _uuid
|
||||||
|
workspace_id = None
|
||||||
|
try:
|
||||||
|
connected = get_end_user_connected_config(end_user_id, db_session)
|
||||||
|
raw = connected.get("workspace_id")
|
||||||
|
if raw and raw != "None":
|
||||||
|
workspace_id = _uuid.UUID(str(raw))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
memory_config = config_service.load_memory_config(
|
memory_config = config_service.load_memory_config(
|
||||||
config_id=memory_config, # 改为整数
|
config_id=memory_config_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
if long_term_type=='chunk':
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||||
'''方案一:对话窗口6轮对话'''
|
# Dialogue window with 6 rounds of conversation
|
||||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||||
if long_term_type=='time':
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||||
"""时间"""
|
# Time-based strategy
|
||||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||||
if long_term_type=='aggregate':
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||||
"""方案三:聚合判断"""
|
# Aggregate judgment
|
||||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||||
|
|
||||||
|
|
||||||
|
async def write_long_term(
|
||||||
|
storage_type: str,
|
||||||
|
end_user_id: str,
|
||||||
|
messages: list[dict],
|
||||||
|
user_rag_memory_id: str,
|
||||||
|
actual_config_id: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Write long-term memory with different storage types
|
||||||
|
|
||||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
Handles both RAG-based storage and traditional memory storage approaches.
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
For traditional storage, uses chunk-based strategy with paired user-AI messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_type: Type of storage (RAG or traditional)
|
||||||
|
end_user_id: User group identifier
|
||||||
|
messages: message list
|
||||||
|
user_rag_memory_id: RAG memory identifier
|
||||||
|
actual_config_id: Actual configuration ID
|
||||||
|
"""
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
|
||||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
message_content = []
|
||||||
|
for message in messages:
|
||||||
|
message_content.append(f'{message.get("role")}:{message.get("content")}')
|
||||||
|
messages_string = "\n".join(message_content)
|
||||||
|
await write_rag(end_user_id, messages_string, user_rag_memory_id)
|
||||||
else:
|
else:
|
||||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
await long_term_storage(long_term_type=CHUNK,
|
||||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
langchain_messages=messages,
|
||||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
memory_config_id=actual_config_id,
|
||||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
end_user_id=end_user_id,
|
||||||
|
scope=SCOPE)
|
||||||
# async def main():
|
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
|
||||||
# """主函数 - 运行工作流"""
|
|
||||||
# langchain_messages = [
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": "今天周五去爬山"
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": "好耶"
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# ]
|
|
||||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
|
||||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
|
||||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# import asyncio
|
|
||||||
# asyncio.run(main())
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class ParameterBuilder:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize the parameter builder."""
|
"""Initialize the parameter builder."""
|
||||||
logger.info("ParameterBuilder initialized")
|
logger.debug("ParameterBuilder initialized")
|
||||||
|
|
||||||
def build_tool_args(
|
def build_tool_args(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -7,21 +7,88 @@ and deduplication.
|
|||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.core.memory.enums import Neo4jNodeType
|
||||||
from app.core.memory.src.search import run_hybrid_search
|
from app.core.memory.src.search import run_hybrid_search
|
||||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||||
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||||
|
_EXPAND_FIELDS_TO_REMOVE = {
|
||||||
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
|
'created_at', 'chunk_id', 'apply_id',
|
||||||
|
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_expand_fields(obj):
|
||||||
|
"""递归过滤展开结果中不可序列化的字段(DateTime 等)。"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_clean_expand_fields(i) for i in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
async def expand_communities_to_statements(
|
||||||
|
community_results: List[dict],
|
||||||
|
end_user_id: str,
|
||||||
|
existing_content: str = "",
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Tuple[List[dict], List[str]]:
|
||||||
|
"""
|
||||||
|
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
||||||
|
|
||||||
|
- 对展开结果去重(过滤已在 existing_content 中出现的文本)
|
||||||
|
- 过滤不可序列化字段
|
||||||
|
- 返回 (cleaned_expanded_stmts, new_texts)
|
||||||
|
- cleaned_expanded_stmts: 可直接写回 raw_results 的列表
|
||||||
|
- new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content
|
||||||
|
"""
|
||||||
|
community_ids = [r.get("id") for r in community_results if r.get("id")]
|
||||||
|
if not community_ids or not end_user_id:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
result = await search_graph_community_expand(
|
||||||
|
connector=connector,
|
||||||
|
community_ids=community_ids,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}")
|
||||||
|
return [], []
|
||||||
|
finally:
|
||||||
|
await connector.close()
|
||||||
|
|
||||||
|
expanded_stmts = result.get("expanded_statements", [])
|
||||||
|
if not expanded_stmts:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
existing_lines = set(existing_content.splitlines())
|
||||||
|
new_texts = [
|
||||||
|
s["statement"] for s in expanded_stmts
|
||||||
|
if s.get("statement") and s["statement"] not in existing_lines
|
||||||
|
]
|
||||||
|
cleaned = _clean_expand_fields(expanded_stmts)
|
||||||
|
logger.info(
|
||||||
|
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||||
|
return cleaned, new_texts
|
||||||
|
|
||||||
|
|
||||||
class SearchService:
|
class SearchService:
|
||||||
"""Service for executing hybrid search and processing results."""
|
"""Service for executing hybrid search and processing results."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize the search service."""
|
"""Initialize the search service."""
|
||||||
logger.info("SearchService initialized")
|
logger.debug("SearchService initialized")
|
||||||
|
|
||||||
def extract_content_from_result(self, result: dict) -> str:
|
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
Extract only meaningful content from search results, dropping all metadata.
|
Extract only meaningful content from search results, dropping all metadata.
|
||||||
|
|
||||||
@@ -30,9 +97,11 @@ class SearchService:
|
|||||||
- Entities: extract 'name' and 'fact_summary' fields
|
- Entities: extract 'name' and 'fact_summary' fields
|
||||||
- Summaries: extract 'content' field
|
- Summaries: extract 'content' field
|
||||||
- Chunks: extract 'content' field
|
- Chunks: extract 'content' field
|
||||||
|
- Communities: extract 'content' field (c.summary), prefixed with community name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
result: Search result dictionary
|
result: Search result dictionary
|
||||||
|
node_type: Hint for node type ("community", "summary", etc.)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Clean content string without metadata
|
Clean content string without metadata
|
||||||
@@ -43,11 +112,24 @@ class SearchService:
|
|||||||
content_parts = []
|
content_parts = []
|
||||||
|
|
||||||
# Statements: extract statement field
|
# Statements: extract statement field
|
||||||
if 'statement' in result and result['statement']:
|
if Neo4jNodeType.STATEMENT in result and result[Neo4jNodeType.STATEMENT]:
|
||||||
content_parts.append(result['statement'])
|
content_parts.append(result[Neo4jNodeType.STATEMENT])
|
||||||
|
|
||||||
# Summaries/Chunks: extract content field
|
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||||
if 'content' in result and result['content']:
|
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||||
|
is_community = (
|
||||||
|
node_type == Neo4jNodeType.COMMUNITY
|
||||||
|
or 'member_count' in result
|
||||||
|
or 'core_entities' in result
|
||||||
|
)
|
||||||
|
if is_community:
|
||||||
|
name = result.get('name', '')
|
||||||
|
content = result.get('content', '')
|
||||||
|
if content:
|
||||||
|
prefix = f"[主题:{name}] " if name else ""
|
||||||
|
content_parts.append(f"{prefix}{content}")
|
||||||
|
elif 'content' in result and result['content']:
|
||||||
|
# Summaries / Chunks
|
||||||
content_parts.append(result['content'])
|
content_parts.append(result['content'])
|
||||||
|
|
||||||
# Entities: extract name and fact_summary (commented out in original)
|
# Entities: extract name and fact_summary (commented out in original)
|
||||||
@@ -77,7 +159,7 @@ class SearchService:
|
|||||||
|
|
||||||
# Remove wrapping quotes
|
# Remove wrapping quotes
|
||||||
if (q.startswith("'") and q.endswith("'")) or (
|
if (q.startswith("'") and q.endswith("'")) or (
|
||||||
q.startswith('"') and q.endswith('"')
|
q.startswith('"') and q.endswith('"')
|
||||||
):
|
):
|
||||||
q = q[1:-1]
|
q = q[1:-1]
|
||||||
|
|
||||||
@@ -90,16 +172,17 @@ class SearchService:
|
|||||||
return q
|
return q
|
||||||
|
|
||||||
async def execute_hybrid_search(
|
async def execute_hybrid_search(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
question: str,
|
question: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
rerank_alpha: float = 0.4,
|
rerank_alpha: float = 0.4,
|
||||||
output_path: str = "search_results.json",
|
output_path: str = "search_results.json",
|
||||||
return_raw_results: bool = False,
|
return_raw_results: bool = False,
|
||||||
memory_config = None
|
memory_config=None,
|
||||||
|
expand_communities: bool = True,
|
||||||
) -> Tuple[str, str, Optional[dict]]:
|
) -> Tuple[str, str, Optional[dict]]:
|
||||||
"""
|
"""
|
||||||
Execute hybrid search and return clean content.
|
Execute hybrid search and return clean content.
|
||||||
@@ -114,13 +197,15 @@ class SearchService:
|
|||||||
output_path: Path to save search results (default: "search_results.json")
|
output_path: Path to save search results (default: "search_results.json")
|
||||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||||
memory_config: Memory configuration object (required)
|
memory_config: Memory configuration object (required)
|
||||||
|
expand_communities: If True, expand community hits to member statements (default: True).
|
||||||
|
Set to False for quick-summary paths that only need community-level text.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (clean_content, cleaned_query, raw_results)
|
Tuple of (clean_content, cleaned_query, raw_results)
|
||||||
raw_results is None if return_raw_results=False
|
raw_results is None if return_raw_results=False
|
||||||
"""
|
"""
|
||||||
if include is None:
|
if include is None:
|
||||||
include = ["statements", "chunks", "entities", "summaries"]
|
include = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
||||||
|
|
||||||
# Clean query
|
# Clean query
|
||||||
cleaned_query = self.clean_query(question)
|
cleaned_query = self.clean_query(question)
|
||||||
@@ -146,8 +231,8 @@ class SearchService:
|
|||||||
if search_type == "hybrid":
|
if search_type == "hybrid":
|
||||||
reranked_results = answer.get('reranked_results', {})
|
reranked_results = answer.get('reranked_results', {})
|
||||||
|
|
||||||
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
priority_order = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
||||||
|
|
||||||
for category in priority_order:
|
for category in priority_order:
|
||||||
if category in include and category in reranked_results:
|
if category in include and category in reranked_results:
|
||||||
@@ -157,7 +242,7 @@ class SearchService:
|
|||||||
else:
|
else:
|
||||||
# For keyword or embedding search, results are directly in answer dict
|
# For keyword or embedding search, results are directly in answer dict
|
||||||
# Apply same priority order
|
# Apply same priority order
|
||||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
priority_order = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
||||||
|
|
||||||
for category in priority_order:
|
for category in priority_order:
|
||||||
if category in include and category in answer:
|
if category in include and category in answer:
|
||||||
@@ -165,12 +250,25 @@ class SearchService:
|
|||||||
if isinstance(category_results, list):
|
if isinstance(category_results, list):
|
||||||
answer_list.extend(category_results)
|
answer_list.extend(category_results)
|
||||||
|
|
||||||
# Extract clean content from all results
|
# 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要)
|
||||||
content_list = [
|
if expand_communities and Neo4jNodeType.COMMUNITY in include:
|
||||||
self.extract_content_from_result(ans)
|
community_results = (
|
||||||
for ans in answer_list
|
answer.get('reranked_results', {}).get(Neo4jNodeType.COMMUNITY.value, [])
|
||||||
]
|
if search_type == "hybrid"
|
||||||
|
else answer.get(Neo4jNodeType.COMMUNITY.value, [])
|
||||||
|
)
|
||||||
|
cleaned_stmts, new_texts = await expand_communities_to_statements(
|
||||||
|
community_results=community_results,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
answer_list.extend(cleaned_stmts)
|
||||||
|
|
||||||
|
# Extract clean content from all results,按类型传入 node_type 区分 community
|
||||||
|
content_list = []
|
||||||
|
for ans in answer_list:
|
||||||
|
# community 节点有 member_count 或 core_entities 字段
|
||||||
|
ntype = Neo4jNodeType.COMMUNITY if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||||
|
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||||
|
|
||||||
# Filter out empty strings and join with newlines
|
# Filter out empty strings and join with newlines
|
||||||
clean_content = '\n'.join([c for c in content_list if c])
|
clean_content = '\n'.join([c for c in content_list if c])
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class SessionService:
|
|||||||
store: Redis session store instance
|
store: Redis session store instance
|
||||||
"""
|
"""
|
||||||
self.store = store
|
self.store = store
|
||||||
logger.info("SessionService initialized")
|
logger.debug("SessionService initialized")
|
||||||
|
|
||||||
def resolve_user_id(self, session_string: str) -> str:
|
def resolve_user_id(self, session_string: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class TemplateService:
|
|||||||
loader=FileSystemLoader(template_root),
|
loader=FileSystemLoader(template_root),
|
||||||
autoescape=False # Disable autoescape for prompt templates
|
autoescape=False # Disable autoescape for prompt templates
|
||||||
)
|
)
|
||||||
logger.info(f"TemplateService initialized with root: {template_root}")
|
logger.debug(f"TemplateService initialized with root: {template_root}")
|
||||||
|
|
||||||
@lru_cache(maxsize=128)
|
@lru_cache(maxsize=128)
|
||||||
def _load_template(self, template_name: str) -> Template:
|
def _load_template(self, template_name: str) -> Template:
|
||||||
|
|||||||
@@ -1,7 +1,4 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||||
@@ -11,17 +8,20 @@ async def get_chunked_dialogs(
|
|||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
end_user_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
messages: list = None,
|
messages: list = None,
|
||||||
ref_id: str = "wyl_20251027",
|
ref_id: str = "",
|
||||||
config_id: str = None
|
config_id: str = None,
|
||||||
|
workspace_id=None,
|
||||||
|
snapshot=None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||||
end_user_id: Group identifier
|
end_user_id: Group identifier
|
||||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
messages: Structured message list [{"role": "user", "content": "...", "dialog_at": "..."}]
|
||||||
ref_id: Reference identifier
|
ref_id: Reference identifier
|
||||||
config_id: Configuration ID for processing (used to load pruning config)
|
config_id: Configuration ID for processing (used to load pruning config)
|
||||||
|
snapshot: Optional PipelineSnapshot instance for saving pruning output
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of DialogData objects with generated chunks
|
List of DialogData objects with generated chunks
|
||||||
@@ -34,18 +34,25 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
conversation_messages = []
|
conversation_messages = []
|
||||||
|
|
||||||
|
# step1: 消息格式校验 role:user、assistant。content
|
||||||
for idx, msg in enumerate(messages):
|
for idx, msg in enumerate(messages):
|
||||||
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
||||||
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
||||||
|
|
||||||
role = msg['role']
|
role = msg['role']
|
||||||
content = msg['content']
|
content = msg['content']
|
||||||
|
files = msg.get("file_content", [])
|
||||||
|
|
||||||
if role not in ['user', 'assistant']:
|
if role not in ['user', 'assistant']:
|
||||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||||
|
|
||||||
if content.strip():
|
if content.strip():
|
||||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
conversation_messages.append(ConversationMessage(
|
||||||
|
role=role,
|
||||||
|
msg=content.strip(),
|
||||||
|
dialog_at=msg.get("dialog_at"),
|
||||||
|
files=files,
|
||||||
|
))
|
||||||
|
|
||||||
if not conversation_messages:
|
if not conversation_messages:
|
||||||
raise ValueError("Message list cannot be empty after filtering")
|
raise ValueError("Message list cannot be empty after filtering")
|
||||||
@@ -55,10 +62,10 @@ async def get_chunked_dialogs(
|
|||||||
context=conversation_context,
|
context=conversation_context,
|
||||||
ref_id=ref_id,
|
ref_id=ref_id,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
config_id=config_id
|
config_id=config_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 语义剪枝步骤(在分块之前)
|
# step2: 语义剪枝步骤(在分块之前)
|
||||||
try:
|
try:
|
||||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
||||||
from app.core.memory.models.config_models import PruningConfig
|
from app.core.memory.models.config_models import PruningConfig
|
||||||
@@ -75,6 +82,7 @@ async def get_chunked_dialogs(
|
|||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
memory_config = config_service.load_memory_config(
|
memory_config = config_service.load_memory_config(
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
service_name="semantic_pruning"
|
service_name="semantic_pruning"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -84,7 +92,7 @@ async def get_chunked_dialogs(
|
|||||||
pruning_scene=memory_config.pruning_scene or "education",
|
pruning_scene=memory_config.pruning_scene or "education",
|
||||||
pruning_threshold=memory_config.pruning_threshold,
|
pruning_threshold=memory_config.pruning_threshold,
|
||||||
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||||
ontology_classes=memory_config.ontology_classes,
|
ontology_class_infos=memory_config.ontology_class_infos,
|
||||||
)
|
)
|
||||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
||||||
|
|
||||||
@@ -94,7 +102,7 @@ async def get_chunked_dialogs(
|
|||||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||||
|
|
||||||
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
||||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
|
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client, snapshot=snapshot)
|
||||||
original_msg_count = len(dialog_data.context.msgs)
|
original_msg_count = len(dialog_data.context.msgs)
|
||||||
|
|
||||||
# 使用 prune_dataset 而不是 prune_dialog
|
# 使用 prune_dataset 而不是 prune_dialog
|
||||||
@@ -106,6 +114,13 @@ async def get_chunked_dialogs(
|
|||||||
remaining_msg_count = len(dialog_data.context.msgs)
|
remaining_msg_count = len(dialog_data.context.msgs)
|
||||||
deleted_count = original_msg_count - remaining_msg_count
|
deleted_count = original_msg_count - remaining_msg_count
|
||||||
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
|
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
|
||||||
|
|
||||||
|
# 将剪枝记录挂到 metadata,供 graph_build_step 构建节点
|
||||||
|
if pruner.pruning_records:
|
||||||
|
dialog_data.metadata["assistant_pruning_records"] = [
|
||||||
|
r.model_dump() for r in pruner.pruning_records
|
||||||
|
]
|
||||||
|
logger.info(f"[剪枝] 收集到 {len(pruner.pruning_records)} 条剪枝记录")
|
||||||
else:
|
else:
|
||||||
logger.warning("[剪枝] prune_dataset 返回空列表")
|
logger.warning("[剪枝] prune_dataset 返回空列表")
|
||||||
else:
|
else:
|
||||||
@@ -115,6 +130,7 @@ async def get_chunked_dialogs(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
|
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# step3: 分块
|
||||||
chunker = DialogueChunker(chunker_strategy)
|
chunker = DialogueChunker(chunker_strategy)
|
||||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||||
dialog_data.chunks = extracted_chunks
|
dialog_data.chunks = extracted_chunks
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, TypedDict
|
from typing import Annotated, TypedDict
|
||||||
@@ -8,10 +7,11 @@ from langgraph.graph import add_messages
|
|||||||
|
|
||||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||||
|
|
||||||
|
|
||||||
class WriteState(TypedDict):
|
class WriteState(TypedDict):
|
||||||
'''
|
"""
|
||||||
Langgrapg Writing TypedDict
|
Langgrapg Writing TypedDict
|
||||||
'''
|
"""
|
||||||
messages: Annotated[list[AnyMessage], add_messages]
|
messages: Annotated[list[AnyMessage], add_messages]
|
||||||
end_user_id: str
|
end_user_id: str
|
||||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||||
@@ -20,6 +20,7 @@ class WriteState(TypedDict):
|
|||||||
data: str
|
data: str
|
||||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
||||||
|
|
||||||
|
|
||||||
class ReadState(TypedDict):
|
class ReadState(TypedDict):
|
||||||
"""
|
"""
|
||||||
LangGraph 工作流状态定义
|
LangGraph 工作流状态定义
|
||||||
@@ -43,18 +44,21 @@ class ReadState(TypedDict):
|
|||||||
config_id: str
|
config_id: str
|
||||||
data: str # 新增字段用于传递内容
|
data: str # 新增字段用于传递内容
|
||||||
spit_data: dict # 新增字段用于传递问题分解结果
|
spit_data: dict # 新增字段用于传递问题分解结果
|
||||||
problem_extension:dict
|
problem_extension: dict
|
||||||
storage_type: str
|
storage_type: str
|
||||||
user_rag_memory_id: str
|
user_rag_memory_id: str
|
||||||
llm_id: str
|
llm_id: str
|
||||||
embedding_id: str
|
embedding_id: str
|
||||||
memory_config: object # 新增字段用于传递内存配置对象
|
memory_config: object # 新增字段用于传递内存配置对象
|
||||||
retrieve:dict
|
retrieve: dict
|
||||||
|
perceptual_data: dict
|
||||||
RetrieveSummary: dict
|
RetrieveSummary: dict
|
||||||
InputSummary: dict
|
InputSummary: dict
|
||||||
verify: dict
|
verify: dict
|
||||||
SummaryFails: dict
|
SummaryFails: dict
|
||||||
summary: dict
|
summary: dict
|
||||||
|
|
||||||
|
|
||||||
class COUNTState:
|
class COUNTState:
|
||||||
"""
|
"""
|
||||||
工作流对话检索内容计数器
|
工作流对话检索内容计数器
|
||||||
@@ -99,6 +103,7 @@ class COUNTState:
|
|||||||
self.total = 0
|
self.total = 0
|
||||||
print("[COUNTState] 已重置为 0")
|
print("[COUNTState] 已重置为 0")
|
||||||
|
|
||||||
|
|
||||||
def deduplicate_entries(entries):
|
def deduplicate_entries(entries):
|
||||||
seen = set()
|
seen = set()
|
||||||
deduped = []
|
deduped = []
|
||||||
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
|
|||||||
deduped.append(entry)
|
deduped.append(entry)
|
||||||
return deduped
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||||
grouped = defaultdict(list)
|
grouped = defaultdict(list)
|
||||||
for item in data:
|
for item in data:
|
||||||
|
|||||||
@@ -39,6 +39,30 @@
|
|||||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||||
|
|
||||||
|
## 指代消歧规则(Coreference Resolution):
|
||||||
|
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
||||||
|
|
||||||
|
1. **"用户"的消歧**:
|
||||||
|
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
||||||
|
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物,则"用户"指的就是这个人
|
||||||
|
- 示例:历史中有"老李的原名叫李建国",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
||||||
|
|
||||||
|
2. **"我"的消歧**:
|
||||||
|
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
||||||
|
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
||||||
|
|
||||||
|
3. **"他/她/它"的消歧**:
|
||||||
|
- 从上下文或历史中找出最近提到的同类实体
|
||||||
|
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
||||||
|
|
||||||
|
4. **"那个人/这个人"的消歧**:
|
||||||
|
- 从历史中找出最近提到的人物
|
||||||
|
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
||||||
|
|
||||||
|
5. **优先级**:
|
||||||
|
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
||||||
|
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
输出要求:
|
输出要求:
|
||||||
@@ -71,6 +95,34 @@
|
|||||||
"reason": "输出原问题的关键要素"
|
"reason": "输出原问题的关键要素"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
## 指代消歧示例(重要):
|
||||||
|
示例1 - "用户"的消歧:
|
||||||
|
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
||||||
|
输入问题:"用户是谁?"
|
||||||
|
输出:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"original_question": "用户是谁?",
|
||||||
|
"extended_question": "李建国是谁?",
|
||||||
|
"type": "单跳",
|
||||||
|
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
示例2 - "我"的消歧:
|
||||||
|
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
||||||
|
输入问题:"我推荐的书是什么?"
|
||||||
|
输出:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"original_question": "我推荐的书是什么?",
|
||||||
|
"extended_question": "张曼玉推荐的书是什么?",
|
||||||
|
"type": "单跳",
|
||||||
|
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
**Output format**
|
**Output format**
|
||||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||||
|
|||||||
@@ -27,6 +27,30 @@
|
|||||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||||
|
|
||||||
|
## 指代消歧规则(Coreference Resolution):
|
||||||
|
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
||||||
|
|
||||||
|
1. **"用户"的消歧**:
|
||||||
|
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
||||||
|
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物(如"老李"、"李建国"),则"用户"指的就是这个人
|
||||||
|
- 示例:历史中反复出现"老李/李建国/建国哥",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
||||||
|
|
||||||
|
2. **"我"的消歧**:
|
||||||
|
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
||||||
|
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
||||||
|
|
||||||
|
3. **"他/她/它"的消歧**:
|
||||||
|
- 从上下文或历史中找出最近提到的同类实体
|
||||||
|
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
||||||
|
|
||||||
|
4. **"那个人/这个人"的消歧**:
|
||||||
|
- 从历史中找出最近提到的人物
|
||||||
|
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
||||||
|
|
||||||
|
5. **优先级**:
|
||||||
|
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
||||||
|
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
||||||
|
|
||||||
## 指令:
|
## 指令:
|
||||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||||
单跳(Single-hop)
|
单跳(Single-hop)
|
||||||
@@ -151,6 +175,34 @@
|
|||||||
]
|
]
|
||||||
- 必须通过json.loads()的格式支持的形式输出
|
- 必须通过json.loads()的格式支持的形式输出
|
||||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||||
|
|
||||||
|
## 指代消歧示例(重要):
|
||||||
|
示例1 - "用户"的消歧:
|
||||||
|
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
||||||
|
输入问题:"用户是谁?"
|
||||||
|
输出:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "Q1",
|
||||||
|
"question": "李建国是谁?",
|
||||||
|
"type": "单跳",
|
||||||
|
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
示例2 - "我"的消歧:
|
||||||
|
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
||||||
|
输入问题:"我推荐的书是什么?"
|
||||||
|
输出:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "Q1",
|
||||||
|
"question": "张曼玉推荐的书是什么?",
|
||||||
|
"type": "单跳",
|
||||||
|
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
- 关键的JSON格式要求
|
- 关键的JSON格式要求
|
||||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import uuid
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from typing import List, Dict, Any, Optional, Union
|
from typing import List, Dict, Any, Optional, Union
|
||||||
|
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
from app.core.memory.agent.utils.redis_base import (
|
from app.core.memory.agent.utils.redis_base import (
|
||||||
serialize_messages,
|
serialize_messages,
|
||||||
deserialize_messages,
|
deserialize_messages,
|
||||||
@@ -14,7 +15,7 @@ from app.core.memory.agent.utils.redis_base import (
|
|||||||
get_current_timestamp
|
get_current_timestamp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RedisWriteStore:
|
class RedisWriteStore:
|
||||||
@@ -66,10 +67,10 @@ class RedisWriteStore:
|
|||||||
})
|
})
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[save_session_write] 保存会话失败: {e}")
|
logger.error(f"[save_session_write] 保存会话失败: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||||
@@ -112,10 +113,10 @@ class RedisWriteStore:
|
|||||||
if not results:
|
if not results:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
logger.error(f"[get_session_by_userid] 查询失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||||
@@ -144,7 +145,7 @@ class RedisWriteStore:
|
|||||||
# 只查询 write 类型的 key
|
# 只查询 write 类型的 key
|
||||||
keys = self.r.keys('session:write:*')
|
keys = self.r.keys('session:write:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -175,18 +176,16 @@ class RedisWriteStore:
|
|||||||
results.append(session_info)
|
results.append(session_info)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 按时间排序(最新的在前)
|
# 按时间排序(最新的在前)
|
||||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||||
|
|
||||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def find_user_recent_sessions(self, userid: str,
|
def find_user_recent_sessions(self, userid: str,
|
||||||
@@ -207,7 +206,7 @@ class RedisWriteStore:
|
|||||||
# 只查询 write 类型的 key
|
# 只查询 write 类型的 key
|
||||||
keys = self.r.keys('session:write:*')
|
keys = self.r.keys('session:write:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -234,11 +233,10 @@ class RedisWriteStore:
|
|||||||
# 根据时间范围过滤
|
# 根据时间范围过滤
|
||||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||||
# 排序并移除时间字段
|
# 排序并移除时间字段
|
||||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
result_items = sort_and_limit_results(filtered_items)
|
||||||
print(result_items)
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||||
|
|
||||||
return result_items
|
return result_items
|
||||||
@@ -278,7 +276,7 @@ class RedisCountStore:
|
|||||||
decode_responses=True,
|
decode_responses=True,
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
self.uudi = session_id
|
self.uuid = session_id
|
||||||
|
|
||||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -298,7 +296,7 @@ class RedisCountStore:
|
|||||||
|
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
pipe.hset(key, mapping={
|
pipe.hset(key, mapping={
|
||||||
"id": self.uudi,
|
"id": self.uuid,
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"count": int(count),
|
"count": int(count),
|
||||||
"messages": serialize_messages(messages),
|
"messages": serialize_messages(messages),
|
||||||
@@ -311,10 +309,10 @@ class RedisCountStore:
|
|||||||
|
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
|
||||||
"""
|
"""
|
||||||
通过 end_user_id 查询访问次数统计
|
通过 end_user_id 查询访问次数统计
|
||||||
|
|
||||||
@@ -335,7 +333,7 @@ class RedisCountStore:
|
|||||||
self.r.delete(index_key)
|
self.r.delete(index_key)
|
||||||
return False
|
return False
|
||||||
except Exception as type_error:
|
except Exception as type_error:
|
||||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||||
|
|
||||||
session_id = self.r.get(index_key)
|
session_id = self.r.get(index_key)
|
||||||
|
|
||||||
@@ -355,15 +353,20 @@ class RedisCountStore:
|
|||||||
messages_str = data.get('messages')
|
messages_str = data.get('messages')
|
||||||
|
|
||||||
if count is not None:
|
if count is not None:
|
||||||
messages = deserialize_messages(messages_str)
|
messages: list[dict] = deserialize_messages(messages_str)
|
||||||
return [int(count), messages]
|
return int(count), messages
|
||||||
|
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[get_sessions_count] 查询失败: {e}")
|
logger.error(f"[get_sessions_count] 查询失败: {e}")
|
||||||
return False
|
return False
|
||||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
|
||||||
messages: Any) -> bool:
|
def update_sessions_count(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
new_count: int,
|
||||||
|
messages: Any
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||||
|
|
||||||
@@ -384,17 +387,17 @@ class RedisCountStore:
|
|||||||
key_type = self.r.type(index_key)
|
key_type = self.r.type(index_key)
|
||||||
if key_type != 'string' and key_type != 'none':
|
if key_type != 'string' and key_type != 'none':
|
||||||
# 索引键类型错误,删除并返回 False
|
# 索引键类型错误,删除并返回 False
|
||||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||||
self.r.delete(index_key)
|
self.r.delete(index_key)
|
||||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
return False
|
return False
|
||||||
except Exception as type_error:
|
except Exception as type_error:
|
||||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||||
|
|
||||||
session_id = self.r.get(index_key)
|
session_id = self.r.get(index_key)
|
||||||
|
|
||||||
if not session_id:
|
if not session_id:
|
||||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 直接更新数据
|
# 直接更新数据
|
||||||
@@ -402,15 +405,15 @@ class RedisCountStore:
|
|||||||
messages_str = serialize_messages(messages)
|
messages_str = serialize_messages(messages)
|
||||||
|
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
pipe.hset(key, 'count', int(new_count))
|
pipe.hset(key, 'count', str(new_count))
|
||||||
pipe.hset(key, 'messages', messages_str)
|
pipe.hset(key, 'messages', messages_str)
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[update_sessions_count] 更新失败: {e}")
|
logger.debug(f"[update_sessions_count] 更新失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_all_count_sessions(self) -> int:
|
def delete_all_count_sessions(self) -> int:
|
||||||
@@ -453,7 +456,7 @@ class RedisSessionStore:
|
|||||||
# ==================== 写入操作 ====================
|
# ==================== 写入操作 ====================
|
||||||
|
|
||||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||||
apply_id: str, end_user_id: str) -> str:
|
apply_id: str, end_user_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
写入一条会话数据,返回 session_id
|
写入一条会话数据,返回 session_id
|
||||||
|
|
||||||
@@ -483,10 +486,10 @@ class RedisSessionStore:
|
|||||||
})
|
})
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[save_session] 保存会话失败: {e}")
|
logger.error(f"[save_session] 保存会话失败: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# ==================== 读取操作 ====================
|
# ==================== 读取操作 ====================
|
||||||
@@ -521,7 +524,7 @@ class RedisSessionStore:
|
|||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||||
end_user_id: str) -> List[Dict[str, str]]:
|
end_user_id: str) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||||
|
|
||||||
@@ -538,7 +541,7 @@ class RedisSessionStore:
|
|||||||
|
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -556,7 +559,7 @@ class RedisSessionStore:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if (data.get('apply_id') == apply_id and
|
if (data.get('apply_id') == apply_id and
|
||||||
data.get('end_user_id') == end_user_id):
|
data.get('end_user_id') == end_user_id):
|
||||||
# 支持模糊匹配或完全匹配 sessionid
|
# 支持模糊匹配或完全匹配 sessionid
|
||||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||||
matched_items.append(format_session_data(data, include_time=True))
|
matched_items.append(format_session_data(data, include_time=True))
|
||||||
@@ -565,7 +568,7 @@ class RedisSessionStore:
|
|||||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||||
|
|
||||||
return result_items
|
return result_items
|
||||||
|
|
||||||
@@ -632,7 +635,7 @@ class RedisSessionStore:
|
|||||||
|
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print("[delete_duplicate_sessions] 没有会话数据")
|
logger.debug("[delete_duplicate_sessions] 没有会话数据")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 批量获取所有数据
|
# 批量获取所有数据
|
||||||
@@ -678,7 +681,7 @@ class RedisSessionStore:
|
|||||||
deleted_count += len(batch)
|
deleted_count += len(batch)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,248 +0,0 @@
|
|||||||
"""
|
|
||||||
Write Tools for Memory Knowledge Extraction Pipeline
|
|
||||||
|
|
||||||
This module provides the main write function for executing the knowledge extraction
|
|
||||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
|
||||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|
||||||
from app.core.memory.utils.log.logging_utils import log_time
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
|
||||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
|
||||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def write(
|
|
||||||
end_user_id: str,
|
|
||||||
memory_config: MemoryConfig,
|
|
||||||
messages: list,
|
|
||||||
ref_id: str = "wyl20251027",
|
|
||||||
language: str = "zh",
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Execute the complete knowledge extraction pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: Group identifier
|
|
||||||
memory_config: MemoryConfig object containing all configuration
|
|
||||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
|
||||||
ref_id: Reference ID, defaults to "wyl20251027"
|
|
||||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
|
||||||
"""
|
|
||||||
# Extract config values
|
|
||||||
embedding_model_id = str(memory_config.embedding_model_id)
|
|
||||||
chunker_strategy = memory_config.chunker_strategy
|
|
||||||
config_id = str(memory_config.config_id)
|
|
||||||
|
|
||||||
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
|
||||||
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
|
|
||||||
logger.info(f"Workspace: {memory_config.workspace_name}")
|
|
||||||
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
|
||||||
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
|
||||||
logger.info(f"Chunker strategy: {chunker_strategy}")
|
|
||||||
logger.info(f"end_user_id ID: {end_user_id}")
|
|
||||||
|
|
||||||
# Construct clients from memory_config using factory pattern with db session
|
|
||||||
with get_db_context() as db:
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
|
||||||
embedder_client = factory.get_embedder_client_from_config(memory_config)
|
|
||||||
logger.info("LLM and embedding clients constructed")
|
|
||||||
|
|
||||||
# Initialize timing log
|
|
||||||
log_file = "logs/time.log"
|
|
||||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
with open(log_file, "a", encoding="utf-8") as f:
|
|
||||||
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
|
|
||||||
f.write(f"Config: {memory_config.config_name} (ID: {config_id})\n")
|
|
||||||
|
|
||||||
pipeline_start = time.time()
|
|
||||||
|
|
||||||
# Initialize Neo4j connector
|
|
||||||
neo4j_connector = Neo4jConnector()
|
|
||||||
|
|
||||||
# Step 1: Load and chunk data
|
|
||||||
step_start = time.time()
|
|
||||||
chunked_dialogs = await get_chunked_dialogs(
|
|
||||||
chunker_strategy=chunker_strategy,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
messages=messages,
|
|
||||||
ref_id=ref_id,
|
|
||||||
config_id=config_id,
|
|
||||||
)
|
|
||||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
|
||||||
|
|
||||||
# Step 2: Initialize and run ExtractionOrchestrator
|
|
||||||
step_start = time.time()
|
|
||||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
|
||||||
pipeline_config = get_pipeline_config(memory_config)
|
|
||||||
|
|
||||||
# Fetch ontology types if scene_id is configured
|
|
||||||
ontology_types = None
|
|
||||||
if memory_config.scene_id:
|
|
||||||
try:
|
|
||||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
|
||||||
|
|
||||||
with get_db_context() as db:
|
|
||||||
ontology_types = load_ontology_types_for_scene(
|
|
||||||
scene_id=memory_config.scene_id,
|
|
||||||
workspace_id=memory_config.workspace_id,
|
|
||||||
db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
if ontology_types:
|
|
||||||
logger.info(
|
|
||||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
orchestrator = ExtractionOrchestrator(
|
|
||||||
llm_client=llm_client,
|
|
||||||
embedder_client=embedder_client,
|
|
||||||
connector=neo4j_connector,
|
|
||||||
config=pipeline_config,
|
|
||||||
embedding_id=embedding_model_id,
|
|
||||||
language=language,
|
|
||||||
ontology_types=ontology_types,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the complete extraction pipeline
|
|
||||||
(
|
|
||||||
all_dialogue_nodes,
|
|
||||||
all_chunk_nodes,
|
|
||||||
all_statement_nodes,
|
|
||||||
all_entity_nodes,
|
|
||||||
all_statement_chunk_edges,
|
|
||||||
all_statement_entity_edges,
|
|
||||||
all_entity_entity_edges,
|
|
||||||
all_dedup_details,
|
|
||||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
|
||||||
|
|
||||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
|
||||||
|
|
||||||
# Step 3: Save all data to Neo4j database
|
|
||||||
step_start = time.time()
|
|
||||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
|
||||||
try:
|
|
||||||
await create_fulltext_indexes()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
|
||||||
|
|
||||||
# 添加死锁重试机制
|
|
||||||
max_retries = 3
|
|
||||||
retry_delay = 1 # 秒
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
success = await save_dialog_and_statements_to_neo4j(
|
|
||||||
dialogue_nodes=all_dialogue_nodes,
|
|
||||||
chunk_nodes=all_chunk_nodes,
|
|
||||||
statement_nodes=all_statement_nodes,
|
|
||||||
entity_nodes=all_entity_nodes,
|
|
||||||
statement_chunk_edges=all_statement_chunk_edges,
|
|
||||||
statement_entity_edges=all_statement_entity_edges,
|
|
||||||
entity_edges=all_entity_entity_edges,
|
|
||||||
connector=neo4j_connector
|
|
||||||
)
|
|
||||||
if success:
|
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logger.warning("Failed to save some data to Neo4j")
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
|
|
||||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = str(e)
|
|
||||||
# 检查是否是死锁错误
|
|
||||||
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
|
|
||||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
# 非死锁错误,直接抛出
|
|
||||||
raise
|
|
||||||
|
|
||||||
try:
|
|
||||||
await neo4j_connector.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error closing Neo4j connector: {e}")
|
|
||||||
|
|
||||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
|
||||||
|
|
||||||
# Step 4: Generate Memory summaries and save to Neo4j
|
|
||||||
step_start = time.time()
|
|
||||||
try:
|
|
||||||
summaries = await memory_summary_generation(
|
|
||||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
ms_connector = Neo4jConnector()
|
|
||||||
await add_memory_summary_nodes(summaries, ms_connector)
|
|
||||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await ms_connector.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
|
||||||
finally:
|
|
||||||
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
|
||||||
|
|
||||||
# Log total pipeline time
|
|
||||||
total_time = time.time() - pipeline_start
|
|
||||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
|
||||||
|
|
||||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
with open(log_file, "a", encoding="utf-8") as f:
|
|
||||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
|
||||||
|
|
||||||
# 将提取统计写入 Redis,按 workspace_id 存储
|
|
||||||
try:
|
|
||||||
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
|
||||||
|
|
||||||
stats_to_cache = {
|
|
||||||
"chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0,
|
|
||||||
"statements_count": len(all_statement_nodes) if all_statement_nodes else 0,
|
|
||||||
"triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0,
|
|
||||||
"triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0,
|
|
||||||
"temporal_count": 0,
|
|
||||||
}
|
|
||||||
await ActivityStatsCache.set_activity_stats(
|
|
||||||
workspace_id=str(memory_config.workspace_id),
|
|
||||||
stats=stats_to_cache,
|
|
||||||
)
|
|
||||||
logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
|
||||||
except Exception as cache_err:
|
|
||||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
|
||||||
|
|
||||||
logger.info("=== Pipeline Complete ===")
|
|
||||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
|
||||||
@@ -64,7 +64,7 @@ class ImplicitMemoryLLMClient:
|
|||||||
self.default_model_id = default_model_id
|
self.default_model_id = default_model_id
|
||||||
self._client_factory = MemoryClientFactory(db)
|
self._client_factory = MemoryClientFactory(db)
|
||||||
|
|
||||||
logger.info("ImplicitMemoryLLMClient initialized")
|
logger.debug("ImplicitMemoryLLMClient initialized")
|
||||||
|
|
||||||
def _get_llm_client(self, model_id: Optional[str] = None):
|
def _get_llm_client(self, model_id: Optional[str] = None):
|
||||||
"""Get LLM client instance.
|
"""Get LLM client instance.
|
||||||
|
|||||||
31
api/app/core/memory/enums.py
Normal file
31
api/app/core/memory/enums.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class StorageType(StrEnum):
|
||||||
|
NEO4J = 'neo4j'
|
||||||
|
RAG = 'rag'
|
||||||
|
|
||||||
|
|
||||||
|
class Neo4jStorageStrategy(StrEnum):
|
||||||
|
WINDOW = 'window'
|
||||||
|
TIMELINE = 'timeline'
|
||||||
|
AGGREGATE = "aggregate"
|
||||||
|
|
||||||
|
|
||||||
|
class SearchStrategy(StrEnum):
|
||||||
|
DEEP = "0"
|
||||||
|
NORMAL = "1"
|
||||||
|
QUICK = "2"
|
||||||
|
|
||||||
|
|
||||||
|
class Neo4jNodeType(StrEnum):
|
||||||
|
CHUNK = "Chunk"
|
||||||
|
COMMUNITY = "Community"
|
||||||
|
DIALOGUE = "Dialogue"
|
||||||
|
EXTRACTEDENTITY = "ExtractedEntity"
|
||||||
|
MEMORYSUMMARY = "MemorySummary"
|
||||||
|
PERCEPTUAL = "Perceptual"
|
||||||
|
STATEMENT = "Statement"
|
||||||
|
|
||||||
|
RAG = "Rag"
|
||||||
|
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
from typing import Any, List
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Fix tokenizer parallelism warning
|
# Fix tokenizer parallelism warning
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -21,6 +21,7 @@ from chonkie import (
|
|||||||
|
|
||||||
from app.core.memory.models.config_models import ChunkerConfig
|
from app.core.memory.models.config_models import ChunkerConfig
|
||||||
from app.core.memory.models.message_models import DialogData, Chunk
|
from app.core.memory.models.message_models import DialogData, Chunk
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class LLMChunker:
|
class LLMChunker:
|
||||||
"""LLM-based intelligent chunking strategy"""
|
"""LLM-based intelligent chunking strategy"""
|
||||||
|
|
||||||
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
@@ -46,7 +48,8 @@ class LLMChunker:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
|
{"role": "system",
|
||||||
|
"content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
|
||||||
{"role": "user", "content": prompt}
|
{"role": "user", "content": prompt}
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -239,6 +242,7 @@ class ChunkerClient:
|
|||||||
chunk = Chunk(
|
chunk = Chunk(
|
||||||
content=f"{msg.role}: {sub_chunk_text}",
|
content=f"{msg.role}: {sub_chunk_text}",
|
||||||
speaker=msg.role, # 直接继承角色
|
speaker=msg.role, # 直接继承角色
|
||||||
|
dialog_at=getattr(msg, "dialog_at", None),
|
||||||
metadata={
|
metadata={
|
||||||
"message_index": msg_idx,
|
"message_index": msg_idx,
|
||||||
"message_role": msg.role,
|
"message_role": msg.role,
|
||||||
@@ -246,6 +250,7 @@ class ChunkerClient:
|
|||||||
"total_sub_chunks": len(sub_chunks),
|
"total_sub_chunks": len(sub_chunks),
|
||||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||||
},
|
},
|
||||||
|
files=msg.files
|
||||||
)
|
)
|
||||||
dialogue.chunks.append(chunk)
|
dialogue.chunks.append(chunk)
|
||||||
else:
|
else:
|
||||||
@@ -253,11 +258,13 @@ class ChunkerClient:
|
|||||||
chunk = Chunk(
|
chunk = Chunk(
|
||||||
content=f"{msg.role}: {msg_content}",
|
content=f"{msg.role}: {msg_content}",
|
||||||
speaker=msg.role, # 直接继承角色
|
speaker=msg.role, # 直接继承角色
|
||||||
|
dialog_at=getattr(msg, "dialog_at", None),
|
||||||
metadata={
|
metadata={
|
||||||
"message_index": msg_idx,
|
"message_index": msg_idx,
|
||||||
"message_role": msg.role,
|
"message_role": msg.role,
|
||||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||||
},
|
},
|
||||||
|
files=msg.files
|
||||||
)
|
)
|
||||||
dialogue.chunks.append(chunk)
|
dialogue.chunks.append(chunk)
|
||||||
|
|
||||||
@@ -309,7 +316,7 @@ class ChunkerClient:
|
|||||||
f.write("=" * 60 + "\n\n")
|
f.write("=" * 60 + "\n\n")
|
||||||
|
|
||||||
for i, chunk in enumerate(dialogue.chunks):
|
for i, chunk in enumerate(dialogue.chunks):
|
||||||
f.write(f"Chunk {i+1}:\n")
|
f.write(f"Chunk {i + 1}:\n")
|
||||||
f.write(f"Size: {len(chunk.content)} characters\n")
|
f.write(f"Size: {len(chunk.content)} characters\n")
|
||||||
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
|
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
|
||||||
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")
|
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class LLMClient(ABC):
|
|||||||
self.max_retries = self.config.max_retries
|
self.max_retries = self.config.max_retries
|
||||||
self.timeout = self.config.timeout
|
self.timeout = self.config.timeout
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"初始化 LLM 客户端: provider={self.provider}, "
|
f"初始化 LLM 客户端: provider={self.provider}, "
|
||||||
f"model={self.model_name}, max_retries={self.max_retries}"
|
f"model={self.model_name}, max_retries={self.max_retries}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
|
|||||||
type=type_
|
type=type_
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
|
logger.debug(f"OpenAI 客户端初始化完成: type={type_}")
|
||||||
|
|
||||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
OpenAI Embedder 客户端实现
|
OpenAI Embedder 客户端实现
|
||||||
|
|
||||||
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
|
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
|
||||||
|
自动支持火山引擎的多模态 Embedding。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import (
|
|||||||
)
|
)
|
||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
from app.core.models.embedding import RedBearEmbeddings
|
from app.core.models.embedding import RedBearEmbeddings
|
||||||
|
from app.models.models_model import ModelProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
- 批量文本嵌入
|
- 批量文本嵌入
|
||||||
- 自动重试机制
|
- 自动重试机制
|
||||||
- 错误处理
|
- 错误处理
|
||||||
|
- 火山引擎多模态 Embedding(自动识别)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_config: RedBearModelConfig):
|
def __init__(self, model_config: RedBearModelConfig):
|
||||||
@@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
"""
|
"""
|
||||||
super().__init__(model_config)
|
super().__init__(model_config)
|
||||||
|
|
||||||
# 初始化 RedBearEmbeddings 模型
|
# 初始化 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||||
self.model = RedBearEmbeddings(
|
self.model = RedBearEmbeddings(
|
||||||
RedBearModelConfig(
|
RedBearModelConfig(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
@@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.is_multimodal = self.model.is_multimodal_supported()
|
||||||
|
|
||||||
logger.info("OpenAI Embedder 客户端初始化完成")
|
logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})")
|
||||||
|
|
||||||
async def response(
|
async def response(
|
||||||
self,
|
self,
|
||||||
@@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# 生成嵌入向量
|
# 生成嵌入向量
|
||||||
embeddings = await self.model.aembed_documents(texts)
|
if self.is_multimodal:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
embeddings = await self.model.aembed_multimodal(
|
||||||
|
[{"type": "text", "text": text} for text in texts]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 普通 Embedding
|
||||||
|
embeddings = await self.model.aembed_documents(texts)
|
||||||
|
|
||||||
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
|
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|||||||
143
api/app/core/memory/memory_service.py
Normal file
143
api/app/core/memory/memory_service.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""
|
||||||
|
MemoryService — 记忆模块统一入口(Facade)
|
||||||
|
|
||||||
|
所有外部调用方(controllers、Celery tasks、API service)只依赖此类。
|
||||||
|
|
||||||
|
职责:
|
||||||
|
- 接收已加载的 MemoryConfig,选择并调用对应的 Pipeline
|
||||||
|
- 不包含任何业务逻辑实现
|
||||||
|
- 不直接操作数据库或 LLM
|
||||||
|
|
||||||
|
依赖方向:外部调用方 → MemoryService → Pipeline → Engine → Repository
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.core.memory.pipelines.pilot_write_pipeline import PilotWriteResult
|
||||||
|
from app.core.memory.pipelines.write_pipeline import WriteResult
|
||||||
|
from app.core.memory.models.message_models import DialogData
|
||||||
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryService:
|
||||||
|
"""记忆模块统一入口
|
||||||
|
|
||||||
|
所有外部调用方(controllers、Celery tasks、API service)只依赖此类。
|
||||||
|
|
||||||
|
设计决策:
|
||||||
|
- __init__ 接收已加载的 MemoryConfig(而非 config_id),
|
||||||
|
配置加载的职责留在调用方(MemoryAgentService),
|
||||||
|
因为调用方需要 config 做其他事情(如感知记忆处理)。
|
||||||
|
- 未实现的方法抛出 NotImplementedError,明确标记待实现状态。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
memory_config: MemoryConfig,
|
||||||
|
end_user_id: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
memory_config: 已加载的不可变配置对象
|
||||||
|
end_user_id: 终端用户 ID
|
||||||
|
"""
|
||||||
|
self.memory_config = memory_config
|
||||||
|
self.end_user_id = end_user_id
|
||||||
|
|
||||||
|
async def write(
|
||||||
|
self,
|
||||||
|
messages: List[dict],
|
||||||
|
language: str = "zh",
|
||||||
|
ref_id: str = "",
|
||||||
|
is_pilot_run: bool = False,
|
||||||
|
progress_callback: Optional[
|
||||||
|
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
|
||||||
|
] = None,
|
||||||
|
) -> WriteResult:
|
||||||
|
"""写入记忆:对话 → 萃取 → 存储 → 聚类 → 摘要
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 结构化消息 [{"role": "user"/"assistant", "content": "...", "dialog_at": "..."}]
|
||||||
|
language: 语言 ("zh" | "en")
|
||||||
|
ref_id: 引用 ID,为空则自动生成
|
||||||
|
is_pilot_run: 试运行模式(只萃取不写入)
|
||||||
|
progress_callback: 可选的进度回调
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WriteResult 包含状态和统计信息
|
||||||
|
"""
|
||||||
|
from app.core.memory.pipelines.write_pipeline import WritePipeline
|
||||||
|
|
||||||
|
pipeline = WritePipeline(
|
||||||
|
memory_config=self.memory_config,
|
||||||
|
end_user_id=self.end_user_id,
|
||||||
|
language=language,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
return await pipeline.run(
|
||||||
|
messages=messages,
|
||||||
|
ref_id=ref_id,
|
||||||
|
is_pilot_run=is_pilot_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def pilot_write(
|
||||||
|
self,
|
||||||
|
chunked_dialogs: List[DialogData],
|
||||||
|
language: str = "zh",
|
||||||
|
progress_callback: Optional[
|
||||||
|
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
|
||||||
|
] = None,
|
||||||
|
) -> PilotWriteResult:
|
||||||
|
"""试运行写入:只执行萃取链路,不写入 Neo4j
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunked_dialogs: 预处理 + 分块后的 DialogData 列表
|
||||||
|
language: 语言 ("zh" | "en")
|
||||||
|
progress_callback: 可选的进度回调
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PilotWriteResult 包含萃取结果、图构建结果和去重结果
|
||||||
|
"""
|
||||||
|
from app.core.memory.pipelines.pilot_write_pipeline import PilotWritePipeline
|
||||||
|
|
||||||
|
pipeline = PilotWritePipeline(
|
||||||
|
memory_config=self.memory_config,
|
||||||
|
end_user_id=self.end_user_id,
|
||||||
|
language=language,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
return await pipeline.run(chunked_dialogs)
|
||||||
|
|
||||||
|
async def read(
|
||||||
|
self, query: str, history: list, search_switch: str
|
||||||
|
) -> dict:
|
||||||
|
"""读取记忆:根据 search_switch 选择快速/深度路径"""
|
||||||
|
raise NotImplementedError("ReadPipeline 尚未实现")
|
||||||
|
|
||||||
|
# async def search(
|
||||||
|
# self,
|
||||||
|
# query: str,
|
||||||
|
# search_type: str = "hybrid",
|
||||||
|
# limit: int = 10,
|
||||||
|
# ) -> dict:
|
||||||
|
# """独立检索:不经过 LangGraph,直接执行混合检索"""
|
||||||
|
# raise NotImplementedError("SearchPipeline 尚未实现")
|
||||||
|
|
||||||
|
async def forget(
|
||||||
|
self, max_batch: int = 100, min_days: int = 30
|
||||||
|
) -> dict:
|
||||||
|
"""遗忘:识别低激活节点并融合"""
|
||||||
|
raise NotImplementedError("ForgettingPipeline 尚未实现")
|
||||||
|
|
||||||
|
async def reflect(self) -> dict:
|
||||||
|
"""反思:检测事实冲突并修正"""
|
||||||
|
raise NotImplementedError("ReflectionPipeline 尚未实现")
|
||||||
|
|
||||||
|
# async def cluster(self, new_entity_ids: list[str] = None) -> None:
|
||||||
|
# """聚类:全量初始化或增量更新社区"""
|
||||||
|
# raise NotImplementedError("ClusteringPipeline 尚未实现")
|
||||||
@@ -58,6 +58,12 @@ from app.core.memory.models.triplet_models import (
|
|||||||
TripletExtractionResponse,
|
TripletExtractionResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# User metadata models
|
||||||
|
from app.core.memory.models.metadata_models import (
|
||||||
|
MetadataExtractionResponse,
|
||||||
|
MetadataFieldChange,
|
||||||
|
)
|
||||||
|
|
||||||
# Ontology scenario models (LLM extracted from scenarios)
|
# Ontology scenario models (LLM extracted from scenarios)
|
||||||
from app.core.memory.models.ontology_scenario_models import (
|
from app.core.memory.models.ontology_scenario_models import (
|
||||||
OntologyClass,
|
OntologyClass,
|
||||||
@@ -124,6 +130,8 @@ __all__ = [
|
|||||||
"Entity",
|
"Entity",
|
||||||
"Triplet",
|
"Triplet",
|
||||||
"TripletExtractionResponse",
|
"TripletExtractionResponse",
|
||||||
|
"MetadataExtractionResponse",
|
||||||
|
"MetadataFieldChange",
|
||||||
# Ontology models
|
# Ontology models
|
||||||
"OntologyClass",
|
"OntologyClass",
|
||||||
"OntologyExtractionResponse",
|
"OntologyExtractionResponse",
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ of the memory system including LLM, chunking, pruning, and search.
|
|||||||
Classes:
|
Classes:
|
||||||
LLMConfig: Configuration for LLM client
|
LLMConfig: Configuration for LLM client
|
||||||
ChunkerConfig: Configuration for dialogue chunking
|
ChunkerConfig: Configuration for dialogue chunking
|
||||||
|
OntologyClassInfo: Single ontology class with name and description
|
||||||
PruningConfig: Configuration for semantic pruning
|
PruningConfig: Configuration for semantic pruning
|
||||||
TemporalSearchParams: Parameters for temporal search queries
|
TemporalSearchParams: Parameters for temporal search queries
|
||||||
"""
|
"""
|
||||||
@@ -50,30 +51,41 @@ class ChunkerConfig(BaseModel):
|
|||||||
min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.")
|
min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.")
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyClassInfo(BaseModel):
|
||||||
|
"""本体类型的名称与语义描述,用于剪枝提示词注入。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
class_name: 本体类型名称(如"患者"、"课程")
|
||||||
|
class_description: 本体类型语义描述,告知 LLM 该类型在当前场景下的含义
|
||||||
|
"""
|
||||||
|
class_name: str = Field(..., description="本体类型名称")
|
||||||
|
class_description: str = Field(default="", description="本体类型语义描述")
|
||||||
|
|
||||||
|
|
||||||
class PruningConfig(BaseModel):
|
class PruningConfig(BaseModel):
|
||||||
"""Configuration for semantic pruning of dialogue content.
|
"""Configuration for semantic pruning of dialogue content.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
pruning_switch: Enable or disable semantic pruning
|
pruning_switch: Enable or disable semantic pruning
|
||||||
pruning_scene: Scene name for pruning, either a built-in key
|
pruning_scene: Scene name for pruning from ontology_scene table
|
||||||
('education', 'online_service', 'outbound') or a custom scene_name
|
|
||||||
from ontology_scene table
|
|
||||||
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
||||||
scene_id: Optional ontology scene UUID, used to load custom ontology classes
|
scene_id: Optional ontology scene UUID
|
||||||
ontology_classes: List of class_name strings from ontology_class table,
|
ontology_class_infos: Full ontology class info (name + description) from
|
||||||
injected into the prompt when pruning_scene is not a built-in scene
|
ontology_class table, injected into the pruning prompt to drive
|
||||||
|
scene-aware preservation decisions
|
||||||
"""
|
"""
|
||||||
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
||||||
pruning_scene: str = Field(
|
pruning_scene: str = Field(
|
||||||
"education",
|
"education",
|
||||||
description="Scene for pruning: built-in key or custom scene_name from ontology_scene.",
|
description="Scene name from ontology_scene table.",
|
||||||
)
|
)
|
||||||
pruning_threshold: float = Field(
|
pruning_threshold: float = Field(
|
||||||
0.5, ge=0.0, le=0.9,
|
0.5, ge=0.0, le=0.9,
|
||||||
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
||||||
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
|
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
|
||||||
ontology_classes: Optional[List[str]] = Field(
|
ontology_class_infos: List[OntologyClassInfo] = Field(
|
||||||
None, description="Class names from ontology_class table for custom scenes."
|
default_factory=list,
|
||||||
|
description="Full ontology class info (name + description) injected into pruning prompt."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -106,7 +106,6 @@ class Edge(BaseModel):
|
|||||||
end_user_id: End user ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
run_id: Unique identifier for the pipeline run that created this edge
|
run_id: Unique identifier for the pipeline run that created this edge
|
||||||
created_at: Timestamp when the edge was created (system perspective)
|
created_at: Timestamp when the edge was created (system perspective)
|
||||||
expired_at: Optional timestamp when the edge expires (system perspective)
|
|
||||||
"""
|
"""
|
||||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
||||||
source: str = Field(..., description="The ID of the source node.")
|
source: str = Field(..., description="The ID of the source node.")
|
||||||
@@ -114,7 +113,6 @@ class Edge(BaseModel):
|
|||||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkEdge(Edge):
|
class ChunkEdge(Edge):
|
||||||
@@ -162,6 +160,7 @@ class EntityEntityEdge(Edge):
|
|||||||
invalid_at: Optional end date of temporal validity
|
invalid_at: Optional end date of temporal validity
|
||||||
"""
|
"""
|
||||||
relation_type: str = Field(..., description="Relation type as defined in ontology")
|
relation_type: str = Field(..., description="Relation type as defined in ontology")
|
||||||
|
relation_type_description: str = Field(default="", description="Chinese definition of the relation type from ontology")
|
||||||
relation_value: Optional[str] = Field(None, description="Value of the relation")
|
relation_value: Optional[str] = Field(None, description="Value of the relation")
|
||||||
statement: str = Field(..., description='The statement of the edge.')
|
statement: str = Field(..., description='The statement of the edge.')
|
||||||
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
|
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
|
||||||
@@ -175,6 +174,12 @@ class EntityEntityEdge(Edge):
|
|||||||
return parse_historical_datetime(v)
|
return parse_historical_datetime(v)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualEdge(Edge):
|
||||||
|
"""Edge connecting perceptual nodes to their source chunks
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Node(BaseModel):
|
class Node(BaseModel):
|
||||||
"""Base class for all graph nodes in the knowledge graph.
|
"""Base class for all graph nodes in the knowledge graph.
|
||||||
|
|
||||||
@@ -184,14 +189,12 @@ class Node(BaseModel):
|
|||||||
end_user_id: End user ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
run_id: Unique identifier for the pipeline run that created this node
|
run_id: Unique identifier for the pipeline run that created this node
|
||||||
created_at: Timestamp when the node was created (system perspective)
|
created_at: Timestamp when the node was created (system perspective)
|
||||||
expired_at: Optional timestamp when the node expires (system perspective)
|
|
||||||
"""
|
"""
|
||||||
id: str = Field(..., description="The unique identifier for the node.")
|
id: str = Field(..., description="The unique identifier for the node.")
|
||||||
name: str = Field(..., description="The name of the node.")
|
name: str = Field(..., description="The name of the node.")
|
||||||
end_user_id: str = Field(..., description="The end user ID of the node.")
|
end_user_id: str = Field(..., description="The end user ID of the node.")
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
||||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
|
||||||
|
|
||||||
|
|
||||||
class DialogueNode(Node):
|
class DialogueNode(Node):
|
||||||
@@ -206,7 +209,8 @@ class DialogueNode(Node):
|
|||||||
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
||||||
content: str = Field(..., description="Dialogue content")
|
content: str = Field(..., description="Dialogue content")
|
||||||
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this dialogue (integer or string)")
|
||||||
|
|
||||||
|
|
||||||
class StatementNode(Node):
|
class StatementNode(Node):
|
||||||
@@ -276,12 +280,14 @@ class StatementNode(Node):
|
|||||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||||
|
dialog_at: Optional[datetime] = Field(None, description="Absolute timestamp of the conversation this statement belongs to")
|
||||||
|
|
||||||
# Embedding and other fields
|
# Embedding and other fields
|
||||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this statement (integer or string)")
|
||||||
|
|
||||||
# ACT-R Memory Activation Properties
|
# ACT-R Memory Activation Properties
|
||||||
importance_score: float = Field(
|
importance_score: float = Field(
|
||||||
@@ -310,7 +316,7 @@ class StatementNode(Node):
|
|||||||
description="Total number of times this node has been accessed"
|
description="Total number of times this node has been accessed"
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
@field_validator('valid_at', 'invalid_at', 'dialog_at', mode='before')
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_datetime(cls, v):
|
def validate_datetime(cls, v):
|
||||||
"""使用通用的历史日期解析函数"""
|
"""使用通用的历史日期解析函数"""
|
||||||
@@ -356,12 +362,14 @@ class ChunkNode(Node):
|
|||||||
Attributes:
|
Attributes:
|
||||||
dialog_id: ID of the parent dialog
|
dialog_id: ID of the parent dialog
|
||||||
content: The text content of the chunk
|
content: The text content of the chunk
|
||||||
|
speaker: Speaker identifier ('user' or 'assistant')
|
||||||
chunk_embedding: Optional embedding vector for the chunk
|
chunk_embedding: Optional embedding vector for the chunk
|
||||||
sequence_number: Order of this chunk within the dialog
|
sequence_number: Order of this chunk within the dialog
|
||||||
metadata: Additional chunk metadata as key-value pairs
|
metadata: Additional chunk metadata as key-value pairs
|
||||||
"""
|
"""
|
||||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||||
content: str = Field(..., description="The text content of the chunk")
|
content: str = Field(..., description="The text content of the chunk")
|
||||||
|
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||||
sequence_number: int = Field(..., description="Order of this chunk within the dialog")
|
sequence_number: int = Field(..., description="Order of this chunk within the dialog")
|
||||||
metadata: dict = Field(default_factory=dict, description="Additional chunk metadata")
|
metadata: dict = Field(default_factory=dict, description="Additional chunk metadata")
|
||||||
@@ -403,6 +411,7 @@ class ExtractedEntityNode(Node):
|
|||||||
entity_idx: int = Field(..., description="Unique identifier for the entity")
|
entity_idx: int = Field(..., description="Unique identifier for the entity")
|
||||||
statement_id: str = Field(..., description="Statement this entity was extracted from")
|
statement_id: str = Field(..., description="Statement this entity was extracted from")
|
||||||
entity_type: str = Field(..., description="Type of the entity")
|
entity_type: str = Field(..., description="Type of the entity")
|
||||||
|
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
|
||||||
description: str = Field(..., description="Entity description")
|
description: str = Field(..., description="Entity description")
|
||||||
example: str = Field(
|
example: str = Field(
|
||||||
default="",
|
default="",
|
||||||
@@ -416,7 +425,8 @@ class ExtractedEntityNode(Node):
|
|||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this entity (integer or string)")
|
||||||
|
|
||||||
# ACT-R Memory Activation Properties
|
# ACT-R Memory Activation Properties
|
||||||
importance_score: float = Field(
|
importance_score: float = Field(
|
||||||
@@ -451,9 +461,19 @@ class ExtractedEntityNode(Node):
|
|||||||
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
|
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# User Metadata Fields (populated by async metadata extraction after dedup)
|
||||||
|
core_facts: List[str] = Field(default_factory=list, description="Stable basic facts about the user")
|
||||||
|
traits: List[str] = Field(default_factory=list, description="Stable personality traits or behavioral tendencies")
|
||||||
|
relations: List[str] = Field(default_factory=list, description="Durable relationships with people/groups/entities")
|
||||||
|
goals: List[str] = Field(default_factory=list, description="Long-term goals or ongoing pursuits")
|
||||||
|
interests: List[str] = Field(default_factory=list, description="Stable interests, preferences, or hobbies")
|
||||||
|
beliefs_or_stances: List[str] = Field(default_factory=list, description="Stable beliefs, values, or stances")
|
||||||
|
anchors: List[str] = Field(default_factory=list, description="Personally meaningful objects or symbols")
|
||||||
|
events: List[str] = Field(default_factory=list, description="Durable personal experiences or milestones")
|
||||||
|
|
||||||
@field_validator('aliases', mode='before')
|
@field_validator('aliases', mode='before')
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||||
"""Validate and clean aliases field using utility function.
|
"""Validate and clean aliases field using utility function.
|
||||||
|
|
||||||
This validator ensures that the aliases field is always a valid list of strings.
|
This validator ensures that the aliases field is always a valid list of strings.
|
||||||
@@ -507,7 +527,8 @@ class MemorySummaryNode(Node):
|
|||||||
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
||||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this summary (integer or string)")
|
||||||
|
|
||||||
# ACT-R Forgetting Engine Properties
|
# ACT-R Forgetting Engine Properties
|
||||||
original_statement_id: Optional[str] = Field(
|
original_statement_id: Optional[str] = Field(
|
||||||
@@ -549,3 +570,62 @@ class MemorySummaryNode(Node):
|
|||||||
ge=0,
|
ge=0,
|
||||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualNode(Node):
|
||||||
|
"""Node representing a multimodal message in the knowledge graph.
|
||||||
|
"""
|
||||||
|
perceptual_type: int
|
||||||
|
file_path: str
|
||||||
|
file_name: str
|
||||||
|
file_ext: str
|
||||||
|
summary: str
|
||||||
|
keywords: list[str]
|
||||||
|
topic: str
|
||||||
|
domain: str
|
||||||
|
file_type: str
|
||||||
|
summary_embedding: list[float] | None
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantOriginalNode(Node):
|
||||||
|
"""Node storing the original text of an Assistant message before pruning.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pair_id: Shared ID with the corresponding AssistantPrunedNode for pairing
|
||||||
|
dialog_id: ID of the parent dialogue this message belongs to
|
||||||
|
text: The full original Assistant response text
|
||||||
|
"""
|
||||||
|
pair_id: str = Field(..., description="Shared pairing ID with the corresponding pruned node")
|
||||||
|
dialog_id: str = Field(..., description="ID of the parent dialogue")
|
||||||
|
text: str = Field(..., description="Original Assistant message text")
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantPrunedNode(Node):
|
||||||
|
"""Node storing the pruned (compressed) text of an Assistant message.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pair_id: Shared ID with the corresponding AssistantOriginalNode for pairing
|
||||||
|
dialog_id: ID of the parent dialogue this message belongs to
|
||||||
|
text: The pruned memory hint text (or "NULL" if no memory value)
|
||||||
|
memory_type: Type of the memory hint (comfort|suggestion|recommendation|warning|instruction|NULL)
|
||||||
|
text_embedding: Optional embedding vector for semantic search on pruned text
|
||||||
|
"""
|
||||||
|
pair_id: str = Field(..., description="Shared pairing ID with the corresponding original node")
|
||||||
|
dialog_id: str = Field(..., description="ID of the parent dialogue")
|
||||||
|
text: str = Field(..., description="Pruned assistant memory hint text")
|
||||||
|
memory_type: str = Field(..., description="Memory type: comfort|suggestion|recommendation|warning|instruction|NULL")
|
||||||
|
text_embedding: Optional[List[float]] = Field(None, description="Embedding vector for semantic search")
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantPrunedEdge(Edge):
|
||||||
|
"""Edge connecting an AssistantOriginal node to its AssistantPruned node (PRUNED_TO).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pair_id: Shared pairing ID for traceability
|
||||||
|
"""
|
||||||
|
pair_id: str = Field(..., description="Shared pairing ID for traceability")
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantDialogEdge(Edge):
|
||||||
|
"""Edge connecting an AssistantOriginal node to its parent Dialogue node (BELONGS_TO_DIALOG)."""
|
||||||
|
pass
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ class ConversationMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||||
msg: str = Field(..., description="The text content of the message.")
|
msg: str = Field(..., description="The text content of the message.")
|
||||||
|
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of this message (ISO 8601).")
|
||||||
|
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
||||||
|
|
||||||
|
|
||||||
class TemporalValidityRange(BaseModel):
|
class TemporalValidityRange(BaseModel):
|
||||||
@@ -93,6 +95,13 @@ class Statement(BaseModel):
|
|||||||
emotion_keywords: Optional[List[str]] = Field(default_factory=list, description="Emotion keywords, max 3")
|
emotion_keywords: Optional[List[str]] = Field(default_factory=list, description="Emotion keywords, max 3")
|
||||||
emotion_subject: Optional[str] = Field(None, description="Emotion subject: self/other/object")
|
emotion_subject: Optional[str] = Field(None, description="Emotion subject: self/other/object")
|
||||||
emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name")
|
emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name")
|
||||||
|
# Reference resolution
|
||||||
|
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
||||||
|
has_emotional_state: bool = Field(
|
||||||
|
False,
|
||||||
|
description="Whether the statement reflects user's emotional state",
|
||||||
|
)
|
||||||
|
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of the source message (ISO 8601).")
|
||||||
|
|
||||||
|
|
||||||
class ConversationContext(BaseModel):
|
class ConversationContext(BaseModel):
|
||||||
@@ -130,7 +139,9 @@ class Chunk(BaseModel):
|
|||||||
content: str = Field(..., description="The content of the chunk as a string.")
|
content: str = Field(..., description="The content of the chunk as a string.")
|
||||||
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
||||||
|
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
||||||
|
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of the source message (ISO 8601).")
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -147,6 +158,7 @@ class Chunk(BaseModel):
|
|||||||
return cls(
|
return cls(
|
||||||
content=f"{message.role}: {message.msg}",
|
content=f"{message.role}: {message.msg}",
|
||||||
speaker=message.role,
|
speaker=message.role,
|
||||||
|
dialog_at=message.dialog_at,
|
||||||
metadata=metadata or {}
|
metadata=metadata or {}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -161,7 +173,6 @@ class DialogData(BaseModel):
|
|||||||
ref_id: Reference ID linking to external dialog system
|
ref_id: Reference ID linking to external dialog system
|
||||||
end_user_id: End user ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
created_at: Timestamp when the dialog was created
|
created_at: Timestamp when the dialog was created
|
||||||
expired_at: Timestamp when the dialog expires (default: far future)
|
|
||||||
metadata: Additional metadata as key-value pairs
|
metadata: Additional metadata as key-value pairs
|
||||||
chunks: List of chunks from the conversation
|
chunks: List of chunks from the conversation
|
||||||
config_id: Configuration ID used to process this dialog
|
config_id: Configuration ID used to process this dialog
|
||||||
@@ -176,7 +187,6 @@ class DialogData(BaseModel):
|
|||||||
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
|
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
||||||
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the dialog.")
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the dialog.")
|
||||||
chunks: List[Chunk] = Field(default_factory=list, description="A list of chunks from the conversation context.")
|
chunks: List[Chunk] = Field(default_factory=list, description="A list of chunks from the conversation context.")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialog (integer or string)")
|
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialog (integer or string)")
|
||||||
|
|||||||
80
api/app/core/memory/models/metadata_models.py
Normal file
80
api/app/core/memory/models/metadata_models.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""Models for user metadata extraction.
|
||||||
|
|
||||||
|
Independent from triplet_models.py - these models are used by the
|
||||||
|
standalone metadata extraction pipeline (post-dedup async Celery task).
|
||||||
|
|
||||||
|
The field definitions align with the Jinja2 prompt template
|
||||||
|
``extract_user_metadata.jinja2``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataExtractionResponse(BaseModel):
|
||||||
|
"""LLM 元数据提取响应结构。
|
||||||
|
|
||||||
|
字段与 extract_user_metadata.jinja2 模板的输出 JSON 一一对应。
|
||||||
|
每个字段都是字符串数组,表示本次新增的元数据条目。
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
|
aliases: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="用户别名、昵称、称呼",
|
||||||
|
)
|
||||||
|
core_facts: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="用户稳定的基础事实(身份、年龄、国籍、所在地等)",
|
||||||
|
)
|
||||||
|
traits: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="用户稳定的人格特质、风格、行为倾向",
|
||||||
|
)
|
||||||
|
relations: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="用户与他人/群体/宠物/重要对象之间的长期关系",
|
||||||
|
)
|
||||||
|
goals: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="用户明确、稳定的长期目标或计划",
|
||||||
|
)
|
||||||
|
interests: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="用户稳定的兴趣、偏好、长期爱好",
|
||||||
|
)
|
||||||
|
beliefs_or_stances: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="用户稳定的信念、价值立场",
|
||||||
|
)
|
||||||
|
anchors: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="对用户有长期意义的物品、收藏、纪念物",
|
||||||
|
)
|
||||||
|
events: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="对用户画像有长期价值的个人经历、事件、里程碑",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 便捷属性 ──
|
||||||
|
|
||||||
|
METADATA_FIELDS: List[str] = [
|
||||||
|
"core_facts", "traits", "relations", "goals",
|
||||||
|
"interests", "beliefs_or_stances", "anchors", "events",
|
||||||
|
]
|
||||||
|
|
||||||
|
def has_any_metadata(self) -> bool:
|
||||||
|
"""是否提取到了任何元数据(不含 aliases)。"""
|
||||||
|
return any(
|
||||||
|
bool(getattr(self, field, []))
|
||||||
|
for field in self.METADATA_FIELDS
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_metadata_dict(self) -> dict:
|
||||||
|
"""返回 8 个元数据字段的字典(不含 aliases),用于 Neo4j 回写。"""
|
||||||
|
return {
|
||||||
|
field: getattr(self, field, [])
|
||||||
|
for field in self.METADATA_FIELDS
|
||||||
|
}
|
||||||
65
api/app/core/memory/models/service_models.py
Normal file
65
api/app/core/memory/models/service_models.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from typing import Self
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_serializer, ConfigDict, model_validator, computed_field
|
||||||
|
|
||||||
|
from app.core.memory.enums import Neo4jNodeType, StorageType
|
||||||
|
from app.core.validators import file_validator
|
||||||
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryContext(BaseModel):
|
||||||
|
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
end_user_id: str
|
||||||
|
memory_config: MemoryConfig
|
||||||
|
storage_type: StorageType = StorageType.NEO4J
|
||||||
|
user_rag_memory_id: str | None = None
|
||||||
|
language: str = "zh"
|
||||||
|
|
||||||
|
|
||||||
|
class Memory(BaseModel):
|
||||||
|
source: Neo4jNodeType = Field(...)
|
||||||
|
score: float = Field(default=0.0)
|
||||||
|
content: str = Field(default="")
|
||||||
|
data: dict = Field(default_factory=dict)
|
||||||
|
query: str = Field(...)
|
||||||
|
id: str = Field(...)
|
||||||
|
|
||||||
|
@field_serializer("source")
|
||||||
|
def serialize_source(self, v) -> str:
|
||||||
|
return v.value
|
||||||
|
|
||||||
|
|
||||||
|
class MemorySearchResult(BaseModel):
|
||||||
|
memories: list[Memory]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def content(self) -> str:
|
||||||
|
return "\n".join([memory.content for memory in self.memories])
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
return len(self.memories)
|
||||||
|
|
||||||
|
def filter(self, score_threshold: float) -> Self:
|
||||||
|
self.memories = [memory for memory in self.memories if memory.score >= score_threshold]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __add__(self, other: "MemorySearchResult") -> "MemorySearchResult":
|
||||||
|
if not isinstance(other, MemorySearchResult):
|
||||||
|
raise TypeError("")
|
||||||
|
|
||||||
|
merged = MemorySearchResult(memories=list(self.memories))
|
||||||
|
|
||||||
|
ids = {m.id for m in merged.memories}
|
||||||
|
|
||||||
|
for memory in other.memories:
|
||||||
|
if memory.id not in ids:
|
||||||
|
merged.memories.append(memory)
|
||||||
|
ids.add(memory.id)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
@@ -37,6 +37,7 @@ class Entity(BaseModel):
|
|||||||
name: str = Field(..., description="Name of the entity")
|
name: str = Field(..., description="Name of the entity")
|
||||||
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
|
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
|
||||||
type: str = Field(..., description="Type/category of the entity")
|
type: str = Field(..., description="Type/category of the entity")
|
||||||
|
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
|
||||||
description: str = Field(..., description="Description of the entity")
|
description: str = Field(..., description="Description of the entity")
|
||||||
example: str = Field(
|
example: str = Field(
|
||||||
default="",
|
default="",
|
||||||
@@ -79,6 +80,7 @@ class Triplet(BaseModel):
|
|||||||
subject_name: str = Field(..., description="Name of the subject entity")
|
subject_name: str = Field(..., description="Name of the subject entity")
|
||||||
subject_id: int = Field(..., description="ID of the subject entity")
|
subject_id: int = Field(..., description="ID of the subject entity")
|
||||||
predicate: str = Field(..., description="Relationship/predicate between subject and object")
|
predicate: str = Field(..., description="Relationship/predicate between subject and object")
|
||||||
|
predicate_description: str = Field(default="", description="Chinese definition of the predicate from ontology")
|
||||||
object_name: str = Field(..., description="Name of the object entity")
|
object_name: str = Field(..., description="Name of the object entity")
|
||||||
object_id: int = Field(..., description="ID of the object entity")
|
object_id: int = Field(..., description="ID of the object entity")
|
||||||
value: Optional[str] = Field(None, description="Additional value or context")
|
value: Optional[str] = Field(None, description="Additional value or context")
|
||||||
|
|||||||
@@ -149,3 +149,16 @@ class ExtractionPipelineConfig(BaseModel):
|
|||||||
temporal_extraction: TemporalExtractionConfig = Field(default_factory=TemporalExtractionConfig)
|
temporal_extraction: TemporalExtractionConfig = Field(default_factory=TemporalExtractionConfig)
|
||||||
deduplication: DedupConfig = Field(default_factory=DedupConfig)
|
deduplication: DedupConfig = Field(default_factory=DedupConfig)
|
||||||
forgetting_engine: ForgettingEngineConfig = Field(default_factory=ForgettingEngineConfig)
|
forgetting_engine: ForgettingEngineConfig = Field(default_factory=ForgettingEngineConfig)
|
||||||
|
# 情绪引擎(旁路模块,SidecarStepFactory 通过此字段判断是否启用)
|
||||||
|
emotion_enabled: bool = Field(default=False, description="是否启用情绪提取旁路")
|
||||||
|
|
||||||
|
# TODO 设置控制并发数量以适配LLM的QPM限流
|
||||||
|
# # 流水线 LLM 并发上限(statement + triplet 共享),防止 QPM 爆掉
|
||||||
|
# # 可通过环境变量 MAX_CONCURRENT_LLM_CALLS 覆盖
|
||||||
|
# max_concurrent_llm_calls: int = Field(
|
||||||
|
# default_factory=lambda: int(
|
||||||
|
# __import__("os").environ.get("MAX_CONCURRENT_LLM_CALLS", "5")
|
||||||
|
# ),
|
||||||
|
# ge=1, le=64,
|
||||||
|
# description="Maximum concurrent LLM calls in the extraction pipeline",
|
||||||
|
# )
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -23,15 +23,12 @@ from app.core.memory.models.ontology_extraction_models import OntologyTypeInfo,
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 默认核心通用类型
|
# 默认核心通用类型 —— 与 ontology.md Entity Ontology 对齐的 13 类
|
||||||
DEFAULT_CORE_GENERAL_TYPES: Set[str] = {
|
DEFAULT_CORE_GENERAL_TYPES: Set[str] = {
|
||||||
"Person", "Organization", "Company", "GovernmentAgency",
|
"人物", "组织", "群体", "角色职业",
|
||||||
"Place", "Location", "City", "Country", "Building",
|
"地点设施", "物品设备", "软件平台", "识别联系信息",
|
||||||
"Event", "SportsEvent", "MusicEvent", "SocialEvent",
|
"文档媒体", "知识能力", "偏好习惯", "具体目标",
|
||||||
"Work", "Book", "Film", "Software", "Album",
|
"称呼别名",
|
||||||
"Concept", "TopicalConcept", "AcademicSubject",
|
|
||||||
"Device", "Food", "Drug", "ChemicalSubstance",
|
|
||||||
"TimePeriod", "Year",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -129,9 +126,11 @@ class OntologyTypeMerger:
|
|||||||
if type_name not in seen_names and remaining_slots > 0:
|
if type_name not in seen_names and remaining_slots > 0:
|
||||||
general_type = self.general_registry.get_type(type_name)
|
general_type = self.general_registry.get_type(type_name)
|
||||||
if general_type:
|
if general_type:
|
||||||
|
# 优先使用 rdfs:comment(完整定义),其次才是 label;
|
||||||
|
# 对中文 13 类本体,label 与 class_name 相同,单独展示无增益。
|
||||||
description = (
|
description = (
|
||||||
general_type.labels.get("zh") or
|
|
||||||
general_type.description or
|
general_type.description or
|
||||||
|
general_type.labels.get("zh") or
|
||||||
general_type.get_label("en") or
|
general_type.get_label("en") or
|
||||||
type_name
|
type_name
|
||||||
)
|
)
|
||||||
@@ -157,8 +156,8 @@ class OntologyTypeMerger:
|
|||||||
parent_type = self.general_registry.get_type(parent_name)
|
parent_type = self.general_registry.get_type(parent_name)
|
||||||
if parent_type:
|
if parent_type:
|
||||||
description = (
|
description = (
|
||||||
parent_type.labels.get("zh") or
|
|
||||||
parent_type.description or
|
parent_type.description or
|
||||||
|
parent_type.labels.get("zh") or
|
||||||
parent_name
|
parent_name
|
||||||
)
|
)
|
||||||
related_types_added.append(OntologyTypeInfo(
|
related_types_added.append(OntologyTypeInfo(
|
||||||
|
|||||||
44
api/app/core/memory/pipelines/__init__.py
Normal file
44
api/app/core/memory/pipelines/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""
|
||||||
|
Memory Pipelines — 记忆模块流水线编排层
|
||||||
|
|
||||||
|
每条 Pipeline 定义一个完整的业务流程,按顺序编排多个 Engine 的调用。
|
||||||
|
Pipeline 不包含业务逻辑实现,只做步骤编排和数据传递。
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name):
|
||||||
|
"""延迟导入,避免循环依赖"""
|
||||||
|
if name in ("WritePipeline", "ExtractionResult", "WriteResult"):
|
||||||
|
from app.core.memory.pipelines.write_pipeline import (
|
||||||
|
ExtractionResult,
|
||||||
|
WritePipeline,
|
||||||
|
WriteResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
_exports = {
|
||||||
|
"WritePipeline": WritePipeline,
|
||||||
|
"ExtractionResult": ExtractionResult,
|
||||||
|
"WriteResult": WriteResult,
|
||||||
|
}
|
||||||
|
return _exports[name]
|
||||||
|
if name in ("PilotWritePipeline", "PilotWriteResult"):
|
||||||
|
from app.core.memory.pipelines.pilot_write_pipeline import (
|
||||||
|
PilotWritePipeline,
|
||||||
|
PilotWriteResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
_exports = {
|
||||||
|
"PilotWritePipeline": PilotWritePipeline,
|
||||||
|
"PilotWriteResult": PilotWriteResult,
|
||||||
|
}
|
||||||
|
return _exports[name]
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"WritePipeline",
|
||||||
|
"ExtractionResult",
|
||||||
|
"WriteResult",
|
||||||
|
"PilotWritePipeline",
|
||||||
|
"PilotWriteResult",
|
||||||
|
]
|
||||||
54
api/app/core/memory/pipelines/base_pipeline.py
Normal file
54
api/app/core/memory/pipelines/base_pipeline.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.memory.models.service_models import MemoryContext
|
||||||
|
from app.core.models import RedBearModelConfig, RedBearLLM, RedBearEmbeddings
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
|
||||||
|
|
||||||
|
class ModelClientMixin(ABC):
|
||||||
|
@staticmethod
|
||||||
|
def get_llm_client(db: Session, model_id: uuid.UUID) -> RedBearLLM:
|
||||||
|
api_config = ModelApiKeyService.get_available_api_key(db, model_id)
|
||||||
|
return RedBearLLM(
|
||||||
|
RedBearModelConfig(
|
||||||
|
model_name=api_config.model_name,
|
||||||
|
provider=api_config.provider,
|
||||||
|
api_key=api_config.api_key,
|
||||||
|
base_url=api_config.api_base,
|
||||||
|
is_omni=api_config.is_omni,
|
||||||
|
support_thinking="thinking" in (api_config.capability or []),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_embedding_client(db: Session, model_id: uuid.UUID) -> RedBearEmbeddings:
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
embedder_client_config = config_service.get_embedder_config(str(model_id))
|
||||||
|
return RedBearEmbeddings(
|
||||||
|
RedBearModelConfig(
|
||||||
|
model_name=embedder_client_config["model_name"],
|
||||||
|
provider=embedder_client_config["provider"],
|
||||||
|
api_key=embedder_client_config["api_key"],
|
||||||
|
base_url=embedder_client_config["base_url"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BasePipeline(ABC):
|
||||||
|
def __init__(self, ctx: MemoryContext):
|
||||||
|
self.ctx = ctx
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run(self, *args, **kwargs) -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DBRequiredPipeline(BasePipeline, ABC):
|
||||||
|
def __init__(self, ctx: MemoryContext, db: Session):
|
||||||
|
super().__init__(ctx)
|
||||||
|
self.db = db
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user