Compare commits
2323 Commits
revert-218
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
feae2f2e1e | ||
|
|
415234d4c8 | ||
|
|
e38a60e107 | ||
|
|
86eb08c73f | ||
|
|
53f1b0e586 | ||
|
|
49cc47a79a | ||
|
|
1817f52edf | ||
|
|
40633d72c3 | ||
|
|
6f10296969 | ||
|
|
89228825cf | ||
|
|
cab4deb2ff | ||
|
|
4048a10858 | ||
|
|
d6ef0f4923 | ||
|
|
75fbe44839 | ||
|
|
06597c567b | ||
|
|
28694fefb0 | ||
|
|
7a0f08148e | ||
|
|
d3058ce379 | ||
|
|
8d88df391d | ||
|
|
7621321d1b | ||
|
|
0e29b0b2a5 | ||
|
|
2fa4d29548 | ||
|
|
7bb181c1c7 | ||
|
|
a9c87b03ff | ||
|
|
720af8d261 | ||
|
|
09d32ed446 | ||
|
|
9a5ce7f7c6 | ||
|
|
531d785629 | ||
|
|
6d80d74f4a | ||
|
|
3d9882643e | ||
|
|
b4e4be1133 | ||
|
|
16926d9db5 | ||
|
|
f369a63c8d | ||
|
|
1861b0fbc9 | ||
|
|
750d4ca841 | ||
|
|
ce4a3daec7 | ||
|
|
c12d06bb07 | ||
|
|
98d8d7b261 | ||
|
|
12a08a487d | ||
|
|
f7fa33c0c4 | ||
|
|
faf8d1a51a | ||
|
|
adb7f873b5 | ||
|
|
b64bcc2c50 | ||
|
|
8baa466b31 | ||
|
|
d9de96cffa | ||
|
|
dd7f9f6cee | ||
|
|
546bfb9627 | ||
|
|
d5d81f0c4f | ||
|
|
9301eaf8df | ||
|
|
a268d0f7f1 | ||
|
|
610ae27cf9 | ||
|
|
6aef8227b1 | ||
|
|
675c7faf32 | ||
|
|
cd34d5f5ce | ||
|
|
1403b38648 | ||
|
|
b6e27da7b0 | ||
|
|
2c14344d3f | ||
|
|
141fd94513 | ||
|
|
a9413f57d1 | ||
|
|
0fc463036e | ||
|
|
ed5f98a746 | ||
|
|
422af69904 | ||
|
|
6cb48664b7 | ||
|
|
f48bb3cbee | ||
|
|
8dee2eae6a | ||
|
|
f63bcd6321 | ||
|
|
0228e6ad64 | ||
|
|
84ccb1e528 | ||
|
|
caef0fe44e | ||
|
|
21eb500680 | ||
|
|
c70f536acc | ||
|
|
5f96a6380e | ||
|
|
2c864f6337 | ||
|
|
32dfee803a | ||
|
|
4d9cfb70f7 | ||
|
|
4b0afe867a | ||
|
|
676c9a226c | ||
|
|
8f31236303 | ||
|
|
f2aedd29bc | ||
|
|
cf8db47389 | ||
|
|
62af9cd241 | ||
|
|
74be09340c | ||
|
|
cedf47b3bc | ||
|
|
0a51ab619d | ||
|
|
c7c1570d40 | ||
|
|
c556995f3a | ||
|
|
dc0a0ebcae | ||
|
|
2c2551e15c | ||
|
|
be10bab763 | ||
|
|
89f2f9a045 | ||
|
|
f4c168d904 | ||
|
|
1191f0f54e | ||
|
|
58710bc800 | ||
|
|
b33f5951d8 | ||
|
|
279353e1ce | ||
|
|
2d120a64b1 | ||
|
|
0f7a7263eb | ||
|
|
767eb5e6f2 | ||
|
|
5c89acced6 | ||
|
|
9fdb952396 | ||
|
|
fb23c34475 | ||
|
|
4619b40d03 | ||
|
|
5f39d9a208 | ||
|
|
f6cf53f81c | ||
|
|
08a455f6b3 | ||
|
|
5960b5add8 | ||
|
|
7ac0eff0b8 | ||
|
|
c818855bab | ||
|
|
fe2c975d61 | ||
|
|
8deb69b595 | ||
|
|
404ce9f9ba | ||
|
|
aac89b172f | ||
|
|
bf9a3503de | ||
|
|
5c836c90c9 | ||
|
|
fc7d9df3cb | ||
|
|
17905196c9 | ||
|
|
b8009074d5 | ||
|
|
09393b2326 | ||
|
|
eaa66ba71a | ||
|
|
c59a97afba | ||
|
|
9480a61229 | ||
|
|
7ffd250b08 | ||
|
|
52bccfaede | ||
|
|
27f6d18a05 | ||
|
|
2a514a9e04 | ||
|
|
9233e74f36 | ||
|
|
46dfd92a9f | ||
|
|
5f33cec8ad | ||
|
|
334502f06b | ||
|
|
b0bb5e883c | ||
|
|
b9cfc47e1e | ||
|
|
4a4391a19c | ||
|
|
7ccc1068ff | ||
|
|
f650406869 | ||
|
|
7193eed9e3 | ||
|
|
ec6b08cde2 | ||
|
|
f93ec8d609 | ||
|
|
fedb02caf7 | ||
|
|
ae770fb131 | ||
|
|
f8ef32c1dd | ||
|
|
c5ae82c3c2 | ||
|
|
2a03f70287 | ||
|
|
124e8d0639 | ||
|
|
6f323f2435 | ||
|
|
881d74d29d | ||
|
|
903b4f2a6e | ||
|
|
7cd76444f1 | ||
|
|
7dc35bb3fb | ||
|
|
b488590537 | ||
|
|
aa56ad15f9 | ||
|
|
cda20ac3f1 | ||
|
|
d6af459ca8 | ||
|
|
2f7fd85ab1 | ||
|
|
398aebd0c5 | ||
|
|
eaa4058c56 | ||
|
|
21b25bfef7 | ||
|
|
a61acbef93 | ||
|
|
a90757745d | ||
|
|
749083bdbe | ||
|
|
b882863907 | ||
|
|
9159d5cbb0 | ||
|
|
7552a5c8fa | ||
|
|
537f6a1812 | ||
|
|
1ea0f308ba | ||
|
|
f37e9b444b | ||
|
|
5304117ae2 | ||
|
|
77c023102e | ||
|
|
ad24119b2d | ||
|
|
ea6fa154e0 | ||
|
|
158507cf8e | ||
|
|
5e0d30dde8 | ||
|
|
363d775270 | ||
|
|
ad4121b0d8 | ||
|
|
71f62bb591 | ||
|
|
46504fda30 | ||
|
|
1cfad37c64 | ||
|
|
129c9cbb3c | ||
|
|
acafceafb0 | ||
|
|
aff94a766a | ||
|
|
42ebba9090 | ||
|
|
1e95cb6604 | ||
|
|
8b3e3c8044 | ||
|
|
671df83bcd | ||
|
|
8bb5a66401 | ||
|
|
4c9f327833 | ||
|
|
866a5552d4 | ||
|
|
93d4607b14 | ||
|
|
9533a9a693 | ||
|
|
6bd528eace | ||
|
|
2b5bece9b6 | ||
|
|
ea0e65f1ec | ||
|
|
cb2a7aa60a | ||
|
|
402c8aef5d | ||
|
|
eb98a69a84 | ||
|
|
152a84aff3 | ||
|
|
a106f4e3cd | ||
|
|
9c20301a52 | ||
|
|
c5c8be89ed | ||
|
|
30aed72b74 | ||
|
|
35c2d9d0d3 | ||
|
|
27275eee43 | ||
|
|
cde02026d3 | ||
|
|
1a826c0026 | ||
|
|
8cab49c2b1 | ||
|
|
7eb21f677f | ||
|
|
6de5d413c4 | ||
|
|
a2df14f658 | ||
|
|
aecb0f6497 | ||
|
|
83b7c6870d | ||
|
|
74157adb12 | ||
|
|
8011610acc | ||
|
|
f1dc507b5c | ||
|
|
f3ac7e084d | ||
|
|
ba3743f9f1 | ||
|
|
20ddc76a4d | ||
|
|
84ca98555d | ||
|
|
7e6d17e4e3 | ||
|
|
7f3c48ce2a | ||
|
|
e5c16a2a24 | ||
|
|
8887600f7d | ||
|
|
df6eb74b28 | ||
|
|
b4b9974064 | ||
|
|
ff65dee754 | ||
|
|
2c2ed0ebf3 | ||
|
|
d60f838fb8 | ||
|
|
817aa78d03 | ||
|
|
4c73887a48 | ||
|
|
94d2d975ee | ||
|
|
d59990d326 | ||
|
|
3227c25b07 | ||
|
|
dc3207b1d3 | ||
|
|
08b5c7bc8a | ||
|
|
688503a1ca | ||
|
|
475e573891 | ||
|
|
b03300c804 | ||
|
|
a5d07ee66d | ||
|
|
10a655772f | ||
|
|
aeeb18581d | ||
|
|
fb1160e833 | ||
|
|
c448cf0660 | ||
|
|
c50969dea4 | ||
|
|
3a1d222c42 | ||
|
|
10a91ec5cb | ||
|
|
b4812cdac1 | ||
|
|
1744b045fb | ||
|
|
5289b3a2cb | ||
|
|
48f3d9b105 | ||
|
|
559b4bef6b | ||
|
|
4a39fd5f46 | ||
|
|
b22c15cccc | ||
|
|
a2f85b3d98 | ||
|
|
7f1cf13b23 | ||
|
|
d4129edcf5 | ||
|
|
ab2a58d68e | ||
|
|
a28b62763e | ||
|
|
86540a81d1 | ||
|
|
dcd874fecd | ||
|
|
bbd85733b8 | ||
|
|
22c5f12657 | ||
|
|
7b5d7696cb | ||
|
|
cb33724673 | ||
|
|
48b56a3d88 | ||
|
|
83d0fb9387 | ||
|
|
bb964c1ed8 | ||
|
|
81d58b001f | ||
|
|
99bc84a9f2 | ||
|
|
37dbe0f95b | ||
|
|
d4a1904b19 | ||
|
|
ecdad19f54 | ||
|
|
fb93c509f4 | ||
|
|
f597139913 | ||
|
|
113ae59f84 | ||
|
|
62c721bdf6 | ||
|
|
4cbb0cee2f | ||
|
|
8c586935a8 | ||
|
|
d5272af76f | ||
|
|
cf8912e929 | ||
|
|
327c1904b1 | ||
|
|
58c13aaeb4 | ||
|
|
377ddd2b9b | ||
|
|
52f7ea7456 | ||
|
|
b02baedd2c | ||
|
|
f3c3b6255e | ||
|
|
b659e2a6e1 | ||
|
|
e15e32cc7b | ||
|
|
04d20dc094 | ||
|
|
b8123fc84c | ||
|
|
5a17b7fd0d | ||
|
|
e3d0602850 | ||
|
|
696b2d2417 | ||
|
|
a5613314b8 | ||
|
|
e87529876c | ||
|
|
7bb3e65fb7 | ||
|
|
5ada7e77fc | ||
|
|
79b7da44e2 | ||
|
|
26a3d8a41b | ||
|
|
2380cd55ef | ||
|
|
a105df33ab | ||
|
|
749cf79581 | ||
|
|
0dd8cc5d43 | ||
|
|
fd90a4c2ad | ||
|
|
b302a94620 | ||
|
|
c96dc53534 | ||
|
|
f883c1469d | ||
|
|
ddfd81259a | ||
|
|
e015455fb8 | ||
|
|
915cb54f21 | ||
|
|
cada860a16 | ||
|
|
e1f8ad871b | ||
|
|
e205aaa6e6 | ||
|
|
62edafcebe | ||
|
|
ccdf7ae81d | ||
|
|
643f69bb90 | ||
|
|
73fbc19747 | ||
|
|
7ba0726473 | ||
|
|
8c6b65db12 | ||
|
|
5ce0bdb0f5 | ||
|
|
a01525e239 | ||
|
|
b59e2b5bcd | ||
|
|
5a2fe738dc | ||
|
|
f04412c455 | ||
|
|
db6fc5d2db | ||
|
|
b6aca0b1e7 | ||
|
|
4fd7395464 | ||
|
|
78ba313262 | ||
|
|
d35bc3a2cf | ||
|
|
d5c8d16e64 | ||
|
|
09496bd7b9 | ||
|
|
171f25a350 | ||
|
|
c7230659e3 | ||
|
|
502d87e88d | ||
|
|
1faa258e23 | ||
|
|
bef6a50deb | ||
|
|
cc12ec3fa8 | ||
|
|
466864afe3 | ||
|
|
643a3fbe09 | ||
|
|
e0d7a5a91f | ||
|
|
5ac2d5602e | ||
|
|
f4c3974956 | ||
|
|
71e5b6586a | ||
|
|
bfb723a468 | ||
|
|
61f2e44bd5 | ||
|
|
ed765b7c26 | ||
|
|
3018d186f7 | ||
|
|
2e1470cb52 | ||
|
|
737858731b | ||
|
|
d072eb1af7 | ||
|
|
2716a55c7f | ||
|
|
daaee63bd5 | ||
|
|
e3c643b659 | ||
|
|
017efdc320 | ||
|
|
29aef4527c | ||
|
|
d9cb2b511b | ||
|
|
18be1a9f89 | ||
|
|
49e0801d15 | ||
|
|
dde7ea9039 | ||
|
|
3e48d620b2 | ||
|
|
5262aedab9 | ||
|
|
441b21774d | ||
|
|
d6dd038167 | ||
|
|
47c242e513 | ||
|
|
811193dd75 | ||
|
|
797780824c | ||
|
|
75e95bab01 | ||
|
|
e7a400bb96 | ||
|
|
28ca4d1734 | ||
|
|
5e6490213d | ||
|
|
3b359df02f | ||
|
|
fcf3071cb0 | ||
|
|
1294aabbcc | ||
|
|
3c2a78a449 | ||
|
|
4f0e5d0866 | ||
|
|
7a84ee33c6 | ||
|
|
e3265e4ba3 | ||
|
|
3e7a004599 | ||
|
|
fa1e5ee43c | ||
|
|
c72a6fd724 | ||
|
|
0965008210 | ||
|
|
bcadd2a6f1 | ||
|
|
e4f306dabb | ||
|
|
b5ec5c2cea | ||
|
|
e539b3eeb7 | ||
|
|
7f8765b815 | ||
|
|
72b39c6fa3 | ||
|
|
9032f50a19 | ||
|
|
aa683efaa0 | ||
|
|
2d9986f902 | ||
|
|
06075ffef5 | ||
|
|
a7336b0829 | ||
|
|
0d16e168e7 | ||
|
|
a882e5e5c4 | ||
|
|
c614bb5be7 | ||
|
|
1ff0f3ebfd | ||
|
|
bafcb5c545 | ||
|
|
f8d27fada6 | ||
|
|
90365cd026 | ||
|
|
d96c7b88f0 | ||
|
|
99559621c5 | ||
|
|
926f65a1ff | ||
|
|
b20971dc95 | ||
|
|
1ff0274027 | ||
|
|
8495aa5dde | ||
|
|
d8ef7a8e02 | ||
|
|
7a4a02b2bb | ||
|
|
8f623a66c8 | ||
|
|
77ed9faea1 | ||
|
|
1ff3748935 | ||
|
|
f023c43f80 | ||
|
|
60124e3232 | ||
|
|
70d4e79de1 | ||
|
|
59b5a1bcf2 | ||
|
|
62f345b3de | ||
|
|
a3f0415cd3 | ||
|
|
2450fe3afe | ||
|
|
52e726eabc | ||
|
|
7ca80b5d01 | ||
|
|
9470dd2f1e | ||
|
|
10f1089198 | ||
|
|
095f4e3001 | ||
|
|
ef8c7093b5 | ||
|
|
05ea372776 | ||
|
|
2b067ce08a | ||
|
|
b63cff2993 | ||
|
|
5bb9ce9018 | ||
|
|
aa581a9083 | ||
|
|
ac51ccaf1f | ||
|
|
dca3173ed9 | ||
|
|
5eaedaad77 | ||
|
|
bd955569b3 | ||
|
|
7a2a941ac4 | ||
|
|
19fa8314e4 | ||
|
|
cba24e58db | ||
|
|
62355186ef | ||
|
|
82faedc972 | ||
|
|
11ea486f82 | ||
|
|
efdee32f85 | ||
|
|
988d101e93 | ||
|
|
418f9f4dba | ||
|
|
520ee7c132 | ||
|
|
72be9f75f9 | ||
|
|
2b52b32b96 | ||
|
|
a96f20ee05 | ||
|
|
b8acc0a32f | ||
|
|
e1cf3bb3d2 | ||
|
|
6f66c9727f | ||
|
|
3beca641e1 | ||
|
|
b8507a1df6 | ||
|
|
0f28d54c43 | ||
|
|
0afc38e7ef | ||
|
|
07fd85c342 | ||
|
|
4c2a1e6d1d | ||
|
|
7cfb6ace22 | ||
|
|
91cc20d589 | ||
|
|
f01ca51896 | ||
|
|
f4a63f7d55 | ||
|
|
0019f3acfd | ||
|
|
3fe90a5e13 | ||
|
|
bc14c94407 | ||
|
|
a21dad70ed | ||
|
|
807a4e715d | ||
|
|
58d18b476c | ||
|
|
5e5927a0b9 | ||
|
|
7869121382 | ||
|
|
7c0fb624d9 | ||
|
|
af83980f99 | ||
|
|
cf0d11208c | ||
|
|
87d1630230 | ||
|
|
50392384e7 | ||
|
|
9a926a8398 | ||
|
|
e5e6699168 | ||
|
|
068e2bfb7e | ||
|
|
4ce6fede67 | ||
|
|
8497c955f9 | ||
|
|
72fe3962cf | ||
|
|
c253968aa8 | ||
|
|
d517bceda2 | ||
|
|
412183c359 | ||
|
|
90e8e90528 | ||
|
|
fd05c000f6 | ||
|
|
627d6a0381 | ||
|
|
807dee8460 | ||
|
|
ac7d39524e | ||
|
|
cd018814fe | ||
|
|
e0b7e95af6 | ||
|
|
3a62d50048 | ||
|
|
0e60da6d8a | ||
|
|
39e94eb3ea | ||
|
|
3e0f59adc6 | ||
|
|
660cd2fadb | ||
|
|
6f1bb43eab | ||
|
|
61b5627505 | ||
|
|
af6392fb09 | ||
|
|
84b1a95313 | ||
|
|
8b21dab255 | ||
|
|
fc5ce63e44 | ||
|
|
15a863b41a | ||
|
|
5226c5b79d | ||
|
|
27e9f9968d | ||
|
|
d38612a10d | ||
|
|
32c71dcd89 | ||
|
|
428e7ebaa5 | ||
|
|
57833689d9 | ||
|
|
384a67482c | ||
|
|
7842435321 | ||
|
|
33c4c5d31b | ||
|
|
ca4f7aa65d | ||
|
|
b875626f18 | ||
|
|
130684cac0 | ||
|
|
5adff38bda | ||
|
|
62e0b2730b | ||
|
|
55b2e05ba8 | ||
|
|
562ca6c1f1 | ||
|
|
e298b38de9 | ||
|
|
a7b8ba0c66 | ||
|
|
460c86cd94 | ||
|
|
33a1c178ff | ||
|
|
c81612e6d3 | ||
|
|
9f9ac69f97 | ||
|
|
0516822d42 | ||
|
|
b598171a3d | ||
|
|
a4ea7f0385 | ||
|
|
32ae60fc65 | ||
|
|
6b272c5b44 | ||
|
|
2782d0661f | ||
|
|
ea2f5e61c9 | ||
|
|
5975d70bf9 | ||
|
|
e0546e01ef | ||
|
|
70aab94fc3 | ||
|
|
0f50537d7d | ||
|
|
b7c1ce261b | ||
|
|
edac6a164e | ||
|
|
1503b242ea | ||
|
|
3ff44f0108 | ||
|
|
18fd48505d | ||
|
|
807ddce5cd | ||
|
|
62fb6c79a0 | ||
|
|
cc373b2864 | ||
|
|
f2d7479229 | ||
|
|
ae1909b7e9 | ||
|
|
8e397b83b6 | ||
|
|
b0aaa12340 | ||
|
|
5eb65e7ad8 | ||
|
|
cb5610e8b1 | ||
|
|
6bb01119d0 | ||
|
|
c16e832081 | ||
|
|
e3d50c5c55 | ||
|
|
e64aadce95 | ||
|
|
bad6087c25 | ||
|
|
b04c05f4a4 | ||
|
|
5e372627f7 | ||
|
|
29611738ce | ||
|
|
de846c05ab | ||
|
|
6475387af8 | ||
|
|
b330bdba29 | ||
|
|
bed279c604 | ||
|
|
9eaf779e67 | ||
|
|
1fbccd98a7 | ||
|
|
931b800bb6 | ||
|
|
d76bb36b9f | ||
|
|
3c93409f7f | ||
|
|
9451a08e7f | ||
|
|
bc49bd2a43 | ||
|
|
4bfc9ca991 | ||
|
|
1ba60401af | ||
|
|
236e8973ac | ||
|
|
ca6cc8ae63 | ||
|
|
dd2cc89c62 | ||
|
|
a87bba93c2 | ||
|
|
a153fdb7cb | ||
|
|
4eed393db5 | ||
|
|
dc8e432719 | ||
|
|
16a0099e27 | ||
|
|
c4cf639bbc | ||
|
|
f91431a70d | ||
|
|
0102ad3a30 | ||
|
|
ebe298b71d | ||
|
|
a486ca7857 | ||
|
|
dd40e5df5f | ||
|
|
8e1ec1bae6 | ||
|
|
1f9c4919be | ||
|
|
065182ad5c | ||
|
|
90ec8db0d8 | ||
|
|
78baf1c60a | ||
|
|
072c94cccb | ||
|
|
a66030d1b3 | ||
|
|
a90ceaf5a2 | ||
|
|
725f2f5146 | ||
|
|
a2c3357e80 | ||
|
|
acbc954e6f | ||
|
|
77032583ab | ||
|
|
ef533d27ac | ||
|
|
ca1a2c7b9e | ||
|
|
145fa398dd | ||
|
|
274430b2c9 | ||
|
|
e9972834fe | ||
|
|
1ecc04fee7 | ||
|
|
78cd1f69a3 | ||
|
|
aabd9a1b57 | ||
|
|
b9439b337a | ||
|
|
eb9f4f39f1 | ||
|
|
baa4b56426 | ||
|
|
49bcc6131b | ||
|
|
0d3f6f1e14 | ||
|
|
84536925c6 | ||
|
|
b22a5a9f12 | ||
|
|
b8825a83dd | ||
|
|
08b4d5c1cf | ||
|
|
a5dfc472d3 | ||
|
|
48d29bcc63 | ||
|
|
856c6f6d78 | ||
|
|
bfc47ad738 | ||
|
|
841b6abb33 | ||
|
|
8a1114a1a7 | ||
|
|
be8c481d6d | ||
|
|
5d439346a1 | ||
|
|
ed753caaf7 | ||
|
|
9a931389ea | ||
|
|
b9d469b6e3 | ||
|
|
e817cfd292 | ||
|
|
af86cb3556 | ||
|
|
e48b146e60 | ||
|
|
07b66a9801 | ||
|
|
c3ee3c4af9 | ||
|
|
cd8229f370 | ||
|
|
9a4a614fc8 | ||
|
|
0b5a030e46 | ||
|
|
675d0fc5ef | ||
|
|
6291c28f0a | ||
|
|
30b512e554 | ||
|
|
33c73c6c6f | ||
|
|
072d118935 | ||
|
|
2e7ebf174b | ||
|
|
3ece83d419 | ||
|
|
9c1c232b2e | ||
|
|
bfc98efc9d | ||
|
|
cfbf83f71e | ||
|
|
a43e8fa594 | ||
|
|
6c8d0d9d64 | ||
|
|
bd2a3bd7ef | ||
|
|
1f72b8aa70 | ||
|
|
9bb32888a2 | ||
|
|
caee5d214e | ||
|
|
38f3455bab | ||
|
|
d60cb423a4 | ||
|
|
b20a65ce29 | ||
|
|
99862db7a0 | ||
|
|
00a8099857 | ||
|
|
117e29fbe3 | ||
|
|
32740e8159 | ||
|
|
bc5ea2d421 | ||
|
|
d34bf4bc89 | ||
|
|
c4ff1a325b | ||
|
|
d1f0258065 | ||
|
|
5db59bc9cf | ||
|
|
a711635694 | ||
|
|
15b3ce3dd5 | ||
|
|
9cc19047b4 | ||
|
|
2e8e63878e | ||
|
|
38955d7d45 | ||
|
|
b6167d4e94 | ||
|
|
7890970a39 | ||
|
|
203732de1d | ||
|
|
4961e7df79 | ||
|
|
fa4be10e51 | ||
|
|
1b52850526 | ||
|
|
1732fc7af5 | ||
|
|
a52e2137b7 | ||
|
|
377f79773d | ||
|
|
cae87de6ef | ||
|
|
63235de42b | ||
|
|
106a32bc3a | ||
|
|
dcb7b496d3 | ||
|
|
2f0bb793d8 | ||
|
|
010eff17cf | ||
|
|
0b47194f12 | ||
|
|
9ff3a3d5f7 | ||
|
|
abbd92b74c | ||
|
|
960ee9f2df | ||
|
|
1c133d3d6c | ||
|
|
d270d25a99 | ||
|
|
8abd59b26e | ||
|
|
bd48b4fdbe | ||
|
|
9535545947 | ||
|
|
aad6955709 | ||
|
|
18703919a8 | ||
|
|
9f2cd6afae | ||
|
|
d1beb9e5d5 | ||
|
|
2c7aaebdd5 | ||
|
|
be38c9e385 | ||
|
|
1aec7115a5 | ||
|
|
9facb513b2 | ||
|
|
9bce14be4e | ||
|
|
59f5c7a8bb | ||
|
|
12f3a3ed77 | ||
|
|
8b9eb81d36 | ||
|
|
4fb3d6992c | ||
|
|
370a668ead | ||
|
|
daaad51357 | ||
|
|
6eca5f6cdf | ||
|
|
f61f86f8fe | ||
|
|
57eb5aa967 | ||
|
|
1305a08c86 | ||
|
|
cf519738f4 | ||
|
|
cdebe014cf | ||
|
|
853ce6f4e1 | ||
|
|
9cbe9d5edc | ||
|
|
767f9ab17c | ||
|
|
7b5b2ab31a | ||
|
|
924d10ac5b | ||
|
|
0470a71d03 | ||
|
|
378b110d91 | ||
|
|
5f7db778b5 | ||
|
|
0d15457299 | ||
|
|
ad4ddea977 | ||
|
|
75bb96d4e7 | ||
|
|
68fdf5d76f | ||
|
|
258c19f9e0 | ||
|
|
386ed2b914 | ||
|
|
264183cec2 | ||
|
|
9561578a2a | ||
|
|
7ce29019f7 | ||
|
|
99ff07ccac | ||
|
|
e77a1a92fd | ||
|
|
d3cd66fc6e | ||
|
|
b95a627424 | ||
|
|
c9ca5df05c | ||
|
|
70c3c7dd74 | ||
|
|
b482822629 | ||
|
|
8f609ba29c | ||
|
|
a1ef5146d7 | ||
|
|
8b997b422a | ||
|
|
6d6338eb06 | ||
|
|
b5c5863b39 | ||
|
|
ab45b7abac | ||
|
|
2dfc3b25d8 | ||
|
|
3ea42ac27f | ||
|
|
fff5e0e8b8 | ||
|
|
fe29141437 | ||
|
|
17d3c81c02 | ||
|
|
ef626951bc | ||
|
|
4533644e13 | ||
|
|
ca255304d9 | ||
|
|
b40f4829cb | ||
|
|
52ae914e17 | ||
|
|
baf02e4faa | ||
|
|
87c2419186 | ||
|
|
2ad25c48d2 | ||
|
|
75e8caf441 | ||
|
|
4d6038c3cc | ||
|
|
d4450658a8 | ||
|
|
02660c7c97 | ||
|
|
3ceb2efeaf | ||
|
|
e134b96333 | ||
|
|
3ea57d1cb0 | ||
|
|
4a71484151 | ||
|
|
db8b3416a6 | ||
|
|
4df41966fe | ||
|
|
2d6cde157e | ||
|
|
abc27c8372 | ||
|
|
dbe387f666 | ||
|
|
5e70d436a8 | ||
|
|
b7198f1abd | ||
|
|
5c87a2beeb | ||
|
|
3419bb137a | ||
|
|
a00684c67d | ||
|
|
6e7c641fd4 | ||
|
|
876c39b1b0 | ||
|
|
0c677701c0 | ||
|
|
4974f9aa98 | ||
|
|
c90b58bbcd | ||
|
|
d6a243f1be | ||
|
|
418114ef72 | ||
|
|
ceed61167f | ||
|
|
83774d7443 | ||
|
|
052c7c19b3 | ||
|
|
d42db0ca33 | ||
|
|
e15af5a2ba | ||
|
|
8b44b2cd61 | ||
|
|
9d91453200 | ||
|
|
ea8db7cd90 | ||
|
|
d60f16df1b | ||
|
|
3cca35a74f | ||
|
|
8dd24533bf | ||
|
|
ed90405439 | ||
|
|
533000030f | ||
|
|
a58ac385b1 | ||
|
|
91b7f2a980 | ||
|
|
891cfc2704 | ||
|
|
f7e89af9d2 | ||
|
|
afbd8c9b4f | ||
|
|
09b3b01d37 | ||
|
|
e3dcbed5f9 | ||
|
|
c7b51e7ad8 | ||
|
|
e9ad13504a | ||
|
|
c0cd2373c0 | ||
|
|
6e757ae9e2 | ||
|
|
64a73c41d6 | ||
|
|
dae7431075 | ||
|
|
643bbbcf5c | ||
|
|
6702e86536 | ||
|
|
13e35ed122 | ||
|
|
ab2bdfa088 | ||
|
|
8285250096 | ||
|
|
e59a215078 | ||
|
|
c89eccf8fe | ||
|
|
5703fc0cb4 | ||
|
|
7acb7045f0 | ||
|
|
3aed5c447a | ||
|
|
13352178ad | ||
|
|
f9f302dd2a | ||
|
|
8f216db353 | ||
|
|
9f6026492d | ||
|
|
b699b746a5 | ||
|
|
6095170169 | ||
|
|
173697e86a | ||
|
|
5c11da6a2e | ||
|
|
96214c433f | ||
|
|
167c915631 | ||
|
|
f485398768 | ||
|
|
289b1989e5 | ||
|
|
8224848ce1 | ||
|
|
c43d258455 | ||
|
|
c3e5c8b8bb | ||
|
|
930cadcaa8 | ||
|
|
57b6b34567 | ||
|
|
f878846364 | ||
|
|
7dce63dc0b | ||
|
|
03bc8ee7f5 | ||
|
|
4aefb01b0b | ||
|
|
4e9b5736b1 | ||
|
|
46fa99a8b8 | ||
|
|
17ea92357d | ||
|
|
bd70a8b812 | ||
|
|
ad5dc3c138 | ||
|
|
e37b1b01ca | ||
|
|
e659ca9fa2 | ||
|
|
758be0087f | ||
|
|
200c13b59f | ||
|
|
32f6886000 | ||
|
|
7fbf3e8873 | ||
|
|
3026702000 | ||
|
|
8677db114b | ||
|
|
2597a1f532 | ||
|
|
4298cd7d06 | ||
|
|
8197f9db35 | ||
|
|
3da6331515 | ||
|
|
539999131c | ||
|
|
d0ca5c8b27 | ||
|
|
ee6b8ffa62 | ||
|
|
14838dc064 | ||
|
|
e017870f44 | ||
|
|
9730c5ce0f | ||
|
|
bca43fcc75 | ||
|
|
f30260939a | ||
|
|
8ba0a74473 | ||
|
|
4f69224cfd | ||
|
|
6f7fee18c9 | ||
|
|
7fd00009a2 | ||
|
|
4534b65d6a | ||
|
|
cc58c7333c | ||
|
|
c936277507 | ||
|
|
701df40270 | ||
|
|
b724dbe53a | ||
|
|
ac7c891ded | ||
|
|
a5bce221bd | ||
|
|
3ed6f49bb0 | ||
|
|
a416a6b2bd | ||
|
|
35be03803f | ||
|
|
6427018ffb | ||
|
|
06b823ff96 | ||
|
|
0fdb489227 | ||
|
|
f6394a791e | ||
|
|
4bfd4944d0 | ||
|
|
7faf291ec3 | ||
|
|
3d291e3c23 | ||
|
|
b35bedc730 | ||
|
|
4d39cdf464 | ||
|
|
a874cc70a4 | ||
|
|
2319432182 | ||
|
|
7556468c6e | ||
|
|
91d38c0648 | ||
|
|
df3d58d388 | ||
|
|
80856e3c92 | ||
|
|
8c6f395818 | ||
|
|
2f4f7219e3 | ||
|
|
4c5183eddc | ||
|
|
dfc0ee9424 | ||
|
|
8dbb067b83 | ||
|
|
1df3fc416a | ||
|
|
6223b80cc4 | ||
|
|
68489f1b28 | ||
|
|
477853b04e | ||
|
|
863be50aaf | ||
|
|
d72d57f966 | ||
|
|
5b940e5f1a | ||
|
|
9ae1d2f0d9 | ||
|
|
318f1be107 | ||
|
|
4cab6317de | ||
|
|
81bfc9af36 | ||
|
|
189013f0f8 | ||
|
|
6f5bcd18a4 | ||
|
|
c7ef97c7a6 | ||
|
|
4d4a780ab7 | ||
|
|
9d2f3aa8f9 | ||
|
|
f2c9902a07 | ||
|
|
2525f8795c | ||
|
|
b7a03a844f | ||
|
|
c13c3846d1 | ||
|
|
30b5db1e98 | ||
|
|
f92eb9f45a | ||
|
|
a136d44e27 | ||
|
|
65b2f9e6e1 | ||
|
|
5275a274c3 | ||
|
|
4f09c4fbb3 | ||
|
|
7a3220aff5 | ||
|
|
14a32778f7 | ||
|
|
2a12cb04bf | ||
|
|
1e986c641f | ||
|
|
38c6c7f053 | ||
|
|
7c0743eb8f | ||
|
|
e981f066a3 | ||
|
|
db14d40fb3 | ||
|
|
e8d575fd0b | ||
|
|
a7285e35ad | ||
|
|
c4461c4917 | ||
|
|
2df615eca0 | ||
|
|
504e5ba61e | ||
|
|
0bae290e0c | ||
|
|
294ee49d59 | ||
|
|
26c36f70e6 | ||
|
|
c4b83b1f9c | ||
|
|
14413fd413 | ||
|
|
caab58dd2f | ||
|
|
0e899bea05 | ||
|
|
1794f8f209 | ||
|
|
85daf576e9 | ||
|
|
56fd5680cf | ||
|
|
0380c13a3b | ||
|
|
9ddc523f91 | ||
|
|
491ef27b8a | ||
|
|
edd115582f | ||
|
|
45eef12842 | ||
|
|
49364802c2 | ||
|
|
8873078006 | ||
|
|
2b9fd33bc8 | ||
|
|
e86d679ae5 | ||
|
|
def7367e33 | ||
|
|
54cff5861a | ||
|
|
dc2a73155b | ||
|
|
1856c55c04 | ||
|
|
522eb569f1 | ||
|
|
9df41456f6 | ||
|
|
04c54081c8 | ||
|
|
1c49e3c167 | ||
|
|
fb6ce839d2 | ||
|
|
c7275dccac | ||
|
|
d62b484d71 | ||
|
|
8ff1c6bd08 | ||
|
|
3dcf901043 | ||
|
|
d6dfc2cb12 | ||
|
|
8a3032ce4a | ||
|
|
391c60c812 | ||
|
|
b739b032d9 | ||
|
|
3dc863cabf | ||
|
|
611b14dfea | ||
|
|
de6e2f54d2 | ||
|
|
89d188fbf3 | ||
|
|
6bba574ca6 | ||
|
|
9cbffd6408 | ||
|
|
4d2ad5757c | ||
|
|
cd0ca9cae4 | ||
|
|
3369b702e4 | ||
|
|
cbec2c1356 | ||
|
|
5987eee0a8 | ||
|
|
6348304b7d | ||
|
|
59f8010519 | ||
|
|
9308c6efae | ||
|
|
2f78b7cf5e | ||
|
|
f86448f4bf | ||
|
|
48e2e613bb | ||
|
|
1060074740 | ||
|
|
95b7df7e38 | ||
|
|
fd1634eec4 | ||
|
|
efeead41b2 | ||
|
|
a3428c2435 | ||
|
|
31b8a3764e | ||
|
|
2ff81ba101 | ||
|
|
93deb286a3 | ||
|
|
7bd97bf6d3 | ||
|
|
2d1a1b4a1f | ||
|
|
503c890d93 | ||
|
|
1f73501786 | ||
|
|
eef13cb717 | ||
|
|
c70ac1339e | ||
|
|
24c13d408e | ||
|
|
338d7f1065 | ||
|
|
27672cfaa0 | ||
|
|
4dbb2bf2e2 | ||
|
|
37bc4beab4 | ||
|
|
6056952936 | ||
|
|
31085ed678 | ||
|
|
dce7206c44 | ||
|
|
c17a2dad2d | ||
|
|
e8ae46b286 | ||
|
|
78316de411 | ||
|
|
c205e7d20e | ||
|
|
81f3b50200 | ||
|
|
e3795fe1ed | ||
|
|
72a2f2a7e8 | ||
|
|
0f092e08f4 | ||
|
|
8e7603bcc4 | ||
|
|
035cc17264 | ||
|
|
a079358028 | ||
|
|
cf26c9f39c | ||
|
|
fa29a39920 | ||
|
|
2146c555d2 | ||
|
|
240f1d431b | ||
|
|
9f947a3395 | ||
|
|
bf5c4628c3 | ||
|
|
911d5e0b34 | ||
|
|
bd31aa5abf | ||
|
|
0775fad5f0 | ||
|
|
726148d7ee | ||
|
|
0f1b1d7d10 | ||
|
|
fabc8936ab | ||
|
|
11aa2e1f9e | ||
|
|
ca654cca74 | ||
|
|
bd1f649bd0 | ||
|
|
06de54ebfd | ||
|
|
ea00747c66 | ||
|
|
3db031891e | ||
|
|
fb6ca3909a | ||
|
|
929afb1770 | ||
|
|
6235584b2e | ||
|
|
0b1ea33b41 | ||
|
|
3929f811b8 | ||
|
|
7c6e48b04e | ||
|
|
b1b53f6b1d | ||
|
|
551a2b59a5 | ||
|
|
9a765ac71e | ||
|
|
83e26732de | ||
|
|
52fdfc7744 | ||
|
|
4e544325a0 | ||
|
|
99a2f396fd | ||
|
|
0157c9d262 | ||
|
|
5ddacab162 | ||
|
|
a51e34852c | ||
|
|
fcc81ac025 | ||
|
|
36f670b2e9 | ||
|
|
cbcbc8822c | ||
|
|
69c001bf84 | ||
|
|
aa2d1e7a35 | ||
|
|
39b2f3ba0e | ||
|
|
43064ab71b | ||
|
|
4144f0b9b5 | ||
|
|
08f0be17ce | ||
|
|
2915e464bf | ||
|
|
152559ae46 | ||
|
|
1f531f1ace | ||
|
|
7ec947189c | ||
|
|
b4615bacdc | ||
|
|
e849fed5c1 | ||
|
|
0f5cae4590 | ||
|
|
1c3029f360 | ||
|
|
e2411e0bdd | ||
|
|
7af88b19cf | ||
|
|
c3f8dbd4bc | ||
|
|
c1e48fde86 | ||
|
|
f644c84fbb | ||
|
|
d0afce27c4 | ||
|
|
b84aba71e7 | ||
|
|
2e481df465 | ||
|
|
a322ec4fd5 | ||
|
|
bdbf9c0609 | ||
|
|
ef7d59e442 | ||
|
|
27b782e12a | ||
|
|
37a22fbfa9 | ||
|
|
d798d101f7 | ||
|
|
825f225f63 | ||
|
|
4d5e2958dc | ||
|
|
6105d46198 | ||
|
|
7aec157859 | ||
|
|
13abb03d87 | ||
|
|
e8947ad0bb | ||
|
|
7056865726 | ||
|
|
9d8c26b999 | ||
|
|
c2c832f8c9 | ||
|
|
6bc4f04293 | ||
|
|
9d150ab353 | ||
|
|
f045b59b2d | ||
|
|
0bb8278a39 | ||
|
|
e43f812c14 | ||
|
|
d584b47280 | ||
|
|
3e995cd971 | ||
|
|
b018e35ada | ||
|
|
4bc030c1ef | ||
|
|
86a0aa1f9f | ||
|
|
d523e4f3c6 | ||
|
|
84c23e7c4e | ||
|
|
186d097e00 | ||
|
|
c5cfe557da | ||
|
|
f786a66a3c | ||
|
|
ebd51928d7 | ||
|
|
2258b5c43c | ||
|
|
2e50e30071 | ||
|
|
8c804a1011 | ||
|
|
1a4c2d7cd0 | ||
|
|
c2fc4ab4ff | ||
|
|
83fcabadae | ||
|
|
d12ad213e0 | ||
|
|
33d522b387 | ||
|
|
5997458aaf | ||
|
|
68f9471caf | ||
|
|
ecbb61db27 | ||
|
|
b42815ee7a | ||
|
|
49d7398e14 | ||
|
|
91589c1497 | ||
|
|
a07727c047 | ||
|
|
25bc506f74 | ||
|
|
18ca83d763 | ||
|
|
4bbc561625 | ||
|
|
d77220a603 | ||
|
|
f52b681133 | ||
|
|
f6efa0d711 | ||
|
|
0fccc91dac | ||
|
|
8d8c6c695a | ||
|
|
57342259ce | ||
|
|
be46ed8865 | ||
|
|
04b2205769 | ||
|
|
76ba357982 | ||
|
|
2c318f6e60 | ||
|
|
3f04153f22 | ||
|
|
3df8af3852 | ||
|
|
8b9ab8a841 | ||
|
|
750dbcc7c3 | ||
|
|
5d6007aaff | ||
|
|
291767031c | ||
|
|
22ffe6ef1d | ||
|
|
02df1a70f3 | ||
|
|
8c5fa9c441 | ||
|
|
e6c558c2a0 | ||
|
|
b52e4d756c | ||
|
|
1089a52ca0 | ||
|
|
c7fb9ab8e3 | ||
|
|
83017d0c80 | ||
|
|
e24217a6ba | ||
|
|
a0f2f738df | ||
|
|
9d9250954b | ||
|
|
f042f44501 | ||
|
|
56c98648f9 | ||
|
|
956efe6a09 | ||
|
|
bb64ad23dd | ||
|
|
a97326df74 | ||
|
|
1503f8781a | ||
|
|
163ddbb6ed | ||
|
|
7bbfd33ca0 | ||
|
|
0ea47ce890 | ||
|
|
38f891235c | ||
|
|
4d83c074d9 | ||
|
|
0e9672df80 | ||
|
|
abc7460539 | ||
|
|
4bb2ccfba7 | ||
|
|
969d428320 | ||
|
|
ff64522c50 | ||
|
|
65dc1a8f48 | ||
|
|
859b7f3c7f | ||
|
|
da3f875555 | ||
|
|
44d63a44da | ||
|
|
7e5e1609b0 | ||
|
|
d94adcb19c | ||
|
|
83894df260 | ||
|
|
7b99a32a1e | ||
|
|
e8c3744f5e | ||
|
|
06d1f54030 | ||
|
|
599ccb6bde | ||
|
|
db9050c302 | ||
|
|
71b3b665b5 | ||
|
|
3b8a806661 | ||
|
|
774719fb50 | ||
|
|
a3ccd41288 | ||
|
|
8ddacb7bc9 | ||
|
|
e74a74c3fb | ||
|
|
262a9ddc48 | ||
|
|
70f84b65ec | ||
|
|
ec5cb42f67 | ||
|
|
0802481fd2 | ||
|
|
548ba0ae36 | ||
|
|
fc2360d40d | ||
|
|
ab67bda5a1 | ||
|
|
376d5ca7d0 | ||
|
|
55438136b0 | ||
|
|
82db3517d7 | ||
|
|
130490c022 | ||
|
|
ede8a11584 | ||
|
|
ba65b06582 | ||
|
|
f4f04036f3 | ||
|
|
43130dcbc8 | ||
|
|
ff6459e439 | ||
|
|
1893de4c75 | ||
|
|
dfcc85a466 | ||
|
|
dacfb360f6 | ||
|
|
8a0d83b340 | ||
|
|
be2ce854a1 | ||
|
|
e492dcd968 | ||
|
|
55bfee856d | ||
|
|
f951075551 | ||
|
|
964086a08a | ||
|
|
67501025b3 | ||
|
|
e1cc5c841a | ||
|
|
6b839bd5a8 | ||
|
|
5df339b56d | ||
|
|
56adca9f22 | ||
|
|
1e63dd8d2d | ||
|
|
fab9272124 | ||
|
|
2f66fd9aae | ||
|
|
5616583fa1 | ||
|
|
3f0e991112 | ||
|
|
477d404727 | ||
|
|
8e6288bca8 | ||
|
|
72bba0662f | ||
|
|
090f46006a | ||
|
|
abe0c7e7d1 | ||
|
|
6516f56ada | ||
|
|
ea391dc44e | ||
|
|
e21f713de0 | ||
|
|
3498e2e884 | ||
|
|
ea8edc5914 | ||
|
|
b62c40dba3 | ||
|
|
0832337839 | ||
|
|
b82f4491fb | ||
|
|
bdf0c256b3 | ||
|
|
3d91a9e926 | ||
|
|
779dbdea26 | ||
|
|
e8e342c206 | ||
|
|
78829d36cc | ||
|
|
f7c2e82dc0 | ||
|
|
88598fb9fb | ||
|
|
19d149c129 | ||
|
|
f09de3a11c | ||
|
|
e13acdc8a9 | ||
|
|
b8e85bed61 | ||
|
|
396493ad2b | ||
|
|
f32d92b9d0 | ||
|
|
6d79db8ba3 | ||
|
|
f9fb480cc3 | ||
|
|
1efa8798bf | ||
|
|
c244e9834f | ||
|
|
b1a7b58f97 | ||
|
|
e81f39b50e | ||
|
|
3c99fb116c | ||
|
|
e7e136036c | ||
|
|
ca84fc6c9d | ||
|
|
a0c4515a81 | ||
|
|
4bf418a3d6 | ||
|
|
f033607c8b | ||
|
|
32d612fbeb | ||
|
|
9ce3a881f3 | ||
|
|
860cd31799 | ||
|
|
d674b48f7d | ||
|
|
1635f9dbef | ||
|
|
07c899f0a9 | ||
|
|
382e4c5377 | ||
|
|
fe6518d052 | ||
|
|
dc513dfbeb | ||
|
|
3d9bc7a986 | ||
|
|
75e36173cd | ||
|
|
8097f227ca | ||
|
|
3d79b72d70 | ||
|
|
6eb9b772e7 | ||
|
|
90c8ff35d1 | ||
|
|
ad87fd96db | ||
|
|
fd1debe681 | ||
|
|
c7cc0cd922 | ||
|
|
81a232177e | ||
|
|
73aee97be5 | ||
|
|
39f3a85bb1 | ||
|
|
098a2e54ae | ||
|
|
d575478b53 | ||
|
|
aab54ca1a8 | ||
|
|
d4f2094ee0 | ||
|
|
c354618e20 | ||
|
|
5141a91041 | ||
|
|
668539e737 | ||
|
|
967139cea4 | ||
|
|
6d8b1aede4 | ||
|
|
744ba31ba6 | ||
|
|
db8257b67a | ||
|
|
85770dc037 | ||
|
|
69f976a79a | ||
|
|
fd7e77eff8 | ||
|
|
05c2a093c0 | ||
|
|
01a1e8eab1 | ||
|
|
b71bc1f875 | ||
|
|
6a0ee22d81 | ||
|
|
cbc8714414 | ||
|
|
065f8db2f7 | ||
|
|
0ac7f83726 | ||
|
|
d03473da10 | ||
|
|
dac1c01a2c | ||
|
|
f6d929ab7a | ||
|
|
a7a2dabc5a | ||
|
|
83015a3404 | ||
|
|
b88e9c5f5e | ||
|
|
8380a8a811 | ||
|
|
6c69181290 | ||
|
|
0694075447 | ||
|
|
d66b9dd8cb | ||
|
|
7267198a8c | ||
|
|
7b8f101824 | ||
|
|
0f36c5c872 | ||
|
|
6a67f028ce | ||
|
|
5d82786c20 | ||
|
|
e368f1c1d6 | ||
|
|
572ce7f9ec | ||
|
|
a4c942a21f | ||
|
|
4859ab3ba7 | ||
|
|
983b5f5087 | ||
|
|
75b87955dd | ||
|
|
110de0afbc | ||
|
|
2c074cd5c1 | ||
|
|
73e51a9b0b | ||
|
|
3a47039919 | ||
|
|
2961ea4e44 | ||
|
|
af2ffc9737 | ||
|
|
d7911244fc | ||
|
|
2a66775e45 | ||
|
|
6029a5a9a8 | ||
|
|
71d9ae15a1 | ||
|
|
f0c3d5f308 | ||
|
|
4706ea59fe | ||
|
|
5774a95f61 | ||
|
|
d660521c5c | ||
|
|
5db2c5092e | ||
|
|
59618457df | ||
|
|
c612dfbc1f | ||
|
|
fc58ac0408 | ||
|
|
8d053c97a7 | ||
|
|
a3e6f67ff7 | ||
|
|
01da2e3eee | ||
|
|
168cce1678 | ||
|
|
7240dfe793 | ||
|
|
b9340ba02d | ||
|
|
4f5ee24bc5 | ||
|
|
6a1b8d3ee3 | ||
|
|
f1207dc8b9 | ||
|
|
86c51559bb | ||
|
|
8b0f806079 | ||
|
|
99e94b3567 | ||
|
|
cfd5c1bc93 | ||
|
|
45d9e45346 | ||
|
|
fcb3845543 | ||
|
|
97eabc0c36 | ||
|
|
5328163973 | ||
|
|
7ff9dfee8c | ||
|
|
5b431400be | ||
|
|
1e1675ec12 | ||
|
|
f941541304 | ||
|
|
3f7083c5b3 | ||
|
|
e81faebf69 | ||
|
|
8a4d58c520 | ||
|
|
2ac29ee89c | ||
|
|
252cdcd6f5 | ||
|
|
16e2c95965 | ||
|
|
10560fb34c | ||
|
|
58aa60ca0e | ||
|
|
d24b186d3e | ||
|
|
b4e81615b1 | ||
|
|
424d2033ea | ||
|
|
fd556f9b00 | ||
|
|
e2f5fa87b1 | ||
|
|
e4a2bd3b9b | ||
|
|
e3ada17a78 | ||
|
|
3e5a7adfe4 | ||
|
|
3237f4cd6e | ||
|
|
beea826377 | ||
|
|
7cdbbefc64 | ||
|
|
18780622b3 | ||
|
|
f405ac4d84 | ||
|
|
9fe47e2fb2 | ||
|
|
e4aaa18f61 | ||
|
|
5c3d9717dd | ||
|
|
ac86bbd60c | ||
|
|
33d12c43b2 | ||
|
|
107c676185 | ||
|
|
0f221b7ee6 | ||
|
|
e1939ef472 | ||
|
|
5438d35f17 | ||
|
|
9c26d1f4c8 | ||
|
|
4c2b31f31f | ||
|
|
4f88a13256 | ||
|
|
21ae448ed7 | ||
|
|
50466124c8 | ||
|
|
ece88a3879 | ||
|
|
cedc4a92cc | ||
|
|
c8065b0c60 | ||
|
|
476632294f | ||
|
|
349d46e043 | ||
|
|
00e0201bf9 | ||
|
|
b9ebe22df1 | ||
|
|
389dd8d402 | ||
|
|
966bd8528d | ||
|
|
8f789d47a2 | ||
|
|
509d1a2e24 | ||
|
|
94a40e49a0 | ||
|
|
8429279eea | ||
|
|
cef14cda9e | ||
|
|
c14f067afb | ||
|
|
6c8dca6379 | ||
|
|
819d205166 | ||
|
|
153e68e055 | ||
|
|
77b9a6a94e | ||
|
|
d68bbab419 | ||
|
|
6d53d9178c | ||
|
|
9e17f65eda | ||
|
|
7373f68172 | ||
|
|
0999bd30d7 | ||
|
|
f01185a7fc | ||
|
|
7cd7303754 | ||
|
|
d19fec2155 | ||
|
|
2612abc9d0 | ||
|
|
06fe3f2f01 | ||
|
|
e2b6c713e7 | ||
|
|
0b3b241436 | ||
|
|
4c18f9e858 | ||
|
|
8fec54c085 | ||
|
|
d8e37a4d2b | ||
|
|
1da2c4fa37 | ||
|
|
d080b44ac3 | ||
|
|
df18868888 | ||
|
|
4438b08560 | ||
|
|
1029f94669 | ||
|
|
0a3acf446d | ||
|
|
c01ad5a19e | ||
|
|
5a7723553c | ||
|
|
975844eccf | ||
|
|
865ad31f2f | ||
|
|
b756f0c86c | ||
|
|
3e5f6176af | ||
|
|
ab5b165dc2 | ||
|
|
f9393c2f63 | ||
|
|
aa6638424c | ||
|
|
834387e254 | ||
|
|
9caa986c80 | ||
|
|
72b84dfc8f | ||
|
|
af10195025 | ||
|
|
22382423ad | ||
|
|
0f80c67cbd | ||
|
|
aa6473c1c7 | ||
|
|
cde61cb6ac | ||
|
|
b1368997c2 | ||
|
|
ec7dc448c1 | ||
|
|
254147265e | ||
|
|
479bba9a4e | ||
|
|
cfb39a6baa | ||
|
|
05c9ed1450 | ||
|
|
f53633a8b8 | ||
|
|
f56bc0f85a | ||
|
|
63882e9391 | ||
|
|
3c4dfb868f | ||
|
|
9600d687fa | ||
|
|
cae9105b8d | ||
|
|
41a0036bf6 | ||
|
|
2c9401ccfb | ||
|
|
08e4ad6a7c | ||
|
|
2b0dedc81c | ||
|
|
314e6e29d5 | ||
|
|
16b87de0df | ||
|
|
8c3af7f4ff | ||
|
|
391cd602a2 | ||
|
|
5f56cc8056 | ||
|
|
827ab27bef | ||
|
|
ccc67df8df | ||
|
|
82538c469f | ||
|
|
076ceee29d | ||
|
|
822b73b015 | ||
|
|
862bff51cb | ||
|
|
247db844a4 | ||
|
|
5495d32822 | ||
|
|
bccbeaabe4 | ||
|
|
a496991400 | ||
|
|
0ea83b4364 | ||
|
|
03676b7adc | ||
|
|
af6fde414f | ||
|
|
d069809001 | ||
|
|
fc240849cf | ||
|
|
61d2a328fe | ||
|
|
fed0ae8e9c | ||
|
|
eaf0de453b | ||
|
|
e833db954a | ||
|
|
0b2651f4ed | ||
|
|
10c677a6fd | ||
|
|
3398c4737a | ||
|
|
a008f5fbef | ||
|
|
6a42e73667 | ||
|
|
7611db19f3 | ||
|
|
d3399dfaf5 | ||
|
|
248f0d95ac | ||
|
|
5c39d841ee | ||
|
|
87be67cb9a | ||
|
|
1a08bea864 | ||
|
|
bc4406cec6 | ||
|
|
4206c849c3 | ||
|
|
3f052b7798 | ||
|
|
f1c5f24f6b | ||
|
|
e981c95225 | ||
|
|
4ce4f53835 | ||
|
|
f16e369540 | ||
|
|
47bf93d65e | ||
|
|
5c2e0af33e | ||
|
|
aaa0410781 | ||
|
|
366b148f3d | ||
|
|
6a265de31c | ||
|
|
c3707f543c | ||
|
|
8de368348b | ||
|
|
d052c31ac5 | ||
|
|
31320afed6 | ||
|
|
7afe507296 | ||
|
|
4188443101 | ||
|
|
a1fc0fd394 | ||
|
|
71fe35533d | ||
|
|
a2ed335e59 | ||
|
|
8422a05d74 | ||
|
|
139ae3bcb4 | ||
|
|
a0a57d5fbb | ||
|
|
80fa88ac37 | ||
|
|
0fda1c752d | ||
|
|
6c2fc75199 | ||
|
|
2cb6aeb022 | ||
|
|
e0174f75b3 | ||
|
|
51d04746a3 | ||
|
|
3b08d6c320 | ||
|
|
495c5802a0 | ||
|
|
621b074b3d | ||
|
|
6df32983b5 | ||
|
|
9c9fe9dde7 | ||
|
|
128c1a6178 | ||
|
|
f90e102854 | ||
|
|
2e1eb9a5a6 | ||
|
|
60a95f6556 | ||
|
|
218637e81d | ||
|
|
404f78af0f | ||
|
|
130f15665c | ||
|
|
6301528301 | ||
|
|
6feea968e0 | ||
|
|
b5199b2eb9 | ||
|
|
78ce2a9a8b | ||
|
|
6ed542b007 | ||
|
|
5322b0c4a3 | ||
|
|
a72d5d2c77 | ||
|
|
16c1cbe24f | ||
|
|
0d8f4c76e7 | ||
|
|
e511b14933 | ||
|
|
b5ba53208e | ||
|
|
b8bfb4d0c5 | ||
|
|
1b666638bc | ||
|
|
2bd364eca3 | ||
|
|
f27fc51801 | ||
|
|
0f85eff76b | ||
|
|
0def474cc2 | ||
|
|
026e4376d4 | ||
|
|
cf571cf02b | ||
|
|
590ec3a446 | ||
|
|
23bfdcefef | ||
|
|
647a978865 | ||
|
|
86f72100f0 | ||
|
|
8b255259ba | ||
|
|
8aad8faae9 | ||
|
|
420f391f3c | ||
|
|
817221347f | ||
|
|
13dce5e265 | ||
|
|
850d9ee70b | ||
|
|
ba36ccb21f | ||
|
|
f712754927 | ||
|
|
efe3865aa4 | ||
|
|
53dbe2f436 | ||
|
|
720498084b | ||
|
|
f5eda38dc9 | ||
|
|
8ada221777 | ||
|
|
4ee198813a | ||
|
|
440e8acd99 | ||
|
|
218671ef06 | ||
|
|
34de0bb9c5 | ||
|
|
8e6cf09056 | ||
|
|
5929072b76 | ||
|
|
37325e9802 | ||
|
|
778bc4bd70 | ||
|
|
f78f59ec42 | ||
|
|
d4c4160215 | ||
|
|
85aea97c21 | ||
|
|
b075cad4de | ||
|
|
f326febc8a | ||
|
|
1738e45090 | ||
|
|
6e758faa37 | ||
|
|
32e79c5df0 | ||
|
|
aa69cd3a0c | ||
|
|
da4a1f536d | ||
|
|
b3af757167 | ||
|
|
82794f051a | ||
|
|
a726a81224 | ||
|
|
9aae6163f0 | ||
|
|
941527e7ee | ||
|
|
a3f05220d3 | ||
|
|
7446241735 | ||
|
|
6033d37537 | ||
|
|
1524d7b5ce | ||
|
|
e00341a4cc | ||
|
|
f5185d2e95 | ||
|
|
c041d24989 | ||
|
|
dc9003f9db | ||
|
|
07e0c70629 | ||
|
|
37f77e0990 | ||
|
|
aef1a57ea8 | ||
|
|
69af479224 | ||
|
|
f38223c97f | ||
|
|
1ac6702eb0 | ||
|
|
2510f60dce | ||
|
|
b9d7fb2598 | ||
|
|
a39ba564fa | ||
|
|
34310bfabe | ||
|
|
78fd189510 | ||
|
|
94836ed9af | ||
|
|
1d662fb63e | ||
|
|
d1933d2aef | ||
|
|
163872be6e | ||
|
|
14fcb66a9c | ||
|
|
c488eb0cd0 | ||
|
|
91d20f7272 | ||
|
|
c3d7963fe0 | ||
|
|
c31a92bf01 | ||
|
|
b5703c1b82 | ||
|
|
df34735a9b | ||
|
|
31bee889d7 | ||
|
|
b3ba0a6ed6 | ||
|
|
ce3b7897d7 | ||
|
|
9115ad6950 | ||
|
|
c6b76438f4 | ||
|
|
68c4c7429c | ||
|
|
8466c8e019 | ||
|
|
d899b27448 | ||
|
|
229eb5cc86 | ||
|
|
66c153f1ad | ||
|
|
bbb2c6c903 | ||
|
|
5edf3f2b8a | ||
|
|
006c6cd159 | ||
|
|
9675982555 | ||
|
|
c6c7a1827c | ||
|
|
3ac8a9431b | ||
|
|
5c42a84c3e | ||
|
|
8fdaebbe6e | ||
|
|
9a98ccff2c | ||
|
|
ee4027c561 | ||
|
|
7f36a06f26 | ||
|
|
0826a34d8b | ||
|
|
1792cb4d93 | ||
|
|
304ccef101 | ||
|
|
bdc22c892d | ||
|
|
a5034e84ba | ||
|
|
6e2de96fed | ||
|
|
2b6d86e591 | ||
|
|
8c6f4cb117 | ||
|
|
16d4b32eb7 | ||
|
|
45a64dbbac | ||
|
|
537668b463 | ||
|
|
07fea23dd0 | ||
|
|
cef14291f0 | ||
|
|
bbde0588af | ||
|
|
aa7d52568b | ||
|
|
f39c77ac70 | ||
|
|
aa733354e8 | ||
|
|
7cec966979 | ||
|
|
74865d2cf2 | ||
|
|
c9a8753473 | ||
|
|
ce8a2cbe34 | ||
|
|
c0fdd0c6d3 | ||
|
|
88bfcfe6cd | ||
|
|
c4dcf1fd65 | ||
|
|
6cebddf893 | ||
|
|
1738ed3664 | ||
|
|
37ddcb91ac | ||
|
|
574ab4506b | ||
|
|
81353538e5 | ||
|
|
5abfcdfbe8 | ||
|
|
9962a61c21 | ||
|
|
5cf2b08777 | ||
|
|
9be1c01b70 | ||
|
|
62b2ecdfc2 | ||
|
|
2ff9000d25 | ||
|
|
5829148ce4 | ||
|
|
8e15a340f6 | ||
|
|
1270b7cdd8 | ||
|
|
7c02fe8148 | ||
|
|
4ac63e1c23 | ||
|
|
4aeb653ed2 | ||
|
|
2d5c2de613 | ||
|
|
96590941cf | ||
|
|
0655ff4a91 | ||
|
|
0ba370052e | ||
|
|
4d59e04aba | ||
|
|
6db6c33564 | ||
|
|
ed0d963aec | ||
|
|
3a36d038ee | ||
|
|
3d068a9c96 | ||
|
|
87df352adc | ||
|
|
8b546b7366 | ||
|
|
77ea0680fb | ||
|
|
4c592bf7e3 | ||
|
|
6718553bf4 | ||
|
|
79dc6f3f69 | ||
|
|
8df72d2822 | ||
|
|
b9578bd08a | ||
|
|
035e56e42f | ||
|
|
3ce5926689 | ||
|
|
035464c0ac | ||
|
|
f1fcffbfc0 | ||
|
|
b79fe07052 | ||
|
|
e6aa0e0e10 | ||
|
|
54700e6fbe | ||
|
|
5a90d4776d | ||
|
|
f81fdca62a | ||
|
|
3a0671c661 | ||
|
|
1037729fb3 | ||
|
|
5f211620c5 | ||
|
|
cb6a3aae9e | ||
|
|
5e512df3d4 | ||
|
|
9916cf3265 | ||
|
|
729c283c63 | ||
|
|
c99f04314c | ||
|
|
dd9be2ed90 | ||
|
|
f7aed9dd98 | ||
|
|
5253cf3899 | ||
|
|
f7d92be5ea | ||
|
|
97d8168824 | ||
|
|
550bd4da23 | ||
|
|
2327be7557 | ||
|
|
a7ffc19ba1 | ||
|
|
bbaa39c569 | ||
|
|
d1de0250e7 | ||
|
|
2d731c6412 | ||
|
|
6a6e64f487 | ||
|
|
b9201c918a | ||
|
|
7dedad898a | ||
|
|
d497189352 | ||
|
|
fa4da8f467 | ||
|
|
e9ff742162 | ||
|
|
3849cfb835 | ||
|
|
c453af23c6 | ||
|
|
bcf2376f5a | ||
|
|
4f0b653a82 | ||
|
|
be2f56ae6a | ||
|
|
cbc9602495 | ||
|
|
616709acbb | ||
|
|
c72ce381c0 | ||
|
|
67053ab8ae | ||
|
|
33238d34c9 | ||
|
|
2ef54168fc | ||
|
|
b33ccf00f9 | ||
|
|
829eb4b3be | ||
|
|
6c49456c13 | ||
|
|
fc8f06ee14 | ||
|
|
120a524b7e | ||
|
|
bd037ac3a3 | ||
|
|
b8ea427029 | ||
|
|
275be47224 | ||
|
|
4ea9c7e660 | ||
|
|
92d78d9a52 | ||
|
|
a820001eea | ||
|
|
8273f6d217 | ||
|
|
bd63e0fce8 | ||
|
|
12ba3d473e | ||
|
|
0b9cc0f068 | ||
|
|
5ca397befa | ||
|
|
da735fe776 | ||
|
|
b4f69f2cff | ||
|
|
1885c00cbc | ||
|
|
1e4fdeb1a6 | ||
|
|
cb7dbb0ed4 | ||
|
|
44083aec79 | ||
|
|
4a9b743153 | ||
|
|
b462e17a5b | ||
|
|
b272a52b57 | ||
|
|
3f87c64e83 | ||
|
|
1795364f5f | ||
|
|
e69fbb2f97 | ||
|
|
32b40fc6bf | ||
|
|
f039ea7f56 | ||
|
|
41334f5f1e | ||
|
|
79b19b744e | ||
|
|
2103410694 | ||
|
|
2143d94e83 | ||
|
|
9ae2612945 | ||
|
|
3a09b26b6d | ||
|
|
e381449aec | ||
|
|
bacffc94d9 | ||
|
|
7044f705e7 | ||
|
|
6db4fe28a7 | ||
|
|
f966176694 | ||
|
|
bd24de4577 | ||
|
|
dc2ea5c007 | ||
|
|
4fb673077a | ||
|
|
b3a136ac03 | ||
|
|
22f1bfa3fa | ||
|
|
f6ad0aab94 | ||
|
|
371fdeb948 | ||
|
|
f7a0af75c4 | ||
|
|
b31e526e4d | ||
|
|
26abf7b586 | ||
|
|
d477e24e34 | ||
|
|
3ca3e8e023 | ||
|
|
3bd374495b | ||
|
|
b26f60ee8d | ||
|
|
df681eaf22 | ||
|
|
01458ac111 | ||
|
|
6c7a68802b | ||
|
|
e3074b833f | ||
|
|
1097d699f8 | ||
|
|
55b4e0ebd3 | ||
|
|
0011a8ce9f | ||
|
|
100bf4fa49 | ||
|
|
6da5b81311 | ||
|
|
787adf5423 | ||
|
|
01b500e7d1 | ||
|
|
e64603ea27 | ||
|
|
4219e12cc0 | ||
|
|
c86ccf0931 | ||
|
|
d4571fb75b | ||
|
|
ec2369c397 | ||
|
|
6ebd48408b | ||
|
|
7e7b54593c | ||
|
|
f93c9f5cd2 | ||
|
|
a810fbe008 | ||
|
|
600a914bd9 | ||
|
|
b1688950c4 | ||
|
|
d8e3f9b7b8 | ||
|
|
08d55e4463 | ||
|
|
55e2baa865 | ||
|
|
55174dc707 | ||
|
|
d57e3b3f64 | ||
|
|
aa42cd0aec | ||
|
|
ac6d9a39ec | ||
|
|
9b07775395 | ||
|
|
936fb8b8a1 | ||
|
|
6c8318b696 | ||
|
|
d554079e2b | ||
|
|
37464a101e | ||
|
|
c5674246b0 | ||
|
|
f076199e3f | ||
|
|
8326db1143 | ||
|
|
992e41e0a0 | ||
|
|
076e95d5c2 | ||
|
|
dfd79e5972 | ||
|
|
b16c9d53ef | ||
|
|
5fe85fb457 | ||
|
|
b45f470310 | ||
|
|
0ecda33ab8 | ||
|
|
7fcfca455a | ||
|
|
6a32154b8f | ||
|
|
132206677f | ||
|
|
30a8775548 | ||
|
|
045bc9aefc | ||
|
|
d5c46574cc | ||
|
|
37fea09403 | ||
|
|
063e8fae43 | ||
|
|
184c4fbf7f | ||
|
|
e19d27f640 | ||
|
|
ea96830758 | ||
|
|
d2edbc738d | ||
|
|
03bc8c8280 | ||
|
|
68908213da | ||
|
|
b3d5add89a | ||
|
|
7fe2d8fbe1 | ||
|
|
de545a69ca | ||
|
|
dc48ba540d | ||
|
|
81e92b4fa6 | ||
|
|
ebad5e00a3 | ||
|
|
bca03f1365 | ||
|
|
c89f55f0bd | ||
|
|
4d98bace87 | ||
|
|
dcdc899528 | ||
|
|
b57aa55001 | ||
|
|
d0c0168c20 | ||
|
|
af596a09cf | ||
|
|
6849c620b8 | ||
|
|
12598f0dca | ||
|
|
3f4ce4f16f | ||
|
|
4aaf0d8d5c | ||
|
|
65db056e09 | ||
|
|
232cef7cb9 | ||
|
|
73a432879a | ||
|
|
09afec17f9 | ||
|
|
ac47ab3deb | ||
|
|
8b3d7c168a | ||
|
|
60e8eb63ac | ||
|
|
4f29cd24b8 | ||
|
|
ba73ade2a0 | ||
|
|
7559305fc9 | ||
|
|
6985f553f9 | ||
|
|
8fc15df6d0 | ||
|
|
eb8160a5af | ||
|
|
16cf6eee9b | ||
|
|
320f684354 | ||
|
|
12062a5440 | ||
|
|
4423a9d979 | ||
|
|
1eb44defb6 | ||
|
|
e253fba2e9 | ||
|
|
c05d95924f | ||
|
|
2db583d62d | ||
|
|
59d8e1bf9f | ||
|
|
1001344c27 | ||
|
|
8a0e2da03f | ||
|
|
f58886be6f | ||
|
|
3c1d3b4d6a | ||
|
|
bbba995ff7 | ||
|
|
0033b5be80 | ||
|
|
87d53fb9b7 | ||
|
|
157031f23e | ||
|
|
8a37869489 | ||
|
|
5c10f11681 | ||
|
|
7b72bf0cd0 | ||
|
|
be29666916 | ||
|
|
8d4c5b5b33 | ||
|
|
52260f469a | ||
|
|
c566d22836 | ||
|
|
75f59a86c8 | ||
|
|
1eaf12446f | ||
|
|
efdd42426e | ||
|
|
62c557deae | ||
|
|
db1da4a61a | ||
|
|
db46c186aa | ||
|
|
677a603835 | ||
|
|
447d8790ad | ||
|
|
7a78f15a90 | ||
|
|
c1941809e9 | ||
|
|
623aaf8a0e | ||
|
|
7b3bf41120 | ||
|
|
0c3960eb0b | ||
|
|
fe3c31c08c | ||
|
|
94600cdbfc | ||
|
|
4e7ab3d7e3 | ||
|
|
47b25d7a26 | ||
|
|
0249666fa4 | ||
|
|
2e8504ce2f | ||
|
|
aca7d25001 | ||
|
|
2444309bc2 | ||
|
|
97c5a78d48 | ||
|
|
effdb88455 | ||
|
|
2f0ce3852e | ||
|
|
5475496399 | ||
|
|
b569d77a23 | ||
|
|
dfa7a2d4cf | ||
|
|
169e01276d | ||
|
|
07e698265e | ||
|
|
0632d7611f | ||
|
|
b3f39eedac | ||
|
|
46ed7e38bf | ||
|
|
8c5199d32d | ||
|
|
36ed833d64 | ||
|
|
47969ce61e | ||
|
|
06731e2026 | ||
|
|
123347169d | ||
|
|
f9101a744c | ||
|
|
97eb33000f | ||
|
|
60231ec88d | ||
|
|
3364374dc6 | ||
|
|
a3cf773e75 | ||
|
|
4092d5fbaf | ||
|
|
07e9fde9e8 | ||
|
|
9b4613630b | ||
|
|
f125d11b6d | ||
|
|
657d48a5f9 | ||
|
|
3735bdde19 | ||
|
|
3f906d81cb | ||
|
|
7c1f622797 | ||
|
|
cfe696ae8d | ||
|
|
021c50a8f2 | ||
|
|
95745ba869 | ||
|
|
adfae54816 | ||
|
|
10ed093eb8 | ||
|
|
c96df6bfa5 | ||
|
|
0126d18525 | ||
|
|
9e6e8f50f8 | ||
|
|
7e0b31626f | ||
|
|
1d9e249a77 | ||
|
|
88b89ef315 | ||
|
|
62b7925cb0 | ||
|
|
cc1528f550 | ||
|
|
1c8a83140b | ||
|
|
34276e2066 | ||
|
|
71abd16ae7 | ||
|
|
918e7285c4 | ||
|
|
056d422c71 | ||
|
|
5ee54f4e0e | ||
|
|
260c75e70c | ||
|
|
2d7401922f | ||
|
|
8c7a1348cf | ||
|
|
24fbdbd716 | ||
|
|
aad8f0e36b | ||
|
|
15cad44f08 | ||
|
|
0271454671 | ||
|
|
d0ddf288ca | ||
|
|
bc250ac377 | ||
|
|
7922fc3b0e | ||
|
|
161da723b9 | ||
|
|
514c19a247 | ||
|
|
41550d4a41 | ||
|
|
33cc3c1c3f | ||
|
|
7d15182202 | ||
|
|
8f0a1d9c6e | ||
|
|
72b5e5cf8e | ||
|
|
62aba2dd38 | ||
|
|
cdd6b80089 | ||
|
|
333836f5e7 | ||
|
|
a2dfda3471 | ||
|
|
2d28b4b05c | ||
|
|
87f9bcc6a3 | ||
|
|
48aca996ff | ||
|
|
c8c7e9b304 | ||
|
|
97ff023995 | ||
|
|
e273a336f8 | ||
|
|
34f0c3b90c | ||
|
|
7c2902d2b8 | ||
|
|
8e41afdffc | ||
|
|
7268886294 | ||
|
|
cbae900866 | ||
|
|
ffff138a6f | ||
|
|
88c95db8d0 | ||
|
|
56e657a0bb | ||
|
|
bc36b79105 | ||
|
|
5694bc0230 | ||
|
|
36130031f9 | ||
|
|
b8f1095f53 | ||
|
|
442fa09533 | ||
|
|
42ef2efbc8 | ||
|
|
ead3080b2b | ||
|
|
c6ea31c296 | ||
|
|
21eae29bb7 | ||
|
|
406740b524 | ||
|
|
9d30bc4062 | ||
|
|
fad91b64ab | ||
|
|
2132e71a81 | ||
|
|
bd8a451879 | ||
|
|
24dafa7359 | ||
|
|
3b5df793fb | ||
|
|
da835b6138 | ||
|
|
7e650d86a5 | ||
|
|
308e28cecc | ||
|
|
9a3c74fb64 | ||
|
|
f571f0688a | ||
|
|
1e9c32a102 | ||
|
|
8c69199689 | ||
|
|
3efb3e8a35 | ||
|
|
cfcb278406 | ||
|
|
9e195ea63b | ||
|
|
dc0d34c281 | ||
|
|
72076c218f | ||
|
|
151fd3b950 | ||
|
|
2d484fcb30 | ||
|
|
6e0407f404 | ||
|
|
8670aaba1e | ||
|
|
f27de7df35 | ||
|
|
63fa4dc8ec | ||
|
|
a191e32f71 | ||
|
|
9a38e8a4a0 | ||
|
|
6194222289 | ||
|
|
0d077eaeb7 | ||
|
|
b2c7a9a005 | ||
|
|
be01f1869e | ||
|
|
9f2b6390b0 | ||
|
|
e196f86e30 | ||
|
|
ec41d45234 | ||
|
|
567d1ba18b | ||
|
|
df8706983b | ||
|
|
8697498b32 | ||
|
|
af917c538a | ||
|
|
034e97dfa6 | ||
|
|
5e1e5f68e1 | ||
|
|
fb76f765cc | ||
|
|
7a3f57261d | ||
|
|
a1a460625d | ||
|
|
3f42ea2c61 | ||
|
|
940c594066 | ||
|
|
5e47fc45ab | ||
|
|
b471d56a86 | ||
|
|
61f8029205 | ||
|
|
e2f047d035 | ||
|
|
1aff4eda67 | ||
|
|
a6c5c44ed8 | ||
|
|
3f389d685a | ||
|
|
5d5351f0bc | ||
|
|
1224802ac6 | ||
|
|
e919f89caf | ||
|
|
bb8e7a68ea | ||
|
|
48f95e0ea4 | ||
|
|
931e9bcf0d | ||
|
|
67a3351c4c | ||
|
|
dfe5eeed7b | ||
|
|
3464573f17 | ||
|
|
9cf49c9c75 | ||
|
|
4e837cb90c | ||
|
|
e4fb58496b | ||
|
|
15a254c0cd | ||
|
|
d62746fc8c | ||
|
|
4b8b6fe407 | ||
|
|
6754834eb3 | ||
|
|
be98db561d | ||
|
|
574d0afc72 | ||
|
|
31c8ad611c | ||
|
|
b23730388d | ||
|
|
1b853aa893 | ||
|
|
36cb0a12ad | ||
|
|
5439eacf2d | ||
|
|
2687c3b80e | ||
|
|
fa009327ad | ||
|
|
838bd46e83 | ||
|
|
ccc2009aa8 | ||
|
|
d9aba92314 | ||
|
|
696b0475a8 | ||
|
|
e7370489e8 | ||
|
|
f1503b2238 | ||
|
|
cd4661e878 | ||
|
|
364e01ec7a | ||
|
|
ffb7b0ba38 | ||
|
|
22151eb49b | ||
|
|
d0354345f6 | ||
|
|
b1e61eb1e4 | ||
|
|
36e0ed15b6 | ||
|
|
095dfc2879 | ||
|
|
17dea9433e | ||
|
|
c285444e2f | ||
|
|
8ba402d080 | ||
|
|
88ab86734d | ||
|
|
504d87b0b0 | ||
|
|
b0d5818351 | ||
|
|
8826a01d32 | ||
|
|
cfb7a40841 | ||
|
|
8267761890 | ||
|
|
a651ae6ed4 | ||
|
|
a01911ba5f | ||
|
|
ee50b25d06 | ||
|
|
a67be85858 | ||
|
|
59c5a3973a | ||
|
|
d76d7343ff | ||
|
|
2b9638e7d3 | ||
|
|
3459a73705 | ||
|
|
bd480a466b | ||
|
|
4c34cb55b6 | ||
|
|
7347f9104c | ||
|
|
e137e4a38a | ||
|
|
b5989bbc25 | ||
|
|
c31ff7ceef | ||
|
|
9206c7642a | ||
|
|
d1b4f2b6c2 | ||
|
|
75066f2827 | ||
|
|
303f3aefef | ||
|
|
44fb5e0fd5 | ||
|
|
17a695120a | ||
|
|
6dc716eaf8 | ||
|
|
194be086d4 | ||
|
|
cca3900678 | ||
|
|
4fe32b7dbc | ||
|
|
c49603c25b | ||
|
|
8de85a4041 | ||
|
|
58a2135fa4 | ||
|
|
ab9a97db22 | ||
|
|
d291c241d5 | ||
|
|
24d4cb9b94 | ||
|
|
5b9adb799f | ||
|
|
38b41df36b | ||
|
|
34a9befe5c | ||
|
|
67fd579074 | ||
|
|
e2714b942d | ||
|
|
6b2556f870 | ||
|
|
43e6e9d201 | ||
|
|
131e0cc4c7 | ||
|
|
537be81b8f | ||
|
|
765168db7f | ||
|
|
1e16b06a24 | ||
|
|
42b59a644d | ||
|
|
d9fa9039bb | ||
|
|
cd4c93a5cb | ||
|
|
808961243d | ||
|
|
4d80e119f7 | ||
|
|
10c87edae1 | ||
|
|
0eb335d112 | ||
|
|
b8b26ccfe5 | ||
|
|
e89c23da4d | ||
|
|
f3da8956d9 | ||
|
|
b1147d77af | ||
|
|
66bc2fb41f | ||
|
|
4e538a6df8 | ||
|
|
ced087f8ae | ||
|
|
0f1eed0b1e | ||
|
|
95f15b77a3 | ||
|
|
f9ccfd5ca0 | ||
|
|
7207d7c847 | ||
|
|
00c4a524b7 | ||
|
|
9c3e0b5541 | ||
|
|
33bfe33eb3 | ||
|
|
3127c382a4 | ||
|
|
1748a390ec | ||
|
|
a7c0837049 | ||
|
|
44bf1eeae2 | ||
|
|
762b7a8ef1 | ||
|
|
102712a16e | ||
|
|
40810c59d7 | ||
|
|
35a10e86b5 | ||
|
|
c0c985494d | ||
|
|
8984ba7aef | ||
|
|
179869d481 | ||
|
|
5f29956f2b | ||
|
|
dbc4ba84c2 | ||
|
|
9e4a527675 | ||
|
|
45833542a7 | ||
|
|
1be6de30d7 | ||
|
|
981d78c8ba | ||
|
|
fbc7bedb6c | ||
|
|
4786b0c5d4 | ||
|
|
17bed26096 | ||
|
|
511e16f1d3 | ||
|
|
18204bc1f7 | ||
|
|
b58d97fad3 | ||
|
|
d2a67a53b5 | ||
|
|
c0b556000c | ||
|
|
462c3b0696 | ||
|
|
d34ad73439 | ||
|
|
2c21712d58 | ||
|
|
ce01e588c9 | ||
|
|
2a23082203 | ||
|
|
d373f924f6 | ||
|
|
eaf46ee006 | ||
|
|
d51355a0ad | ||
|
|
1e481a311a | ||
|
|
46abb23ee8 | ||
|
|
8555bb697c | ||
|
|
f821893653 | ||
|
|
75b3ea1f05 | ||
|
|
74f0018962 | ||
|
|
3a0f07d36f | ||
|
|
a047cf2e91 | ||
|
|
a8ae16e321 | ||
|
|
a53be31765 | ||
|
|
4475be51cc | ||
|
|
d53cbe7868 | ||
|
|
722746c78b | ||
|
|
e1f5607836 | ||
|
|
7cd0d78424 | ||
|
|
d740559749 | ||
|
|
399357f752 | ||
|
|
9de6b4f151 | ||
|
|
94cced8323 | ||
|
|
9b8ed16e37 | ||
|
|
a5e44cd229 | ||
|
|
eccc208229 | ||
|
|
79cfabb45d | ||
|
|
af6e1e2b99 | ||
|
|
4ad51c1b24 | ||
|
|
c44712167f | ||
|
|
1aabaff1f2 | ||
|
|
21c0383efb | ||
|
|
ebe018347b | ||
|
|
86fe6fe5ab | ||
|
|
9e828b1750 | ||
|
|
940d3d4567 | ||
|
|
6bd7b2b8bb | ||
|
|
f2d6fd7b08 | ||
|
|
b84c82880c | ||
|
|
fcc418b4a0 | ||
|
|
15c0bb4c9e | ||
|
|
8db4f914d8 | ||
|
|
f3f9211c9c | ||
|
|
a2a69840f7 | ||
|
|
3a4a7590c2 | ||
|
|
bcc8b7ce3c | ||
|
|
1c7fe6d134 | ||
|
|
c4039f52bd | ||
|
|
bd851d5e86 | ||
|
|
00e448c5d6 | ||
|
|
4aeec8afbf | ||
|
|
f10432bf3f | ||
|
|
f0efed8aa1 | ||
|
|
4a4931bee2 | ||
|
|
afcf12ebc9 | ||
|
|
8f86d3417d | ||
|
|
92dfc54c4c | ||
|
|
c93bcb8678 | ||
|
|
98b2da9123 | ||
|
|
cd5f1a1b28 | ||
|
|
0e2e495d09 | ||
|
|
84c6c7e2a6 | ||
|
|
c8ebf9c75a | ||
|
|
29852ff0a5 | ||
|
|
f06ca62589 | ||
|
|
3f39a2be12 | ||
|
|
575190a96d | ||
|
|
78559d98eb | ||
|
|
398964c747 | ||
|
|
a634565296 | ||
|
|
a5ecbec9a6 | ||
|
|
fe79978f88 | ||
|
|
978ec8bc75 | ||
|
|
6e77f5b068 | ||
|
|
c9dbb64269 | ||
|
|
546d32e3eb | ||
|
|
616f6401b4 | ||
|
|
d047190453 | ||
|
|
17504b1b9c | ||
|
|
5a0d3df689 | ||
|
|
871304c89b | ||
|
|
8155150e45 | ||
|
|
d9fb8edaa9 | ||
|
|
dda61679bd | ||
|
|
6ac10a8297 | ||
|
|
0695c11739 | ||
|
|
7a4297c4f1 | ||
|
|
2c9e5df27d | ||
|
|
6db37d35ed | ||
|
|
ceee4fe5cf | ||
|
|
130b4a57de | ||
|
|
1cee27e830 | ||
|
|
ba2ff053f9 | ||
|
|
227665439f | ||
|
|
1a2e043ec2 | ||
|
|
89500df0ac | ||
|
|
cb4e80f1bc |
164
.github/workflows/release-notify-wechat.yml
vendored
Normal file
164
.github/workflows/release-notify-wechat.yml
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
name: Release Notify Workflow
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
|
||||
jobs:
|
||||
notify:
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
startsWith(github.event.pull_request.base.ref, 'release')
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# 防止 GitHub HEAD 未同步
|
||||
- run: sleep 3
|
||||
|
||||
# 1️⃣ 获取分支 HEAD
|
||||
- name: Get HEAD
|
||||
id: head
|
||||
run: |
|
||||
HEAD_SHA=$(curl -s \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
https://api.github.com/repos/${{ github.repository }}/git/ref/heads/${{ github.event.pull_request.base.ref }} \
|
||||
| jq -r '.object.sha')
|
||||
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
|
||||
|
||||
# 2️⃣ 判断是否最终PR
|
||||
- name: Check Latest
|
||||
id: check
|
||||
run: |
|
||||
if [ "${{ github.event.pull_request.merge_commit_sha }}" = "${{ steps.head.outputs.head_sha }}" ]; then
|
||||
echo "ok=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "ok=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# 3️⃣ 尝试从 PR body 提取 Sourcery 摘要
|
||||
- name: Extract Sourcery Summary
|
||||
if: steps.check.outputs.ok == 'true'
|
||||
id: sourcery
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import os, re
|
||||
|
||||
body = os.environ.get("PR_BODY", "") or ""
|
||||
match = re.search(
|
||||
r"## Summary by Sourcery\s*\n(.*?)(?=\n## |\Z)",
|
||||
body,
|
||||
re.DOTALL
|
||||
)
|
||||
|
||||
if match:
|
||||
summary = match.group(1).strip()
|
||||
found = "true"
|
||||
else:
|
||||
summary = ""
|
||||
found = "false"
|
||||
|
||||
with open("sourcery_summary.txt", "w", encoding="utf-8") as f:
|
||||
f.write(summary)
|
||||
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||
gh.write(f"found={found}\n")
|
||||
gh.write("summary<<EOF\n")
|
||||
gh.write(summary + "\n")
|
||||
gh.write("EOF\n")
|
||||
PYEOF
|
||||
|
||||
# 4️⃣ Fallback: 获取 commits + 通义千问总结
|
||||
- name: Get Commits
|
||||
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||
run: |
|
||||
curl -s \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
${{ github.event.pull_request.commits_url }} \
|
||||
| jq -r '.[].commit.message' | head -n 20 > commits.txt
|
||||
|
||||
- name: AI Summary (Qwen Fallback)
|
||||
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||
id: qwen
|
||||
env:
|
||||
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import json, os, urllib.request
|
||||
|
||||
with open("commits.txt", "r") as f:
|
||||
commits = f.read().strip()
|
||||
|
||||
prompt = "请用中文总结以下代码提交,输出3-5条要点,面向测试人员。直接输出编号列表,不要输出标题或前言:\n" + commits
|
||||
payload = {"model": "qwen-plus", "input": {"prompt": prompt}}
|
||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
|
||||
data=data,
|
||||
headers={
|
||||
"Authorization": "Bearer " + os.environ["DASHSCOPE_API_KEY"],
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
resp = urllib.request.urlopen(req)
|
||||
result = json.loads(resp.read().decode())
|
||||
summary = result.get("output", {}).get("text", "AI 摘要生成失败")
|
||||
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||
gh.write("summary<<EOF\n")
|
||||
gh.write(summary + "\n")
|
||||
gh.write("EOF\n")
|
||||
PYEOF
|
||||
|
||||
# 5️⃣ 企业微信通知(Markdown)
|
||||
- name: Notify WeChat
|
||||
if: steps.check.outputs.ok == 'true'
|
||||
env:
|
||||
WECHAT_WEBHOOK: ${{ secrets.WECHAT_WEBHOOK }}
|
||||
BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||
AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SOURCERY_FOUND: ${{ steps.sourcery.outputs.found }}
|
||||
SOURCERY_SUMMARY: ${{ steps.sourcery.outputs.summary }}
|
||||
QWEN_SUMMARY: ${{ steps.qwen.outputs.summary }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import json, os, urllib.request
|
||||
|
||||
if os.environ.get("SOURCERY_FOUND") == "true":
|
||||
label = "Summary by Sourcery"
|
||||
summary = os.environ.get("SOURCERY_SUMMARY", "")
|
||||
else:
|
||||
label = "AI变更摘要"
|
||||
summary = os.environ.get("QWEN_SUMMARY", "AI 摘要生成失败")
|
||||
|
||||
pr_number = os.environ.get("PR_NUMBER", "")
|
||||
short_sha = os.environ.get("MERGE_SHA", "")[:7]
|
||||
|
||||
content = (
|
||||
"## 🚀 Release 发布通知\n"
|
||||
"> <20> **分支**: " + os.environ["BRANCH"] + "\n"
|
||||
"> 👤 **提交人**: " + os.environ["AUTHOR"] + "\n"
|
||||
"> 📝 **标题**: " + os.environ["PR_TITLE"] + "\n"
|
||||
"> 🔢 **PR编号**: #" + pr_number + "\n"
|
||||
"> 🔖 **Commit**: " + short_sha + "\n\n"
|
||||
"### 🧠 " + label + "\n" +
|
||||
summary + "\n\n"
|
||||
"---\n"
|
||||
"🔗 [查看PR详情](" + os.environ["PR_URL"] + ")"
|
||||
)
|
||||
payload = {"msgtype": "markdown", "markdown": {"content": content}}
|
||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
os.environ["WECHAT_WEBHOOK"],
|
||||
data=data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
resp = urllib.request.urlopen(req)
|
||||
print(resp.read().decode())
|
||||
PYEOF
|
||||
33
.github/workflows/sync-to-gitee.yml
vendored
Normal file
33
.github/workflows/sync-to-gitee.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Sync to Gitee
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- '**' # All branchs
|
||||
tags:
|
||||
- '**' # All version tags (v1.0.0, etc.)
|
||||
|
||||
jobs:
|
||||
sync:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout Source Code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Sync to Gitee
|
||||
run: |
|
||||
GITEE_URL="https://${{ secrets.GITEE_USERNAME }}:${{ secrets.GITEE_TOKEN }}@gitee.com/hangzhou-hongxiong-intelligent_1/MemoryBear.git"
|
||||
git remote add gitee "$GITEE_URL"
|
||||
|
||||
# 遍历并推送所有分支
|
||||
for branch in $(git branch -r | grep -v HEAD | sed 's/origin\///'); do
|
||||
echo "Syncing branch: $branch"
|
||||
git push -f gitee "origin/$branch:refs/heads/$branch"
|
||||
done
|
||||
|
||||
# 推送所有标签
|
||||
echo "Syncing tags..."
|
||||
git push gitee --tags --force
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -18,16 +18,22 @@ examples/
|
||||
.kiro
|
||||
.vscode
|
||||
.idea
|
||||
.claude
|
||||
|
||||
# Temporary outputs
|
||||
.DS_Store
|
||||
.hypothesis/
|
||||
time.log
|
||||
celerybeat-schedule.db
|
||||
search_results.json
|
||||
redbear-mem-metrics/
|
||||
redbear-mem-benchmark/
|
||||
pitch-deck/
|
||||
|
||||
api/migrations/versions
|
||||
tmp
|
||||
files
|
||||
powers/
|
||||
|
||||
# Exclude dep files
|
||||
huggingface.co/
|
||||
@@ -36,5 +42,4 @@ tika-server*.jar*
|
||||
cl100k_base.tiktoken
|
||||
libssl*.deb
|
||||
|
||||
sandbox/lib/seccomp_python/target
|
||||
sandbox/lib/seccomp_nodejs/target
|
||||
sandbox/lib/seccomp_redbear/target
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
# MemoryBear empowers AI with human-like memory capabilities
|
||||
|
||||
[](LICENSE)
|
||||
[](https://www.python.org/)
|
||||
[](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
|
||||
|
||||
[中文](./README_CN.md) | English
|
||||
|
||||
### [Installation Guide](#memorybear-installation-guide)
|
||||
@@ -226,8 +230,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (Using Redis as broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
# MemoryBear 让AI拥有如同人类一样的记忆
|
||||
|
||||
[](LICENSE)
|
||||
[](https://www.python.org/)
|
||||
[](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
|
||||
|
||||
中文 | [English](./README.md)
|
||||
|
||||
### [安装教程](#memorybear安装教程)
|
||||
@@ -201,8 +205,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (使用Redis作为broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT密钥 (生成方式: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
|
||||
@@ -45,7 +45,8 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||
apt install -y ghostscript
|
||||
apt install -y ghostscript && \
|
||||
apt install -y libmagic1
|
||||
|
||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
|
||||
@@ -60,7 +60,12 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = postgresql://user:password@localhost/dbname
|
||||
# Database connection URL - DO NOT hardcode credentials here!
|
||||
# Connection string is set dynamically from environment variables in migrations/env.py
|
||||
# Required env vars: DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME
|
||||
# Example: postgresql://user:password@localhost:5432/dbname
|
||||
; sqlalchemy.url = postgresql://user:password@host:port/dbname
|
||||
sqlalchemy.url = driver://user:password@host:port/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio import ConnectionPool
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# 设置日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 创建连接池
|
||||
pool = ConnectionPool.from_url(
|
||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
|
||||
@@ -21,6 +23,51 @@ pool = ConnectionPool.from_url(
|
||||
)
|
||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||
|
||||
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
|
||||
|
||||
# Thread-local storage for connection pools.
|
||||
# Each thread (and each forked process) gets its own pool to avoid
|
||||
# "Future attached to a different loop" errors in Celery --pool=threads
|
||||
# and stale connections after fork in --pool=prefork.
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
def get_thread_safe_redis() -> redis.StrictRedis:
|
||||
"""Return a Redis client whose connection pool is bound to the current
|
||||
thread, process **and** event loop.
|
||||
|
||||
The pool is recreated when:
|
||||
- The PID changes (fork, Celery --pool=prefork)
|
||||
- The thread has no pool yet (Celery --pool=threads)
|
||||
- The previously-cached event loop has been closed (Celery tasks call
|
||||
``_shutdown_loop_gracefully`` which closes the loop after each run)
|
||||
"""
|
||||
current_pid = os.getpid()
|
||||
cached_loop = getattr(_thread_local, "loop", None)
|
||||
loop_stale = cached_loop is not None and cached_loop.is_closed()
|
||||
|
||||
if not hasattr(_thread_local, "pool") \
|
||||
or getattr(_thread_local, "pid", None) != current_pid \
|
||||
or loop_stale:
|
||||
_thread_local.pid = current_pid
|
||||
# Python 3.10+: get_event_loop() raises RuntimeError in threads
|
||||
# where no loop has been set yet (e.g. Celery --pool=threads).
|
||||
try:
|
||||
_thread_local.loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
_thread_local.loop = None
|
||||
_thread_local.pool = ConnectionPool.from_url(
|
||||
_REDIS_URL,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
decode_responses=True,
|
||||
max_connections=5,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
return redis.StrictRedis(connection_pool=_thread_local.pool)
|
||||
|
||||
|
||||
async def get_redis_connection():
|
||||
"""获取Redis连接"""
|
||||
try:
|
||||
@@ -29,7 +76,8 @@ async def get_redis_connection():
|
||||
logger.error(f"Redis连接失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
|
||||
async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
||||
"""设置Redis键值
|
||||
|
||||
Args:
|
||||
@@ -40,16 +88,15 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
val = json.dumps(val, ensure_ascii=False)
|
||||
|
||||
|
||||
if expire is not None:
|
||||
# 设置带过期时间的键值
|
||||
await aio_redis.set(key, val, ex=expire)
|
||||
else:
|
||||
# 设置永久键值
|
||||
await aio_redis.set(key, val)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set错误: {str(e)}")
|
||||
|
||||
|
||||
async def aio_redis_get(key: str):
|
||||
"""获取Redis键值"""
|
||||
try:
|
||||
@@ -58,6 +105,7 @@ async def aio_redis_get(key: str):
|
||||
logger.error(f"Redis get错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def aio_redis_delete(key: str):
|
||||
"""删除Redis键"""
|
||||
try:
|
||||
@@ -66,6 +114,7 @@ async def aio_redis_delete(key: str):
|
||||
logger.error(f"Redis delete错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||
"""发布消息到Redis频道"""
|
||||
try:
|
||||
@@ -78,9 +127,10 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||
logger.error(f"Redis发布错误: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
class RedisSubscriber:
|
||||
"""Redis订阅器"""
|
||||
|
||||
|
||||
def __init__(self, channel: str):
|
||||
self.channel = channel
|
||||
self.conn = None
|
||||
@@ -88,25 +138,25 @@ class RedisSubscriber:
|
||||
self.is_closed = False
|
||||
self._queue = asyncio.Queue()
|
||||
self._task = None
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""开始订阅"""
|
||||
if self.is_closed or self._task:
|
||||
return
|
||||
|
||||
|
||||
self._task = asyncio.create_task(self._receive_messages())
|
||||
logger.info(f"开始订阅: {self.channel}")
|
||||
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""接收消息"""
|
||||
try:
|
||||
self.conn = await get_redis_connection()
|
||||
if not self.conn:
|
||||
return
|
||||
|
||||
|
||||
self.pubsub = self.conn.pubsub()
|
||||
await self.pubsub.subscribe(self.channel)
|
||||
|
||||
|
||||
while not self.is_closed:
|
||||
try:
|
||||
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
|
||||
@@ -127,7 +177,7 @@ class RedisSubscriber:
|
||||
finally:
|
||||
await self._queue.put(None)
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
async def _cleanup(self):
|
||||
"""清理资源"""
|
||||
if self.pubsub:
|
||||
@@ -141,7 +191,7 @@ class RedisSubscriber:
|
||||
await self.conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def get_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取消息"""
|
||||
if self.is_closed:
|
||||
@@ -153,7 +203,7 @@ class RedisSubscriber:
|
||||
except Exception as e:
|
||||
logger.error(f"获取消息错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def close(self):
|
||||
"""关闭订阅器"""
|
||||
if self.is_closed:
|
||||
@@ -163,32 +213,33 @@ class RedisSubscriber:
|
||||
self._task.cancel()
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
class RedisPubSubManager:
|
||||
"""Redis发布订阅管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.subscribers = {}
|
||||
|
||||
|
||||
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
|
||||
return await aio_redis_publish(channel, message)
|
||||
|
||||
|
||||
def get_subscriber(self, channel: str) -> RedisSubscriber:
|
||||
if channel in self.subscribers:
|
||||
subscriber = self.subscribers[channel]
|
||||
if not subscriber.is_closed:
|
||||
return subscriber
|
||||
|
||||
|
||||
subscriber = RedisSubscriber(channel)
|
||||
self.subscribers[channel] = subscriber
|
||||
return subscriber
|
||||
|
||||
|
||||
def cancel_subscription(self, channel: str) -> bool:
|
||||
if channel in self.subscribers:
|
||||
asyncio.create_task(self.subscribers[channel].close())
|
||||
del self.subscribers[channel]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cancel_all_subscriptions(self) -> int:
|
||||
count = len(self.subscribers)
|
||||
for subscriber in self.subscribers.values():
|
||||
@@ -196,6 +247,6 @@ class RedisPubSubManager:
|
||||
self.subscribers.clear()
|
||||
return count
|
||||
|
||||
|
||||
# 全局实例
|
||||
pubsub_manager = RedisPubSubManager()
|
||||
|
||||
|
||||
5
api/app/cache/__init__.py
vendored
5
api/app/cache/__init__.py
vendored
@@ -3,9 +3,8 @@ Cache 缓存模块
|
||||
|
||||
提供各种缓存功能的统一入口
|
||||
"""
|
||||
from .memory import EmotionMemoryCache, ImplicitMemoryCache
|
||||
from .memory import InterestMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
"InterestMemoryCache",
|
||||
]
|
||||
|
||||
8
api/app/cache/memory/__init__.py
vendored
8
api/app/cache/memory/__init__.py
vendored
@@ -3,10 +3,10 @@ Memory 缓存模块
|
||||
|
||||
提供记忆系统相关的缓存功能
|
||||
"""
|
||||
from .emotion_memory import EmotionMemoryCache
|
||||
from .implicit_memory import ImplicitMemoryCache
|
||||
from .interest_memory import InterestMemoryCache
|
||||
from .activity_stats_cache import ActivityStatsCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
"InterestMemoryCache",
|
||||
"ActivityStatsCache",
|
||||
]
|
||||
|
||||
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Recent Activity Stats Cache
|
||||
|
||||
记忆提取活动统计缓存模块
|
||||
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放
|
||||
查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import get_thread_safe_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期时间:24小时
|
||||
ACTIVITY_STATS_CACHE_EXPIRE = 86400
|
||||
|
||||
|
||||
class ActivityStatsCache:
|
||||
"""记忆提取活动统计缓存类"""
|
||||
|
||||
PREFIX = "cache:memory:activity_stats"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, workspace_id: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
|
||||
|
||||
@classmethod
|
||||
async def set_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
stats: Dict[str, Any],
|
||||
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
|
||||
) -> bool:
|
||||
"""设置记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
stats: 统计数据,格式:
|
||||
{
|
||||
"chunk_count": int,
|
||||
"statements_count": int,
|
||||
"triplet_entities_count": int,
|
||||
"triplet_relations_count": int,
|
||||
"temporal_count": int,
|
||||
}
|
||||
expire: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
payload = {
|
||||
"stats": stats,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"workspace_id": workspace_id,
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""获取记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
统计数据字典,缓存不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
value = await get_thread_safe_redis().get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中活动统计缓存: {key}")
|
||||
return payload
|
||||
logger.info(f"活动统计缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
) -> bool:
|
||||
"""删除记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
result = await get_thread_safe_redis().delete(key)
|
||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
134
api/app/cache/memory/emotion_memory.py
vendored
134
api/app/cache/memory/emotion_memory.py
vendored
@@ -1,134 +0,0 @@
|
||||
"""
|
||||
Emotion Suggestions Cache
|
||||
|
||||
情绪个性化建议缓存模块
|
||||
用于缓存用户的情绪个性化建议数据
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmotionMemoryCache:
|
||||
"""情绪建议缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:emotion_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_emotion_suggestions(
|
||||
cls,
|
||||
user_id: str,
|
||||
suggestions_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
suggestions_data: 建议数据字典,包含:
|
||||
- health_summary: 健康状态摘要
|
||||
- suggestions: 建议列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in suggestions_data:
|
||||
suggestions_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
suggestions_data["cached"] = True
|
||||
|
||||
value = json.dumps(suggestions_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
建议数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取情绪建议缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"情绪建议缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_emotion_suggestions(cls, user_id: str) -> bool:
|
||||
"""删除用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除情绪建议缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_suggestions_ttl(cls, user_id: str) -> int:
|
||||
"""获取情绪建议缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"情绪建议缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存TTL失败: {e}")
|
||||
return -2
|
||||
136
api/app/cache/memory/implicit_memory.py
vendored
136
api/app/cache/memory/implicit_memory.py
vendored
@@ -1,136 +0,0 @@
|
||||
"""
|
||||
Implicit Memory Profile Cache
|
||||
|
||||
隐式记忆用户画像缓存模块
|
||||
用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯)
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplicitMemoryCache:
|
||||
"""隐式记忆用户画像缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:implicit_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_user_profile(
|
||||
cls,
|
||||
user_id: str,
|
||||
profile_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
profile_data: 画像数据字典,包含:
|
||||
- preferences: 偏好标签列表
|
||||
- portrait: 四维画像对象
|
||||
- interest_areas: 兴趣领域分布对象
|
||||
- habits: 行为习惯列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in profile_data:
|
||||
profile_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
profile_data["cached"] = True
|
||||
|
||||
value = json.dumps(profile_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
画像数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取用户画像缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"用户画像缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_user_profile(cls, user_id: str) -> bool:
|
||||
"""删除用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除用户画像缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_profile_ttl(cls, user_id: str) -> int:
|
||||
"""获取用户画像缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"用户画像缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存TTL失败: {e}")
|
||||
return -2
|
||||
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Interest Distribution Cache
|
||||
|
||||
兴趣分布缓存模块
|
||||
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import get_thread_safe_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期时间:24小时
|
||||
INTEREST_CACHE_EXPIRE = 86400
|
||||
|
||||
|
||||
class InterestMemoryCache:
|
||||
"""兴趣分布缓存类"""
|
||||
|
||||
PREFIX = "cache:memory:interest_distribution"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, end_user_id: str, language: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
|
||||
|
||||
@classmethod
|
||||
async def set_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
data: List[Dict[str, Any]],
|
||||
expire: int = INTEREST_CACHE_EXPIRE,
|
||||
) -> bool:
|
||||
"""设置用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
|
||||
expire: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
payload = {
|
||||
"data": data,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""获取用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
兴趣分布列表,缓存不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
value = await get_thread_safe_redis().get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中兴趣分布缓存: {key}")
|
||||
return payload.get("data")
|
||||
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> bool:
|
||||
"""删除用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
result = await get_thread_safe_redis().delete(key)
|
||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -1,20 +1,59 @@
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.config import settings
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _mask_url(url: str) -> str:
|
||||
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
|
||||
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
|
||||
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
# backend: 结果存储(使用 Redis DB 10)
|
||||
# broker: 优先使用环境变量 CELERY_BROKER_URL(支持 amqp:// 等任意协议),
|
||||
# 未配置则回退到 Redis 方案
|
||||
# backend: 结果存储(使用 Redis)
|
||||
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||
|
||||
_broker_url = os.getenv("CELERY_BROKER_URL") or \
|
||||
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
|
||||
# integration and accidentally override our canonical URLs.
|
||||
os.environ.pop("BROKER_URL", None)
|
||||
os.environ.pop("RESULT_BACKEND", None)
|
||||
os.environ.pop("CELERY_BROKER", None)
|
||||
os.environ.pop("CELERY_BACKEND", None)
|
||||
|
||||
celery_app = Celery(
|
||||
"redbear_tasks",
|
||||
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||
broker=_broker_url,
|
||||
backend=_backend_url,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Celery app initialized",
|
||||
extra={
|
||||
"broker": _mask_url(_broker_url),
|
||||
"backend": _mask_url(_backend_url),
|
||||
},
|
||||
)
|
||||
# Default queue for unrouted tasks
|
||||
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||
|
||||
@@ -28,50 +67,70 @@ celery_app.conf.update(
|
||||
task_serializer='json',
|
||||
accept_content=['json'],
|
||||
result_serializer='json',
|
||||
|
||||
# 时区
|
||||
timezone='Asia/Shanghai',
|
||||
enable_utc=True,
|
||||
|
||||
|
||||
# # 时区
|
||||
# timezone='Asia/Shanghai',
|
||||
# enable_utc=False,
|
||||
|
||||
# 任务追踪
|
||||
task_track_started=True,
|
||||
task_ignore_result=False,
|
||||
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=1800, # 30分钟硬超时
|
||||
task_soft_time_limit=1500, # 25分钟软超时
|
||||
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
|
||||
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
|
||||
# 任务确认设置
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_disable_rate_limits=True,
|
||||
|
||||
|
||||
# FLower setting
|
||||
worker_send_task_events=True,
|
||||
task_send_sent_event=True,
|
||||
|
||||
|
||||
# task routing
|
||||
task_routes={
|
||||
# Memory tasks → memory_tasks queue (threads worker)
|
||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
# Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题)
|
||||
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
|
||||
|
||||
# Metadata extraction → memory_tasks queue
|
||||
'app.tasks.extract_user_metadata': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# Beat/periodic tasks → document_tasks queue (prefork worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# GraphRAG tasks → graphrag_tasks queue (独立队列,避免阻塞文档解析)
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'graphrag_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_document': {'queue': 'graphrag_tasks'},
|
||||
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -79,10 +138,14 @@ celery_app.conf.update(
|
||||
celery_app.autodiscover_tasks(['app'])
|
||||
|
||||
# Celery Beat schedule for periodic tasks
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
|
||||
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
|
||||
implicit_emotions_update_schedule = crontab(
|
||||
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
@@ -103,16 +166,16 @@ beat_schedule_config = {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
"write-all-workspaces-memory": {
|
||||
"task": "app.tasks.write_all_workspaces_memory_task",
|
||||
"schedule": memory_increment_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"update-implicit-emotions-storage": {
|
||||
"task": "app.tasks.update_implicit_emotions_storage",
|
||||
"schedule": implicit_emotions_update_schedule,
|
||||
"args": (),
|
||||
},
|
||||
}
|
||||
|
||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
if settings.DEFAULT_WORKSPACE_ID:
|
||||
beat_schedule_config["write-total-memory"] = {
|
||||
"task": "app.controllers.memory_storage_controller.search_all",
|
||||
"schedule": memory_increment_schedule,
|
||||
"kwargs": {
|
||||
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
},
|
||||
}
|
||||
|
||||
celery_app.conf.beat_schedule = beat_schedule_config
|
||||
|
||||
500
api/app/celery_task_scheduler.py
Normal file
500
api/app/celery_task_scheduler.py
Normal file
@@ -0,0 +1,500 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import redis
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_named_logger
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = get_named_logger("task_scheduler")
|
||||
|
||||
# per-user queue scheduler:uq:{user_id}
|
||||
USER_QUEUE_PREFIX = "scheduler:uq:"
|
||||
# User Collection of Pending Messages
|
||||
ACTIVE_USERS = "scheduler:active_users"
|
||||
# Set of users that can dispatch (ready signal)
|
||||
READY_SET = "scheduler:ready_users"
|
||||
# Metadata of tasks that have been dispatched and are pending completion
|
||||
PENDING_HASH = "scheduler:pending_tasks"
|
||||
# Dynamic Sharding: Instance Registry
|
||||
REGISTRY_KEY = "scheduler:instances"
|
||||
|
||||
TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded
|
||||
HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds)
|
||||
INSTANCE_TTL = 30 # Instance timeout (seconds)
|
||||
|
||||
LUA_ATOMIC_LOCK = """
|
||||
local dispatch_lock = KEYS[1]
|
||||
local lock_key = KEYS[2]
|
||||
local instance_id = ARGV[1]
|
||||
local dispatch_ttl = tonumber(ARGV[2])
|
||||
local lock_ttl = tonumber(ARGV[3])
|
||||
|
||||
if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then
|
||||
return 0
|
||||
end
|
||||
|
||||
if redis.call('EXISTS', lock_key) == 1 then
|
||||
redis.call('DEL', dispatch_lock)
|
||||
return -1
|
||||
end
|
||||
|
||||
redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl)
|
||||
return 1
|
||||
"""
|
||||
|
||||
LUA_SAFE_DELETE = """
|
||||
if redis.call('GET', KEYS[1]) == ARGV[1] then
|
||||
return redis.call('DEL', KEYS[1])
|
||||
end
|
||||
return 0
|
||||
"""
|
||||
|
||||
|
||||
def stable_hash(value: str) -> int:
|
||||
return int.from_bytes(
|
||||
hashlib.md5(value.encode("utf-8")).digest(),
|
||||
"big"
|
||||
)
|
||||
|
||||
|
||||
def health_check_server(scheduler_ref):
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
health_app = FastAPI()
|
||||
|
||||
@health_app.get("/")
|
||||
def health():
|
||||
return scheduler_ref.health()
|
||||
|
||||
port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001"))
|
||||
threading.Thread(
|
||||
target=uvicorn.run,
|
||||
kwargs={
|
||||
"app": health_app,
|
||||
"host": "0.0.0.0",
|
||||
"port": port,
|
||||
"log_config": None,
|
||||
},
|
||||
daemon=True,
|
||||
).start()
|
||||
logger.info("[Health] Server started at http://0.0.0.0:%s", port)
|
||||
|
||||
|
||||
class RedisTaskScheduler:
|
||||
def __init__(self):
|
||||
self.redis = redis.Redis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB_CELERY_BACKEND,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
decode_responses=True,
|
||||
)
|
||||
self.running = False
|
||||
self.dispatched = 0
|
||||
self.errors = 0
|
||||
|
||||
self.instance_id = f"{socket.gethostname()}-{os.getpid()}"
|
||||
self._shard_index = 0
|
||||
self._shard_count = 1
|
||||
self._last_heartbeat = 0.0
|
||||
|
||||
def push_task(self, task_name, user_id, params):
|
||||
try:
|
||||
msg_id = str(uuid.uuid4())
|
||||
msg = json.dumps({
|
||||
"msg_id": msg_id,
|
||||
"task_name": task_name,
|
||||
"user_id": user_id,
|
||||
"params": json.dumps(params),
|
||||
})
|
||||
|
||||
lock_key = f"{task_name}:{user_id}"
|
||||
queue_key = f"{USER_QUEUE_PREFIX}{user_id}"
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.rpush(queue_key, msg)
|
||||
pipe.sadd(ACTIVE_USERS, user_id)
|
||||
pipe.set(
|
||||
f"task_tracker:{msg_id}",
|
||||
json.dumps({"status": "QUEUED", "task_id": None}),
|
||||
ex=86400,
|
||||
)
|
||||
pipe.execute()
|
||||
|
||||
if not self.redis.exists(lock_key):
|
||||
self.redis.sadd(READY_SET, user_id)
|
||||
|
||||
logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id)
|
||||
return msg_id
|
||||
except Exception as e:
|
||||
logger.error("Push task exception %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
def get_task_status(self, msg_id: str) -> dict:
|
||||
raw = self.redis.get(f"task_tracker:{msg_id}")
|
||||
if raw is None:
|
||||
return {"status": "NOT_FOUND"}
|
||||
|
||||
tracker = json.loads(raw)
|
||||
status = tracker["status"]
|
||||
task_id = tracker.get("task_id")
|
||||
result_content = tracker.get("result") or {}
|
||||
|
||||
if status == "DISPATCHED" and task_id:
|
||||
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
|
||||
if result_raw:
|
||||
result_data = json.loads(result_raw)
|
||||
status = result_data.get("status", status)
|
||||
result_content = result_data.get("result")
|
||||
|
||||
return {"status": status, "task_id": task_id, "result": result_content}
|
||||
|
||||
def _cleanup_finished(self):
|
||||
pending = self.redis.hgetall(PENDING_HASH)
|
||||
if not pending:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
task_ids = list(pending.keys())
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
for task_id in task_ids:
|
||||
pipe.get(f"celery-task-meta-{task_id}")
|
||||
results = pipe.execute()
|
||||
|
||||
cleanup_pipe = self.redis.pipeline()
|
||||
has_cleanup = False
|
||||
ready_user_ids = set()
|
||||
|
||||
for task_id, raw_result in zip(task_ids, results):
|
||||
try:
|
||||
meta = json.loads(pending[task_id])
|
||||
lock_key = meta["lock_key"]
|
||||
dispatched_at = meta.get("dispatched_at", 0)
|
||||
age = now - dispatched_at
|
||||
|
||||
should_cleanup = False
|
||||
result_data = {}
|
||||
|
||||
if raw_result is not None:
|
||||
result_data = json.loads(raw_result)
|
||||
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
|
||||
should_cleanup = True
|
||||
logger.info(
|
||||
"Task finished: %s state=%s", task_id,
|
||||
result_data.get("status"),
|
||||
)
|
||||
elif age > TASK_TIMEOUT:
|
||||
should_cleanup = True
|
||||
logger.warning(
|
||||
"Task expired or lost: %s age=%.0fs, force cleanup",
|
||||
task_id, age,
|
||||
)
|
||||
|
||||
if should_cleanup:
|
||||
final_status = (
|
||||
result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
|
||||
)
|
||||
|
||||
self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id)
|
||||
|
||||
cleanup_pipe.hdel(PENDING_HASH, task_id)
|
||||
|
||||
tracker_msg_id = meta.get("msg_id")
|
||||
if tracker_msg_id:
|
||||
cleanup_pipe.set(
|
||||
f"task_tracker:{tracker_msg_id}",
|
||||
json.dumps({
|
||||
"status": final_status,
|
||||
"task_id": task_id,
|
||||
"result": result_data.get("result") or {},
|
||||
}),
|
||||
ex=86400,
|
||||
)
|
||||
has_cleanup = True
|
||||
|
||||
parts = lock_key.split(":", 1)
|
||||
if len(parts) == 2:
|
||||
ready_user_ids.add(parts[1])
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
|
||||
self.errors += 1
|
||||
|
||||
if has_cleanup:
|
||||
cleanup_pipe.execute()
|
||||
|
||||
if ready_user_ids:
|
||||
self.redis.sadd(READY_SET, *ready_user_ids)
|
||||
|
||||
def _heartbeat(self):
|
||||
now = time.time()
|
||||
if now - self._last_heartbeat < HEARTBEAT_INTERVAL:
|
||||
return
|
||||
self._last_heartbeat = now
|
||||
|
||||
self.redis.hset(REGISTRY_KEY, self.instance_id, str(now))
|
||||
|
||||
all_instances = self.redis.hgetall(REGISTRY_KEY)
|
||||
|
||||
alive = []
|
||||
dead = []
|
||||
for iid, ts in all_instances.items():
|
||||
if now - float(ts) < INSTANCE_TTL:
|
||||
alive.append(iid)
|
||||
else:
|
||||
dead.append(iid)
|
||||
|
||||
if dead:
|
||||
pipe = self.redis.pipeline()
|
||||
for iid in dead:
|
||||
pipe.hdel(REGISTRY_KEY, iid)
|
||||
pipe.execute()
|
||||
logger.info("Cleaned dead instances: %s", dead)
|
||||
|
||||
alive.sort()
|
||||
self._shard_count = max(len(alive), 1)
|
||||
self._shard_index = (
|
||||
alive.index(self.instance_id) if self.instance_id in alive else 0
|
||||
)
|
||||
logger.debug(
|
||||
"Shard: %s/%s (instance=%s, alive=%d)",
|
||||
self._shard_index, self._shard_count,
|
||||
self.instance_id, len(alive),
|
||||
)
|
||||
|
||||
def _is_mine(self, user_id: str) -> bool:
|
||||
if self._shard_count <= 1:
|
||||
return True
|
||||
return stable_hash(user_id) % self._shard_count == self._shard_index
|
||||
|
||||
def _dispatch(self, msg_id, msg_data) -> bool:
|
||||
user_id = msg_data["user_id"]
|
||||
task_name = msg_data["task_name"]
|
||||
params = json.loads(msg_data.get("params", "{}"))
|
||||
|
||||
lock_key = f"{task_name}:{user_id}"
|
||||
dispatch_lock = f"dispatch:{msg_id}"
|
||||
|
||||
result = self.redis.eval(
|
||||
LUA_ATOMIC_LOCK, 2,
|
||||
dispatch_lock, lock_key,
|
||||
self.instance_id, str(300), str(3600),
|
||||
)
|
||||
|
||||
if result == 0:
|
||||
return False
|
||||
if result == -1:
|
||||
return False
|
||||
|
||||
try:
|
||||
task = celery_app.send_task(task_name, kwargs=params)
|
||||
except Exception as e:
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.delete(dispatch_lock)
|
||||
pipe.delete(lock_key)
|
||||
pipe.execute()
|
||||
self.errors += 1
|
||||
logger.error(
|
||||
"send_task failed for %s:%s msg=%s: %s",
|
||||
task_name, user_id, msg_id, e, exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.set(lock_key, task.id, ex=3600)
|
||||
pipe.hset(PENDING_HASH, task.id, json.dumps({
|
||||
"lock_key": lock_key,
|
||||
"dispatched_at": time.time(),
|
||||
"msg_id": msg_id,
|
||||
}))
|
||||
pipe.delete(dispatch_lock)
|
||||
pipe.set(
|
||||
f"task_tracker:{msg_id}",
|
||||
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
|
||||
ex=86400,
|
||||
)
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Post-dispatch state update failed for %s: %s",
|
||||
task.id, e, exc_info=True,
|
||||
)
|
||||
self.errors += 1
|
||||
|
||||
self.dispatched += 1
|
||||
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
|
||||
return True
|
||||
|
||||
def _process_batch(self, user_ids):
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
for uid in user_ids:
|
||||
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
|
||||
heads = pipe.execute()
|
||||
|
||||
candidates = [] # (user_id, msg_dict)
|
||||
empty_users = []
|
||||
|
||||
for uid, head in zip(user_ids, heads):
|
||||
if head is None:
|
||||
empty_users.append(uid)
|
||||
else:
|
||||
try:
|
||||
candidates.append((uid, json.loads(head)))
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.error("Bad message in queue for user %s: %s", uid, e)
|
||||
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
|
||||
|
||||
if empty_users:
|
||||
pipe = self.redis.pipeline()
|
||||
for uid in empty_users:
|
||||
pipe.srem(ACTIVE_USERS, uid)
|
||||
pipe.execute()
|
||||
|
||||
if not candidates:
|
||||
return
|
||||
|
||||
for uid, msg in candidates:
|
||||
if self._dispatch(msg["msg_id"], msg):
|
||||
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
|
||||
|
||||
def schedule_loop(self):
|
||||
self._heartbeat()
|
||||
self._cleanup_finished()
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.smembers(READY_SET)
|
||||
pipe.delete(READY_SET)
|
||||
results = pipe.execute()
|
||||
ready_users = results[0] or set()
|
||||
|
||||
my_users = [uid for uid in ready_users if self._is_mine(uid)]
|
||||
|
||||
if not my_users:
|
||||
time.sleep(0.5)
|
||||
return
|
||||
|
||||
self._process_batch(my_users)
|
||||
time.sleep(0.1)
|
||||
|
||||
def _full_scan(self):
|
||||
cursor = 0
|
||||
ready_batch = []
|
||||
while True:
|
||||
cursor, user_ids = self.redis.sscan(
|
||||
ACTIVE_USERS, cursor=cursor, count=1000,
|
||||
)
|
||||
if user_ids:
|
||||
my_users = [uid for uid in user_ids if self._is_mine(uid)]
|
||||
if my_users:
|
||||
pipe = self.redis.pipeline()
|
||||
for uid in my_users:
|
||||
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
|
||||
heads = pipe.execute()
|
||||
|
||||
for uid, head in zip(my_users, heads):
|
||||
if head is None:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(head)
|
||||
lock_key = f"{msg['task_name']}:{uid}"
|
||||
ready_batch.append((uid, lock_key))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
if not ready_batch:
|
||||
return
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
for _, lock_key in ready_batch:
|
||||
pipe.exists(lock_key)
|
||||
lock_exists = pipe.execute()
|
||||
|
||||
ready_uids = [
|
||||
uid for (uid, _), locked in zip(ready_batch, lock_exists)
|
||||
if not locked
|
||||
]
|
||||
|
||||
if ready_uids:
|
||||
self.redis.sadd(READY_SET, *ready_uids)
|
||||
logger.info("Full scan found %d ready users", len(ready_uids))
|
||||
|
||||
def run_server(self):
|
||||
health_check_server(self)
|
||||
self.running = True
|
||||
|
||||
last_full_scan = 0.0
|
||||
full_scan_interval = 30.0
|
||||
|
||||
logger.info(
|
||||
"Scheduler started: instance=%s", self.instance_id,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.schedule_loop()
|
||||
|
||||
now = time.time()
|
||||
if now - last_full_scan > full_scan_interval:
|
||||
self._full_scan()
|
||||
last_full_scan = now
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Scheduler exception %s", e, exc_info=True)
|
||||
self.errors += 1
|
||||
time.sleep(5)
|
||||
|
||||
def health(self) -> dict:
|
||||
return {
|
||||
"running": self.running,
|
||||
"active_users": self.redis.scard(ACTIVE_USERS),
|
||||
"ready_users": self.redis.scard(READY_SET),
|
||||
"pending_tasks": self.redis.hlen(PENDING_HASH),
|
||||
"dispatched": self.dispatched,
|
||||
"errors": self.errors,
|
||||
"shard": f"{self._shard_index}/{self._shard_count}",
|
||||
"instance": self.instance_id,
|
||||
}
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Scheduler shutting down: instance=%s", self.instance_id)
|
||||
self.running = False
|
||||
try:
|
||||
self.redis.hdel(REGISTRY_KEY, self.instance_id)
|
||||
except Exception as e:
|
||||
logger.error("Shutdown cleanup error: %s", e)
|
||||
|
||||
|
||||
scheduler: RedisTaskScheduler | None = None
|
||||
if scheduler is None:
|
||||
scheduler = RedisTaskScheduler()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import signal
|
||||
import sys
|
||||
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
scheduler.shutdown()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
|
||||
scheduler.run_server()
|
||||
@@ -2,6 +2,8 @@
|
||||
Celery Worker 入口点
|
||||
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
|
||||
"""
|
||||
from celery.signals import worker_process_init
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import LoggingConfig, get_logger
|
||||
|
||||
@@ -13,4 +15,39 @@ logger.info("Celery worker logging initialized")
|
||||
# 导入任务模块以注册任务
|
||||
import app.tasks
|
||||
|
||||
__all__ = ['celery_app']
|
||||
|
||||
@worker_process_init.connect
|
||||
def _reinit_db_pool(**kwargs):
|
||||
"""
|
||||
prefork 子进程启动时重建被 fork 污染的资源。
|
||||
|
||||
fork() 后子进程继承了父进程的:
|
||||
1. SQLAlchemy 连接池 — 多进程共享 TCP socket 导致 DB 连接损坏
|
||||
2. ThreadPoolExecutor — fork 后线程状态不确定,第二个任务会死锁
|
||||
"""
|
||||
# 重建 DB 连接池
|
||||
from app.db import engine
|
||||
engine.dispose()
|
||||
logger.info("DB connection pool disposed for forked worker process")
|
||||
|
||||
# 重建模块级 ThreadPoolExecutor(fork 后线程池不可用)
|
||||
try:
|
||||
from app.core.rag.deepdoc.parser import figure_parser
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
figure_parser.shared_executor = ThreadPoolExecutor(max_workers=10)
|
||||
logger.info("figure_parser.shared_executor recreated")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to recreate figure_parser.shared_executor: {e}")
|
||||
|
||||
try:
|
||||
from app.core.rag.utils import libre_office
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import os
|
||||
max_workers = os.cpu_count() * 2 if os.cpu_count() else 4
|
||||
libre_office.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
logger.info("libre_office.executor recreated")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to recreate libre_office.executor: {e}")
|
||||
|
||||
|
||||
__all__ = ['celery_app']
|
||||
|
||||
1
api/app/config/__init__.py
Normal file
1
api/app/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Configuration module for application settings."""
|
||||
77
api/app/config/default_free_plan.py
Normal file
77
api/app/config/default_free_plan.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
社区版默认免费套餐配置
|
||||
当无法从 SaaS 版获取 premium 模块时,使用此配置作为兜底
|
||||
|
||||
可通过环境变量覆盖配额配置,格式:QUOTA_<QUOTA_NAME>
|
||||
例如:QUOTA_END_USER_QUOTA=100
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def _get_quota_from_env():
|
||||
"""从环境变量获取配额配置"""
|
||||
quota_keys = [
|
||||
"workspace_quota",
|
||||
"skill_quota",
|
||||
"app_quota",
|
||||
"knowledge_capacity_quota",
|
||||
"memory_engine_quota",
|
||||
"end_user_quota",
|
||||
"ontology_project_quota",
|
||||
"model_quota",
|
||||
"api_ops_rate_limit",
|
||||
]
|
||||
quotas = {}
|
||||
for key in quota_keys:
|
||||
env_key = f"QUOTA_{key.upper()}"
|
||||
env_value = os.getenv(env_key)
|
||||
if env_value is not None:
|
||||
try:
|
||||
quotas[key] = float(env_value) if '.' in env_value else int(env_value)
|
||||
except ValueError:
|
||||
pass
|
||||
return quotas
|
||||
|
||||
|
||||
def _build_default_free_plan():
|
||||
"""构建默认免费套餐配置"""
|
||||
base = {
|
||||
"name": "记忆体验版",
|
||||
"name_en": "Memory Experience",
|
||||
"category": "saas_personal",
|
||||
"tier_level": 0,
|
||||
"version": "1.0",
|
||||
"status": True,
|
||||
"price": 0,
|
||||
"billing_cycle": "permanent_free",
|
||||
"core_value": "感受永久记忆",
|
||||
"core_value_en": "Experience Permanent Memory",
|
||||
"tech_support": "社群交流",
|
||||
"tech_support_en": "Community Support",
|
||||
"sla_compliance": "无",
|
||||
"sla_compliance_en": "None",
|
||||
"page_customization": "无",
|
||||
"page_customization_en": "None",
|
||||
"theme_color": "#64748B",
|
||||
"quotas": {
|
||||
"workspace_quota": 1,
|
||||
"skill_quota": 5,
|
||||
"app_quota": 2,
|
||||
"knowledge_capacity_quota": 0.3,
|
||||
"memory_engine_quota": 1,
|
||||
"end_user_quota": 10,
|
||||
"ontology_project_quota": 3,
|
||||
"model_quota": 1,
|
||||
"api_ops_rate_limit": 50,
|
||||
},
|
||||
}
|
||||
|
||||
env_quotas = _get_quota_from_env()
|
||||
if env_quotas:
|
||||
base["quotas"].update(env_quotas)
|
||||
|
||||
return base
|
||||
|
||||
|
||||
DEFAULT_FREE_PLAN = _build_default_free_plan()
|
||||
239
api/app/config/default_ontology_config.py
Normal file
239
api/app/config/default_ontology_config.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""默认本体场景配置
|
||||
|
||||
本模块定义系统预设的本体场景和实体类型配置。
|
||||
这些配置用于在工作空间创建时自动初始化默认场景。
|
||||
支持中英文双语配置,根据用户语言偏好创建对应语言的场景。
|
||||
"""
|
||||
|
||||
# 在线教育场景配置
|
||||
ONLINE_EDUCATION_SCENE = {
|
||||
"name_chinese": "在线教育",
|
||||
"name_english": "Online Education",
|
||||
"description_chinese": "适用于在线教育平台的本体建模,包含学生、教师、课程等核心实体类型",
|
||||
"description_english": "Ontology modeling for online education platforms, including core entity types such as students, teachers, and courses",
|
||||
"types": [
|
||||
{
|
||||
"name_chinese": "学生",
|
||||
"name_english": "Student",
|
||||
"description_chinese": "在教育系统中接受教育的个体,包含姓名、学号、年级、班级等属性",
|
||||
"description_english": "Individuals receiving education in the education system, including attributes such as name, student ID, grade, and class"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教师",
|
||||
"name_english": "Teacher",
|
||||
"description_chinese": "在教育系统中提供教学服务的个体,包含姓名、工号、任教学科、职称等属性",
|
||||
"description_english": "Individuals providing teaching services in the education system, including attributes such as name, employee ID, teaching subject, and title"
|
||||
},
|
||||
{
|
||||
"name_chinese": "课程",
|
||||
"name_english": "Course",
|
||||
"description_chinese": "教育系统中的教学内容单元,包含课程名称、课程代码、学分、学时等属性",
|
||||
"description_english": "Teaching content units in the education system, including attributes such as course name, course code, credits, and class hours"
|
||||
},
|
||||
{
|
||||
"name_chinese": "作业",
|
||||
"name_english": "Assignment",
|
||||
"description_chinese": "课程中布置的学习任务,包含作业标题、截止日期、所属课程、提交状态等属性",
|
||||
"description_english": "Learning tasks assigned in courses, including attributes such as assignment title, deadline, course, and submission status"
|
||||
},
|
||||
{
|
||||
"name_chinese": "成绩",
|
||||
"name_english": "Grade",
|
||||
"description_chinese": "学生学习成果的评价结果,包含分数、评级、考试类型、所属课程等属性",
|
||||
"description_english": "Evaluation results of student learning outcomes, including attributes such as score, rating, exam type, and course"
|
||||
},
|
||||
{
|
||||
"name_chinese": "考试",
|
||||
"name_english": "Exam",
|
||||
"description_chinese": "评估学生学习成果的测试活动,包含考试名称、时间、地点、科目等属性",
|
||||
"description_english": "Test activities to assess student learning outcomes, including attributes such as exam name, time, location, and subject"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教室",
|
||||
"name_english": "Classroom",
|
||||
"description_chinese": "进行教学活动的物理或虚拟空间,包含教室编号、容量、设备等属性",
|
||||
"description_english": "Physical or virtual spaces for teaching activities, including attributes such as classroom number, capacity, and equipment"
|
||||
},
|
||||
{
|
||||
"name_chinese": "学科",
|
||||
"name_english": "Subject",
|
||||
"description_chinese": "知识的分类领域,包含学科名称、代码、所属院系等属性",
|
||||
"description_english": "Classification domains of knowledge, including attributes such as subject name, code, and department"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教材",
|
||||
"name_english": "Textbook",
|
||||
"description_chinese": "教学使用的书籍或资料,包含书名、作者、出版社、ISBN等属性",
|
||||
"description_english": "Books or materials used for teaching, including attributes such as title, author, publisher, and ISBN"
|
||||
},
|
||||
{
|
||||
"name_chinese": "班级",
|
||||
"name_english": "Class",
|
||||
"description_chinese": "学生的组织单位,包含班级名称、年级、人数、班主任等属性",
|
||||
"description_english": "Organizational units of students, including attributes such as class name, grade, number of students, and class teacher"
|
||||
},
|
||||
{
|
||||
"name_chinese": "学期",
|
||||
"name_english": "Semester",
|
||||
"description_chinese": "教学时间的划分单位,包含学期名称、开始时间、结束时间等属性",
|
||||
"description_english": "Time division units for teaching, including attributes such as semester name, start time, and end time"
|
||||
},
|
||||
{
|
||||
"name_chinese": "课时",
|
||||
"name_english": "Class Hour",
|
||||
"description_chinese": "课程的时间单位,包含上课时间、地点、教师、课程等属性",
|
||||
"description_english": "Time units of courses, including attributes such as class time, location, teacher, and course"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教学计划",
|
||||
"name_english": "Teaching Plan",
|
||||
"description_chinese": "课程的教学安排,包含教学目标、内容安排、进度计划等属性",
|
||||
"description_english": "Teaching arrangements for courses, including attributes such as teaching objectives, content arrangement, and progress plan"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 情感陪伴场景配置
|
||||
EMOTIONAL_COMPANION_SCENE = {
|
||||
"name_chinese": "情感陪伴",
|
||||
"name_english": "Emotional Companion",
|
||||
"description_chinese": "适用于情感陪伴应用的本体建模,包含用户、情绪、活动等核心实体类型",
|
||||
"description_english": "Ontology modeling for emotional companion applications, including core entity types such as users, emotions, and activities",
|
||||
"types": [
|
||||
{
|
||||
"name_chinese": "用户",
|
||||
"name_english": "User",
|
||||
"description_chinese": "使用情感陪伴服务的个体,包含姓名、昵称、性格特征、偏好等属性",
|
||||
"description_english": "Individuals using emotional companion services, including attributes such as name, nickname, personality traits, and preferences"
|
||||
},
|
||||
{
|
||||
"name_chinese": "情绪",
|
||||
"name_english": "Emotion",
|
||||
"description_chinese": "用户的情感状态,包含情绪类型、强度、触发原因、持续时间等属性",
|
||||
"description_english": "Emotional states of users, including attributes such as emotion type, intensity, trigger cause, and duration"
|
||||
},
|
||||
{
|
||||
"name_chinese": "活动",
|
||||
"name_english": "Activity",
|
||||
"description_chinese": "用户参与的各类活动,包含活动名称、类型、参与者、时间地点等属性",
|
||||
"description_english": "Various activities users participate in, including attributes such as activity name, type, participants, time, and location"
|
||||
},
|
||||
{
|
||||
"name_chinese": "对话",
|
||||
"name_english": "Conversation",
|
||||
"description_chinese": "用户之间的交流记录,包含对话主题、参与者、时间、关键内容等属性",
|
||||
"description_english": "Communication records between users, including attributes such as conversation topic, participants, time, and key content"
|
||||
},
|
||||
{
|
||||
"name_chinese": "兴趣爱好",
|
||||
"name_english": "Hobby",
|
||||
"description_chinese": "用户的兴趣和爱好,包含爱好名称、类别、熟练程度、相关活动等属性",
|
||||
"description_english": "User interests and hobbies, including attributes such as hobby name, category, proficiency level, and related activities"
|
||||
},
|
||||
{
|
||||
"name_chinese": "日常事件",
|
||||
"name_english": "Daily Event",
|
||||
"description_chinese": "用户日常生活中的事件,包含事件描述、时间、地点、相关人物等属性",
|
||||
"description_english": "Events in users' daily lives, including attributes such as event description, time, location, and related people"
|
||||
},
|
||||
{
|
||||
"name_chinese": "关系",
|
||||
"name_english": "Relationship",
|
||||
"description_chinese": "用户之间的社会关系,包含关系类型、亲密度、建立时间等属性",
|
||||
"description_english": "Social relationships between users, including attributes such as relationship type, intimacy, and establishment time"
|
||||
},
|
||||
{
|
||||
"name_chinese": "回忆",
|
||||
"name_english": "Memory",
|
||||
"description_chinese": "用户的重要记忆片段,包含回忆内容、时间、地点、相关人物等属性",
|
||||
"description_english": "Important memory fragments of users, including attributes such as memory content, time, location, and related people"
|
||||
},
|
||||
{
|
||||
"name_chinese": "地点",
|
||||
"name_english": "Location",
|
||||
"description_chinese": "用户活动的地理位置,包含地点名称、地址、类型、相关事件等属性",
|
||||
"description_english": "Geographic locations of user activities, including attributes such as location name, address, type, and related events"
|
||||
},
|
||||
{
|
||||
"name_chinese": "时间节点",
|
||||
"name_english": "Time Point",
|
||||
"description_chinese": "重要的时间标记,包含日期、事件、意义等属性",
|
||||
"description_english": "Important time markers, including attributes such as date, event, and significance"
|
||||
},
|
||||
{
|
||||
"name_chinese": "目标",
|
||||
"name_english": "Goal",
|
||||
"description_chinese": "用户设定的目标,包含目标描述、截止时间、完成状态、相关活动等属性",
|
||||
"description_english": "Goals set by users, including attributes such as goal description, deadline, completion status, and related activities"
|
||||
},
|
||||
{
|
||||
"name_chinese": "成就",
|
||||
"name_english": "Achievement",
|
||||
"description_chinese": "用户获得的成就,包含成就名称、获得时间、描述、相关目标等属性",
|
||||
"description_english": "Achievements obtained by users, including attributes such as achievement name, acquisition time, description, and related goals"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 导出默认场景列表
|
||||
DEFAULT_SCENES = [ONLINE_EDUCATION_SCENE, EMOTIONAL_COMPANION_SCENE]
|
||||
|
||||
|
||||
def get_scene_name(scene_config: dict, language: str = "zh") -> str:
|
||||
"""获取场景名称(根据语言)
|
||||
|
||||
Args:
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的场景名称
|
||||
"""
|
||||
if language == "en":
|
||||
return scene_config.get("name_english", scene_config.get("name_chinese"))
|
||||
return scene_config.get("name_chinese")
|
||||
|
||||
|
||||
def get_scene_description(scene_config: dict, language: str = "zh") -> str:
|
||||
"""获取场景描述(根据语言)
|
||||
|
||||
Args:
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的场景描述
|
||||
"""
|
||||
if language == "en":
|
||||
return scene_config.get("description_english", scene_config.get("description_chinese"))
|
||||
return scene_config.get("description_chinese")
|
||||
|
||||
|
||||
def get_type_name(type_config: dict, language: str = "zh") -> str:
|
||||
"""获取类型名称(根据语言)
|
||||
|
||||
Args:
|
||||
type_config: 类型配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的类型名称
|
||||
"""
|
||||
if language == "en":
|
||||
return type_config.get("name_english", type_config.get("name_chinese"))
|
||||
return type_config.get("name_chinese")
|
||||
|
||||
|
||||
def get_type_description(type_config: dict, language: str = "zh") -> str:
|
||||
"""获取类型描述(根据语言)
|
||||
|
||||
Args:
|
||||
type_config: 类型配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的类型描述
|
||||
"""
|
||||
if language == "en":
|
||||
return type_config.get("description_english", type_config.get("description_chinese"))
|
||||
return type_config.get("description_chinese")
|
||||
249
api/app/config/default_ontology_initializer.py
Normal file
249
api/app/config/default_ontology_initializer.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""默认本体场景初始化器
|
||||
|
||||
本模块提供默认本体场景和类型的自动初始化功能。
|
||||
在工作空间创建时,自动添加预设的本体场景和实体类型。
|
||||
|
||||
Classes:
|
||||
DefaultOntologyInitializer: 默认本体场景初始化器
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config.default_ontology_config import (
|
||||
DEFAULT_SCENES,
|
||||
get_scene_name,
|
||||
get_scene_description,
|
||||
get_type_name,
|
||||
get_type_description,
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
|
||||
class DefaultOntologyInitializer:
|
||||
"""默认本体场景初始化器
|
||||
|
||||
负责在工作空间创建时自动初始化默认的本体场景和类型。
|
||||
遵循最小侵入原则,确保初始化失败不阻止工作空间创建。
|
||||
|
||||
Attributes:
|
||||
db: 数据库会话
|
||||
scene_repo: 场景Repository
|
||||
class_repo: 类型Repository
|
||||
logger: 业务日志记录器
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self.scene_repo = OntologySceneRepository(db)
|
||||
self.class_repo = OntologyClassRepository(db)
|
||||
self.logger = get_business_logger()
|
||||
|
||||
def initialize_default_scenes(
|
||||
self,
|
||||
workspace_id: UUID,
|
||||
language: str = "zh"
|
||||
) -> Tuple[bool, str]:
|
||||
"""为工作空间初始化默认场景
|
||||
|
||||
创建两个默认场景(在线教育、情感陪伴)及其对应的实体类型。
|
||||
如果创建失败,记录错误日志但不抛出异常。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
language: 语言类型 ("zh" 或 "en"),默认为 "zh"
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
self.logger.info(
|
||||
f"开始初始化默认本体场景 - workspace_id={workspace_id}, language={language}"
|
||||
)
|
||||
|
||||
scenes_created = 0
|
||||
total_types_created = 0
|
||||
|
||||
# 遍历默认场景配置
|
||||
for scene_config in DEFAULT_SCENES:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
|
||||
# 创建场景及其类型
|
||||
scene_id = self._create_scene_with_types(workspace_id, scene_config, language)
|
||||
|
||||
if scene_id:
|
||||
scenes_created += 1
|
||||
# 统计类型数量
|
||||
types_count = len(scene_config.get("types", []))
|
||||
total_types_created += types_count
|
||||
|
||||
self.logger.info(
|
||||
f"场景创建成功 - scene_name={scene_name}, "
|
||||
f"scene_id={scene_id}, types_count={types_count}, language={language}"
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"场景创建失败 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, language={language}"
|
||||
)
|
||||
|
||||
# 记录总体结果
|
||||
self.logger.info(
|
||||
f"默认场景初始化完成 - workspace_id={workspace_id}, "
|
||||
f"language={language}, scenes_created={scenes_created}, "
|
||||
f"total_types_created={total_types_created}"
|
||||
)
|
||||
|
||||
# 如果至少创建了一个场景,视为成功
|
||||
if scenes_created > 0:
|
||||
return True, ""
|
||||
else:
|
||||
error_msg = "所有默认场景创建失败"
|
||||
self.logger.error(
|
||||
f"默认场景初始化失败 - workspace_id={workspace_id}, "
|
||||
f"language={language}, error={error_msg}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"默认场景初始化异常: {str(e)}"
|
||||
self.logger.error(
|
||||
f"默认场景初始化异常 - workspace_id={workspace_id}, "
|
||||
f"language={language}, error={str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
def _create_scene_with_types(
|
||||
self,
|
||||
workspace_id: UUID,
|
||||
scene_config: dict,
|
||||
language: str = "zh"
|
||||
) -> Optional[UUID]:
|
||||
"""创建场景及其类型
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
Optional[UUID]: 创建的场景ID,失败返回None
|
||||
"""
|
||||
try:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
scene_description = get_scene_description(scene_config, language)
|
||||
|
||||
# 检查是否已存在同名场景(支持向后兼容)
|
||||
existing_scene = self.scene_repo.get_by_name(scene_name, workspace_id)
|
||||
if existing_scene:
|
||||
self.logger.info(
|
||||
f"场景已存在,跳过创建 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, scene_id={existing_scene.scene_id}, "
|
||||
f"language={language}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 创建场景记录,设置 is_system_default=true
|
||||
scene_data = {
|
||||
"scene_name": scene_name,
|
||||
"scene_description": scene_description
|
||||
}
|
||||
|
||||
scene = self.scene_repo.create(scene_data, workspace_id)
|
||||
|
||||
# 设置系统默认标识
|
||||
scene.is_system_default = True
|
||||
self.db.flush()
|
||||
|
||||
self.logger.info(
|
||||
f"场景创建成功 - scene_name={scene_name}, "
|
||||
f"scene_id={scene.scene_id}, is_system_default=True, language={language}"
|
||||
)
|
||||
|
||||
# 批量创建类型
|
||||
types_config = scene_config.get("types", [])
|
||||
types_created = self._batch_create_types(scene.scene_id, types_config, language)
|
||||
|
||||
self.logger.info(
|
||||
f"场景类型创建完成 - scene_id={scene.scene_id}, "
|
||||
f"types_created={types_created}/{len(types_config)}, language={language}"
|
||||
)
|
||||
|
||||
return scene.scene_id
|
||||
|
||||
except Exception as e:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
self.logger.error(
|
||||
f"场景创建失败 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, language={language}, error={str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
def _batch_create_types(
|
||||
self,
|
||||
scene_id: UUID,
|
||||
types_config: List[dict],
|
||||
language: str = "zh"
|
||||
) -> int:
|
||||
"""批量创建实体类型
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID
|
||||
types_config: 类型配置列表
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
int: 成功创建的类型数量
|
||||
"""
|
||||
created_count = 0
|
||||
|
||||
for type_config in types_config:
|
||||
try:
|
||||
type_name = get_type_name(type_config, language)
|
||||
type_description = get_type_description(type_config, language)
|
||||
|
||||
# 创建类型数据
|
||||
class_data = {
|
||||
"class_name": type_name,
|
||||
"class_description": type_description
|
||||
}
|
||||
|
||||
# 创建类型
|
||||
ontology_class = self.class_repo.create(class_data, scene_id)
|
||||
|
||||
# 设置系统默认标识
|
||||
ontology_class.is_system_default = True
|
||||
self.db.flush()
|
||||
|
||||
created_count += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"类型创建成功 - class_name={type_name}, "
|
||||
f"class_id={ontology_class.class_id}, "
|
||||
f"scene_id={scene_id}, is_system_default=True, language={language}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
type_name = get_type_name(type_config, language)
|
||||
self.logger.warning(
|
||||
f"单个类型创建失败,继续创建其他类型 - "
|
||||
f"class_name={type_name}, scene_id={scene_id}, "
|
||||
f"language={language}, error={str(e)}"
|
||||
)
|
||||
# 继续创建其他类型
|
||||
continue
|
||||
|
||||
return created_count
|
||||
@@ -8,6 +8,7 @@ from fastapi import APIRouter
|
||||
from . import (
|
||||
api_key_controller,
|
||||
app_controller,
|
||||
app_log_controller,
|
||||
auth_controller,
|
||||
chunk_controller,
|
||||
document_controller,
|
||||
@@ -16,17 +17,22 @@ from . import (
|
||||
file_controller,
|
||||
file_storage_controller,
|
||||
home_page_controller,
|
||||
i18n_controller,
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
mcp_market_controller,
|
||||
mcp_market_config_controller,
|
||||
memory_agent_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_episodic_controller,
|
||||
memory_explicit_controller,
|
||||
memory_forget_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_reflection_controller,
|
||||
memory_short_term_controller,
|
||||
memory_storage_controller,
|
||||
memory_working_controller,
|
||||
model_controller,
|
||||
multi_agent_controller,
|
||||
prompt_optimizer_controller,
|
||||
@@ -39,12 +45,10 @@ from . import (
|
||||
upload_controller,
|
||||
user_controller,
|
||||
user_memory_controllers,
|
||||
workflow_controller,
|
||||
workspace_controller,
|
||||
memory_forget_controller,
|
||||
home_page_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_working_controller,
|
||||
ontology_controller,
|
||||
skill_controller,
|
||||
tenant_subscription_controller,
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -61,10 +65,13 @@ manager_router.include_router(model_controller.router)
|
||||
manager_router.include_router(file_controller.router)
|
||||
manager_router.include_router(document_controller.router)
|
||||
manager_router.include_router(knowledge_controller.router)
|
||||
manager_router.include_router(mcp_market_controller.router)
|
||||
manager_router.include_router(mcp_market_config_controller.router)
|
||||
manager_router.include_router(chunk_controller.router)
|
||||
manager_router.include_router(test_controller.router)
|
||||
manager_router.include_router(knowledgeshare_controller.router)
|
||||
manager_router.include_router(app_controller.router)
|
||||
manager_router.include_router(app_log_controller.router)
|
||||
manager_router.include_router(upload_controller.router)
|
||||
manager_router.include_router(memory_agent_controller.router)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
@@ -77,7 +84,6 @@ manager_router.include_router(release_share_controller.router)
|
||||
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(multi_agent_controller.router)
|
||||
manager_router.include_router(workflow_controller.router)
|
||||
manager_router.include_router(emotion_controller.router)
|
||||
manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
@@ -90,5 +96,10 @@ manager_router.include_router(implicit_memory_controller.router)
|
||||
manager_router.include_router(memory_perceptual_controller.router)
|
||||
manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
manager_router.include_router(i18n_controller.router)
|
||||
manager_router.include_router(tenant_subscription_controller.router)
|
||||
manager_router.include_router(tenant_subscription_controller.public_router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -167,6 +167,8 @@ def update_api_key(
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
||||
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import uuid
|
||||
import io
|
||||
from typing import Optional, Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -17,11 +20,15 @@ from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave
|
||||
from app.services import app_service, workspace_service
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.services.app_service import AppService
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.services.workflow_import_service import WorkflowImportService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_dsl_service import AppDslService
|
||||
from app.core.quota_stub import check_app_quota
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -29,6 +36,7 @@ logger = get_business_logger()
|
||||
|
||||
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
||||
@cur_workspace_access_guard()
|
||||
@check_app_quota
|
||||
def create_app(
|
||||
payload: app_schema.AppCreate,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -47,6 +55,7 @@ def list_apps(
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
include_shared: bool = True,
|
||||
shared_only: bool = False,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
@@ -58,16 +67,42 @@ def list_apps(
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||
- search 参数支持:应用名称模糊搜索、API Key 精确搜索
|
||||
"""
|
||||
from sqlalchemy import select as sa_select
|
||||
from app.models.api_key_model import ApiKey
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
# 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索
|
||||
if search:
|
||||
search = search.strip()
|
||||
# 尝试作为 API Key 精确匹配(API Key 通常较长)
|
||||
if len(search) >= 10:
|
||||
matched_id = db.execute(
|
||||
sa_select(ApiKey.resource_id).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.api_key == search,
|
||||
ApiKey.resource_id.isnot(None),
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if matched_id:
|
||||
# 找到 API Key,直接返回关联的应用
|
||||
ids = str(matched_id)
|
||||
|
||||
# 当 ids 存在时,根据 ids 获取应用(不分页)
|
||||
if ids is not None:
|
||||
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||
if app_ids:
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
# 返回标准分页格式
|
||||
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
# ids 为空时,返回空列表
|
||||
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
|
||||
return success(data=PageData(page=meta, items=[]))
|
||||
|
||||
# 正常分页查询
|
||||
items_orm, total = app_service.list_apps(
|
||||
@@ -78,6 +113,7 @@ def list_apps(
|
||||
status=status,
|
||||
search=search,
|
||||
include_shared=include_shared,
|
||||
shared_only=shared_only,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
)
|
||||
@@ -87,6 +123,37 @@ def list_apps(
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get("/my-shared-out", summary="列出本工作空间主动分享出去的记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_my_shared_out(
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出本工作空间主动分享给其他工作空间的所有记录(我的共享)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
shares = service.list_my_shared_out(workspace_id=workspace_id)
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.delete("/share/{target_workspace_id}", summary="取消对某工作空间的所有应用分享")
|
||||
@cur_workspace_access_guard()
|
||||
def unshare_all_apps_to_workspace(
|
||||
target_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""Cancel all app shares from current workspace to a target workspace."""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
count = service.unshare_all_apps_to_workspace(
|
||||
target_workspace_id=target_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(msg=f"已取消 {count} 个应用的分享", data={"count": count})
|
||||
|
||||
|
||||
@router.get("/{app_id}", summary="获取应用详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app(
|
||||
@@ -152,9 +219,11 @@ def delete_app(
|
||||
|
||||
@router.post("/{app_id}/copy", summary="复制应用")
|
||||
@cur_workspace_access_guard()
|
||||
@check_app_quota
|
||||
def copy_app(
|
||||
app_id: uuid.UUID,
|
||||
new_name: Optional[str] = None,
|
||||
payload: app_schema.CopyAppRequest = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
@@ -166,6 +235,8 @@ def copy_app(
|
||||
- 不影响原应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# body takes precedence over query param for backward compatibility
|
||||
new_name = (payload.new_name if payload else None) or new_name
|
||||
logger.info(
|
||||
"用户请求复制应用",
|
||||
extra={
|
||||
@@ -201,6 +272,19 @@ def update_agent_config(
|
||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/model/parameters/default", summary="获取 Agent 模型参数默认配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_agent_model_parameters(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = AppService(db)
|
||||
model_parameters = service.get_default_model_parameters(app_id=app_id)
|
||||
return success(data=model_parameters, msg="获取 Agent 模型参数默认配置")
|
||||
|
||||
|
||||
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_agent_config(
|
||||
@@ -215,6 +299,36 @@ def get_agent_config(
|
||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_opening(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 根据应用类型获取 features
|
||||
from app.models.app_model import App as AppModel
|
||||
app = db.get(AppModel, app_id)
|
||||
if app and app.type == "workflow":
|
||||
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
else:
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
|
||||
opening = features.get("opening_statement", {})
|
||||
return success(data=app_schema.OpeningResponse(
|
||||
enabled=opening.get("enabled", False),
|
||||
statement=opening.get("statement"),
|
||||
suggested_questions=opening.get("suggested_questions", []),
|
||||
))
|
||||
|
||||
|
||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||
@cur_workspace_access_guard()
|
||||
def publish_app(
|
||||
@@ -296,7 +410,8 @@ def share_app(
|
||||
app_id=app_id,
|
||||
target_workspace_ids=payload.target_workspace_ids,
|
||||
user_id=current_user.id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
permission=payload.permission
|
||||
)
|
||||
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
@@ -327,6 +442,32 @@ def unshare_app(
|
||||
return success(msg="应用分享已取消")
|
||||
|
||||
|
||||
@router.patch("/{app_id}/share/{target_workspace_id}", summary="更新共享权限")
|
||||
@cur_workspace_access_guard()
|
||||
def update_share_permission(
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
payload: app_schema.UpdateSharePermissionRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""更新共享权限(readonly <-> editable)
|
||||
|
||||
- 只能修改自己工作空间应用的共享权限
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
share = service.update_share_permission(
|
||||
app_id=app_id,
|
||||
target_workspace_id=target_workspace_id,
|
||||
permission=payload.permission,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(data=app_schema.AppShare.model_validate(share))
|
||||
|
||||
|
||||
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_shares(
|
||||
@@ -350,6 +491,46 @@ def list_app_shares(
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.delete("/shared/{source_workspace_id}", summary="批量移除某来源工作空间的所有共享应用")
|
||||
@cur_workspace_access_guard()
|
||||
def remove_all_shared_apps_from_workspace(
|
||||
source_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""Remove all shared apps from a specific source workspace (recipient operation)."""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
count = service.remove_all_shared_apps_from_workspace(
|
||||
source_workspace_id=source_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(msg=f"已移除 {count} 个共享应用", data={"count": count})
|
||||
|
||||
|
||||
@router.delete("/{app_id}/shared", summary="移除共享给我的应用")
|
||||
@cur_workspace_access_guard()
|
||||
def remove_shared_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""被共享者从自己的工作空间移除共享应用
|
||||
|
||||
- 不会删除源应用,只删除共享记录
|
||||
- 只能移除共享给自己工作空间的应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
service.remove_shared_app(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(msg="已移除共享应用")
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||
@cur_workspace_access_guard()
|
||||
async def draft_run(
|
||||
@@ -390,13 +571,13 @@ async def draft_run(
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
from app.services.app_service import AppService
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.models import AgentConfig, ModelConfig, AppRelease
|
||||
from sqlalchemy import select
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
service = AppService(db)
|
||||
draft_service = DraftRunService(db)
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
# 1. 验证应用
|
||||
app = service._get_app_or_404(app_id)
|
||||
@@ -407,11 +588,12 @@ async def draft_run(
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
# 先获取 app 的 workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
@@ -428,18 +610,29 @@ async def draft_run(
|
||||
service._check_agent_config(app_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||
is_shared = app.workspace_id != workspace_id
|
||||
if is_shared:
|
||||
if not app.current_release_id:
|
||||
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||
release = db.get(AppRelease, app.current_release_id)
|
||||
if not release:
|
||||
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
agent_cfg = service._agent_config_from_release(release)
|
||||
model_config = db.get(ModelConfig, release.default_model_config_id) if release.default_model_config_id else None
|
||||
else:
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
@@ -454,7 +647,8 @@ async def draft_run(
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -475,12 +669,13 @@ async def draft_run(
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
"has_conversation_id": bool(payload.conversation_id),
|
||||
"has_variables": bool(payload.variables)
|
||||
"has_variables": bool(payload.variables),
|
||||
"has_files": bool(payload.files)
|
||||
}
|
||||
)
|
||||
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
@@ -490,7 +685,8 @@ async def draft_run(
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -592,7 +788,17 @@ async def draft_run(
|
||||
msg="多 Agent 任务执行成功"
|
||||
)
|
||||
elif app.type == AppType.WORKFLOW: # 工作流
|
||||
config = workflow_service.check_config(app_id)
|
||||
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||
is_shared = app.workspace_id != workspace_id
|
||||
if is_shared:
|
||||
if not app.current_release_id:
|
||||
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||
release = db.get(AppRelease, app.current_release_id)
|
||||
if not release:
|
||||
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
config = service._workflow_config_from_release(release)
|
||||
else:
|
||||
config = workflow_service.check_config(app_id)
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
@@ -735,6 +941,16 @@ async def draft_run_compare(
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
# 先获取 app 的 workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=str(current_user.id),
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
from sqlalchemy import select
|
||||
from app.models import AgentConfig
|
||||
@@ -780,25 +996,33 @@ async def draft_run_compare(
|
||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||
})
|
||||
|
||||
# 从 features 中读取功能开关(与 draft_run 保持一致)
|
||||
features_config: dict = agent_cfg.features or {}
|
||||
if hasattr(features_config, 'model_dump'):
|
||||
features_config = features_config.model_dump()
|
||||
web_search_feature = features_config.get("web_search", {})
|
||||
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
async for event in draft_service.run_compare_stream(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
web_search=web_search,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
timeout=payload.timeout or 60,
|
||||
files=payload.files
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -813,22 +1037,23 @@ async def draft_run_compare(
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run_compare(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
web_search=web_search,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
timeout=payload.timeout or 60,
|
||||
files=payload.files
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -870,10 +1095,73 @@ async def update_workflow_config(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if payload.variables:
|
||||
from app.services.workflow_service import WorkflowService
|
||||
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
|
||||
[v.model_dump() for v in payload.variables]
|
||||
)
|
||||
# Patch default values back into VariableDefinition objects
|
||||
for var_def, resolved_def in zip(payload.variables, resolved):
|
||||
var_def.default = resolved_def.get("default", var_def.default)
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/workflow/export")
|
||||
@cur_workspace_access_guard()
|
||||
async def export_workflow_config(
|
||||
app_id: uuid.UUID,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
"""导出工作流配置为YAML文件"""
|
||||
workflow_service = WorkflowService(db)
|
||||
|
||||
return success(data={
|
||||
"content": workflow_service.export_workflow_dsl(app_id=app_id),
|
||||
})
|
||||
|
||||
|
||||
@router.post("/workflow/import")
|
||||
@cur_workspace_access_guard()
|
||||
async def import_workflow_config(
|
||||
file: UploadFile = File(...),
|
||||
platform: str = Form(...),
|
||||
app_id: str = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
|
||||
):
|
||||
"""从YAML内容导入工作流配置"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST)
|
||||
|
||||
raw_text = (await file.read()).decode("utf-8")
|
||||
import_service = WorkflowImportService(db)
|
||||
config = yaml.safe_load(raw_text)
|
||||
result = await import_service.upload_config(platform, config)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.post("/workflow/import/save")
|
||||
@cur_workspace_access_guard()
|
||||
@check_app_quota
|
||||
async def save_workflow_import(
|
||||
data: WorkflowImportSave,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
import_service = WorkflowImportService(db)
|
||||
app = await import_service.save_workflow(
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
temp_id=data.temp_id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
)
|
||||
return success(data=app_schema.App.model_validate(app))
|
||||
|
||||
|
||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_statistics(
|
||||
@@ -884,12 +1172,14 @@ def get_app_statistics(
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用统计数据
|
||||
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
db: 数据库连接
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
- daily_conversations: 每日会话数统计
|
||||
- total_conversations: 总会话数
|
||||
@@ -901,15 +1191,153 @@ def get_app_statistics(
|
||||
- total_tokens: 总token消耗
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
|
||||
result = stats_service.get_app_statistics(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/workspace/api-statistics", summary="工作空间API调用统计")
|
||||
@cur_workspace_access_guard()
|
||||
def get_workspace_api_statistics(
|
||||
start_date: int,
|
||||
end_date: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取工作空间API调用统计
|
||||
|
||||
Args:
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
db: 数据库连接
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
每日统计数据列表,每项包含:
|
||||
- date: 日期
|
||||
- total_calls: 当日总调用次数
|
||||
- app_calls: 当日应用调用次数
|
||||
- service_calls: 当日服务调用次数
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
result = stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
|
||||
@cur_workspace_access_guard()
|
||||
async def export_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
release_id: Optional[uuid.UUID] = None
|
||||
):
|
||||
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
|
||||
release_id: 指定发布版本id,不传则导出当前草稿配置。
|
||||
"""
|
||||
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
|
||||
encoded = quote(filename, safe=".")
|
||||
yaml_bytes = yaml_str.encode("utf-8")
|
||||
file_stream = io.BytesIO(yaml_bytes)
|
||||
file_stream.seek(0)
|
||||
return StreamingResponse(
|
||||
file_stream,
|
||||
media_type="application/octet-stream; charset=utf-8",
|
||||
headers={"Content-Disposition": f"attachment; filename={encoded}",
|
||||
"Content-Length": str(len(yaml_bytes))}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import", summary="从 YAML 文件导入应用")
|
||||
@cur_workspace_access_guard()
|
||||
async def import_app(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
app_id: Optional[str] = Form(None),
|
||||
):
|
||||
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
||||
传入 app_id 时覆盖该应用的配置(类型必须一致),否则创建新应用。
|
||||
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
||||
"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
|
||||
|
||||
raw = (await file.read()).decode("utf-8")
|
||||
dsl = yaml.safe_load(raw)
|
||||
if not dsl or "app" not in dsl:
|
||||
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
||||
|
||||
target_app_id = uuid.UUID(app_id) if app_id else None
|
||||
# 仅新建应用时检查配额,覆盖已有应用时跳过
|
||||
if target_app_id is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
_check_quota(db, current_user.tenant_id, "app_quota", "app", workspace_id=current_user.current_workspace_id)
|
||||
result_app, warnings = AppDslService(db).import_dsl(
|
||||
dsl=dsl,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
tenant_id=current_user.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=target_app_id,
|
||||
)
|
||||
return success(
|
||||
data={"app": app_schema.App.model_validate(result_app), "warnings": warnings},
|
||||
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
||||
)
|
||||
|
||||
|
||||
@router.get("/citations/{document_id}/download", summary="下载引用文档原始文件")
|
||||
async def download_citation_file(
|
||||
document_id: uuid.UUID = Path(..., description="引用文档ID"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
下载引用文档的原始文件。
|
||||
仅当应用功能特性 citation.allow_download=true 时,前端才会展示此下载链接。
|
||||
路由本身不做权限校验,由业务层通过 allow_download 开关控制入口。
|
||||
"""
|
||||
import os
|
||||
from fastapi import HTTPException, status as http_status
|
||||
from fastapi.responses import FileResponse
|
||||
from app.core.config import settings
|
||||
from app.models.document_model import Document
|
||||
from app.models.file_model import File as FileModel
|
||||
|
||||
doc = db.query(Document).filter(Document.id == document_id).first()
|
||||
if not doc:
|
||||
raise HTTPException(status_code=http_status.HTTP_404_NOT_FOUND, detail="文档不存在")
|
||||
|
||||
file_record = db.query(FileModel).filter(FileModel.id == doc.file_id).first()
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=http_status.HTTP_404_NOT_FOUND, detail="原始文件不存在")
|
||||
|
||||
file_path = os.path.join(
|
||||
settings.FILE_PATH,
|
||||
str(file_record.kb_id),
|
||||
str(file_record.parent_id),
|
||||
f"{file_record.id}{file_record.file_ext}"
|
||||
)
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=http_status.HTTP_404_NOT_FOUND, detail="文件未找到")
|
||||
|
||||
encoded_name = quote(doc.file_name)
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=doc.file_name,
|
||||
media_type="application/octet-stream",
|
||||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"}
|
||||
)
|
||||
|
||||
110
api/app/controllers/app_log_controller.py
Normal file
110
api/app/controllers/app_log_controller.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""应用日志(消息记录)接口"""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.app_service import AppService
|
||||
from app.services.app_log_service import AppLogService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["App Logs"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/{app_id}/logs", summary="应用日志 - 会话列表")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_logs(
|
||||
app_id: uuid.UUID,
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(20, ge=1, le=100),
|
||||
is_draft: Optional[bool] = Query(None, description="是否草稿会话(不传则返回全部)"),
|
||||
keyword: Optional[str] = Query(None, description="搜索关键词(匹配消息内容)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""查看应用下所有会话记录(分页)
|
||||
|
||||
- is_draft 不传则返回所有会话(草稿 + 正式)
|
||||
- is_draft=True 只返回草稿会话
|
||||
- is_draft=False 只返回发布会话
|
||||
- 支持按 keyword 搜索(匹配消息内容)
|
||||
- 按最新更新时间倒序排列
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
app_service = AppService(db)
|
||||
app = app_service.get_app(app_id, workspace_id)
|
||||
|
||||
# 使用 Service 层查询
|
||||
log_service = AppLogService(db)
|
||||
conversations, total = log_service.list_conversations(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
is_draft=is_draft,
|
||||
keyword=keyword,
|
||||
app_type=app.type,
|
||||
)
|
||||
|
||||
items = [AppLogConversation.model_validate(c) for c in conversations]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get("/{app_id}/logs/{conversation_id}", summary="应用日志 - 会话消息详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_log_detail(
|
||||
app_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""查看某会话的完整消息记录
|
||||
|
||||
- 返回会话基本信息 + 所有消息(按时间正序)
|
||||
- 消息 meta_data 包含模型名、token 用量等信息
|
||||
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
app_service = AppService(db)
|
||||
app = app_service.get_app(app_id, workspace_id)
|
||||
|
||||
# 使用 Service 层查询
|
||||
log_service = AppLogService(db)
|
||||
conversation, messages, node_executions_map = log_service.get_conversation_detail(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=workspace_id,
|
||||
app_type=app.type
|
||||
)
|
||||
|
||||
# 构建基础会话信息(不经过 ORM relationship)
|
||||
base = AppLogConversation.model_validate(conversation)
|
||||
|
||||
# 单独处理 messages,避免触发 SQLAlchemy relationship 校验
|
||||
if messages and isinstance(messages[0], AppLogMessage):
|
||||
# 工作流:已经是 AppLogMessage 实例
|
||||
msg_list = messages
|
||||
else:
|
||||
# Agent:ORM Message 对象逐个转换
|
||||
msg_list = [AppLogMessage.model_validate(m) for m in messages]
|
||||
|
||||
detail = AppLogConversationDetail(
|
||||
**base.model_dump(),
|
||||
messages=msg_list,
|
||||
node_executions_map=node_executions_map,
|
||||
)
|
||||
|
||||
return success(data=detail)
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -16,6 +17,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.dependencies import get_current_user, oauth2_scheme
|
||||
from app.models.user_model import User
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取专用日志器
|
||||
auth_logger = get_auth_logger()
|
||||
@@ -26,7 +28,8 @@ router = APIRouter(tags=["Authentication"])
|
||||
@router.post("/token", response_model=ApiResponse)
|
||||
async def login_for_access_token(
|
||||
form_data: TokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""用户登录获取token"""
|
||||
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||||
@@ -40,35 +43,38 @@ async def login_for_access_token(
|
||||
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||||
|
||||
if not invite_info.is_valid:
|
||||
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
|
||||
|
||||
if invite_info.email != form_data.email:
|
||||
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
|
||||
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||||
try:
|
||||
# 尝试认证用户
|
||||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||
if form_data.invite:
|
||||
auth_service.bind_workspace_with_invite(db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id)
|
||||
auth_service.bind_workspace_with_invite(
|
||||
db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
except BusinessException as e:
|
||||
# 用户不存在且有邀请码,尝试注册
|
||||
if e.code == BizCode.USER_NOT_FOUND:
|
||||
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
|
||||
user = auth_service.register_user_with_invite(
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
elif e.code == BizCode.PASSWORD_ERROR:
|
||||
# 用户存在但密码错误
|
||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
|
||||
else:
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise
|
||||
@@ -81,7 +87,7 @@ async def login_for_access_token(
|
||||
except BusinessException as e:
|
||||
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(e.message, BizCode.LOGIN_FAILED)
|
||||
|
||||
# 创建 tokens
|
||||
access_token, access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -109,14 +115,15 @@ async def login_for_access_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="登录成功"
|
||||
msg=t("auth.login.success")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=ApiResponse)
|
||||
async def refresh_token(
|
||||
refresh_request: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""刷新token"""
|
||||
auth_logger.info("收到token刷新请求")
|
||||
@@ -124,18 +131,18 @@ async def refresh_token(
|
||||
# 验证 refresh token
|
||||
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||||
if not userId:
|
||||
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 检查用户是否存在
|
||||
user = auth_service.get_user_by_id(db, userId)
|
||||
if not user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NO_ACCESS)
|
||||
|
||||
# 检查 refresh token 黑名单
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||||
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
|
||||
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
|
||||
|
||||
# 生成新 tokens
|
||||
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -166,7 +173,7 @@ async def refresh_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="token刷新成功"
|
||||
msg=t("auth.token.refresh_success")
|
||||
)
|
||||
|
||||
|
||||
@@ -174,14 +181,15 @@ async def refresh_token(
|
||||
async def logout(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""登出当前用户:加入token黑名单并清理会话"""
|
||||
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||||
|
||||
token_id = security.get_token_id(token)
|
||||
if not token_id:
|
||||
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 加入黑名单
|
||||
await SessionService.blacklist_token(token_id)
|
||||
@@ -191,5 +199,5 @@ async def logout(
|
||||
await SessionService.clear_user_session(current_user.username)
|
||||
|
||||
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||||
return success(msg="登出成功")
|
||||
return success(msg=t("auth.logout.success"))
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from app.models.user_model import User
|
||||
from app.schemas import chunk_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -441,14 +442,14 @@ async def retrieve_chunks(
|
||||
# 1 participle search, 2 semantic search, 3 hybrid search
|
||||
match retrieve_data.retrieve_type:
|
||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
case chunk_schema.RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
# Efficient deduplication
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
@@ -456,22 +457,24 @@ async def retrieve_chunks(
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
|
||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) if unique_rs else []
|
||||
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
||||
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
||||
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
||||
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
|
||||
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
|
||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
key=llm_key.api_key,
|
||||
model_name=llm_key.model_name,
|
||||
base_url=llm_key.api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
key=emb_key.api_key,
|
||||
model_name=emb_key.model_name,
|
||||
base_url=emb_key.api_base
|
||||
)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
rs.insert(0, doc)
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
@@ -314,8 +314,10 @@ async def parse_documents(
|
||||
)
|
||||
|
||||
# 4. Check if the file exists
|
||||
api_logger.debug(f"Constructed file path: {file_path}")
|
||||
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
||||
if not os.path.exists(file_path):
|
||||
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
|
||||
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
|
||||
@@ -7,10 +7,11 @@ Routes:
|
||||
GET /memory/config/emotion - 获取情绪引擎配置
|
||||
POST /memory/config/emotion - 更新情绪引擎配置
|
||||
"""
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from sqlalchemy.orm import Session
|
||||
from uuid import UUID
|
||||
|
||||
@@ -21,6 +22,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_config_service import EmotionConfigService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.db import get_db
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -37,7 +39,7 @@ class EmotionConfigQuery(BaseModel):
|
||||
|
||||
class EmotionConfigUpdate(BaseModel):
|
||||
"""情绪配置更新请求模型"""
|
||||
config_id: UUID = Field(..., description="配置ID")
|
||||
config_id: Union[uuid.UUID, int, str]= Field(..., description="配置ID")
|
||||
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
||||
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
||||
@@ -46,7 +48,7 @@ class EmotionConfigUpdate(BaseModel):
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
def get_emotion_config(
|
||||
config_id: UUID = Query(..., description="配置ID"),
|
||||
config_id: UUID|int = Query(..., description="配置ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -79,7 +81,7 @@ def get_emotion_config(
|
||||
f"用户 {current_user.username} 请求获取情绪配置",
|
||||
extra={"config_id": config_id}
|
||||
)
|
||||
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 初始化服务
|
||||
config_service = EmotionConfigService(db)
|
||||
|
||||
@@ -158,6 +160,7 @@ def update_emotion_config(
|
||||
}
|
||||
}
|
||||
"""
|
||||
config.config_id=resolve_config_id(config.config_id, db)
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求更新情绪配置",
|
||||
|
||||
@@ -11,6 +11,7 @@ Routes:
|
||||
"""
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
@@ -45,11 +46,14 @@ emotion_service = EmotionAnalyticsService()
|
||||
@router.post("/tags", response_model=ApiResponse)
|
||||
async def get_emotion_tags(
|
||||
request: EmotionTagsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪标签统计",
|
||||
extra={
|
||||
@@ -57,7 +61,8 @@ async def get_emotion_tags(
|
||||
"emotion_type": request.emotion_type,
|
||||
"start_date": request.start_date,
|
||||
"end_date": request.end_date,
|
||||
"limit": request.limit
|
||||
"limit": request.limit,
|
||||
"language_type": language
|
||||
}
|
||||
)
|
||||
|
||||
@@ -67,7 +72,8 @@ async def get_emotion_tags(
|
||||
emotion_type=request.emotion_type,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
limit=request.limit
|
||||
limit=request.limit,
|
||||
language=language
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
@@ -97,11 +103,14 @@ async def get_emotion_tags(
|
||||
@router.post("/wordcloud", response_model=ApiResponse)
|
||||
async def get_emotion_wordcloud(
|
||||
request: EmotionWordcloudRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪词云数据",
|
||||
extra={
|
||||
@@ -144,11 +153,14 @@ async def get_emotion_wordcloud(
|
||||
@router.post("/health", response_model=ApiResponse)
|
||||
async def get_emotion_health(
|
||||
request: EmotionHealthRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 验证时间范围参数
|
||||
if request.time_range not in ["7d", "30d", "90d"]:
|
||||
raise HTTPException(
|
||||
@@ -174,7 +186,7 @@ async def get_emotion_health(
|
||||
"情绪健康指数获取成功",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"health_score": data.get("health_score", 0),
|
||||
"health_score": data.get("health_score") or 0,
|
||||
"level": data.get("level", "未知")
|
||||
}
|
||||
)
|
||||
@@ -196,14 +208,64 @@ async def get_emotion_health(
|
||||
|
||||
|
||||
|
||||
# @router.post("/check-data", response_model=ApiResponse)
|
||||
# async def check_emotion_data_exists(
|
||||
# request: EmotionSuggestionsRequest,
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user),
|
||||
# ):
|
||||
# """检查用户情绪建议数据是否存在
|
||||
|
||||
# Args:
|
||||
# request: 包含 end_user_id
|
||||
# db: 数据库会话
|
||||
# current_user: 当前用户
|
||||
|
||||
# Returns:
|
||||
# 数据存在状态
|
||||
# """
|
||||
# try:
|
||||
# api_logger.info(
|
||||
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
|
||||
# extra={"end_user_id": request.end_user_id}
|
||||
# )
|
||||
|
||||
# # 从数据库获取建议
|
||||
# data = await emotion_service.get_cached_suggestions(
|
||||
# end_user_id=request.end_user_id,
|
||||
# db=db
|
||||
# )
|
||||
|
||||
# if data is None:
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
|
||||
# return fail(
|
||||
# BizCode.NOT_FOUND,
|
||||
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
|
||||
# {"exists": False}
|
||||
# )
|
||||
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
|
||||
# return success(data={"exists": True}, msg="情绪建议数据已存在")
|
||||
|
||||
# except Exception as e:
|
||||
# api_logger.error(
|
||||
# f"检查情绪建议数据失败: {str(e)}",
|
||||
# extra={"end_user_id": request.end_user_id},
|
||||
# exc_info=True
|
||||
# )
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
# detail=f"检查情绪建议数据失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/suggestions", response_model=ApiResponse)
|
||||
async def get_emotion_suggestions(
|
||||
request: EmotionSuggestionsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取个性化情绪建议(从缓存读取)
|
||||
"""获取个性化情绪建议(从数据库读取)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id 和可选的 config_id
|
||||
@@ -211,44 +273,42 @@ async def get_emotion_suggestions(
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
缓存的个性化情绪建议响应
|
||||
存储的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 从缓存获取建议
|
||||
# 从数据库获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期
|
||||
api_logger.info(
|
||||
f"用户 {request.end_user_id} 的建议缓存不存在或已过期",
|
||||
f"用户 {request.end_user_id} 的建议数据不存在",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"建议缓存不存在或已过期,请右上角刷新生成新建议",
|
||||
""
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
"个性化建议获取成功",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="个性化建议获取成功(缓存)")
|
||||
return success(data=data, msg="个性化建议获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
@@ -265,11 +325,11 @@ async def get_emotion_suggestions(
|
||||
@router.post("/generate_suggestions", response_model=ApiResponse)
|
||||
async def generate_emotion_suggestions(
|
||||
request: EmotionGenerateSuggestionsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""生成个性化情绪建议(调用LLM并缓存)
|
||||
"""生成个性化情绪建议(调用LLM并保存到数据库)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id
|
||||
@@ -280,6 +340,9 @@ async def generate_emotion_suggestions(
|
||||
新生成的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求生成个性化情绪建议",
|
||||
extra={
|
||||
@@ -290,15 +353,15 @@ async def generate_emotion_suggestions(
|
||||
# 调用服务层生成建议
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
db=db,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 保存到缓存
|
||||
# 保存到数据库
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
db=db
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
@@ -320,4 +383,4 @@ async def generate_emotion_suggestions(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成个性化建议失败: {str(e)}"
|
||||
)
|
||||
)
|
||||
@@ -19,6 +19,7 @@ from app.models.user_model import User
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import file_service, document_service
|
||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
@@ -131,6 +132,7 @@ async def create_folder(
|
||||
|
||||
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
@check_knowledge_capacity_quota
|
||||
async def upload_file(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
|
||||
@@ -14,8 +14,11 @@ Routes:
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
import httpx
|
||||
import mimetypes
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -29,7 +32,7 @@ from app.core.storage_exceptions import (
|
||||
StorageUploadError,
|
||||
)
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies import get_current_user, get_share_user_id, ShareTokenData
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
@@ -47,6 +50,19 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _match_scheme(request: Request, url: str) -> str:
|
||||
"""
|
||||
将 presigned URL 的协议替换为与当前请求一致的协议(http/https)。
|
||||
解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。
|
||||
"""
|
||||
incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme
|
||||
if url.startswith("http://") and incoming_scheme == "https":
|
||||
return "https://" + url[7:]
|
||||
if url.startswith("https://") and incoming_scheme == "http":
|
||||
return "http://" + url[8:]
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/files", response_model=ApiResponse)
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
@@ -78,7 +94,7 @@ async def upload_file(
|
||||
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||
)
|
||||
|
||||
@@ -143,8 +159,238 @@ async def upload_file(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/share/files", response_model=ApiResponse)
|
||||
async def upload_file_with_share_token(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Upload a file to the configured storage backend using share_token authentication.
|
||||
"""
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.models.app_model import App
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
# Get share and release info from share_token
|
||||
service = ReleaseShareService(db)
|
||||
|
||||
# Get share object to access app_id
|
||||
share = service.repo.get_by_share_token(share_data.share_token)
|
||||
if not share:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Shared app not found"
|
||||
)
|
||||
|
||||
# Get app to access workspace_id
|
||||
app = db.query(App).filter(
|
||||
App.id == share.app_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="App not found"
|
||||
)
|
||||
|
||||
# Get workspace to access tenant_id
|
||||
workspace = db.query(Workspace).filter(
|
||||
Workspace.id == app.workspace_id
|
||||
).first()
|
||||
|
||||
if not workspace:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Workspace not found"
|
||||
)
|
||||
|
||||
tenant_id = workspace.tenant_id
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
api_logger.info(
|
||||
f"Storage upload request (share): tenant_id={tenant_id}, workspace_id={workspace_id}, "
|
||||
f"filename={file.filename}, share_token={share_data.share_token}"
|
||||
)
|
||||
|
||||
# Read file contents
|
||||
contents = await file.read()
|
||||
file_size = len(contents)
|
||||
|
||||
# Validate file size
|
||||
if file_size == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The file is empty."
|
||||
)
|
||||
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||
)
|
||||
|
||||
# Extract file extension
|
||||
_, file_extension = os.path.splitext(file.filename)
|
||||
file_ext = file_extension.lower()
|
||||
|
||||
# Generate file_id and file_key
|
||||
file_id = uuid.uuid4()
|
||||
file_key = generate_file_key(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
)
|
||||
|
||||
# Create file metadata record with pending status
|
||||
file_metadata = FileMetadata(
|
||||
id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_key=file_key,
|
||||
file_name=file.filename,
|
||||
file_ext=file_ext,
|
||||
file_size=file_size,
|
||||
content_type=file.content_type,
|
||||
status="pending",
|
||||
)
|
||||
db.add(file_metadata)
|
||||
db.commit()
|
||||
db.refresh(file_metadata)
|
||||
|
||||
# Upload file to storage backend
|
||||
try:
|
||||
await storage_service.upload_file(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
content=contents,
|
||||
content_type=file.content_type,
|
||||
)
|
||||
# Update status to completed
|
||||
file_metadata.status = "completed"
|
||||
db.commit()
|
||||
api_logger.info(f"File uploaded to storage (share): file_key={file_key}")
|
||||
except StorageUploadError as e:
|
||||
# Update status to failed
|
||||
file_metadata.status = "failed"
|
||||
db.commit()
|
||||
api_logger.error(f"Storage upload failed (share): {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"File storage failed: {str(e)}"
|
||||
)
|
||||
|
||||
api_logger.info(f"File upload successful (share): {file.filename} (file_id: {file_id})")
|
||||
|
||||
return success(
|
||||
data={"file_id": str(file_id), "file_key": file_key},
|
||||
msg="File upload successful"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/info-by-url", response_model=ApiResponse)
|
||||
async def get_file_info_by_url(
|
||||
url: str,
|
||||
):
|
||||
"""
|
||||
Get file information by network URL (no authentication required).
|
||||
|
||||
Fetches file metadata from a remote URL via HTTP HEAD request.
|
||||
Falls back to GET request if HEAD is not supported.
|
||||
Returns file type, name, and size.
|
||||
|
||||
Args:
|
||||
url: The network URL of the file.
|
||||
|
||||
Returns:
|
||||
ApiResponse with file information.
|
||||
"""
|
||||
api_logger.info(f"File info by URL request: url={url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# Try HEAD request first
|
||||
response = await client.head(url, follow_redirects=True)
|
||||
|
||||
# If HEAD fails, try GET request (some servers don't support HEAD)
|
||||
if response.status_code != 200:
|
||||
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access file: HTTP {response.status_code}"
|
||||
)
|
||||
|
||||
# Get file size from Content-Length header or actual content
|
||||
file_size = response.headers.get("Content-Length")
|
||||
if file_size:
|
||||
file_size = int(file_size)
|
||||
elif hasattr(response, 'content'):
|
||||
file_size = len(response.content)
|
||||
else:
|
||||
file_size = None
|
||||
|
||||
# Get content type from Content-Type header
|
||||
content_type = response.headers.get("Content-Type", "application/octet-stream")
|
||||
# Remove charset and other parameters from content type
|
||||
content_type = content_type.split(';')[0].strip()
|
||||
|
||||
# Extract filename from Content-Disposition or URL
|
||||
file_name = None
|
||||
content_disposition = response.headers.get("Content-Disposition")
|
||||
if content_disposition and "filename=" in content_disposition:
|
||||
parts = content_disposition.split("filename=")
|
||||
if len(parts) > 1:
|
||||
file_name = parts[1].strip('"').strip("'")
|
||||
|
||||
if not file_name:
|
||||
parsed_url = urlparse(url)
|
||||
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
|
||||
|
||||
# Extract file extension from filename
|
||||
_, file_ext = os.path.splitext(file_name)
|
||||
|
||||
# If no extension found, infer from content type
|
||||
if not file_ext:
|
||||
ext = mimetypes.guess_extension(content_type)
|
||||
if ext:
|
||||
file_ext = ext
|
||||
file_name = f"{file_name}{file_ext}"
|
||||
|
||||
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
|
||||
|
||||
return success(
|
||||
data={
|
||||
"url": url,
|
||||
"file_name": file_name,
|
||||
"file_ext": file_ext.lower() if file_ext else "",
|
||||
"file_size": file_size,
|
||||
"content_type": content_type,
|
||||
},
|
||||
msg="File information retrieved successfully"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file information: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}", response_model=Any)
|
||||
async def download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -192,6 +438,7 @@ async def download_file(
|
||||
else:
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except FileNotFoundError:
|
||||
@@ -265,6 +512,7 @@ async def delete_file(
|
||||
|
||||
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
||||
async def get_file_url(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
expires: int = None,
|
||||
permanent: bool = False,
|
||||
@@ -326,8 +574,13 @@ async def get_file_url(
|
||||
# For local storage, generate signed URL with expiration
|
||||
url = generate_signed_url(str(file_id), expires)
|
||||
else:
|
||||
# For remote storage (OSS/S3), get presigned URL
|
||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
||||
# For remote storage (OSS/S3), get presigned URL with forced download
|
||||
url = await storage_service.get_file_url(
|
||||
file_key,
|
||||
expires=expires,
|
||||
file_name=file_metadata.file_name,
|
||||
)
|
||||
url = _match_scheme(request, url)
|
||||
|
||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||
return success(
|
||||
@@ -347,8 +600,54 @@ async def get_file_url(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/public-url", response_model=ApiResponse)
|
||||
async def get_permanent_file_url(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
获取文件的永久公开 URL(无过期时间)。
|
||||
|
||||
- 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置)
|
||||
- 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限)
|
||||
"""
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist")
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}")
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
try:
|
||||
if isinstance(storage, LocalStorage):
|
||||
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
|
||||
else:
|
||||
url = await storage.get_permanent_url(file_key)
|
||||
if not url:
|
||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Permanent URL not supported for current storage backend")
|
||||
|
||||
api_logger.info(f"Generated permanent URL: file_id={file_id}")
|
||||
return success(
|
||||
data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name},
|
||||
msg="Permanent file URL generated successfully"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to generate permanent URL: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate permanent URL: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/public/{file_id}", response_model=Any)
|
||||
async def public_download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
expires: int = 0,
|
||||
signature: str = "",
|
||||
@@ -420,6 +719,7 @@ async def public_download_file(
|
||||
# For remote storage, redirect to presigned URL
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
@@ -431,6 +731,7 @@ async def public_download_file(
|
||||
|
||||
@router.get("/permanent/{file_id}", response_model=Any)
|
||||
async def permanent_download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
@@ -489,7 +790,8 @@ async def permanent_download_file(
|
||||
# For remote storage, redirect to presigned URL with long expiration
|
||||
try:
|
||||
# Use a very long expiration (7 days max for most cloud providers)
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
@@ -497,3 +799,44 @@ async def permanent_download_file(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/status", response_model=ApiResponse)
|
||||
async def get_file_status(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get file upload/processing status (no authentication required).
|
||||
|
||||
This endpoint is used to check if a file (e.g., TTS audio) is ready.
|
||||
Returns status: pending, completed, or failed.
|
||||
|
||||
Args:
|
||||
file_id: The UUID of the file.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
ApiResponse with file status and metadata.
|
||||
"""
|
||||
api_logger.info(f"File status request: file_id={file_id}")
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"file_id": str(file_id),
|
||||
"status": file_metadata.status,
|
||||
"file_name": file_metadata.file_name,
|
||||
"file_size": file_metadata.file_size,
|
||||
"content_type": file_metadata.content_type,
|
||||
},
|
||||
msg="File status retrieved successfully"
|
||||
)
|
||||
|
||||
@@ -3,9 +3,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.db import get_db, SessionLocal
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.repositories.home_page_repository import HomePageRepository
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.home_page_service import HomePageService
|
||||
|
||||
@@ -31,9 +32,32 @@ def get_workspace_list(
|
||||
|
||||
@router.get("/version", response_model=ApiResponse)
|
||||
def get_system_version():
|
||||
"""获取系统版本号+说明"""
|
||||
current_version = settings.SYSTEM_VERSION
|
||||
version_info = HomePageService.load_version_introduction(current_version)
|
||||
"""获取系统版本号 + 说明"""
|
||||
current_version = None
|
||||
version_info = None
|
||||
|
||||
# 1️⃣ 优先从数据库获取最新已发布的版本
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
current_version, version_info = HomePageRepository.get_latest_version_introduction(db)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# 2️⃣ 降级:使用环境变量中的版本号
|
||||
if not current_version:
|
||||
current_version = settings.SYSTEM_VERSION
|
||||
version_info = HomePageService.load_version_introduction(current_version)
|
||||
|
||||
# 3️⃣ 如果数据库和 JSON 都没有,返回基本信息
|
||||
if not version_info:
|
||||
version_info = {
|
||||
"introduction": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []},
|
||||
"introduction_en": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []}
|
||||
}
|
||||
|
||||
return success(
|
||||
data={
|
||||
"version": current_version,
|
||||
|
||||
833
api/app/controllers/i18n_controller.py
Normal file
833
api/app/controllers/i18n_controller.py
Normal file
@@ -0,0 +1,833 @@
|
||||
"""
|
||||
I18n Management API Controller
|
||||
|
||||
This module provides management APIs for:
|
||||
- Language management (list, get, add, update languages)
|
||||
- Translation management (get, update, reload translations)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Callable, Optional
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.i18n.dependencies import get_translator
|
||||
from app.i18n.service import get_translation_service
|
||||
from app.models.user_model import User
|
||||
from app.schemas.i18n_schema import (
|
||||
LanguageInfo,
|
||||
LanguageListResponse,
|
||||
LanguageCreateRequest,
|
||||
LanguageUpdateRequest,
|
||||
TranslationResponse,
|
||||
TranslationUpdateRequest,
|
||||
MissingTranslationsResponse,
|
||||
ReloadResponse
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/i18n",
|
||||
tags=["I18n Management"],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Language Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/languages", response_model=ApiResponse)
|
||||
def get_languages(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of all supported languages.
|
||||
|
||||
Returns:
|
||||
List of language information including code, name, and status
|
||||
"""
|
||||
api_logger.info(f"Get languages request from user: {current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Get available locales from translation service
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Build language info list
|
||||
languages = []
|
||||
for locale in available_locales:
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
# Get native names
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
languages.append(language_info)
|
||||
|
||||
response = LanguageListResponse(languages=languages)
|
||||
|
||||
api_logger.info(f"Returning {len(languages)} languages")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/languages/{locale}", response_model=ApiResponse)
|
||||
def get_language(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get information about a specific language.
|
||||
|
||||
Args:
|
||||
locale: Language code (e.g., 'zh', 'en')
|
||||
|
||||
Returns:
|
||||
Language information
|
||||
"""
|
||||
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Build language info
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
|
||||
api_logger.info(f"Returning language info for: {locale}")
|
||||
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/languages", response_model=ApiResponse)
|
||||
def add_language(
|
||||
request: LanguageCreateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Add a new language (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual language addition
|
||||
requires creating translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
request: Language creation request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Add language request: code={request.code}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language already exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if request.code in available_locales:
|
||||
api_logger.warning(f"Language already exists: {request.code}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.language.already_exists", locale=request.code)
|
||||
)
|
||||
|
||||
# Note: Actual language addition requires creating translation files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language addition validated: {request.code}. "
|
||||
"Translation files need to be created manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t(
|
||||
"i18n.language.add_instructions",
|
||||
locale=request.code,
|
||||
dir=settings.I18N_CORE_LOCALES_DIR
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/languages/{locale}", response_model=ApiResponse)
|
||||
def update_language(
|
||||
locale: str,
|
||||
request: LanguageUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update language configuration (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual configuration
|
||||
changes require updating environment variables or config files.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
request: Language update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update language request: locale={locale}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Note: Actual configuration changes require updating settings
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language update validated: {locale}. "
|
||||
"Configuration changes require environment variable updates."
|
||||
)
|
||||
|
||||
return success(msg=t("i18n.language.update_instructions", locale=locale))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Translation Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/translations", response_model=ApiResponse)
|
||||
def get_all_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for all or specific locale.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
All translations organized by locale and namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get all translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
if locale:
|
||||
# Get translations for specific locale
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = {
|
||||
locale: translation_service._cache.get(locale, {})
|
||||
}
|
||||
else:
|
||||
# Get all translations
|
||||
translations = translation_service._cache
|
||||
|
||||
response = TranslationResponse(translations=translations)
|
||||
|
||||
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}", response_model=ApiResponse)
|
||||
def get_locale_translations(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for a specific locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
|
||||
Returns:
|
||||
All translations for the locale organized by namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get locale translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = translation_service._cache.get(locale, {})
|
||||
|
||||
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
|
||||
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
|
||||
def get_namespace_translations(
|
||||
locale: str,
|
||||
namespace: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get translations for a specific namespace in a locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
namespace: Translation namespace (e.g., 'common', 'auth')
|
||||
|
||||
Returns:
|
||||
Translations for the specified namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get namespace translations request: locale={locale}, "
|
||||
f"namespace={namespace}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Get namespace translations
|
||||
locale_translations = translation_service._cache.get(locale, {})
|
||||
namespace_translations = locale_translations.get(namespace, {})
|
||||
|
||||
if not namespace_translations:
|
||||
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Returning translations for namespace: {namespace} in locale: {locale}"
|
||||
)
|
||||
return success(
|
||||
data={
|
||||
"locale": locale,
|
||||
"namespace": namespace,
|
||||
"translations": namespace_translations
|
||||
},
|
||||
msg=t("common.success.retrieved")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
|
||||
def update_translation(
|
||||
locale: str,
|
||||
key: str,
|
||||
request: TranslationUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update a single translation (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual translation updates
|
||||
require modifying translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
key: Translation key (format: "namespace.key.subkey")
|
||||
request: Translation update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update translation request: locale={locale}, key={key}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Validate key format
|
||||
if "." not in key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.translation.invalid_key_format", key=key)
|
||||
)
|
||||
|
||||
# Note: Actual translation updates require modifying JSON files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Translation update validated: {locale}/{key}. "
|
||||
"Translation files need to be updated manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/translations/missing", response_model=ApiResponse)
|
||||
def get_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of missing translations.
|
||||
|
||||
Compares translations across locales to find missing keys.
|
||||
|
||||
Args:
|
||||
locale: Optional locale to check (defaults to checking all non-default locales)
|
||||
|
||||
Returns:
|
||||
List of missing translation keys
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
default_locale = settings.I18N_DEFAULT_LANGUAGE
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Get default locale translations as reference
|
||||
default_translations = translation_service._cache.get(default_locale, {})
|
||||
|
||||
# Collect all keys from default locale
|
||||
def collect_keys(data, prefix=""):
|
||||
keys = []
|
||||
for key, value in data.items():
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, dict):
|
||||
keys.extend(collect_keys(value, full_key))
|
||||
else:
|
||||
keys.append(full_key)
|
||||
return keys
|
||||
|
||||
default_keys = set()
|
||||
for namespace, translations in default_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
default_keys.update(namespace_keys)
|
||||
|
||||
# Find missing keys in target locale(s)
|
||||
missing_by_locale = {}
|
||||
|
||||
target_locales = [locale] if locale else [
|
||||
loc for loc in available_locales if loc != default_locale
|
||||
]
|
||||
|
||||
for target_locale in target_locales:
|
||||
if target_locale not in available_locales:
|
||||
continue
|
||||
|
||||
target_translations = translation_service._cache.get(target_locale, {})
|
||||
target_keys = set()
|
||||
|
||||
for namespace, translations in target_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
target_keys.update(namespace_keys)
|
||||
|
||||
missing_keys = default_keys - target_keys
|
||||
if missing_keys:
|
||||
missing_by_locale[target_locale] = sorted(list(missing_keys))
|
||||
|
||||
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
|
||||
|
||||
total_missing = sum(len(keys) for keys in missing_by_locale.values())
|
||||
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
|
||||
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/reload", response_model=ApiResponse)
|
||||
def reload_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Trigger hot reload of translation files (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to reload (defaults to reloading all locales)
|
||||
|
||||
Returns:
|
||||
Reload status and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Reload translations request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
if not settings.I18N_ENABLE_HOT_RELOAD:
|
||||
api_logger.warning("Hot reload is disabled in configuration")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=t("i18n.reload.disabled")
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
try:
|
||||
# Reload translations
|
||||
translation_service.reload(locale)
|
||||
|
||||
# Get statistics
|
||||
available_locales = translation_service.get_available_locales()
|
||||
reloaded_locales = [locale] if locale else available_locales
|
||||
|
||||
response = ReloadResponse(
|
||||
success=True,
|
||||
reloaded_locales=reloaded_locales,
|
||||
total_locales=len(available_locales)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
|
||||
)
|
||||
|
||||
return success(data=response.dict(), msg=t("i18n.reload.success"))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to reload translations: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=t("i18n.reload.failed", error=str(e))
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Performance Monitoring APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics", response_model=ApiResponse)
|
||||
def get_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get i18n performance metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Performance metrics including:
|
||||
- Request counts
|
||||
- Missing translations
|
||||
- Timing statistics
|
||||
- Locale usage
|
||||
- Error counts
|
||||
"""
|
||||
api_logger.info(f"Get metrics request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
metrics = translation_service.get_metrics_summary()
|
||||
|
||||
api_logger.info("Returning i18n metrics")
|
||||
return success(data=metrics, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/cache", response_model=ApiResponse)
|
||||
def get_cache_stats(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get cache statistics (admin only).
|
||||
|
||||
Returns:
|
||||
Cache statistics including:
|
||||
- Hit/miss rates
|
||||
- LRU cache performance
|
||||
- Loaded locales
|
||||
- Memory usage
|
||||
"""
|
||||
api_logger.info(f"Get cache stats request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
cache_stats = translation_service.get_cache_stats()
|
||||
memory_usage = translation_service.get_memory_usage()
|
||||
|
||||
data = {
|
||||
"cache": cache_stats,
|
||||
"memory": memory_usage
|
||||
}
|
||||
|
||||
api_logger.info("Returning cache statistics")
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/prometheus")
|
||||
def get_prometheus_metrics(
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get metrics in Prometheus format (admin only).
|
||||
|
||||
Returns:
|
||||
Prometheus-formatted metrics as plain text
|
||||
"""
|
||||
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
prometheus_output = metrics.export_prometheus()
|
||||
|
||||
from fastapi.responses import PlainTextResponse
|
||||
return PlainTextResponse(content=prometheus_output)
|
||||
|
||||
|
||||
@router.post("/metrics/reset", response_model=ApiResponse)
|
||||
def reset_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Reset all metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(f"Reset metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
metrics.reset()
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_service.cache.reset_stats()
|
||||
|
||||
api_logger.info("Metrics reset completed")
|
||||
return success(msg=t("i18n.metrics.reset_success"))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Missing Translation Logging and Reporting APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/logs/missing", response_model=ApiResponse)
|
||||
def get_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get missing translation logs (admin only).
|
||||
|
||||
Returns logged missing translations with context information.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
limit: Maximum number of entries to return (default: 100)
|
||||
|
||||
Returns:
|
||||
Missing translation logs with context
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translation logs request: locale={locale}, "
|
||||
f"limit={limit}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Get missing translations
|
||||
missing_translations = translation_logger.get_missing_translations(locale)
|
||||
|
||||
# Get missing with context
|
||||
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
|
||||
|
||||
# Get statistics
|
||||
statistics = translation_logger.get_statistics()
|
||||
|
||||
data = {
|
||||
"missing_translations": missing_translations,
|
||||
"recent_context": missing_with_context,
|
||||
"statistics": statistics
|
||||
}
|
||||
|
||||
api_logger.info(
|
||||
f"Returning {statistics['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/logs/missing/report", response_model=ApiResponse)
|
||||
def generate_missing_translation_report(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Generate a comprehensive missing translation report (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Comprehensive report with missing translations and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Generate missing translation report request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate report
|
||||
report = translation_logger.generate_report(locale)
|
||||
|
||||
api_logger.info(
|
||||
f"Generated report with {report['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=report, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/logs/missing/export", response_model=ApiResponse)
|
||||
def export_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Export missing translations to JSON file (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Export status and file path
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Export missing translations request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from datetime import datetime
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
locale_suffix = f"_{locale}" if locale else "_all"
|
||||
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
|
||||
|
||||
# Export to file
|
||||
translation_logger.export_to_json(output_file)
|
||||
|
||||
api_logger.info(f"Missing translations exported to: {output_file}")
|
||||
return success(
|
||||
data={"file_path": output_file},
|
||||
msg=t("i18n.logs.export_success", file=output_file)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/logs/missing", response_model=ApiResponse)
|
||||
def clear_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Clear missing translation logs (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to clear (clears all if not specified)
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Clear missing translation logs request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Clear logs
|
||||
translation_logger.clear(locale)
|
||||
|
||||
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
|
||||
return success(msg=t("i18n.logs.clear_success"))
|
||||
@@ -122,6 +122,48 @@ def validate_confidence_threshold(threshold: float) -> None:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def check_user_data_exists(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
检查用户画像数据是否存在
|
||||
|
||||
Args:
|
||||
end_user_id: 目标用户ID
|
||||
|
||||
Returns:
|
||||
数据存在状态
|
||||
"""
|
||||
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="画像数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
|
||||
return success(data={"exists": True}, msg="画像数据已存在")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
|
||||
|
||||
|
||||
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
@@ -159,12 +201,8 @@ async def get_preference_tags(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract preferences from cache
|
||||
preferences = cached_profile.get("preferences", [])
|
||||
@@ -230,12 +268,8 @@ async def get_dimension_portrait(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract portrait from cache
|
||||
portrait = cached_profile.get("portrait", {})
|
||||
@@ -278,12 +312,8 @@ async def get_interest_area_distribution(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract interest areas from cache
|
||||
interest_areas = cached_profile.get("interest_areas", {})
|
||||
@@ -330,12 +360,8 @@ async def get_behavior_habits(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract habits from cache
|
||||
habits = cached_profile.get("habits", [])
|
||||
|
||||
@@ -9,13 +9,16 @@ from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.common import settings
|
||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.nlp import rag_tokenizer, search
|
||||
from app.core.rag.prompts.generator import graph_entity_types
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import knowledge_model
|
||||
@@ -24,6 +27,7 @@ from app.schemas import knowledge_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import knowledge_service, document_service
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -176,6 +180,7 @@ async def get_knowledges(
|
||||
|
||||
|
||||
@router.post("/knowledge", response_model=ApiResponse)
|
||||
@check_knowledge_capacity_quota
|
||||
async def create_knowledge(
|
||||
create_data: knowledge_schema.KnowledgeCreate,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -349,6 +354,7 @@ async def delete_knowledge(
|
||||
# 2. Soft-delete knowledge base
|
||||
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
|
||||
db_knowledge.status = 2
|
||||
db_knowledge.updated_at = datetime.datetime.now()
|
||||
db.commit()
|
||||
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
|
||||
return success(msg="The knowledge base has been successfully deleted")
|
||||
@@ -484,3 +490,99 @@ async def rebuild_knowledge_graph(
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/check/yuque/auth", response_model=ApiResponse)
|
||||
async def check_yuque_auth(
|
||||
yuque_user_id: str,
|
||||
yuque_token: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
check yuque auth info
|
||||
"""
|
||||
api_logger.info(f"check yuque auth info, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_client = YuqueAPIClient(
|
||||
user_id=yuque_user_id,
|
||||
token=yuque_token
|
||||
)
|
||||
async with api_client as client:
|
||||
repos = await client.get_user_repos()
|
||||
if repos:
|
||||
return success(msg="Successfully auth yuque info")
|
||||
return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"auth yuque info failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/check/feishu/auth", response_model=ApiResponse)
|
||||
async def check_feishu_auth(
|
||||
feishu_app_id: str,
|
||||
feishu_app_secret: str,
|
||||
feishu_folder_token: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
check feishu auth info
|
||||
"""
|
||||
api_logger.info(f"check feishu auth info, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_client = FeishuAPIClient(
|
||||
app_id=feishu_app_id,
|
||||
app_secret=feishu_app_secret
|
||||
)
|
||||
async with api_client as client:
|
||||
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
|
||||
if files:
|
||||
return success(msg="Successfully auth feishu info")
|
||||
return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"auth feishu info failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
|
||||
async def sync_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
sync knowledge base information based on knowledge_id
|
||||
"""
|
||||
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query knowledge base information from the database
|
||||
api_logger.debug(f"Query knowledge base: {knowledge_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. sync knowledge
|
||||
# from app.tasks import sync_knowledge_for_kb
|
||||
# sync_knowledge_for_kb(kb_id)
|
||||
task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id])
|
||||
result = {
|
||||
"task_id": task.id
|
||||
}
|
||||
return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
474
api/app/controllers/mcp_market_config_controller.py
Normal file
474
api/app/controllers/mcp_market_config_controller.py
Normal file
@@ -0,0 +1,474 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
import requests
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
from modelscope.hub.errors import raise_for_http_status
|
||||
from modelscope.hub.mcp_api import MCPApi
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import mcp_market_config_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_config_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_config_service, mcp_market_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/mcp_market_configs",
|
||||
tags=["mcp_market_configs"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mcp_servers", response_model=ApiResponse)
|
||||
async def get_mcp_servers(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (Optional search query string,e.g. Chinese service name, English service name, author/owner username)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the mcp servers list in pages
|
||||
- Support keyword search for name,author,owner
|
||||
- Return paging metadata + mcp server list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp server list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
if page * pagesize > 100:
|
||||
api_logger.warning(f"Paging parameters exceed ModelScope limit: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The maximum number of MCP services can view is 100. Please visit the ModelScope MCP Plaza."
|
||||
)
|
||||
|
||||
# 2. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 3. Execute paged query
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
'filter': {},
|
||||
'page_number': page,
|
||||
'page_size': pagesize,
|
||||
'search': keywords
|
||||
}
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(token)
|
||||
headers=api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=headers,
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get MCP servers: {str(e)}"
|
||||
)
|
||||
|
||||
data = api._handle_response(r)
|
||||
total = data.get('total_count', 0)
|
||||
mcp_server_list = data.get('mcp_server_list', [])
|
||||
# items = [{
|
||||
# 'name': item.get('name', ''),
|
||||
# 'id': item.get('id', ''),
|
||||
# 'description': item.get('description', '')
|
||||
# } for item in mcp_server_list]
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": mcp_server_list,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
# 5. Update mck_market.mcp_count
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=db_mcp_market_config.mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={db_mcp_market_config.mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or access is denied"
|
||||
)
|
||||
db_mcp_market.mcp_count = total
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market)
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/operational_mcp_servers", response_model=ApiResponse)
|
||||
async def get_operational_mcp_servers(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the operational mcp servers list in pages
|
||||
- Support keyword search for name,author,owner
|
||||
- Return paging metadata + operational mcp server list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Execute paged query
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
url = f'{api.mcp_base_url}/operational'
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
||||
r = api.session.get(url, headers=headers, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get operational MCP servers: {str(e)}"
|
||||
)
|
||||
|
||||
data = api._handle_response(r)
|
||||
total = data.get('total_count', 0)
|
||||
mcp_server_list = data.get('mcp_server_list', [])
|
||||
# items = [{
|
||||
# 'name': item.get('name', ''),
|
||||
# 'id': item.get('id', ''),
|
||||
# 'description': item.get('description', '')
|
||||
# } for item in mcp_server_list]
|
||||
|
||||
# 3. Return structured response
|
||||
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/mcp_server", response_model=ApiResponse)
|
||||
async def get_mcp_server(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
server_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get detailed information for a specific MCP Server
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp server: tenant_id={current_user.tenant_id}, mcp_market_config_id={mcp_market_config_id}, server_id={server_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Get detailed information for a specific MCP Server
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
result = api.get_mcp_server(server_id=server_id)
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.post("/mcp_market_config", response_model=ApiResponse)
|
||||
async def create_mcp_market_config(
|
||||
create_data: mcp_market_config_schema.McpMarketConfigCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create mcp market config
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market config: mcp_market_id={create_data.mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
|
||||
# 1. Validate token can access ModelScope MCP market
|
||||
if not create_data.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Token is required to access ModelScope MCP market"
|
||||
)
|
||||
try:
|
||||
api = MCPApi()
|
||||
api.login(create_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(create_data.token)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {create_data.token}'
|
||||
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
||||
)
|
||||
# 2. Check if the mcp market name already exists
|
||||
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
|
||||
if db_mcp_market_config_exist:
|
||||
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
||||
)
|
||||
# 2. verify token
|
||||
create_data.status = 1
|
||||
try:
|
||||
api = MCPApi()
|
||||
token = create_data.token
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
'filter': {},
|
||||
'page_number': 1,
|
||||
'page_size': 20,
|
||||
'search': ""
|
||||
}
|
||||
cookies = api.get_cookies(token)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=headers,
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||
create_data.status = 0
|
||||
# 3. create mcp_market_config
|
||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the mcp market config failed: {create_data.mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market config information based on mcp_market_config_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Obtain details of the mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="Successfully obtained mcp market config information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market config query failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/mcp_market_id/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market_config_by_mcp_market_id(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market config information based on mcp_market_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market config: mcp_market_id={mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: mcp_market_id={mcp_market_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="Successfully obtained mcp market config information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market config query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def update_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
update_data: mcp_market_config_schema.McpMarketConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# 1. Check if the mcp market config exists
|
||||
api_logger.debug(f"Query the mcp market config to be updated: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Validate new token if provided
|
||||
if update_data.token is not None:
|
||||
try:
|
||||
api = MCPApi()
|
||||
api.login(update_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(update_data.token)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {update_data.token}'
|
||||
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
||||
)
|
||||
|
||||
# 3. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
updated_fields = []
|
||||
for field, value in update_dict.items():
|
||||
if hasattr(db_mcp_market_config, field):
|
||||
old_value = getattr(db_mcp_market_config, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_mcp_market_config, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 4. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market_config)
|
||||
api_logger.info(f"The mcp market config has been successfully updated: (ID: {db_mcp_market_config.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"The mcp market config update failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"The mcp market config update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 5. Return the updated mcp market config
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def delete_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete mcp market config
|
||||
"""
|
||||
api_logger.info(f"Request to delete mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the mcp market config exists
|
||||
api_logger.debug(f"Check whether the mcp market config exists: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Deleting mcp market config
|
||||
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
api_logger.info(f"The mcp market config has been successfully deleted: (ID: {mcp_market_config_id})")
|
||||
return success(msg="The mcp market config has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
262
api/app/controllers/mcp_market_controller.py
Normal file
262
api/app/controllers/mcp_market_controller.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import mcp_market_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/mcp_markets",
|
||||
tags=["mcp_markets"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mcp_markets", response_model=ApiResponse)
|
||||
async def get_mcp_markets(
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: category, created_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (mcp_market base name)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the mcp markets list in pages
|
||||
- Support keyword search for name,description
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + mcp_market list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp market list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Construct query conditions
|
||||
filters = []
|
||||
|
||||
# Keyword search (fuzzy matching of mcp market name,description)
|
||||
if keywords:
|
||||
api_logger.debug(f"Add keyword search criteria: {keywords}")
|
||||
filters.append(
|
||||
or_(
|
||||
mcp_market_model.McpMarket.name.ilike(f"%{keywords}%"),
|
||||
mcp_market_model.McpMarket.description.ilike(f"%{keywords}%")
|
||||
)
|
||||
)
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug("Start executing mcp market paging query")
|
||||
total, items = mcp_market_service.get_mcp_markets_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"mcp market query successful: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market query failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Query failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=jsonable_encoder(result), msg="Query of mcp market list successful")
|
||||
|
||||
|
||||
@router.post("/mcp_market", response_model=ApiResponse)
|
||||
async def create_mcp_market(
|
||||
create_data: mcp_market_schema.McpMarketCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create mcp market
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market: name={create_data.name}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market: {create_data.name}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=create_data.name, current_user=current_user)
|
||||
if db_mcp_market_exist:
|
||||
api_logger.warning(f"The mcp market name already exists: {create_data.name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market name already exists: {create_data.name}"
|
||||
)
|
||||
db_mcp_market = mcp_market_service.create_mcp_market(db=db, mcp_market=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market has been successfully created: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="The mcp market has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the mcp market failed: {create_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market information based on mcp_market_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Obtain details of the mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market information from the database
|
||||
api_logger.debug(f"Query mcp market: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market query successful: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="Successfully obtained mcp market information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def update_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
update_data: mcp_market_schema.McpMarketUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# 1. Check if the mcp market exists
|
||||
api_logger.debug(f"Query the mcp market to be updated: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(
|
||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. not updating the name (name already exists)
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
if "name" in update_dict:
|
||||
name = update_dict["name"]
|
||||
if name != db_mcp_market.name:
|
||||
# Check if the mcp market name already exists
|
||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=name, current_user=current_user)
|
||||
if db_mcp_market_exist:
|
||||
api_logger.warning(f"The mcp market name already exists: {name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market name already exists: {name}"
|
||||
)
|
||||
# 3. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market fields: {mcp_market_id}")
|
||||
updated_fields = []
|
||||
for field, value in update_dict.items():
|
||||
if hasattr(db_mcp_market, field):
|
||||
old_value = getattr(db_mcp_market, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_mcp_market, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 4. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market)
|
||||
api_logger.info(f"The mcp market has been successfully updated: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"The mcp market update failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"The mcp market update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 5. Return the updated mcp market
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="The mcp market information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def delete_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete mcp market
|
||||
"""
|
||||
api_logger.info(f"Request to delete mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the mcp market exists
|
||||
api_logger.debug(f"Check whether the mcp market exists: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(
|
||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Deleting mcp market
|
||||
mcp_market_service.delete_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
api_logger.info(f"The mcp market has been successfully deleted: (ID: {mcp_market_id})")
|
||||
return success(msg="The mcp market has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the mcp market: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
@@ -1,26 +1,32 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.enums import SearchStrategy, Neo4jNodeType
|
||||
from app.core.memory.memory_service import MemoryService
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||
from app.models import ModelApiKey
|
||||
from app.models.user_model import User
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import task_service, workspace_service
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.memory_agent_service import get_end_user_connected_config as get_config
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
@@ -35,7 +41,7 @@ router = APIRouter(
|
||||
|
||||
@router.get("/health/status", response_model=ApiResponse)
|
||||
async def get_health_status(
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get latest health status written by Celery periodic task
|
||||
@@ -53,8 +59,9 @@ async def get_health_status(
|
||||
|
||||
@router.get("/download_log")
|
||||
async def download_log(
|
||||
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
log_type: str = Query("file", regex="^(file|transmission)$",
|
||||
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Download or stream agent service log file
|
||||
@@ -73,16 +80,16 @@ async def download_log(
|
||||
- transmission mode: StreamingResponse with SSE
|
||||
"""
|
||||
api_logger.info(f"Log download requested with log_type={log_type}")
|
||||
|
||||
|
||||
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
|
||||
if log_type not in ["file", "transmission"]:
|
||||
api_logger.warning(f"Invalid log_type parameter: {log_type}")
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"无效的log_type参数",
|
||||
BizCode.BAD_REQUEST,
|
||||
"无效的log_type参数",
|
||||
"log_type必须是'file'或'transmission'"
|
||||
)
|
||||
|
||||
|
||||
# Route to appropriate mode
|
||||
if log_type == "file":
|
||||
# File mode: Return complete log file content
|
||||
@@ -114,136 +121,150 @@ async def download_log(
|
||||
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Write service endpoint - processes write operations synchronously
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.end_user_id,
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=result, msg="写入成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server_async(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
# @router.post("/writer_service", response_model=ApiResponse)
|
||||
# @cur_workspace_access_guard()
|
||||
# async def write_server(
|
||||
# user_input: Write_UserInput,
|
||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user)
|
||||
# ):
|
||||
# """
|
||||
# Write service endpoint - processes write operations synchronously
|
||||
#
|
||||
# Args:
|
||||
# user_input: Write request containing message and end_user_id
|
||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
#
|
||||
# Returns:
|
||||
# Response with write operation status
|
||||
# """
|
||||
# # 使用集中化的语言校验
|
||||
# language = get_language_from_header(language_type)
|
||||
#
|
||||
# config_id = user_input.config_id
|
||||
# workspace_id = current_user.current_workspace_id
|
||||
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
#
|
||||
# # 获取 storage_type,如果为 None 则使用默认值
|
||||
# storage_type = workspace_service.get_workspace_storage_type(
|
||||
# db=db,
|
||||
# workspace_id=workspace_id,
|
||||
# user=current_user
|
||||
# )
|
||||
# if storage_type is None: storage_type = 'neo4j'
|
||||
# user_rag_memory_id = ''
|
||||
#
|
||||
# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
# if storage_type == 'rag':
|
||||
# if workspace_id:
|
||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
# db=db,
|
||||
# name="USER_RAG_MERORY",
|
||||
# workspace_id=workspace_id
|
||||
# )
|
||||
# if knowledge:
|
||||
# user_rag_memory_id = str(knowledge.id)
|
||||
# else:
|
||||
# api_logger.warning(
|
||||
# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
# storage_type = 'neo4j'
|
||||
# else:
|
||||
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
# storage_type = 'neo4j'
|
||||
#
|
||||
# api_logger.info(
|
||||
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
# try:
|
||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
# result = await memory_agent_service.write_memory(
|
||||
# user_input.end_user_id,
|
||||
# messages_list,
|
||||
# config_id,
|
||||
# db,
|
||||
# storage_type,
|
||||
# user_rag_memory_id,
|
||||
# language
|
||||
# )
|
||||
#
|
||||
# return success(data=result, msg="写入成功")
|
||||
# except BaseException as e:
|
||||
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
# if hasattr(e, 'exceptions'):
|
||||
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
# detailed_error = "; ".join(error_messages)
|
||||
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
#
|
||||
#
|
||||
# @router.post("/writer_service_async", response_model=ApiResponse)
|
||||
# @cur_workspace_access_guard()
|
||||
# async def write_server_async(
|
||||
# user_input: Write_UserInput,
|
||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user)
|
||||
# ):
|
||||
# """
|
||||
# Async write service endpoint - enqueues write processing to Celery
|
||||
#
|
||||
# Args:
|
||||
# user_input: Write request containing message and end_user_id
|
||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
#
|
||||
# Returns:
|
||||
# Task ID for tracking async operation
|
||||
# Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
# """
|
||||
# # 使用集中化的语言校验
|
||||
# language = get_language_from_header(language_type)
|
||||
#
|
||||
# config_id = user_input.config_id
|
||||
# workspace_id = current_user.current_workspace_id
|
||||
# api_logger.info(
|
||||
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
#
|
||||
# # 获取 storage_type,如果为 None 则使用默认值
|
||||
# storage_type = workspace_service.get_workspace_storage_type(
|
||||
# db=db,
|
||||
# workspace_id=workspace_id,
|
||||
# user=current_user
|
||||
# )
|
||||
# if storage_type is None: storage_type = 'neo4j'
|
||||
# user_rag_memory_id = ''
|
||||
# if workspace_id:
|
||||
#
|
||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
# db=db,
|
||||
# name="USER_RAG_MERORY",
|
||||
# workspace_id=workspace_id
|
||||
# )
|
||||
# if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
# try:
|
||||
# # 获取标准化的消息列表
|
||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
#
|
||||
# task = celery_app.send_task(
|
||||
# "app.core.memory.agent.write_message",
|
||||
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
# )
|
||||
# api_logger.info(f"Write task queued: {task.id}")
|
||||
#
|
||||
# return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
# except Exception as e:
|
||||
# api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/read_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def read_server(
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Read service endpoint - processes read operations synchronously
|
||||
@@ -278,35 +299,94 @@ async def read_server(
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await memory_agent_service.read_memory(
|
||||
user_input.end_user_id,
|
||||
user_input.message,
|
||||
user_input.history,
|
||||
user_input.search_switch,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||
query = user_input.message
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||
api_logger.info(
|
||||
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
# result = await memory_agent_service.read_memory(
|
||||
# user_input.end_user_id,
|
||||
# user_input.message,
|
||||
# user_input.history,
|
||||
# user_input.search_switch,
|
||||
# config_id,
|
||||
# db,
|
||||
# storage_type,
|
||||
# user_rag_memory_id
|
||||
# )
|
||||
# if str(user_input.search_switch) == "2":
|
||||
# retrieve_info = result['answer']
|
||||
# history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
||||
# user_input.end_user_id)
|
||||
# query = user_input.message
|
||||
#
|
||||
# # 调用 memory_agent_service 的方法生成最终答案
|
||||
# result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||
# end_user_id=user_input.end_user_id,
|
||||
# retrieve_info=retrieve_info,
|
||||
# history=history,
|
||||
# query=query,
|
||||
# config_id=config_id,
|
||||
# db=db
|
||||
# )
|
||||
# if "信息不足,无法回答" in result['answer']:
|
||||
# result['answer'] = retrieve_info
|
||||
memory_config = get_config(user_input.end_user_id, db)
|
||||
service = MemoryService(
|
||||
db,
|
||||
memory_config["memory_config_id"],
|
||||
end_user_id=user_input.end_user_id
|
||||
)
|
||||
search_result = await service.read(
|
||||
user_input.message,
|
||||
SearchStrategy(user_input.search_switch)
|
||||
)
|
||||
intermediate_outputs = []
|
||||
sub_queries = set()
|
||||
for memory in search_result.memories:
|
||||
sub_queries.add(str(memory.query))
|
||||
if user_input.search_switch in [SearchStrategy.DEEP, SearchStrategy.NORMAL]:
|
||||
intermediate_outputs.append({
|
||||
"type": "problem_split",
|
||||
"title": "问题拆分",
|
||||
"data": [
|
||||
{
|
||||
"id": f"Q{idx+1}",
|
||||
"question": question
|
||||
}
|
||||
for idx, question in enumerate(sub_queries)
|
||||
]
|
||||
})
|
||||
perceptual_data = [
|
||||
memory.data
|
||||
for memory in search_result.memories
|
||||
if memory.source == Neo4jNodeType.PERCEPTUAL
|
||||
]
|
||||
|
||||
intermediate_outputs.append({
|
||||
"type": "perceptual_retrieve",
|
||||
"title": "感知记忆检索",
|
||||
"data": perceptual_data,
|
||||
"total": len(perceptual_data),
|
||||
})
|
||||
intermediate_outputs.append({
|
||||
"type": "search_result",
|
||||
"title": f"合并检索结果 (共{len(sub_queries)}个查询,{len(search_result.memories)}条结果)",
|
||||
"result": search_result.content,
|
||||
"raw_result": search_result.memories,
|
||||
"total": len(search_result.memories),
|
||||
})
|
||||
result = {
|
||||
'answer': await memory_agent_service.generate_summary_from_retrieve(
|
||||
end_user_id=user_input.end_user_id,
|
||||
retrieve_info=retrieve_info,
|
||||
history=history,
|
||||
query=query,
|
||||
retrieve_info=search_result.content,
|
||||
history=[],
|
||||
query=user_input.message,
|
||||
config_id=config_id,
|
||||
db=db
|
||||
)
|
||||
if "信息不足,无法回答" in result['answer']:
|
||||
result['answer']=retrieve_info
|
||||
),
|
||||
"intermediate_outputs": intermediate_outputs
|
||||
}
|
||||
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -322,9 +402,10 @@ async def read_server(
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
async def file_update(
|
||||
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||
model_id:str = Form(..., description="模型ID"),
|
||||
model_id: str = Form(..., description="模型ID"),
|
||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
文件上传接口 - 支持图片识别
|
||||
@@ -337,9 +418,6 @@ async def file_update(
|
||||
Returns:
|
||||
文件处理结果
|
||||
"""
|
||||
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen)
|
||||
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
@@ -348,7 +426,7 @@ async def file_update(
|
||||
for file in files:
|
||||
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
|
||||
content = await file.read()
|
||||
|
||||
|
||||
if file.content_type and file.content_type.startswith("image/"):
|
||||
vision_model = QWenCV(
|
||||
key=apiConfig.api_key,
|
||||
@@ -362,12 +440,12 @@ async def file_update(
|
||||
else:
|
||||
api_logger.warning(f"Unsupported file type: {file.content_type}")
|
||||
file_content.append(f"[不支持的文件类型: {file.content_type}]")
|
||||
|
||||
|
||||
result_text = ';'.join(file_content)
|
||||
api_logger.info(f"File processing completed, result length: {len(result_text)}")
|
||||
|
||||
|
||||
return success(data=result_text, msg="转换文本成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
|
||||
@@ -417,8 +495,8 @@ async def read_server_async(
|
||||
|
||||
@router.get("/read_result/", response_model=ApiResponse)
|
||||
async def get_read_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async read task
|
||||
@@ -439,7 +517,7 @@ async def get_read_task_result(
|
||||
try:
|
||||
result = task_service.get_task_memory_read_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
@@ -457,7 +535,7 @@ async def get_read_task_result(
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="查询任务已完成")
|
||||
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
@@ -466,7 +544,7 @@ async def get_read_task_result(
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
|
||||
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
@@ -486,7 +564,7 @@ async def get_read_task_result(
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
@@ -494,8 +572,8 @@ async def get_read_task_result(
|
||||
|
||||
@router.get("/write_result/", response_model=ApiResponse)
|
||||
async def get_write_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async write task
|
||||
@@ -516,7 +594,7 @@ async def get_write_task_result(
|
||||
try:
|
||||
result = task_service.get_task_memory_write_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
@@ -534,7 +612,7 @@ async def get_write_task_result(
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="写入任务已完成")
|
||||
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
@@ -543,7 +621,7 @@ async def get_write_task_result(
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
|
||||
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
@@ -563,7 +641,7 @@ async def get_write_task_result(
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
@@ -571,9 +649,9 @@ async def get_write_task_result(
|
||||
|
||||
@router.post("/status_type", response_model=ApiResponse)
|
||||
async def status_type(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
@@ -616,26 +694,21 @@ async def status_type(
|
||||
|
||||
@router.get("/stats/types", response_model=ApiResponse)
|
||||
async def get_knowledge_type_stats_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||
会对缺失类型补 0,返回字典形式。
|
||||
可选按状态过滤。
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
|
||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
api_logger.info(
|
||||
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
try:
|
||||
from app.db import get_db
|
||||
|
||||
# 获取数据库会话
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 调用service层函数
|
||||
result = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
@@ -643,59 +716,70 @@ async def get_knowledge_type_stats_api(
|
||||
current_workspace_id=current_user.current_workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
return success(data=result, msg="获取知识库类型统计成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Knowledge type stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_by_user_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
limit: int = Query(20, description="返回标签数量限制"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session=Depends(get_db),
|
||||
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
||||
async def get_interest_distribution_by_user_api(
|
||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
获取指定用户的兴趣分布标签
|
||||
|
||||
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{"name": "标签名", "frequency": 频次},
|
||||
{"name": "兴趣活动名", "frequency": 频次},
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
|
||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||
language = get_language_from_header(language_type)
|
||||
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
|
||||
try:
|
||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||
# 优先读取缓存
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language_type=language_type,
|
||||
model_id=model_id,
|
||||
limit=limit
|
||||
language=language,
|
||||
)
|
||||
return success(data=result, msg="获取热门记忆标签成功")
|
||||
if cached is not None:
|
||||
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
|
||||
return success(data=cached, msg="获取兴趣分布标签成功")
|
||||
|
||||
# 缓存未命中,调用模型生成
|
||||
result = await memory_agent_service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 写入缓存,24小时过期
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
)
|
||||
|
||||
return success(data=result, msg="获取兴趣分布标签成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
|
||||
api_logger.error(f"Interest distribution by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||
async def get_user_profile_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -733,17 +817,17 @@ async def get_user_profile_api(
|
||||
# ):
|
||||
# """
|
||||
# Get parsed API documentation (Public endpoint - no authentication required)
|
||||
|
||||
|
||||
# Args:
|
||||
# file_path: Optional path to API docs file. If None, uses default path.
|
||||
|
||||
|
||||
# Returns:
|
||||
# Parsed API documentation including title, meta info, and sections
|
||||
# """
|
||||
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
|
||||
# try:
|
||||
# result = await memory_agent_service.get_api_docs(file_path)
|
||||
|
||||
|
||||
# if result.get("success"):
|
||||
# return success(msg=result["msg"], data=result["data"])
|
||||
# else:
|
||||
@@ -759,9 +843,9 @@ async def get_user_profile_api(
|
||||
|
||||
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
||||
async def get_end_user_connected_config(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
@@ -777,12 +861,9 @@ async def get_end_user_connected_config(
|
||||
Returns:
|
||||
包含 memory_config_id 和相关信息的响应
|
||||
"""
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config as get_config,
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
|
||||
try:
|
||||
result = get_config(end_user_id, db)
|
||||
return success(data=result, msg="获取终端用户关联配置成功")
|
||||
@@ -791,4 +872,4 @@ async def get_end_user_connected_config(
|
||||
return fail(BizCode.NOT_FOUND, str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from app.core.response_utils import success
|
||||
@@ -9,6 +12,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
|
||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -45,64 +49,64 @@ def get_workspace_total_end_users(
|
||||
|
||||
@router.get("/end_users", response_model=ApiResponse)
|
||||
async def get_workspace_end_users(
|
||||
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"),
|
||||
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"),
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
pagesize: int = Query(10, ge=1, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取工作空间的宿主列表(高性能优化版本 v2)
|
||||
|
||||
优化策略:
|
||||
1. 批量查询 end_users(一次查询而非循环)
|
||||
2. 并发查询所有用户的记忆数量(Neo4j)
|
||||
3. RAG 模式使用批量查询(一次 SQL)
|
||||
4. 只返回必要字段减少数据传输
|
||||
5. 添加短期缓存减少重复查询
|
||||
6. 并发执行配置查询和记忆数量查询
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"end_user": {"id": "uuid", "other_name": "名称"},
|
||||
"memory_num": {"total": 数量},
|
||||
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
||||
}
|
||||
获取工作空间的宿主列表(分页查询,支持模糊搜索)
|
||||
|
||||
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
|
||||
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含宿主列表和分页信息
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 尝试从缓存获取(30秒缓存)
|
||||
cache_key = f"end_users:workspace:{workspace_id}"
|
||||
try:
|
||||
cached_data = await aio_redis_get(cache_key)
|
||||
if cached_data:
|
||||
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
||||
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
||||
|
||||
# 如果未提供 workspace_id,使用当前用户的工作空间
|
||||
if workspace_id is None:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 获取当前空间类型
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
|
||||
# 获取 end_users(已优化为批量查询)
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
|
||||
|
||||
# 获取分页的 end_users
|
||||
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
current_user=current_user,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keyword=keyword
|
||||
)
|
||||
|
||||
end_users = end_users_result.get("items", [])
|
||||
total = end_users_result.get("total", 0)
|
||||
|
||||
if not end_users:
|
||||
api_logger.info("工作空间下没有宿主")
|
||||
# 缓存空结果,避免重复查询
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
return success(data=[], msg="宿主列表获取成功")
|
||||
|
||||
api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}")
|
||||
return success(data={
|
||||
"items": [],
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"hasnext": (page * pagesize) < total
|
||||
}
|
||||
}, msg="宿主列表获取成功")
|
||||
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
|
||||
|
||||
# 并发执行两个独立的查询任务
|
||||
async def get_memory_configs():
|
||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||
@@ -114,7 +118,7 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
return {}
|
||||
|
||||
|
||||
async def get_memory_nums():
|
||||
"""获取记忆数量"""
|
||||
if current_workspace_type == "rag":
|
||||
@@ -128,38 +132,45 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
|
||||
elif current_workspace_type == "neo4j":
|
||||
# Neo4j 模式:并发查询(带并发限制)
|
||||
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
||||
MAX_CONCURRENT_QUERIES = 10
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
||||
|
||||
async def get_neo4j_memory_num(end_user_id: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
return await memory_storage_service.search_all(end_user_id)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
||||
return {"total": 0}
|
||||
|
||||
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
||||
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
||||
|
||||
# Neo4j 模式:批量查询(简化版本,只返回total)
|
||||
try:
|
||||
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
||||
return {uid: {"total": count} for uid, count in batch_result.items()}
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
|
||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||
try:
|
||||
from app.celery_app import celery_app as _celery_app
|
||||
_celery_app.send_task(
|
||||
"app.tasks.init_implicit_emotions_for_users",
|
||||
kwargs={"end_user_ids": end_user_ids},
|
||||
)
|
||||
_celery_app.send_task(
|
||||
"app.tasks.init_interest_distribution_for_users",
|
||||
kwargs={"end_user_ids": end_user_ids},
|
||||
)
|
||||
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
|
||||
|
||||
# 并发执行配置查询和记忆数量查询
|
||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||
get_memory_configs(),
|
||||
get_memory_nums()
|
||||
)
|
||||
|
||||
# 构建结果(优化:使用列表推导式)
|
||||
result = []
|
||||
|
||||
# 构建结果列表
|
||||
items = []
|
||||
for end_user in end_users:
|
||||
user_id = str(end_user.id)
|
||||
config_info = memory_configs_map.get(user_id, {})
|
||||
result.append({
|
||||
items.append({
|
||||
'end_user': {
|
||||
'id': user_id,
|
||||
'other_name': end_user.other_name
|
||||
@@ -170,14 +181,27 @@ async def get_workspace_end_users(
|
||||
"memory_config_name": config_info.get("memory_config_name")
|
||||
}
|
||||
})
|
||||
|
||||
# 写入缓存(30秒过期)
|
||||
|
||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
from app.tasks import init_community_clustering_for_users
|
||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
|
||||
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||
|
||||
# 构建分页响应
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"hasnext": (page * pagesize) < total
|
||||
}
|
||||
}
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
|
||||
@@ -386,14 +410,15 @@ def get_current_user_rag_total_num(
|
||||
@router.get("/rag_content", response_model=ApiResponse)
|
||||
def get_rag_content(
|
||||
end_user_id: str = Query(..., description="宿主ID"),
|
||||
limit: int = Query(15, description="返回记录数"),
|
||||
page: int = Query(1, gt=0, description="页码,从1开始"),
|
||||
pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取当前宿主知识库中的chunk内容
|
||||
获取当前宿主知识库中的chunk内容(分页)
|
||||
"""
|
||||
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
|
||||
data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user)
|
||||
return success(data=data, msg="宿主RAGchunk数据获取成功")
|
||||
|
||||
|
||||
@@ -406,26 +431,18 @@ async def get_chunk_summary_tag(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk总结、提取的标签和人物形象
|
||||
|
||||
读取RAG摘要、标签和人物形象(纯读库,不触发生成)。
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"summary": "chunk内容的总结",
|
||||
"tags": [
|
||||
{"tag": "标签1", "frequency": 5},
|
||||
{"tag": "标签2", "frequency": 3},
|
||||
...
|
||||
],
|
||||
"personas": [
|
||||
"产品设计师",
|
||||
"旅行爱好者",
|
||||
"摄影发烧友",
|
||||
...
|
||||
]
|
||||
"summary": "用户摘要",
|
||||
"tags": [{"tag": "标签1", "frequency": 5}, ...],
|
||||
"personas": ["产品设计师", ...],
|
||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG摘要/标签/人物形象")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_summary_and_tags(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -433,9 +450,8 @@ async def get_chunk_summary_tag(
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
|
||||
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
|
||||
|
||||
return success(data=data, msg="获取成功")
|
||||
|
||||
|
||||
@router.get("/chunk_insight", response_model=ApiResponse)
|
||||
@@ -446,29 +462,64 @@ async def get_chunk_insight(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk的洞察内容
|
||||
|
||||
读取RAG洞察报告(纯读库,不触发生成)。
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"insight": "对chunk内容的深度洞察分析"
|
||||
"insight": "总体概述",
|
||||
"behavior_pattern": "行为模式",
|
||||
"key_findings": "关键发现",
|
||||
"growth_trajectory": "成长轨迹",
|
||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG洞察")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_insight(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info("成功获取chunk洞察")
|
||||
return success(data=data, msg="chunk洞察获取成功")
|
||||
|
||||
return success(data=data, msg="获取成功")
|
||||
|
||||
|
||||
class GenerateRagProfileRequest(BaseModel):
|
||||
end_user_id: str = Field(..., description="宿主ID")
|
||||
limit: int = Field(15, description="参与生成的chunk数量上限")
|
||||
max_tags: int = Field(10, description="最大标签数量")
|
||||
|
||||
|
||||
@router.post("/generate_rag_profile", response_model=ApiResponse)
|
||||
async def generate_rag_profile(
|
||||
body: GenerateRagProfileRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
生产接口:为RAG存储模式的宿主全量重新生成完整画像并持久化到end_user表。
|
||||
每次请求都会重新生成,覆盖已有数据。
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 触发RAG画像生产: end_user_id={body.end_user_id}")
|
||||
|
||||
data = await memory_dashboard_service.generate_rag_profile(
|
||||
end_user_id=body.end_user_id,
|
||||
limit=body.limit,
|
||||
max_tags=body.max_tags,
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
api_logger.info(f"RAG画像生产完成: {data}")
|
||||
return success(data=data, msg="RAG画像生产完成")
|
||||
|
||||
|
||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||
async def dashboard_data(
|
||||
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
|
||||
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -503,6 +554,15 @@ async def dashboard_data(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
|
||||
|
||||
# 如果没有提供时间范围,默认使用最近30天
|
||||
if start_date is None or end_date is None:
|
||||
from datetime import datetime, timedelta
|
||||
end_dt = datetime.now()
|
||||
start_dt = end_dt - timedelta(days=30)
|
||||
end_date = int(end_dt.timestamp() * 1000)
|
||||
start_date = int(start_dt.timestamp() * 1000)
|
||||
api_logger.info(f"使用默认时间范围: {start_dt} 到 {end_dt}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
@@ -531,7 +591,7 @@ async def dashboard_data(
|
||||
"total_api_call": None
|
||||
}
|
||||
|
||||
# 1. 获取记忆总量(total_memory)
|
||||
# 1. 获取记忆总量(total_memory)—— neo4j 独有逻辑:查询 neo4j 存储节点
|
||||
try:
|
||||
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
|
||||
db=db,
|
||||
@@ -540,41 +600,33 @@ async def dashboard_data(
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
from app.repositories import app_repository
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
neo4j_data["total_app"] = len(apps_orm)
|
||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||
|
||||
# 2. 获取知识库类型统计(total_knowledge)
|
||||
try:
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
memory_agent_service = MemoryAgentService()
|
||||
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
only_active=True,
|
||||
current_workspace_id=workspace_id,
|
||||
db=db
|
||||
)
|
||||
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
|
||||
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
||||
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
||||
neo4j_data.update(common_stats)
|
||||
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
||||
|
||||
# 3. 获取API调用增量(total_api_call,转换为整数)
|
||||
# 计算昨日对比
|
||||
try:
|
||||
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
storage_type=storage_type,
|
||||
today_data=neo4j_data
|
||||
)
|
||||
neo4j_data["total_api_call"] = api_increment
|
||||
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
|
||||
neo4j_data.update(changes)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||
|
||||
api_logger.warning(f"计算neo4j昨日对比失败: {str(e)}")
|
||||
neo4j_data.update({
|
||||
"total_memory_change": None,
|
||||
"total_app_change": None,
|
||||
"total_knowledge_change": None,
|
||||
"total_api_call_change": None,
|
||||
})
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
|
||||
@@ -587,28 +639,37 @@ async def dashboard_data(
|
||||
"total_api_call": None
|
||||
}
|
||||
|
||||
# 获取RAG相关数据
|
||||
# 1. 获取记忆总量(total_memory)—— rag 独有逻辑:查询 document 表的 chunk_num
|
||||
try:
|
||||
# total_memory: 使用 total_chunk(总chunk数)
|
||||
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
from app.repositories import app_repository
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
rag_data["total_app"] = len(apps_orm)
|
||||
|
||||
# total_knowledge: 使用 total_kb(总知识库数)
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
rag_data["total_knowledge"] = total_kb
|
||||
|
||||
# total_api_call: 固定值
|
||||
rag_data["total_api_call"] = 1024
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
|
||||
api_logger.info(f"成功获取RAG记忆总量: {total_chunk}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
api_logger.warning(f"获取RAG记忆总量失败: {str(e)}")
|
||||
|
||||
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
||||
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
||||
rag_data.update(common_stats)
|
||||
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
||||
|
||||
# 计算昨日对比
|
||||
try:
|
||||
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
storage_type=storage_type,
|
||||
today_data=rag_data
|
||||
)
|
||||
rag_data.update(changes)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"计算RAG昨日对比失败: {str(e)}")
|
||||
rag_data.update({
|
||||
"total_memory_change": None,
|
||||
"total_app_change": None,
|
||||
"total_knowledge_change": None,
|
||||
"total_api_call_change": None,
|
||||
})
|
||||
|
||||
result["rag_data"] = rag_data
|
||||
api_logger.info("成功获取rag_data")
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
包含情景记忆总览和详情查询接口
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.dependencies import get_current_user
|
||||
@@ -14,6 +15,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_episodic_schema import (
|
||||
EpisodicMemoryOverviewRequest,
|
||||
EpisodicMemoryDetailsRequest,
|
||||
translate_episodic_type,
|
||||
)
|
||||
from app.services.memory_episodic_service import memory_episodic_service
|
||||
|
||||
@@ -84,6 +86,7 @@ async def get_episodic_memory_overview_api(
|
||||
@router.post("/details", response_model=ApiResponse)
|
||||
async def get_episodic_memory_details_api(
|
||||
request: EpisodicMemoryDetailsRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -111,6 +114,11 @@ async def get_episodic_memory_details_api(
|
||||
summary_id=request.summary_id
|
||||
)
|
||||
|
||||
# 根据语言参数翻译 episodic_type
|
||||
language = get_language_from_header(language_type)
|
||||
if "episodic_type" in result:
|
||||
result["episodic_type"] = translate_episodic_type(result["episodic_type"], language)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
处理显性记忆相关的API接口,包括情景记忆和语义记忆的查询。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
@@ -69,6 +71,140 @@ async def get_explicit_memory_overview_api(
|
||||
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/episodics", response_model=ApiResponse)
|
||||
async def get_episodic_memory_list_api(
|
||||
end_user_id: str = Query(..., description="end user ID"),
|
||||
page: int = Query(1, gt=0, description="page number, starting from 1"),
|
||||
pagesize: int = Query(10, gt=0, le=100, description="number of items per page, max 100"),
|
||||
start_date: Optional[int] = Query(None, description="start timestamp (ms)"),
|
||||
end_date: Optional[int] = Query(None, description="end timestamp (ms)"),
|
||||
episodic_type: str = Query("all", description="episodic type :all/conversation/project_work/learning/decision/important_event"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取情景记忆分页列表
|
||||
|
||||
返回指定用户的情景记忆列表,支持分页、时间范围筛选和情景类型筛选。
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID(必填)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10,最大100)
|
||||
start_date: 开始时间戳(可选,毫秒),自动扩展到当天 00:00:00
|
||||
end_date: 结束时间戳(可选,毫秒),自动扩展到当天 23:59:59
|
||||
episodic_type: 情景类型筛选(可选,默认all)
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含情景记忆分页列表
|
||||
|
||||
Examples:
|
||||
- 基础分页查询:GET /episodics?end_user_id=xxx&page=1&pagesize=5
|
||||
返回第1页,每页5条数据
|
||||
- 按时间范围筛选:GET /episodics?end_user_id=xxx&page=1&pagesize=5&start_date=1738684800000&end_date=1738771199000
|
||||
返回指定时间范围内的数据
|
||||
- 按情景类型筛选:GET /episodics?end_user_id=xxx&page=1&pagesize=5&episodic_type=important_event
|
||||
返回类型为"重要事件"的数据
|
||||
|
||||
Notes:
|
||||
- start_date 和 end_date 必须同时提供或同时不提供
|
||||
- start_date 不能大于 end_date
|
||||
- episodic_type 可选值:all, conversation, project_work, learning, decision, important_event
|
||||
- total 为该用户情景记忆总数(不受筛选条件影响)
|
||||
- page.total 为筛选后的总条数
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆列表但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"情景记忆分页查询: end_user_id={end_user_id}, "
|
||||
f"start_date={start_date}, end_date={end_date}, episodic_type={episodic_type}, "
|
||||
f"page={page}, pagesize={pagesize}, username={current_user.username}"
|
||||
)
|
||||
|
||||
# 1. 参数校验
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"分页参数错误: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "分页参数必须大于0")
|
||||
|
||||
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
|
||||
if episodic_type not in valid_episodic_types:
|
||||
api_logger.warning(f"无效的情景类型参数: {episodic_type}")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
|
||||
|
||||
# 时间戳参数校验
|
||||
if (start_date is not None and end_date is None) or (end_date is not None and start_date is None):
|
||||
return fail(BizCode.INVALID_PARAMETER, "start_date和end_date必须同时提供")
|
||||
|
||||
if start_date is not None and end_date is not None and start_date > end_date:
|
||||
return fail(BizCode.INVALID_PARAMETER, "start_date不能大于end_date")
|
||||
|
||||
# 2. 执行查询
|
||||
try:
|
||||
result = await memory_explicit_service.get_episodic_memory_list(
|
||||
end_user_id=end_user_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
episodic_type=episodic_type,
|
||||
)
|
||||
api_logger.info(
|
||||
f"情景记忆分页查询成功: end_user_id={end_user_id}, "
|
||||
f"total={result['total']}, 返回={len(result['items'])}条"
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"情景记忆分页查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "情景记忆分页查询失败", str(e))
|
||||
|
||||
# 3. 返回结构化响应
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
@router.get("/semantics", response_model=ApiResponse)
|
||||
async def get_semantic_memory_list_api(
|
||||
end_user_id: str = Query(..., description="终端用户ID"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取语义记忆列表
|
||||
|
||||
返回指定用户的全量语义记忆列表。
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID(必填)
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含语义记忆全量列表
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询语义记忆列表但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"语义记忆列表查询: end_user_id={end_user_id}, username={current_user.username}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await memory_explicit_service.get_semantic_memory_list(
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
api_logger.info(
|
||||
f"语义记忆列表查询成功: end_user_id={end_user_id}, total={len(result)}"
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"语义记忆列表查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "语义记忆列表查询失败", str(e))
|
||||
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
|
||||
@router.post("/details", response_model=ApiResponse)
|
||||
async def get_explicit_memory_details_api(
|
||||
request: ExplicitMemoryDetailsRequest,
|
||||
|
||||
@@ -31,10 +31,11 @@ from app.schemas.memory_storage_schema import (
|
||||
ForgettingCurveRequest,
|
||||
ForgettingCurveResponse,
|
||||
ForgettingCurvePoint,
|
||||
PendingNodesResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -84,7 +85,8 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
config_id = resolve_config_id((config_id), db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
@@ -129,7 +131,7 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
async def read_forgetting_config(
|
||||
config_id: UUID,
|
||||
config_id: UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -158,6 +160,7 @@ async def read_forgetting_config(
|
||||
)
|
||||
|
||||
try:
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 调用服务层读取配置
|
||||
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
||||
|
||||
@@ -195,6 +198,8 @@ async def update_forgetting_config(
|
||||
ApiResponse: 包含更新结果的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id=resolve_config_id((payload.config_id), db)
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
@@ -255,12 +260,10 @@ async def get_forgetting_stats(
|
||||
ApiResponse: 包含统计信息的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 如果提供了 end_user_id,通过它获取 config_id
|
||||
config_id = None
|
||||
if end_user_id:
|
||||
@@ -269,6 +272,7 @@ async def get_forgetting_stats(
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
@@ -305,6 +309,100 @@ async def get_forgetting_stats(
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/pending-nodes", response_model=ApiResponse)
|
||||
async def get_pending_nodes(
|
||||
end_user_id: str,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取待遗忘节点列表(独立分页接口)
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||
此接口独立分页,与 /stats 接口分离。
|
||||
|
||||
Args:
|
||||
end_user_id: 组ID(即 end_user_id,必填)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含待遗忘节点列表和分页信息的响应
|
||||
|
||||
Examples:
|
||||
- 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
|
||||
- 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
|
||||
|
||||
Notes:
|
||||
- page 从1开始,pagesize 必须大于0
|
||||
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 验证 end_user_id 必填
|
||||
if not end_user_id:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
|
||||
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
|
||||
|
||||
# 通过 end_user_id 获取关联的 config_id
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
# 验证分页参数
|
||||
if page < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
|
||||
if pagesize < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
|
||||
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取待遗忘节点列表
|
||||
result = await forget_service.get_pending_nodes(
|
||||
db=db,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = PendingNodesResponse(**result)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
|
||||
|
||||
|
||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||
async def get_forgetting_curve(
|
||||
request: ForgettingCurveRequest,
|
||||
@@ -325,7 +423,7 @@ async def get_forgetting_curve(
|
||||
ApiResponse: 包含遗忘曲线数据的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
request.config_id = resolve_config_id((request.config_id), db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
||||
|
||||
@@ -1,8 +1,25 @@
|
||||
"""
|
||||
Memory Reflection Controller
|
||||
|
||||
This module provides REST API endpoints for managing memory reflection configurations
|
||||
and operations. It handles reflection engine setup, configuration management, and
|
||||
execution of self-reflection processes across memory systems.
|
||||
|
||||
Key Features:
|
||||
- Reflection configuration management (save, retrieve, update)
|
||||
- Workspace-wide reflection execution across multiple applications
|
||||
- Individual configuration-based reflection runs
|
||||
- Multi-language support for reflection outputs
|
||||
- Integration with Neo4j memory storage and LLM models
|
||||
- Comprehensive error handling and logging
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||
ReflectionConfig,
|
||||
@@ -25,9 +42,15 @@ from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory",
|
||||
tags=["Memory"],
|
||||
@@ -40,17 +63,49 @@ async def save_reflection_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
"""
|
||||
Save reflection configuration to memory config table
|
||||
|
||||
Persists reflection engine configuration settings to the data_config table,
|
||||
including reflection parameters, model settings, and evaluation criteria.
|
||||
Validates configuration parameters and ensures data consistency.
|
||||
|
||||
Args:
|
||||
request: Memory reflection configuration data including:
|
||||
- config_id: Configuration identifier to update
|
||||
- reflection_enabled: Whether reflection is enabled
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model for reflection operations
|
||||
- memory_verify: Enable memory verification checks
|
||||
- quality_assessment: Enable quality assessment evaluation
|
||||
current_user: Authenticated user saving the configuration
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with saved reflection configuration data
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If config_id is missing or parameters are invalid
|
||||
HTTPException 500: If configuration save operation fails
|
||||
|
||||
Database Operations:
|
||||
- Updates memory_config table with reflection settings
|
||||
- Commits transaction and refreshes entity
|
||||
- Maintains configuration consistency
|
||||
"""
|
||||
try:
|
||||
config_id = request.config_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
if not config_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="缺少必需参数: config_id"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
# Update reflection configuration in database
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
@@ -63,6 +118,7 @@ async def save_reflection_config(
|
||||
quality_assessment=request.quality_assessment
|
||||
)
|
||||
|
||||
# Commit transaction and refresh entity
|
||||
db.commit()
|
||||
db.refresh(memory_config)
|
||||
|
||||
@@ -99,51 +155,114 @@ async def start_workspace_reflection(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
"""
|
||||
Start reflection functionality for all matching applications in workspace
|
||||
|
||||
Initiates reflection processes across all applications within the user's current
|
||||
workspace that have valid memory configurations. Processes each application's
|
||||
configurations and associated end users, executing reflection operations
|
||||
with proper error isolation and transaction management.
|
||||
|
||||
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
|
||||
that reflection failures for individual users don't affect other operations.
|
||||
|
||||
Args:
|
||||
current_user: Authenticated user initiating workspace reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection results for all processed applications:
|
||||
- app_id: Application identifier
|
||||
- config_id: Memory configuration identifier
|
||||
- end_user_id: End user identifier
|
||||
- reflection_result: Individual reflection operation result
|
||||
|
||||
Processing Logic:
|
||||
1. Retrieve all applications in the current workspace
|
||||
2. Filter applications with valid memory configurations
|
||||
3. For each configuration, find matching releases
|
||||
4. Execute reflection for each end user with isolated transactions
|
||||
5. Aggregate results with error handling per user
|
||||
|
||||
Error Handling:
|
||||
- Individual user reflection failures are isolated
|
||||
- Failed operations are logged and included in results
|
||||
- Database transactions are isolated per user to prevent cascading failures
|
||||
- Comprehensive error reporting for debugging
|
||||
|
||||
Raises:
|
||||
HTTPException 500: If workspace reflection initialization fails
|
||||
|
||||
Performance Notes:
|
||||
- Uses independent database sessions for each user operation
|
||||
- Prevents transaction failures from affecting other users
|
||||
- Comprehensive logging for operation tracking
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
# Use independent database session to get workspace app details, avoiding transaction failures
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as query_db:
|
||||
service = WorkspaceAppService(query_db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
|
||||
reflection_results = []
|
||||
|
||||
# Process each application in the workspace
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['memory_configs'] == []:
|
||||
# Skip applications without configurations
|
||||
if not data['memory_configs']:
|
||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||
continue
|
||||
|
||||
|
||||
releases = data['releases']
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, memory_configs, end_users):
|
||||
# 安全地转换为整数,处理空字符串和None的情况
|
||||
print(base['config'])
|
||||
try:
|
||||
base_config = int(base['config']) if base['config'] else 0
|
||||
config_id = int(config['config_id']) if config['config_id'] else 0
|
||||
except (ValueError, TypeError):
|
||||
api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}")
|
||||
|
||||
# Execute reflection for each configuration and user combination
|
||||
for config in memory_configs:
|
||||
config_id_str = str(config['config_id'])
|
||||
|
||||
# Find all releases matching this configuration
|
||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||
|
||||
if not matching_releases:
|
||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||
continue
|
||||
|
||||
if base_config == config_id and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
# Execute reflection for each user - using independent database sessions
|
||||
for user in end_users:
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||
|
||||
# Create independent database session for each user to avoid transaction failure impact
|
||||
with get_db_context() as user_db:
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(user_db)
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": {
|
||||
"status": "错误",
|
||||
"message": f"反思失败: {str(e)}"
|
||||
}
|
||||
})
|
||||
|
||||
return success(data=reflection_results, msg="反思配置成功")
|
||||
|
||||
@@ -157,17 +276,57 @@ async def start_workspace_reflection(
|
||||
|
||||
@router.get("/reflection/configs")
|
||||
async def start_reflection_configs(
|
||||
config_id: uuid.UUID,
|
||||
config_id: uuid.UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
"""
|
||||
Query reflection configuration information by config_id
|
||||
|
||||
Retrieves detailed reflection configuration settings from the memory_config
|
||||
table for a specific configuration ID. Provides comprehensive reflection
|
||||
parameters including model settings, evaluation criteria, and operational flags.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) to query
|
||||
current_user: Authenticated user making the request
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with detailed reflection configuration:
|
||||
- config_id: Resolved configuration identifier
|
||||
- reflection_enabled: Whether reflection is enabled for this config
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection operations (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model identifier for reflection
|
||||
- memory_verify: Memory verification flag
|
||||
- quality_assessment: Quality assessment flag
|
||||
|
||||
Database Operations:
|
||||
- Queries memory_config table by resolved config_id
|
||||
- Retrieves all reflection-related configuration fields
|
||||
- Resolves configuration ID for consistent formatting
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration with specified ID is not found
|
||||
HTTPException 500: If configuration query operation fails
|
||||
|
||||
ID Resolution:
|
||||
- Supports both UUID and integer config_id formats
|
||||
- Automatically resolves to appropriate internal format
|
||||
- Maintains consistency across different ID representations
|
||||
"""
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
try:
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
# 构建返回数据
|
||||
memory_config_id = resolve_config_id(result.config_id, db)
|
||||
|
||||
# Build response data with comprehensive configuration details
|
||||
reflection_config = {
|
||||
"config_id": result.config_id,
|
||||
"config_id": memory_config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
"reflection_period_in_hours": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
@@ -178,10 +337,12 @@ async def start_reflection_configs(
|
||||
}
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="反思配置查询成功")
|
||||
|
||||
|
||||
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="Reflection configuration query successful")
|
||||
|
||||
except HTTPException:
|
||||
# 重新抛出HTTP异常
|
||||
# Re-raise HTTP exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"查询反思配置失败: {str(e)}")
|
||||
@@ -192,16 +353,71 @@ async def start_reflection_configs(
|
||||
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: UUID,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
config_id: UUID|int,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
"""
|
||||
Execute reflection engine with specified configuration
|
||||
|
||||
Runs the reflection engine using configuration parameters from the database.
|
||||
Validates model availability, sets up the reflection engine with proper
|
||||
configuration, and executes the reflection process with multi-language support.
|
||||
|
||||
This endpoint provides a test run capability for reflection configurations,
|
||||
allowing users to validate their reflection settings and see results before
|
||||
deploying to production environments.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) for reflection settings
|
||||
language_type: Language preference header for output localization (optional)
|
||||
current_user: Authenticated user executing the reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection execution results including:
|
||||
- baseline: Reflection strategy used
|
||||
- source_data: Input data processed
|
||||
- memory_verifies: Memory verification results (if enabled)
|
||||
- quality_assessments: Quality assessment results (if enabled)
|
||||
- reflexion_data: Generated reflection insights and solutions
|
||||
|
||||
Configuration Validation:
|
||||
- Verifies configuration exists in database
|
||||
- Validates LLM model availability
|
||||
- Falls back to default model if specified model is unavailable
|
||||
- Ensures all required parameters are properly set
|
||||
|
||||
Reflection Engine Setup:
|
||||
- Creates ReflectionConfig with database parameters
|
||||
- Initializes Neo4j connector for memory access
|
||||
- Sets up ReflectionEngine with validated model
|
||||
- Configures language preferences for output
|
||||
|
||||
Error Handling:
|
||||
- Model validation with fallback to default
|
||||
- Configuration validation and error reporting
|
||||
- Comprehensive logging for debugging
|
||||
- Graceful handling of missing configurations
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration is not found
|
||||
HTTPException 500: If reflection execution fails
|
||||
|
||||
Performance Notes:
|
||||
- Direct database query for configuration retrieval
|
||||
- Model validation to prevent runtime failures
|
||||
- Efficient reflection engine initialization
|
||||
- Language-aware output processing
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
# Query reflection configuration using MemoryConfigRepository
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
@@ -211,7 +427,7 @@ async def reflection_run(
|
||||
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 验证模型ID是否存在
|
||||
# Validate model ID existence
|
||||
model_id = result.reflection_model_id
|
||||
if model_id:
|
||||
try:
|
||||
@@ -222,6 +438,7 @@ async def reflection_run(
|
||||
# 可以设置为None,让反思引擎使用默认模型
|
||||
model_id = None
|
||||
|
||||
# Create reflection configuration with database parameters
|
||||
config = ReflectionConfig(
|
||||
enabled=result.enable_self_reflexion,
|
||||
iteration_period=result.iteration_period,
|
||||
@@ -234,11 +451,13 @@ async def reflection_run(
|
||||
model_id=model_id,
|
||||
language_type=language_type
|
||||
)
|
||||
|
||||
# Initialize Neo4j connector and reflection engine
|
||||
connector = Neo4jConnector()
|
||||
engine = ReflectionEngine(
|
||||
config=config,
|
||||
neo4j_connector=connector,
|
||||
llm_client=model_id # 传入验证后的 model_id
|
||||
llm_client=model_id # Pass validated model_id
|
||||
)
|
||||
|
||||
result=await (engine.reflection_run())
|
||||
|
||||
@@ -1,18 +1,40 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
"""
|
||||
Memory Short Term Controller
|
||||
|
||||
This module provides REST API endpoints for managing short-term and long-term memory
|
||||
data retrieval and analysis. It handles memory system statistics, data aggregation,
|
||||
and provides comprehensive memory insights for end users.
|
||||
|
||||
Key Features:
|
||||
- Short-term memory data retrieval and statistics
|
||||
- Long-term memory data aggregation
|
||||
- Entity count integration
|
||||
- Multi-language response support
|
||||
- Memory system analytics and reporting
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
|
||||
from app.services.memory_short_service import LongService, ShortService
|
||||
from app.services.memory_storage_service import search_entity
|
||||
from app.services.memory_short_service import ShortService,LongService
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory/short",
|
||||
tags=["Memory"],
|
||||
@@ -20,25 +42,77 @@ router = APIRouter(
|
||||
@router.get("/short_term")
|
||||
async def short_term_configs(
|
||||
end_user_id: str,
|
||||
language_type:str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type:str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id)
|
||||
short_result=short_term.get_short_databasets()
|
||||
short_count=short_term.get_short_count()
|
||||
"""
|
||||
Retrieve comprehensive short-term and long-term memory statistics
|
||||
|
||||
Provides a comprehensive overview of memory system data for a specific end user,
|
||||
including short-term memory entries, long-term memory aggregations, entity counts,
|
||||
and retrieval statistics. Supports multi-language responses based on request headers.
|
||||
|
||||
This endpoint serves as a central dashboard for memory system analytics, combining
|
||||
data from multiple memory subsystems to provide a holistic view of user memory state.
|
||||
|
||||
Args:
|
||||
end_user_id: Unique identifier for the end user whose memory data to retrieve
|
||||
language_type: Language preference header for response localization (optional)
|
||||
current_user: Authenticated user making the request (injected by dependency)
|
||||
db: Database session for data operations (injected by dependency)
|
||||
|
||||
Returns:
|
||||
dict: Success response containing comprehensive memory statistics:
|
||||
- short_term: List of short-term memory entries with detailed data
|
||||
- long_term: List of long-term memory aggregations and summaries
|
||||
- entity: Count of entities associated with the end user
|
||||
- retrieval_number: Total count of short-term memory retrievals
|
||||
- long_term_number: Total count of long-term memory entries
|
||||
|
||||
Response Structure:
|
||||
{
|
||||
"code": 200,
|
||||
"msg": "Short-term memory system data retrieved successfully",
|
||||
"data": {
|
||||
"short_term": [...], # Short-term memory entries
|
||||
"long_term": [...], # Long-term memory data
|
||||
"entity": 42, # Entity count
|
||||
"retrieval_number": 156, # Short-term retrieval count
|
||||
"long_term_number": 23 # Long-term memory count
|
||||
}
|
||||
}
|
||||
|
||||
Raises:
|
||||
HTTPException: If end_user_id is invalid or data retrieval fails
|
||||
|
||||
Performance Notes:
|
||||
- Combines multiple service calls for comprehensive data
|
||||
- Entity search is performed asynchronously for better performance
|
||||
- Response time depends on memory data volume for the specified user
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# Retrieve short-term memory data and statistics
|
||||
short_term = ShortService(end_user_id, db)
|
||||
short_result = short_term.get_short_databasets() # Get short-term memory entries
|
||||
short_count = short_term.get_short_count() # Get short-term retrieval count
|
||||
|
||||
long_term=LongService(end_user_id)
|
||||
long_result=long_term.get_long_databasets()
|
||||
# Retrieve long-term memory data and aggregations
|
||||
long_term = LongService(end_user_id, db)
|
||||
long_result = long_term.get_long_databasets() # Get long-term memory entries
|
||||
|
||||
# Get entity count for the specified end user
|
||||
entity_result = await search_entity(end_user_id)
|
||||
|
||||
# Compile comprehensive memory statistics response
|
||||
result = {
|
||||
'short_term': short_result,
|
||||
'long_term': long_result,
|
||||
'entity': entity_result.get('num', 0),
|
||||
"retrieval_number":short_count,
|
||||
"long_term_number":len(long_result)
|
||||
'short_term': short_result, # Short-term memory entries
|
||||
'long_term': long_result, # Long-term memory data
|
||||
'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
|
||||
"retrieval_number": short_count, # Short-term retrieval statistics
|
||||
"long_term_number": len(long_result) # Long-term memory entry count
|
||||
}
|
||||
|
||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||
@@ -1,8 +1,12 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
@@ -11,7 +15,6 @@ from app.models.user_model import User
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
@@ -23,7 +26,7 @@ from app.services.memory_storage_service import (
|
||||
analytics_hot_memory_tags,
|
||||
analytics_recent_activity_stats,
|
||||
kb_type_distribution,
|
||||
search_all,
|
||||
search_all_batch,
|
||||
search_chunk,
|
||||
search_detials,
|
||||
search_dialogue,
|
||||
@@ -31,10 +34,13 @@ from app.services.memory_storage_service import (
|
||||
search_entity,
|
||||
search_statement,
|
||||
)
|
||||
from fastapi import APIRouter, Depends
|
||||
from app.core.quota_stub import check_memory_engine_quota
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -49,8 +55,8 @@ router = APIRouter(
|
||||
|
||||
@router.get("/info", response_model=ApiResponse)
|
||||
async def get_storage_info(
|
||||
storage_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
storage_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Example wrapper endpoint - retrieves storage information
|
||||
@@ -70,83 +76,20 @@ async def get_storage_info(
|
||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||
|
||||
|
||||
# --- DB connection dependency ---
|
||||
_CONN: Optional[object] = None
|
||||
|
||||
|
||||
"""PostgreSQL 连接生成与管理(使用 psycopg2)。"""
|
||||
# 这个可以转移,可能是已经有的
|
||||
# PostgreSQL 数据库连接
|
||||
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
|
||||
host = os.getenv("DB_HOST")
|
||||
user = os.getenv("DB_USER")
|
||||
password = os.getenv("DB_PASSWORD")
|
||||
database = os.getenv("DB_NAME")
|
||||
port_str = os.getenv("DB_PORT")
|
||||
try:
|
||||
import psycopg2 # type: ignore
|
||||
port = int(port_str) if port_str else 5432
|
||||
conn = psycopg2.connect(
|
||||
host=host or "localhost",
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
dbname=database,
|
||||
)
|
||||
# 设置自动提交,避免显式事务管理
|
||||
conn.autocommit = True
|
||||
# 设置会话时区为中国标准时间(Asia/Shanghai),便于直接以本地时区展示
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
|
||||
cur.close()
|
||||
except Exception:
|
||||
# 时区设置失败不影响连接,仅记录但不抛出
|
||||
pass
|
||||
return conn
|
||||
except Exception as e:
|
||||
try:
|
||||
print(f"[PostgreSQL] 连接失败: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
|
||||
global _CONN
|
||||
if _CONN is None:
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN
|
||||
|
||||
|
||||
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
|
||||
"""Close and recreate the global DB connection."""
|
||||
global _CONN
|
||||
try:
|
||||
if _CONN:
|
||||
try:
|
||||
_CONN.close()
|
||||
except Exception:
|
||||
pass
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN is not None
|
||||
except Exception:
|
||||
_CONN = None
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@check_memory_engine_quota
|
||||
def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}")
|
||||
try:
|
||||
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
|
||||
@@ -154,46 +97,130 @@ def create_config(
|
||||
svc = DataConfigService(db)
|
||||
result = svc.create(payload)
|
||||
return success(data=result, msg="创建成功")
|
||||
except ValueError as e:
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
|
||||
config_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {err_str}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||
except Exception as e:
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
|
||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||
|
||||
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id: UUID | int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除记忆配置(带终端用户保护)
|
||||
|
||||
- 检查是否为默认配置,默认配置不允许删除
|
||||
- 检查是否有终端用户连接到该配置
|
||||
- 如果有连接且 force=False,返回警告
|
||||
- 如果 force=True,清除终端用户引用后删除配置
|
||||
|
||||
Query Parameters:
|
||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
|
||||
f"config_id={config_id}, force={force}"
|
||||
)
|
||||
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = svc.delete(ConfigParamsDelete(config_id=config_id))
|
||||
return success(data=result, msg="删除成功")
|
||||
# 使用带保护的删除服务
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
result = config_service.delete_config(config_id=config_id, force=force)
|
||||
|
||||
if result["status"] == "error":
|
||||
api_logger.warning(
|
||||
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.FORBIDDEN,
|
||||
msg=result["message"],
|
||||
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
|
||||
)
|
||||
|
||||
if result["status"] == "warning":
|
||||
api_logger.warning(
|
||||
f"记忆配置正在使用,无法删除: config_id={config_id}, "
|
||||
f"connected_count={result['connected_count']}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.RESOURCE_IN_USE,
|
||||
msg=result["message"],
|
||||
data={
|
||||
"connected_count": result["connected_count"],
|
||||
"force_required": result["force_required"]
|
||||
}
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"记忆配置删除成功: config_id={config_id}, "
|
||||
f"affected_users={result['affected_users']}"
|
||||
)
|
||||
return success(
|
||||
msg=result["message"],
|
||||
data={"affected_users": result["affected_users"]}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Delete config failed: {str(e)}")
|
||||
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
# 校验至少有一个字段需要更新
|
||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
|
||||
"config_name, config_desc, scene_id 均为空")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -206,17 +233,17 @@ def update_config(
|
||||
|
||||
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
||||
def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -231,19 +258,19 @@ def update_config_extracted(
|
||||
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
||||
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
config_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -253,18 +280,19 @@ def read_config_extracted(
|
||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -278,17 +306,23 @@ def read_all_config(
|
||||
|
||||
@router.post("/pilot_run", response_model=None)
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}, "
|
||||
f"custom_text_length={len(payload.custom_text) if payload.custom_text else 0}"
|
||||
)
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
svc = DataConfigService(db)
|
||||
return StreamingResponse(
|
||||
svc.pilot_run_stream(payload),
|
||||
svc.pilot_run_stream(payload, language=language),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
@@ -297,15 +331,14 @@ async def pilot_run(
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
||||
"""
|
||||
|
||||
# ==================== Search & Analytics ====================
|
||||
|
||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||
async def get_kb_type_distribution(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await kb_type_distribution(end_user_id)
|
||||
@@ -314,12 +347,12 @@ async def get_kb_type_distribution(
|
||||
api_logger.error(f"KB type distribution failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
@router.get("/search/dialogue", response_model=ApiResponse)
|
||||
async def search_dialogues_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_dialogue(end_user_id)
|
||||
@@ -331,9 +364,9 @@ async def search_dialogues_num(
|
||||
|
||||
@router.get("/search/chunk", response_model=ApiResponse)
|
||||
async def search_chunks_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_chunk(end_user_id)
|
||||
@@ -345,9 +378,9 @@ async def search_chunks_num(
|
||||
|
||||
@router.get("/search/statement", response_model=ApiResponse)
|
||||
async def search_statements_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_statement(end_user_id)
|
||||
@@ -359,9 +392,9 @@ async def search_statements_num(
|
||||
|
||||
@router.get("/search/entity", response_model=ApiResponse)
|
||||
async def search_entities_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_entity(end_user_id)
|
||||
@@ -373,12 +406,15 @@ async def search_entities_num(
|
||||
|
||||
@router.get("/search", response_model=ApiResponse)
|
||||
async def search_all_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_all(end_user_id)
|
||||
if not end_user_id:
|
||||
return success(data={"total": 0}, msg="查询成功")
|
||||
batch_result = await search_all_batch([end_user_id])
|
||||
result = {"total": batch_result.get(end_user_id, 0)}
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search all failed: {str(e)}")
|
||||
@@ -387,9 +423,9 @@ async def search_all_num(
|
||||
|
||||
@router.get("/search/detials", response_model=ApiResponse)
|
||||
async def search_entities_detials(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_detials(end_user_id)
|
||||
@@ -401,9 +437,9 @@ async def search_entities_detials(
|
||||
|
||||
@router.get("/search/edges", response_model=ApiResponse)
|
||||
async def search_entity_edges(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_edges(end_user_id)
|
||||
@@ -413,14 +449,12 @@ async def search_entity_edges(
|
||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_api(
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取热门记忆标签(带Redis缓存)
|
||||
|
||||
@@ -431,17 +465,18 @@ async def get_hot_memory_tags_api(
|
||||
- 缓存未命中:~600-800ms(取决于LLM速度)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 构建缓存键
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
|
||||
|
||||
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
|
||||
|
||||
|
||||
try:
|
||||
# 尝试从Redis缓存获取
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
import json
|
||||
|
||||
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
cached_result = await aio_redis_get(cache_key)
|
||||
if cached_result:
|
||||
api_logger.info(f"Cache hit for key: {cache_key}")
|
||||
@@ -450,11 +485,11 @@ async def get_hot_memory_tags_api(
|
||||
return success(data=data, msg="查询成功(缓存)")
|
||||
except json.JSONDecodeError:
|
||||
api_logger.warning(f"Failed to parse cached data, will refresh")
|
||||
|
||||
|
||||
# 缓存未命中,执行查询
|
||||
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
|
||||
result = await analytics_hot_memory_tags(db, current_user, limit)
|
||||
|
||||
|
||||
# 写入缓存(过期时间:5分钟)
|
||||
# 注意:result是列表,需要转换为JSON字符串
|
||||
try:
|
||||
@@ -464,9 +499,9 @@ async def get_hot_memory_tags_api(
|
||||
except Exception as cache_error:
|
||||
# 缓存写入失败不影响主流程
|
||||
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
|
||||
|
||||
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||
@@ -474,8 +509,8 @@ async def get_hot_memory_tags_api(
|
||||
|
||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||
async def clear_hot_memory_tags_cache(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
清除热门标签缓存
|
||||
|
||||
@@ -485,12 +520,12 @@ async def clear_hot_memory_tags_cache(
|
||||
- 数据更新后立即生效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
from app.aioRedis import aio_redis_delete
|
||||
|
||||
|
||||
# 清除所有limit的缓存(常见的limit值)
|
||||
cleared_count = 0
|
||||
for limit in [5, 10, 15, 20, 30, 50]:
|
||||
@@ -499,12 +534,12 @@ async def clear_hot_memory_tags_cache(
|
||||
if result:
|
||||
cleared_count += 1
|
||||
api_logger.info(f"Cleared cache for key: {cache_key}")
|
||||
|
||||
|
||||
return success(
|
||||
data={"cleared_count": cleared_count},
|
||||
data={"cleared_count": cleared_count},
|
||||
msg=f"成功清除 {cleared_count} 个缓存"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Clear cache failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
|
||||
@@ -512,13 +547,13 @@ async def clear_hot_memory_tags_cache(
|
||||
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info("Recent activity stats requested")
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await analytics_recent_activity_stats()
|
||||
result = await analytics_recent_activity_stats(workspace_id=workspace_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas import conversation_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
@@ -32,35 +33,47 @@ def get_memory_count(
|
||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||
def get_conversations(
|
||||
end_user_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
pagesize: int = 20,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve all conversations for the current user in a specific group.
|
||||
Retrieve conversations for the current user in a specific group with pagination.
|
||||
|
||||
Args:
|
||||
end_user_id (UUID): The group identifier.
|
||||
page (int): Page number (1-based). Defaults to 1.
|
||||
pagesize (int): Number of items per page. Defaults to 20.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains a list of conversation IDs.
|
||||
|
||||
Notes:
|
||||
- Initializes the ConversationService with the current DB session.
|
||||
- Returns only conversation IDs for lightweight response.
|
||||
- Logs can be added to trace requests in production.
|
||||
ApiResponse: Contains a paginated list of conversations.
|
||||
"""
|
||||
page = max(1, page)
|
||||
page_size = max(1, min(pagesize, 100)) # Limit page size between 1 and 100
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
end_user_id
|
||||
conversations, total = conversation_service.get_user_conversations(
|
||||
end_user_id,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
return success(data=[
|
||||
{
|
||||
"id": conversation.id,
|
||||
"title": conversation.title
|
||||
} for conversation in conversations
|
||||
], msg="get conversations success")
|
||||
return success(data={
|
||||
"items": [
|
||||
{
|
||||
"id": conversation.id,
|
||||
"title": conversation.title
|
||||
} for conversation in conversations
|
||||
],
|
||||
"total": total,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": page_size,
|
||||
"total": total,
|
||||
"hasnext": (page * page_size) < total
|
||||
},
|
||||
}, msg="get conversations success")
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||
@@ -90,11 +103,7 @@ def get_messages(
|
||||
conversation_id,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
"created_at": int(message.created_at.timestamp() * 1000),
|
||||
}
|
||||
conversation_schema.Message.model_validate(message)
|
||||
for message in messages_obj
|
||||
]
|
||||
return success(data=messages, msg="get conversation history success")
|
||||
|
||||
@@ -7,7 +7,7 @@ from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
|
||||
from app.models.user_model import User
|
||||
from app.repositories.model_repository import ModelConfigRepository
|
||||
from app.schemas import model_schema
|
||||
@@ -15,6 +15,7 @@ from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.quota_stub import check_model_quota, check_model_activation_quota
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -31,12 +32,18 @@ def get_model_types():
|
||||
|
||||
@router.get("/provider", response_model=ApiResponse)
|
||||
def get_model_providers():
|
||||
return success(msg="获取模型提供商成功", data=list(ModelProvider))
|
||||
providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||
return success(msg="获取模型提供商成功", data=providers)
|
||||
|
||||
@router.get("/strategy", response_model=ApiResponse)
|
||||
def get_model_strategies():
|
||||
return success(msg="获取模型策略成功", data=list(LoadBalanceStrategy))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
@@ -69,10 +76,21 @@ def get_model_list(
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
capability_list = []
|
||||
if capability is not None:
|
||||
flat_capability = []
|
||||
for item in capability:
|
||||
split_items = [c.strip() for c in item.split(', ') if c.strip()]
|
||||
flat_capability.extend(split_items)
|
||||
|
||||
unique_flat_capability = list(dict.fromkeys(flat_capability))
|
||||
capability_list = unique_flat_capability
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
capability=capability_list,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
@@ -91,7 +109,7 @@ def get_model_list(
|
||||
|
||||
|
||||
@router.get("/new", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
def get_model_list_new(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
@@ -147,7 +165,7 @@ def get_model_plaza_list(
|
||||
type: Optional[ModelType] = Query(None, description="模型类型"),
|
||||
provider: Optional[ModelProvider] = Query(None, description="供应商"),
|
||||
is_official: Optional[bool] = Query(None, description="是否官方模型"),
|
||||
is_deprecated: Optional[bool] = Query(False, description="是否弃用"),
|
||||
is_deprecated: Optional[bool] = Query(None, description="是否弃用"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -198,6 +216,10 @@ def update_model_base(
|
||||
):
|
||||
"""更新基础模型"""
|
||||
|
||||
# 不允许更改type类型
|
||||
if data.type is not None or data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
|
||||
|
||||
@@ -282,6 +304,7 @@ async def create_model(
|
||||
|
||||
|
||||
@router.post("/composite", response_model=ApiResponse)
|
||||
@check_model_quota
|
||||
async def create_composite_model(
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -308,6 +331,7 @@ async def create_composite_model(
|
||||
|
||||
|
||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||
@check_model_activation_quota
|
||||
async def update_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
@@ -318,6 +342,8 @@ async def update_composite_model(
|
||||
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
if model_data.type is not None:
|
||||
raise BusinessException("不允许更改模型类型", BizCode.INVALID_PARAMETER)
|
||||
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
||||
|
||||
@@ -357,6 +383,14 @@ def update_model(
|
||||
更新模型配置
|
||||
"""
|
||||
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
if model_data.type is not None or model_data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
if model_data.is_active:
|
||||
active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active)
|
||||
if not active_keys:
|
||||
raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||
@@ -455,13 +489,17 @@ async def create_model_api_key_by_provider(
|
||||
config=api_key_data.config,
|
||||
is_active=api_key_data.is_active,
|
||||
priority=api_key_data.priority,
|
||||
model_config_ids=model_config_ids
|
||||
model_config_ids=model_config_ids,
|
||||
capability=api_key_data.capability,
|
||||
is_omni=api_key_data.is_omni
|
||||
)
|
||||
created_keys = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||
|
||||
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
|
||||
result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
|
||||
return success(data=result_list, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
|
||||
# result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
|
||||
result = "API Key已存在" if len(created_keys) == 0 and len(failed_models) == 0 else \
|
||||
f"成功为 {len(created_keys)} 个模型创建API Key, 失败模型列表{failed_models}"
|
||||
return success(data=result, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建API Key失败: {str(e)}")
|
||||
raise
|
||||
|
||||
1188
api/app/controllers/ontology_controller.py
Normal file
1188
api/app/controllers/ontology_controller.py
Normal file
File diff suppressed because it is too large
Load Diff
663
api/app/controllers/ontology_secondary_routes.py
Normal file
663
api/app/controllers/ontology_secondary_routes.py
Normal file
@@ -0,0 +1,663 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体场景和类型路由(续)
|
||||
|
||||
由于主Controller文件较大,将剩余路由放在此文件中。
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, Header
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger, get_business_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.ontology_schemas import (
|
||||
SceneResponse,
|
||||
SceneListResponse,
|
||||
PaginationInfo,
|
||||
ClassCreateRequest,
|
||||
ClassUpdateRequest,
|
||||
ClassResponse,
|
||||
ClassListResponse,
|
||||
ClassBatchCreateResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.ontology_service import OntologyService
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
|
||||
api_logger = get_api_logger()
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
||||
"""获取OntologyService实例(不需要LLM)
|
||||
|
||||
场景和类型管理不需要LLM,创建一个dummy配置。
|
||||
"""
|
||||
dummy_config = RedBearModelConfig(
|
||||
model_name="dummy",
|
||||
provider="openai",
|
||||
api_key="dummy",
|
||||
base_url="https://api.openai.com/v1"
|
||||
)
|
||||
llm_client = OpenAIClient(model_config=dummy_config)
|
||||
return OntologyService(llm_client=llm_client, db=db)
|
||||
|
||||
|
||||
# 这些函数将被导入到主Controller中
|
||||
|
||||
async def scenes_handler(
|
||||
workspace_id: Optional[str] = None,
|
||||
scene_name: Optional[str] = None,
|
||||
page: Optional[int] = None,
|
||||
pagesize: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
|
||||
|
||||
当提供 scene_name 参数时,进行模糊搜索(不分页);
|
||||
当不提供 scene_name 参数时,返回所有场景(支持分页)。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||
pagesize: 每页数量(可选,仅在全量查询时有效)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if scene_name else "list"
|
||||
api_logger.info(
|
||||
f"Scene {operation} requested by user {current_user.id}, "
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 确定工作空间ID
|
||||
if workspace_id:
|
||||
try:
|
||||
ws_uuid = UUID(workspace_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
|
||||
else:
|
||||
ws_uuid = current_user.current_workspace_id
|
||||
if not ws_uuid:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 根据是否提供 scene_name 决定查询方式
|
||||
if scene_name and scene_name.strip():
|
||||
# 验证分页参数(模糊搜索也支持分页)
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
# 模糊搜索场景(支持分页)
|
||||
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
|
||||
total = len(scenes)
|
||||
|
||||
# 如果提供了分页参数,进行分页处理
|
||||
if page is not None and pagesize is not None:
|
||||
start_idx = (page - 1) * pagesize
|
||||
end_idx = start_idx + pagesize
|
||||
scenes = scenes[start_idx:end_idx]
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(
|
||||
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
|
||||
f"in workspace {ws_uuid}, total={total}"
|
||||
)
|
||||
else:
|
||||
# 获取所有场景(支持分页)
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
# ==================== 本体类型管理接口 ====================
|
||||
|
||||
async def create_class_handler(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = None
|
||||
):
|
||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||
|
||||
# 根据列表长度判断是单个还是批量
|
||||
count = len(request.classes)
|
||||
mode = "single" if count == 1 else "batch"
|
||||
|
||||
api_logger.info(
|
||||
f"Class creation ({mode}) requested by user {current_user.id}, "
|
||||
f"scene_id={request.scene_id}, count={count}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 准备类型数据
|
||||
classes_data = [
|
||||
{
|
||||
"class_name": item.class_name,
|
||||
"class_description": item.class_description
|
||||
}
|
||||
for item in request.classes
|
||||
]
|
||||
|
||||
if count == 1:
|
||||
# 单个创建 - 先检查重名
|
||||
class_data = classes_data[0]
|
||||
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
|
||||
if existing:
|
||||
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
|
||||
ontology_class = service.create_class(
|
||||
scene_id=request.scene_id,
|
||||
class_name=class_data["class_name"],
|
||||
class_description=class_data["class_description"],
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建单个响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
|
||||
|
||||
else:
|
||||
# 批量创建
|
||||
created_classes, errors = service.create_classes_batch(
|
||||
scene_id=request.scene_id,
|
||||
classes=classes_data,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建批量响应
|
||||
items = []
|
||||
for ontology_class in created_classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassBatchCreateResponse(
|
||||
total=len(classes_data),
|
||||
success_count=len(created_classes),
|
||||
failed_count=len(errors),
|
||||
items=items,
|
||||
errors=errors if errors else None
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Batch class creation completed: "
|
||||
f"success={len(created_classes)}, failed={len(errors)}"
|
||||
)
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||
|
||||
except ValueError as e:
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
|
||||
class_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.warning(f"Validation error in class creation: {err_str}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
|
||||
|
||||
except RuntimeError as e:
|
||||
err_str = str(e)
|
||||
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
|
||||
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
class_name = request.classes[0].class_name if request.classes else ""
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
|
||||
|
||||
async def update_class_handler(
|
||||
class_id: str,
|
||||
request: ClassUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新本体类型"""
|
||||
api_logger.info(
|
||||
f"Class update requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认类型
|
||||
class_repo = OntologyClassRepository(db)
|
||||
ontology_class = class_repo.get_by_id(class_uuid)
|
||||
if ontology_class and ontology_class.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试修改系统默认类型: user_id={current_user.id}, "
|
||||
f"class_id={class_id}, class_name={ontology_class.class_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认类型不可修改",
|
||||
"该类型为系统预设类型,不允许修改"
|
||||
)
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 更新类型
|
||||
ontology_class = service.update_class(
|
||||
class_id=class_uuid,
|
||||
class_name=request.class_name,
|
||||
class_description=request.class_description,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class updated successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class update: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
|
||||
async def delete_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除本体类型"""
|
||||
api_logger.info(
|
||||
f"Class deletion requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认类型
|
||||
class_repo = OntologyClassRepository(db)
|
||||
ontology_class = class_repo.get_by_id(class_uuid)
|
||||
if ontology_class and ontology_class.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试删除系统默认类型: user_id={current_user.id}, "
|
||||
f"class_id={class_id}, class_name={ontology_class.class_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认类型不可删除",
|
||||
"该类型为系统预设类型,不允许删除"
|
||||
)
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 删除类型
|
||||
success_flag = service.delete_class(
|
||||
class_id=class_uuid,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
api_logger.info(f"Class deleted successfully: {class_id}")
|
||||
|
||||
return success(data={"deleted": success_flag}, msg="类型删除成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class deletion: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
|
||||
async def get_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取单个本体类型"""
|
||||
api_logger.info(
|
||||
f"Get class requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取类型(会抛出ValueError如果不存在)
|
||||
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class retrieved successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 类型不存在或无权限访问
|
||||
api_logger.warning(f"Validation error in get class: {str(e)}")
|
||||
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
async def classes_handler(
|
||||
scene_id: str,
|
||||
class_name: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取类型列表(支持模糊搜索和全量查询)
|
||||
|
||||
当提供 class_name 参数时,进行模糊搜索;
|
||||
当不提供 class_name 参数时,返回场景下的所有类型。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID(必填)
|
||||
class_name: 类型名称关键词(可选,支持模糊匹配)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if class_name else "list"
|
||||
api_logger.info(
|
||||
f"Class {operation} requested by user {current_user.id}, "
|
||||
f"keyword={class_name}, scene_id={scene_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
scene_uuid = UUID(scene_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid scene_id format: {scene_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取场景信息
|
||||
scene = service.get_scene_by_id(scene_uuid, workspace_id)
|
||||
if not scene:
|
||||
api_logger.warning(f"Scene not found: {scene_id}")
|
||||
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
|
||||
|
||||
# 根据是否提供 class_name 决定查询方式
|
||||
if class_name and class_name.strip():
|
||||
# 模糊搜索类型
|
||||
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
|
||||
else:
|
||||
# 获取所有类型
|
||||
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for ontology_class in classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassListResponse(
|
||||
total=len(items),
|
||||
scene_id=scene_uuid,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
is_system_default=scene.is_system_default,
|
||||
items=items
|
||||
)
|
||||
|
||||
if class_name:
|
||||
api_logger.info(
|
||||
f"Class search completed: found {len(items)} classes matching '{class_name}' "
|
||||
f"in scene {scene_id}"
|
||||
)
|
||||
else:
|
||||
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -8,9 +8,13 @@ from starlette.responses import StreamingResponse
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
|
||||
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
|
||||
from app.schemas.prompt_optimizer_schema import (
|
||||
PromptOptMessage,
|
||||
CreateSessionResponse,
|
||||
SessionHistoryResponse,
|
||||
SessionMessage,
|
||||
PromptSaveRequest
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.prompt_optimizer_service import PromptOptimizerService
|
||||
|
||||
@@ -116,13 +120,15 @@ async def get_prompt_opt(
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
user_require=data.message,
|
||||
skill=data.skill
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
yield f"event:message\ndata: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"event:error\ndata: {json.dumps(
|
||||
{"error": str(e)}
|
||||
{"error": str(e)},
|
||||
ensure_ascii=False
|
||||
)}\n\n"
|
||||
yield "event:end\ndata: {}\n\n"
|
||||
|
||||
@@ -135,3 +141,109 @@ async def get_prompt_opt(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/releases",
|
||||
summary="Get prompt optimization",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def save_prompt(
|
||||
data: PromptSaveRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Save a prompt release for the current tenant.
|
||||
|
||||
Args:
|
||||
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
|
||||
db (Session): SQLAlchemy database session, injected via dependency.
|
||||
current_user: Currently authenticated user object, injected via dependency.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Standard API response containing the saved prompt release info:
|
||||
- id: UUID of the prompt release
|
||||
- session_id: associated session
|
||||
- title: prompt title
|
||||
- prompt: prompt content
|
||||
- created_at: timestamp of creation
|
||||
|
||||
Raises:
|
||||
Any database or service exceptions are propagated to the global exception handler.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
prompt_info = service.save_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=data.session_id,
|
||||
title=data.title,
|
||||
prompt=data.prompt
|
||||
)
|
||||
return success(data=prompt_info)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/releases/{prompt_id}",
|
||||
summary="Delete prompt (soft delete)",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def delete_prompt(
|
||||
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Soft delete a prompt release.
|
||||
|
||||
Args:
|
||||
prompt_id
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Success message confirming deletion
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
service.delete_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
prompt_id=prompt_id
|
||||
)
|
||||
return success(msg="Prompt deleted successfully")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/releases/list",
|
||||
summary="Get paginated list of released prompts with optional filter",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def get_release_list(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
keyword: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieve paginated list of released prompts for the current tenant.
|
||||
Optionally filter by keyword in title.
|
||||
|
||||
Args:
|
||||
page (int): Page number (starting from 1)
|
||||
page_size (int): Number of items per page (max 100)
|
||||
keyword (str | None): Optional keyword to filter prompt titles
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains paginated list of prompt releases with metadata
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
result = service.get_release_list(
|
||||
tenant_id=current_user.tenant_id,
|
||||
page=max(1, page),
|
||||
page_size=min(max(1, page_size), 100),
|
||||
filter_keyword=keyword
|
||||
)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
|
||||
@@ -2,25 +2,34 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.quota_manager import check_end_user_quota
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db, get_db_read
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||
from app.schemas import release_share_schema, conversation_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.app_service import AppService
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
@@ -206,15 +215,27 @@ def list_conversations(
|
||||
logger.debug(f"share_data:{share_data.user_id}")
|
||||
other_id = share_data.user_id
|
||||
service = SharedChatService(db)
|
||||
share, release = service._get_release_by_share_token(share_data.share_token, password)
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
# 仅在新建终端用户时检查配额
|
||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||
if existing_end_user is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
from app.models.workspace_model import Workspace
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if ws:
|
||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
service = SharedChatService(db)
|
||||
conversations, total = service.list_conversations(
|
||||
share_token=share_data.share_token,
|
||||
user_id=str(new_end_user.id),
|
||||
@@ -251,8 +272,41 @@ def get_conversation(
|
||||
conv_service = ConversationService(db)
|
||||
messages = conv_service.get_messages(conversation_id)
|
||||
|
||||
# 构建响应
|
||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
|
||||
file_ids = []
|
||||
message_file_id_map = {}
|
||||
|
||||
# 第一次遍历:解析 audio_url,收集所有有效的 file_id
|
||||
for idx, m in enumerate(messages):
|
||||
if m.role == "assistant" and m.meta_data:
|
||||
audio_url = m.meta_data.get("audio_url")
|
||||
if not audio_url:
|
||||
continue
|
||||
try:
|
||||
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
|
||||
except (ValueError, IndexError):
|
||||
# audio_url 无法解析为 UUID,标记为 unknown
|
||||
m.meta_data["audio_status"] = "unknown"
|
||||
continue
|
||||
|
||||
file_ids.append(file_id)
|
||||
message_file_id_map[idx] = file_id
|
||||
|
||||
# 批量查询所有相关的 FileMetadata
|
||||
file_status_map = {}
|
||||
if file_ids:
|
||||
file_metas = (
|
||||
db.query(FileMetadata)
|
||||
.filter(FileMetadata.id.in_(set(file_ids)))
|
||||
.all()
|
||||
)
|
||||
file_status_map = {fm.id: fm.status for fm in file_metas}
|
||||
|
||||
# 第二次遍历:将查询结果映射回消息
|
||||
for idx, file_id in message_file_id_map.items():
|
||||
m = messages[idx]
|
||||
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
|
||||
|
||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
|
||||
conv_dict["messages"] = [
|
||||
conversation_schema.Message.model_validate(m) for m in messages
|
||||
]
|
||||
@@ -293,40 +347,61 @@ async def chat(
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
||||
from app.models.app_model import AppType
|
||||
|
||||
try:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.app_service import AppService
|
||||
# 验证分享链接和密码
|
||||
share, release = service._get_release_by_share_token(share_token, password)
|
||||
share, release = service.get_release_by_share_token(share_token, password)
|
||||
|
||||
# # Create end_user_id by concatenating app_id with user_id
|
||||
# end_user_id = f"{share.app_id}_{user_id}"
|
||||
|
||||
# Store end_user_id in database with original user_id
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
# 仅在新建终端用户时检查配额,已有用户复用不受限制
|
||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||
logger.info(f"终端用户配额检查: workspace_id={workspace_id}, other_id={other_id}, existing={existing_end_user is not None}")
|
||||
if existing_end_user is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
from app.models.workspace_model import Workspace
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if ws:
|
||||
logger.info(f"新终端用户,执行配额检查: tenant_id={ws.tenant_id}")
|
||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
original_user_id=user_id # Save original user_id to other_id
|
||||
original_user_id=user_id
|
||||
)
|
||||
|
||||
# Only extract and set memory_config_id when the end user doesn't have one yet
|
||||
if not new_end_user.memory_config_id:
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
memory_config_service = MemoryConfigService(db)
|
||||
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
|
||||
if memory_config_id:
|
||||
new_end_user.memory_config_id = memory_config_id
|
||||
db.commit()
|
||||
db.refresh(new_end_user)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
appid = share.app_id
|
||||
# appid = share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||
from app.models.app_model import App
|
||||
app = db.query(App).filter(
|
||||
App.id == appid,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
# app = db.query(App).filter(
|
||||
# App.id == appid,
|
||||
# App.is_active.is_(True)
|
||||
# ).first()
|
||||
# if not app:
|
||||
# raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
workspace_id = app.workspace_id
|
||||
# workspace_id = app.workspace_id
|
||||
|
||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
@@ -359,12 +434,12 @@ async def chat(
|
||||
app_type = release.app.type if release.app else None
|
||||
|
||||
# 根据应用类型验证配置
|
||||
if app_type == "agent":
|
||||
if app_type == AppType.AGENT:
|
||||
# Agent 类型:验证模型配置
|
||||
model_config_id = release.default_model_config_id
|
||||
if not model_config_id:
|
||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app_type == "multi_agent":
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# Multi-Agent 类型:验证多 Agent 配置
|
||||
config = release.config or {}
|
||||
if not config.get("sub_agents"):
|
||||
@@ -402,31 +477,10 @@ async def chat(
|
||||
# 流式返回
|
||||
agent_config = agent_config_4_app_release(release)
|
||||
|
||||
if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
||||
agent_config.model_parameters["deep_thinking"] = False
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
@@ -438,7 +492,8 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -451,20 +506,6 @@ async def chat(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
# 非流式返回
|
||||
# result = await service.chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
result = await app_chat_service.agnet_chat(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
@@ -475,7 +516,8 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
@@ -522,48 +564,6 @@ async def chat(
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
# 多 Agent 流式返回
|
||||
# if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.multi_agent_chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
|
||||
# # 多 Agent 非流式返回
|
||||
# result = await service.multi_agent_chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
config = workflow_config_4_app_release(release)
|
||||
if not config.id:
|
||||
@@ -578,6 +578,7 @@ async def chat(
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
@@ -585,7 +586,8 @@ async def chat(
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release.id
|
||||
release_id=release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
@@ -606,11 +608,11 @@ async def chat(
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
@@ -634,6 +636,40 @@ async def chat(
|
||||
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@router.get("/config", summary="获取应用启动配置")
|
||||
async def config_query(
|
||||
password: str = Query(None, description="访问密码"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
share_service = SharedChatService(db)
|
||||
share_token = share_data.share_token
|
||||
share, release = share_service.get_release_by_share_token(share_token, password)
|
||||
if release.app.type == AppType.WORKFLOW:
|
||||
workflow_service = WorkflowService(db)
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": workflow_service.get_start_node_variables(release.config),
|
||||
"memory": workflow_service.is_memory_enable(release.config),
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
elif release.app.type == AppType.AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": release.config.get("variables"),
|
||||
"memory": release.config.get("memory", {}).get("enabled"),
|
||||
"features": release.config.get("features"),
|
||||
"model_parameters": release.config.get("model_parameters")
|
||||
}
|
||||
elif release.app.type == AppType.MULTI_AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": [],
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
else:
|
||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
return success(data=content)
|
||||
|
||||
@@ -4,7 +4,18 @@
|
||||
认证方式: API Key
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
|
||||
|
||||
from . import (
|
||||
app_api_controller,
|
||||
end_user_api_controller,
|
||||
memory_api_controller,
|
||||
memory_config_api_controller,
|
||||
rag_api_chunk_controller,
|
||||
rag_api_document_controller,
|
||||
rag_api_file_controller,
|
||||
rag_api_knowledge_controller,
|
||||
user_memory_api_controller,
|
||||
)
|
||||
|
||||
# 创建 V1 API 路由器
|
||||
service_router = APIRouter()
|
||||
@@ -16,5 +27,8 @@ service_router.include_router(rag_api_document_controller.router)
|
||||
service_router.include_router(rag_api_file_controller.router)
|
||||
service_router.include_router(rag_api_chunk_controller.router)
|
||||
service_router.include_router(memory_api_controller.router)
|
||||
service_router.include_router(end_user_api_controller.router)
|
||||
service_router.include_router(memory_config_api_controller.router)
|
||||
service_router.include_router(user_memory_api_controller.router)
|
||||
|
||||
__all__ = ["service_router"]
|
||||
|
||||
@@ -12,18 +12,19 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_app_or_workspace
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas import AppChatRequest, conversation_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
from app.services.app_service import get_app_service, AppService
|
||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
||||
logger = get_business_logger()
|
||||
@@ -34,6 +35,7 @@ async def list_apps():
|
||||
"""列出可访问的应用(占位)"""
|
||||
return success(data=[], msg="App API - Coming Soon")
|
||||
|
||||
|
||||
# /v1/app/chat
|
||||
|
||||
# @router.post("/chat")
|
||||
@@ -60,46 +62,68 @@ async def list_apps():
|
||||
# return success(data={"received": True}, msg="消息已接收")
|
||||
|
||||
|
||||
def _checkAppConfig(app: App):
|
||||
if app.type == AppType.AGENT:
|
||||
if not app.current_release.config:
|
||||
def _checkAppConfig(release: AppRelease):
|
||||
if release.type == AppType.AGENT:
|
||||
if not release.config:
|
||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app.type == AppType.MULTI_AGENT:
|
||||
if not app.current_release.config:
|
||||
elif release.type == AppType.MULTI_AGENT:
|
||||
if not release.config:
|
||||
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app.type == AppType.WORKFLOW:
|
||||
if not app.current_release.config:
|
||||
elif release.type == AppType.WORKFLOW:
|
||||
if not release.config:
|
||||
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
else:
|
||||
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
|
||||
raise BusinessException("不支持的应用类型", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@require_api_key(scopes=["app"])
|
||||
async def chat(
|
||||
request:Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
):
|
||||
"""
|
||||
Agent/Workflow 聊天接口
|
||||
|
||||
- 不传 version:使用当前生效版本(current_release,回滚后为回滚目标版本)
|
||||
- 传 version=release_id:使用指定版本uuid的历史快照,例如 {"version": "{{release_id}}"}
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = AppChatRequest(**body)
|
||||
|
||||
other_id = payload.user_id
|
||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||
|
||||
# 版本切换:指定 release_id 时查找对应历史快照,否则使用当前激活版本
|
||||
if payload.version is not None:
|
||||
active_release = app_service.get_release_by_id(app.id, payload.version)
|
||||
else:
|
||||
active_release = app.current_release
|
||||
other_id = payload.user_id
|
||||
workspace_id = app.workspace_id
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
|
||||
# 仅在新建终端用户时检查配额,已有用户复用不受限制
|
||||
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
|
||||
if existing_end_user is None:
|
||||
from app.core.quota_manager import _check_quota
|
||||
from app.models.workspace_model import Workspace
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if ws:
|
||||
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
|
||||
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
original_user_id=other_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
web_search=True
|
||||
memory=True
|
||||
web_search = True
|
||||
memory = True
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
db=db,
|
||||
@@ -126,36 +150,43 @@ async def chat(
|
||||
storage_type = 'neo4j'
|
||||
app_type = app.type
|
||||
# check app config
|
||||
_checkAppConfig(app)
|
||||
_checkAppConfig(active_release)
|
||||
|
||||
# 获取或创建会话(提前验证)
|
||||
conversation = conversation_service.create_or_get_conversation(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=end_user_id,
|
||||
is_draft=False
|
||||
is_draft=False,
|
||||
conversation_id=payload.conversation_id
|
||||
)
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
|
||||
# print("="*50)
|
||||
# print(app.current_release.default_model_config_id)
|
||||
agent_config = agent_config_4_app_release(app.current_release)
|
||||
agent_config = agent_config_4_app_release(active_release)
|
||||
# print(agent_config.default_model_config_id)
|
||||
|
||||
# thinking 开关:仅当 agent 配置了 deep_thinking 且请求 thinking=True 时才启用
|
||||
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
||||
agent_config.model_parameters["deep_thinking"] = False
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id= end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=web_search,
|
||||
config=agent_config,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=web_search,
|
||||
config=agent_config,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -175,30 +206,31 @@ async def chat(
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config= agent_config,
|
||||
config=agent_config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# 多 Agent 流式返回
|
||||
config = multi_agent_config_4_app_release(app.current_release)
|
||||
config = multi_agent_config_4_app_release(active_release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -228,23 +260,24 @@ async def chat(
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
# 多 Agent 流式返回
|
||||
config = workflow_config_4_app_release(app.current_release)
|
||||
config = workflow_config_4_app_release(active_release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=active_release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
@@ -263,21 +296,22 @@ async def chat(
|
||||
}
|
||||
)
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
# workflow 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id
|
||||
release_id=active_release.id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
@@ -291,7 +325,4 @@ async def chat(
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
|
||||
173
api/app/controllers/service/end_user_api_controller.py
Normal file
173
api/app/controllers/service/end_user_api_controller.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""End User 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import user_memory_controllers
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.quota_stub import check_end_user_quota
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.end_user_info_schema import EndUserInfoUpdate
|
||||
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
|
||||
from app.services import api_key_service
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
||||
"""Build a current_user object from API key auth
|
||||
|
||||
Args:
|
||||
api_key_auth: Validated API key auth info
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object with current_workspace_id set
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
return current_user
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
@require_api_key(scopes=["memory"])
|
||||
@check_end_user_quota
|
||||
async def create_end_user(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Create or retrieve an end user for the workspace.
|
||||
|
||||
Creates a new end user and connects it to a memory configuration.
|
||||
If an end user with the same other_id already exists in the workspace,
|
||||
returns the existing one.
|
||||
|
||||
Optionally accepts a memory_config_id to connect the end user to a specific
|
||||
memory configuration. If not provided, falls back to the workspace default config.
|
||||
Optionally accepts an app_id to bind the end user to a specific app.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = CreateEndUserRequest(**body)
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
|
||||
logger.info("Create end user request - other_id: %s, workspace_id: %s", payload.other_id, workspace_id)
|
||||
|
||||
# Resolve memory_config_id: explicit > workspace default
|
||||
memory_config_id = None
|
||||
config_service = MemoryConfigService(db)
|
||||
|
||||
if payload.memory_config_id:
|
||||
try:
|
||||
memory_config_id = uuid.UUID(payload.memory_config_id)
|
||||
except ValueError:
|
||||
raise BusinessException(
|
||||
f"Invalid memory_config_id format: {payload.memory_config_id}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
config = config_service.get_config_with_fallback(memory_config_id, workspace_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
f"Memory config not found: {payload.memory_config_id}",
|
||||
BizCode.MEMORY_CONFIG_NOT_FOUND
|
||||
)
|
||||
memory_config_id = config.config_id
|
||||
else:
|
||||
default_config = config_service.get_workspace_default_config(workspace_id)
|
||||
if default_config:
|
||||
memory_config_id = default_config.config_id
|
||||
logger.info(f"Using workspace default memory config: {memory_config_id}")
|
||||
else:
|
||||
logger.warning(f"No default memory config found for workspace: {workspace_id}")
|
||||
|
||||
# Resolve app_id: explicit from payload, otherwise None
|
||||
app_id = None
|
||||
if payload.app_id:
|
||||
try:
|
||||
app_id = uuid.UUID(payload.app_id)
|
||||
except ValueError:
|
||||
raise BusinessException(
|
||||
f"Invalid app_id format: {payload.app_id}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_or_create_end_user_with_config(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=payload.other_id,
|
||||
memory_config_id=memory_config_id,
|
||||
other_name=payload.other_name,
|
||||
)
|
||||
end_user.other_name = payload.other_name
|
||||
logger.info(f"End user ready: {end_user.id}")
|
||||
|
||||
result = {
|
||||
"id": str(end_user.id),
|
||||
"other_id": end_user.other_id or "",
|
||||
"other_name": end_user.other_name or "",
|
||||
"workspace_id": str(end_user.workspace_id),
|
||||
"memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None,
|
||||
}
|
||||
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_end_user_info(
|
||||
request: Request,
|
||||
end_user_id: str,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get end user info.
|
||||
|
||||
Retrieves the info record (aliases, meta_data, etc.) for the specified end user.
|
||||
Delegates to the manager-side controller for shared logic.
|
||||
"""
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
return await user_memory_controllers.get_end_user_info(
|
||||
end_user_id=end_user_id,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/info/update")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_end_user_info(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update end user info.
|
||||
|
||||
Updates the info record (other_name, aliases, meta_data) for the specified end user.
|
||||
Delegates to the manager-side controller for shared logic.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = EndUserInfoUpdate(**body)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
return await user_memory_controllers.update_end_user_info(
|
||||
info_update=payload,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
@@ -1,49 +1,84 @@
|
||||
"""Memory 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_task_scheduler import scheduler
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.quota_stub import check_end_user_quota
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import (
|
||||
MemoryReadRequest,
|
||||
MemoryReadResponse,
|
||||
MemoryReadSyncResponse,
|
||||
MemoryWriteRequest,
|
||||
MemoryWriteResponse,
|
||||
MemoryWriteSyncResponse,
|
||||
)
|
||||
from app.services.memory_api_service import MemoryAPIService
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def _sanitize_task_result(result: dict) -> dict:
|
||||
"""Make Celery task result JSON-serializable.
|
||||
|
||||
Converts UUID and other non-serializable values to strings.
|
||||
|
||||
Args:
|
||||
result: Raw task result dict from task_service
|
||||
|
||||
Returns:
|
||||
JSON-safe dict
|
||||
"""
|
||||
import uuid as _uuid
|
||||
from datetime import datetime
|
||||
|
||||
def _convert(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: _convert(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_convert(i) for i in obj]
|
||||
if isinstance(obj, _uuid.UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return obj
|
||||
|
||||
return _convert(result)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_memory_info():
|
||||
"""获取记忆服务信息(占位)"""
|
||||
return success(data={}, msg="Memory API - Coming Soon")
|
||||
|
||||
|
||||
@router.post("/write_api_service")
|
||||
@router.post("/write")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def write_memory_api_service(
|
||||
async def write_memory(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
payload: MemoryWriteRequest = Body(..., embed=False),
|
||||
|
||||
message: str = Body(..., description="Message content"),
|
||||
):
|
||||
"""
|
||||
Write memory to storage.
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
Submit a memory write task.
|
||||
|
||||
Validates the end user, then dispatches the write to a Celery background task
|
||||
with per-user fair locking. Returns a task_id for status polling.
|
||||
"""
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||
|
||||
body = await request.json()
|
||||
payload = MemoryWriteRequest(**body)
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = await memory_api_service.write_memory(
|
||||
|
||||
result = memory_api_service.write_memory(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
@@ -51,29 +86,52 @@ async def write_memory_api_service(
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Memory write successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory written successfully")
|
||||
|
||||
logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}")
|
||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
||||
|
||||
|
||||
@router.post("/read_api_service")
|
||||
@router.get("/write/status")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_memory_api_service(
|
||||
async def get_write_task_status(
|
||||
request: Request,
|
||||
task_id: str = Query(..., description="Celery task ID"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Check the status of a memory write task.
|
||||
|
||||
Returns the current status and result (if completed) of a previously submitted write task.
|
||||
"""
|
||||
logger.info(f"Write task status check - task_id: {task_id}")
|
||||
|
||||
result = scheduler.get_task_status(task_id)
|
||||
|
||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||
|
||||
|
||||
@router.post("/read")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_memory(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
payload: MemoryReadRequest = Body(..., embed=False),
|
||||
message: str = Body(..., description="Query message"),
|
||||
):
|
||||
"""
|
||||
Read memory from storage.
|
||||
|
||||
Queries and retrieves memories for the specified end user with context-aware responses.
|
||||
Submit a memory read task.
|
||||
|
||||
Validates the end user, then dispatches the read to a Celery background task.
|
||||
Returns a task_id for status polling.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryReadRequest(**body)
|
||||
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = await memory_api_service.read_memory(
|
||||
|
||||
result = memory_api_service.read_memory(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
@@ -82,6 +140,95 @@ async def read_memory_api_service(
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
||||
|
||||
logger.info(f"Memory read task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
|
||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read task submitted")
|
||||
|
||||
|
||||
@router.get("/read/status")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_read_task_status(
|
||||
request: Request,
|
||||
task_id: str = Query(..., description="Celery task ID"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Check the status of a memory read task.
|
||||
|
||||
Returns the current status and result (if completed) of a previously submitted read task.
|
||||
"""
|
||||
logger.info(f"Read task status check - task_id: {task_id}")
|
||||
|
||||
from app.services.task_service import get_task_memory_read_result
|
||||
result = get_task_memory_read_result(task_id)
|
||||
|
||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||
|
||||
|
||||
@router.post("/write/sync")
|
||||
@require_api_key(scopes=["memory"])
|
||||
@check_end_user_quota
|
||||
async def write_memory_sync(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Message content"),
|
||||
):
|
||||
"""
|
||||
Write memory synchronously.
|
||||
|
||||
Blocks until the write completes and returns the result directly.
|
||||
For async processing with task polling, use /write instead.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryWriteRequest(**body)
|
||||
logger.info(f"Memory write (sync) request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = await memory_api_service.write_memory_sync(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
config_id=payload.config_id,
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Memory write (sync) successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryWriteSyncResponse(**result).model_dump(), msg="Memory written successfully")
|
||||
|
||||
|
||||
@router.post("/read/sync")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_memory_sync(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Query message"),
|
||||
):
|
||||
"""
|
||||
Read memory synchronously.
|
||||
|
||||
Blocks until the read completes and returns the answer directly.
|
||||
For async processing with task polling, use /read instead.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryReadRequest(**body)
|
||||
logger.info(f"Memory read (sync) request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = await memory_api_service.read_memory_sync(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
end_user_id=payload.end_user_id,
|
||||
message=payload.message,
|
||||
search_switch=payload.search_switch,
|
||||
config_id=payload.config_id,
|
||||
storage_type=payload.storage_type,
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Memory read (sync) successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryReadSyncResponse(**result).model_dump(), msg="Memory read successfully")
|
||||
|
||||
491
api/app/controllers/service/memory_config_api_controller.py
Normal file
491
api/app/controllers/service/memory_config_api_controller.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""Memory Config 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, Query, Request
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import memory_storage_controller
|
||||
from app.controllers import memory_forget_controller
|
||||
from app.controllers import ontology_controller
|
||||
from app.controllers import emotion_config_controller
|
||||
from app.controllers import memory_reflection_controller
|
||||
from app.schemas.memory_storage_schema import ForgettingConfigUpdateRequest
|
||||
from app.controllers.emotion_config_controller import EmotionConfigUpdate
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import (
|
||||
ConfigUpdateExtractedRequest,
|
||||
ConfigUpdateRequest,
|
||||
ListConfigsResponse,
|
||||
ConfigCreateRequest,
|
||||
ConfigUpdateForgettingRequest,
|
||||
EmotionConfigUpdateRequest,
|
||||
ReflectionConfigUpdateRequest,
|
||||
)
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigParamsCreate,
|
||||
)
|
||||
from app.services import api_key_service
|
||||
from app.services.memory_api_service import MemoryAPIService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
router = APIRouter(prefix="/memory_config", tags=["V1 - Memory Config API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
||||
"""Build a current_user object from API key auth
|
||||
|
||||
Args:
|
||||
api_key_auth: Validated API key auth info
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object with current_workspace_id set
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
return current_user
|
||||
|
||||
|
||||
def _verify_config_ownership(config_id:str, workspace_id:uuid.UUID, db:Session):
|
||||
"""Verify that the config belongs to the workspace.
|
||||
|
||||
Args:
|
||||
config_id: The ID of the config to verify
|
||||
workspace_id: The workspace ID tocheck against
|
||||
db: Database session for querying
|
||||
Raises:
|
||||
BusinessException: If the config does not exist or does not belong to the workspace
|
||||
"""
|
||||
try:
|
||||
resolved_id = resolve_config_id(config_id, db)
|
||||
except ValueError as e:
|
||||
raise BusinessException(
|
||||
message=f"Invalid config_id: {e}",
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
)
|
||||
config = MemoryConfigRepository.get_by_id(db, resolved_id)
|
||||
if not config or config.workspace_id != workspace_id:
|
||||
raise BusinessException(
|
||||
message="Config not found or access denied",
|
||||
code=BizCode.MEMORY_CONFIG_NOT_FOUND,
|
||||
)
|
||||
|
||||
# @router.get("/configs")
|
||||
# @require_api_key(scopes=["memory"])
|
||||
# async def list_memory_configs(
|
||||
# request: Request,
|
||||
# api_key_auth: ApiKeyAuth = None,
|
||||
# db: Session = Depends(get_db),
|
||||
# ):
|
||||
# """
|
||||
# List all memory configs for the workspace.
|
||||
|
||||
# Returns all available memory configurations associated with the authorized workspace.
|
||||
# """
|
||||
# logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
# memory_api_service = MemoryAPIService(db)
|
||||
|
||||
# result = memory_api_service.list_memory_configs(
|
||||
# workspace_id=api_key_auth.workspace_id,
|
||||
# )
|
||||
|
||||
# logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||
# return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||
|
||||
@router.get("/read_all_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_all_config(
|
||||
request:Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List all memory configs with full details (enhanced version).
|
||||
|
||||
Returns complete config fields for the authorized workspace.
|
||||
No config_id ownership check needed — results are filtered by workspace.
|
||||
"""
|
||||
logger.info(f"V1 get all configs (full) - workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return memory_storage_controller.read_all_config(
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
@router.get("/scenes/simple")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_ontology_scenes(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get available ontology scenes for the workspace.
|
||||
|
||||
Returns a simple list of scene_id and scene_name for dropdown selection.
|
||||
Used before creating a memory config to choose which ontology scene to associate.
|
||||
"""
|
||||
logger.info(f"V1 get scenes - workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return await ontology_controller.get_scenes_simple(
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
@router.get("/read_config_extracted")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_extracted(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get extraction engine config details for a specific config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read extracted config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return memory_storage_controller.read_config_extracted(
|
||||
config_id = config_id,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
|
||||
@router.get("/read_config_forgetting")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_forgetting(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get forgetting settings for a specific memory config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read forgetting config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
result = await memory_forget_controller.read_forgetting_config(
|
||||
config_id = config_id,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
|
||||
|
||||
@router.get("/read_config_emotion")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_emotion(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get emotion engine config details for a specific config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read emotion config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return jsonable_encoder(emotion_config_controller.get_emotion_config(
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
))
|
||||
|
||||
@router.get("/read_config_reflection")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def read_config_reflection(
|
||||
request: Request,
|
||||
config_id: str = Query(..., description="config_id"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get reflection engine config details for a specific config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be queried.
|
||||
"""
|
||||
logger.info(f"V1 read reflection config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return jsonable_encoder(await memory_reflection_controller.start_reflection_configs(
|
||||
config_id=config_id,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
))
|
||||
|
||||
|
||||
@router.post("/create_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def create_memory_config(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
):
|
||||
"""
|
||||
Create a new memory config for the workspace.
|
||||
|
||||
The config will be associated with the workspace of the API Key.
|
||||
config_name is required, other fields are optional.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigCreateRequest(**body)
|
||||
|
||||
logger.info(f"V1 create config - workspace: {api_key_auth.workspace_id}, config_name: {payload.config_name}")
|
||||
|
||||
# 构造管理端 Schema,workspace_id 从 API Key 注入
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
mgmt_payload = ConfigParamsCreate(
|
||||
config_name=payload.config_name,
|
||||
config_desc=payload.config_desc or "",
|
||||
scene_id=payload.scene_id,
|
||||
llm_id=payload.llm_id,
|
||||
embedding_id=payload.embedding_id,
|
||||
rerank_id=payload.rerank_id,
|
||||
reflection_model_id=payload.reflection_model_id,
|
||||
emotion_model_id=payload.emotion_model_id,
|
||||
)
|
||||
#将返回数据中UUID序列化处理
|
||||
result =memory_storage_controller.create_config(
|
||||
payload=mgmt_payload,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
x_language_type=x_language_type,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@router.put("/update_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_memory_config(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update memory config basic info (name, description, scene).
|
||||
|
||||
Requires API Key with 'memory' scope
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigUpdateRequest(**body)
|
||||
|
||||
logger.info(f"V1 update config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
mgmt_payload = ConfigUpdate(
|
||||
config_id = payload.config_id,
|
||||
config_name = payload.config_name,
|
||||
config_desc = payload.config_desc,
|
||||
scene_id = payload.scene_id,
|
||||
)
|
||||
|
||||
return memory_storage_controller.update_config(
|
||||
payload = mgmt_payload,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
|
||||
@router.put("/update_config_extracted")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_memory_config_extracted(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
update memory config extraction engine config (models, thresholds, chunking, pruning, etc.).
|
||||
|
||||
Requires API Key with 'memory' scope.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigUpdateExtractedRequest(**body)
|
||||
|
||||
logger.info(f"V1 update extracted config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
#校验权限
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = ConfigUpdateExtracted(**update_fields)
|
||||
|
||||
return memory_storage_controller.update_config_extracted(
|
||||
payload = mgmt_payload,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
|
||||
@router.put("/update_config_forgetting")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_memory_config_forgetting(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
update memory config forgetting settings (forgetting strategy, parameters, etc.).
|
||||
|
||||
Requires API Key with 'memory' scope.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ConfigUpdateForgettingRequest(**body)
|
||||
|
||||
logger.info(f"V1 update forgetting config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
#校验权限
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = ForgettingConfigUpdateRequest(**update_fields)
|
||||
|
||||
#将返回数据中UUID序列化处理
|
||||
result = await memory_forget_controller.update_forgetting_config(
|
||||
payload = mgmt_payload,
|
||||
current_user = current_user,
|
||||
db = db,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@router.put("/update_config_emotion")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_config_emotion(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update emotion engine config (full update).
|
||||
|
||||
All fields except emotion_model_id are required.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = EmotionConfigUpdateRequest(**body)
|
||||
|
||||
logger.info(f"V1 update emotion config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = EmotionConfigUpdate(**update_fields)
|
||||
return jsonable_encoder(emotion_config_controller.update_emotion_config(
|
||||
config=mgmt_payload,
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
))
|
||||
|
||||
@router.put("/update_config_reflection")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def update_config_reflection(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
):
|
||||
"""
|
||||
Update reflection engine config (full update).
|
||||
|
||||
All fields are required.
|
||||
Only configs belonging to the authorized workspace can be updated.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = ReflectionConfigUpdateRequest(**body)
|
||||
|
||||
logger.info(f"V1 update reflection config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
update_fields = payload.model_dump(exclude_unset=True)
|
||||
mgmt_payload = Memory_Reflection(**update_fields)
|
||||
|
||||
return jsonable_encoder(await memory_reflection_controller.save_reflection_config(
|
||||
request=mgmt_payload,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
))
|
||||
|
||||
@router.delete("/delete_config")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def delete_memory_config(
|
||||
config_id: str,
|
||||
request: Request,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Delete a memory config.
|
||||
|
||||
- Default configs cannot be deleted.
|
||||
- If end users are connected and force=False, returns a warning.
|
||||
- If force=True, clears end user references and deletes the config.
|
||||
|
||||
Only configs belonging to the authorized workspace can be deleted.
|
||||
"""
|
||||
logger.info(f"V1 delete config - config_id: {config_id}, force: {force}, workspace: {api_key_auth.workspace_id}")
|
||||
|
||||
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||
|
||||
current_user = _get_current_user(api_key_auth, db)
|
||||
|
||||
return memory_storage_controller.delete_config(
|
||||
config_id=config_id,
|
||||
force=force,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
@@ -246,3 +246,73 @@ async def rebuild_knowledge_graph(
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/check/yuque/auth", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def check_yuque_auth(
|
||||
yuque_user_id: str,
|
||||
yuque_token: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
check yuque auth info
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
api_logger.info(f"check yuque auth info, username: {current_user.username}")
|
||||
|
||||
return await knowledge_controller.check_yuque_auth(yuque_user_id=yuque_user_id,
|
||||
yuque_token=yuque_token,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/check/feishu/auth", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def check_feishu_auth(
|
||||
feishu_app_id: str,
|
||||
feishu_app_secret: str,
|
||||
feishu_folder_token: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
check feishu auth info
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
api_logger.info(f"check feishu auth info, username: {current_user.username}")
|
||||
|
||||
return await knowledge_controller.check_feishu_auth(feishu_app_id=feishu_app_id,
|
||||
feishu_app_secret=feishu_app_secret,
|
||||
feishu_folder_token=feishu_folder_token,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def sync_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
sync knowledge base information based on knowledge_id
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.sync_knowledge(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
230
api/app/controllers/service/user_memory_api_controller.py
Normal file
230
api/app/controllers/service/user_memory_api_controller.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""User Memory 服务接口 — 基于 API Key 认证
|
||||
|
||||
包装 user_memory_controllers.py 和 memory_agent_controller.py 中的内部接口,
|
||||
提供基于 API Key 认证的对外服务:
|
||||
1./analytics/graph_data - 知识图谱数据接口
|
||||
2./analytics/community_graph - 社区图谱接口
|
||||
3./analytics/node_statistics - 记忆节点统计接口
|
||||
4./analytics/user_summary - 用户摘要接口
|
||||
5./analytics/memory_insight - 记忆洞察接口
|
||||
6./analytics/interest_distribution - 兴趣分布接口
|
||||
7./analytics/end_user_info - 终端用户信息接口
|
||||
8./analytics/generate_cache - 缓存生成接口
|
||||
|
||||
|
||||
路由前缀: /memory
|
||||
子路径: /analytics/...
|
||||
最终路径: /v1/memory/analytics/...
|
||||
认证方式: API Key (@require_api_key)
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, Request, Body
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.api_key_utils import get_current_user_from_api_key, validate_end_user_in_workspace
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
|
||||
# 包装内部服务 controller
|
||||
from app.controllers import user_memory_controllers, memory_agent_controller
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["V1 - User Memory API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
# ==================== 知识图谱 ====================
|
||||
|
||||
|
||||
@router.get("/analytics/graph_data")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_graph_data(
|
||||
request: Request,
|
||||
end_user_id: str = Query(..., description="End user ID"),
|
||||
node_types: Optional[str] = Query(None, description="Comma-separated node types filter"),
|
||||
limit: int = Query(100, description="Max nodes to return (auto-capped at 1000 in service layer)"),
|
||||
depth: int = Query(1, description="Graph traversal depth (auto-capped at 3 in service layer)"),
|
||||
center_node_id: Optional[str] = Query(None, description="Center node for subgraph"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get knowledge graph data (nodes + edges) for an end user."""
|
||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||
|
||||
return await user_memory_controllers.get_graph_data_api(
|
||||
end_user_id=end_user_id,
|
||||
node_types=node_types,
|
||||
limit=limit,
|
||||
depth=depth,
|
||||
center_node_id=center_node_id,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/analytics/community_graph")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_community_graph(
|
||||
request: Request,
|
||||
end_user_id: str = Query(..., description="End user ID"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get community clustering graph for an end user."""
|
||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||
|
||||
return await user_memory_controllers.get_community_graph_data_api(
|
||||
end_user_id=end_user_id,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
# ==================== 节点统计 ====================
|
||||
|
||||
|
||||
@router.get("/analytics/node_statistics")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_node_statistics(
|
||||
request: Request,
|
||||
end_user_id: str = Query(..., description="End user ID"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get memory node type statistics for an end user."""
|
||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||
|
||||
return await user_memory_controllers.get_node_statistics_api(
|
||||
end_user_id=end_user_id,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
# ==================== 用户摘要 & 洞察 ====================
|
||||
|
||||
|
||||
@router.get("/analytics/user_summary")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_user_summary(
|
||||
request: Request,
|
||||
end_user_id: str = Query(..., description="End user ID"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get cached user summary for an end user."""
|
||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||
|
||||
return await user_memory_controllers.get_user_summary_api(
|
||||
end_user_id=end_user_id,
|
||||
language_type=language_type,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/analytics/memory_insight")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_memory_insight(
|
||||
request: Request,
|
||||
end_user_id: str = Query(..., description="End user ID"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get cached memory insight report for an end user."""
|
||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||
|
||||
return await user_memory_controllers.get_memory_insight_report_api(
|
||||
end_user_id=end_user_id,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
# ==================== 兴趣分布 ====================
|
||||
|
||||
|
||||
@router.get("/analytics/interest_distribution")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_interest_distribution(
|
||||
request: Request,
|
||||
end_user_id: str = Query(..., description="End user ID"),
|
||||
limit: int = Query(5, le=5, description="Max interest tags to return"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get interest distribution tags for an end user."""
|
||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||
|
||||
return await memory_agent_controller.get_interest_distribution_by_user_api(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
language_type=language_type,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
# ==================== 终端用户信息 ====================
|
||||
|
||||
|
||||
@router.get("/analytics/end_user_info")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def get_end_user_info(
|
||||
request: Request,
|
||||
end_user_id: str = Query(..., description="End user ID"),
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get end user basic information (name, aliases, metadata)."""
|
||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||
validate_end_user_in_workspace(db, end_user_id, api_key_auth.workspace_id)
|
||||
|
||||
return await user_memory_controllers.get_end_user_info(
|
||||
end_user_id=end_user_id,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
# ==================== 缓存生成 ====================
|
||||
|
||||
|
||||
@router.post("/analytics/generate_cache")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def generate_cache(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(None, description="Request body"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
):
|
||||
"""Trigger cache generation (user summary + memory insight) for an end user or all workspace users."""
|
||||
body = await request.json()
|
||||
cache_request = GenerateCacheRequest(**body)
|
||||
|
||||
current_user = get_current_user_from_api_key(db, api_key_auth)
|
||||
|
||||
if cache_request.end_user_id:
|
||||
validate_end_user_in_workspace(db, cache_request.end_user_id, api_key_auth.workspace_id)
|
||||
|
||||
return await user_memory_controllers.generate_cache_api(
|
||||
request=cache_request,
|
||||
language_type=language_type,
|
||||
current_user=current_user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
87
api/app/controllers/skill_controller.py
Normal file
87
api/app/controllers/skill_controller.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Skill Controller - 技能市场管理"""
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas import skill_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.skill_service import SkillService
|
||||
from app.core.response_utils import success
|
||||
from app.core.quota_stub import check_skill_quota
|
||||
|
||||
router = APIRouter(prefix="/skills", tags=["Skills"])
|
||||
|
||||
|
||||
@router.post("", summary="创建技能")
|
||||
@check_skill_quota
|
||||
def create_skill(
|
||||
data: skill_schema.SkillCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建技能 - 可以关联现有工具(内置、MCP、自定义)"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.create_skill(db, data, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
|
||||
|
||||
|
||||
@router.get("", summary="技能列表")
|
||||
def list_skills(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_active: Optional[bool] = Query(None, description="是否激活"),
|
||||
is_public: Optional[bool] = Query(None, description="是否公开"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""技能市场列表 - 包含本工作空间和公开的技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skills, total = SkillService.list_skills(
|
||||
db, tenant_id, search, is_active, is_public, page, pagesize
|
||||
)
|
||||
|
||||
items = [skill_schema.Skill.model_validate(s) for s in skills]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
|
||||
|
||||
|
||||
@router.get("/{skill_id}", summary="获取技能详情")
|
||||
def get_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取技能详情"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.get_skill(db, skill_id, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
|
||||
|
||||
|
||||
@router.put("/{skill_id}", summary="更新技能")
|
||||
def update_skill(
|
||||
skill_id: uuid.UUID,
|
||||
data: skill_schema.SkillUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
|
||||
|
||||
|
||||
@router.delete("/{skill_id}", summary="删除技能")
|
||||
def delete_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
SkillService.delete_skill(db, skill_id, tenant_id)
|
||||
return success(msg="技能删除成功")
|
||||
173
api/app/controllers/tenant_subscription_controller.py
Normal file
173
api/app/controllers/tenant_subscription_controller.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
租户套餐查询接口(普通用户可访问)
|
||||
"""
|
||||
import datetime
|
||||
from typing import Callable, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.i18n.dependencies import get_translator
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
logger = get_api_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenant", tags=["Tenant"])
|
||||
public_router = APIRouter(tags=["Tenant"])
|
||||
|
||||
|
||||
@router.get("/subscription", response_model=ApiResponse, summary="获取当前用户所属租户的套餐信息")
|
||||
async def get_my_tenant_subscription(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator),
|
||||
):
|
||||
"""
|
||||
获取当前登录用户所属租户的有效套餐订阅信息。
|
||||
包含套餐名称、版本、配额、到期时间等。
|
||||
"""
|
||||
try:
|
||||
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||
|
||||
if not current_user.tenant:
|
||||
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
||||
|
||||
tenant_id = current_user.tenant.id
|
||||
svc = TenantSubscriptionService(db)
|
||||
sub = svc.get_subscription(tenant_id)
|
||||
|
||||
if not sub:
|
||||
# 无订阅记录时,兜底返回免费套餐信息
|
||||
free_plan = svc.plan_repo.get_free_plan()
|
||||
if not free_plan:
|
||||
return success(data=None, msg="暂无有效套餐")
|
||||
return success(data={
|
||||
"subscription_id": None,
|
||||
"tenant_id": str(tenant_id),
|
||||
"package_plan_id": str(free_plan.id),
|
||||
"package_version": free_plan.version,
|
||||
"package_plan": {
|
||||
"id": str(free_plan.id),
|
||||
"name": free_plan.name,
|
||||
"name_en": free_plan.name_en,
|
||||
"version": free_plan.version,
|
||||
"category": free_plan.category,
|
||||
"tier_level": free_plan.tier_level,
|
||||
"price": float(free_plan.price) if free_plan.price is not None else 0.0,
|
||||
"billing_cycle": free_plan.billing_cycle,
|
||||
"core_value": free_plan.core_value,
|
||||
"core_value_en": free_plan.core_value_en,
|
||||
"tech_support": free_plan.tech_support,
|
||||
"tech_support_en": free_plan.tech_support_en,
|
||||
"sla_compliance": free_plan.sla_compliance,
|
||||
"sla_compliance_en": free_plan.sla_compliance_en,
|
||||
"page_customization": free_plan.page_customization,
|
||||
"page_customization_en": free_plan.page_customization_en,
|
||||
"theme_color": free_plan.theme_color,
|
||||
},
|
||||
"started_at": None,
|
||||
"expired_at": None,
|
||||
"status": "active",
|
||||
"quotas": free_plan.quotas or {},
|
||||
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
}, msg="免费套餐")
|
||||
|
||||
return success(data=svc.build_response(sub))
|
||||
|
||||
except ModuleNotFoundError:
|
||||
# 社区版无 premium 模块,从配置文件读取免费套餐
|
||||
if not current_user.tenant:
|
||||
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
||||
|
||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||
|
||||
plan = DEFAULT_FREE_PLAN
|
||||
response_data = {
|
||||
"subscription_id": None,
|
||||
"tenant_id": str(current_user.tenant.id),
|
||||
"package_plan_id": None,
|
||||
"package_version": plan["version"],
|
||||
"package_plan": {
|
||||
"id": None,
|
||||
"name": plan["name"],
|
||||
"name_en": plan.get("name_en"),
|
||||
"version": plan["version"],
|
||||
"category": plan["category"],
|
||||
"tier_level": plan["tier_level"],
|
||||
"price": float(plan["price"]),
|
||||
"billing_cycle": plan["billing_cycle"],
|
||||
"core_value": plan.get("core_value"),
|
||||
"core_value_en": plan.get("core_value_en"),
|
||||
"tech_support": plan.get("tech_support"),
|
||||
"tech_support_en": plan.get("tech_support_en"),
|
||||
"sla_compliance": plan.get("sla_compliance"),
|
||||
"sla_compliance_en": plan.get("sla_compliance_en"),
|
||||
"page_customization": plan.get("page_customization"),
|
||||
"page_customization_en": plan.get("page_customization_en"),
|
||||
"theme_color": plan.get("theme_color"),
|
||||
},
|
||||
"started_at": None,
|
||||
"expired_at": None,
|
||||
"status": "active",
|
||||
"quotas": plan["quotas"],
|
||||
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||
}
|
||||
return success(data=response_data, msg="社区版免费套餐")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取租户套餐信息失败: {e}", exc_info=True)
|
||||
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐信息失败"))
|
||||
|
||||
|
||||
@public_router.get("/package-plans", response_model=ApiResponse, summary="获取套餐列表(公开)")
|
||||
async def list_package_plans_public(
|
||||
category: Optional[str] = None,
|
||||
status: Optional[bool] = None,
|
||||
search: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
公开接口,无需鉴权。
|
||||
SaaS 版从数据库读取套餐列表;社区版降级返回 default_free_plan.py 中的免费套餐。
|
||||
"""
|
||||
try:
|
||||
from premium.platform_admin.package_plan_service import PackagePlanService
|
||||
from premium.platform_admin.package_plan_schema import PackagePlanResponse
|
||||
svc = PackagePlanService(db)
|
||||
result = svc.get_list(page=1, size=9999, category=category, status=status, search=search)
|
||||
return success(data=[PackagePlanResponse.model_validate(p).model_dump(mode="json") for p in result["items"]])
|
||||
except ModuleNotFoundError:
|
||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||
plan = DEFAULT_FREE_PLAN
|
||||
return success(data=[{
|
||||
"id": None,
|
||||
"name": plan["name"],
|
||||
"name_en": plan.get("name_en"),
|
||||
"version": plan["version"],
|
||||
"category": plan["category"],
|
||||
"tier_level": plan["tier_level"],
|
||||
"price": float(plan["price"]),
|
||||
"billing_cycle": plan["billing_cycle"],
|
||||
"core_value": plan.get("core_value"),
|
||||
"core_value_en": plan.get("core_value_en"),
|
||||
"tech_support": plan.get("tech_support"),
|
||||
"tech_support_en": plan.get("tech_support_en"),
|
||||
"sla_compliance": plan.get("sla_compliance"),
|
||||
"sla_compliance_en": plan.get("sla_compliance_en"),
|
||||
"page_customization": plan.get("page_customization"),
|
||||
"page_customization_en": plan.get("page_customization_en"),
|
||||
"theme_color": plan.get("theme_color"),
|
||||
"status": plan.get("status", True),
|
||||
"quotas": plan["quotas"],
|
||||
}])
|
||||
except Exception as e:
|
||||
logger.error(f"获取套餐列表失败: {e}", exc_info=True)
|
||||
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐列表失败"))
|
||||
@@ -3,8 +3,11 @@ from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.tool_schema import (
|
||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
|
||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
|
||||
CustomToolTestRequest, ToolActiveUpdate
|
||||
)
|
||||
|
||||
from app.core.response_utils import success
|
||||
@@ -14,6 +17,7 @@ from app.models import User
|
||||
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
||||
from app.services.tool_service import ToolService
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.exceptions import BusinessException
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
||||
|
||||
@@ -72,6 +76,8 @@ async def get_tool_methods(
|
||||
if methods is None:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
return success(data=methods, msg="获取工具方法成功")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -97,7 +103,13 @@ async def create_tool(
|
||||
):
|
||||
"""创建工具"""
|
||||
try:
|
||||
tool_id = service.create_tool(
|
||||
# 将 MCP 来源字段合并进 config
|
||||
if request.tool_type == ToolType.MCP:
|
||||
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
|
||||
val = getattr(request, key, None)
|
||||
if val is not None:
|
||||
request.config[key] = val
|
||||
tool_id = await service.create_tool(
|
||||
name=request.name,
|
||||
tool_type=request.tool_type,
|
||||
tenant_id=current_user.tenant_id,
|
||||
@@ -107,8 +119,12 @@ async def create_tool(
|
||||
tags=request.tags
|
||||
)
|
||||
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
||||
except BusinessException as e:
|
||||
raise HTTPException(status_code=400, detail=e.message)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -137,6 +153,8 @@ async def update_tool(
|
||||
return success(msg="工具更新成功")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -147,7 +165,7 @@ async def delete_tool(
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""删除工具"""
|
||||
"""删除工具(逻辑删除,is_active=False)"""
|
||||
try:
|
||||
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
||||
if not success_flag:
|
||||
@@ -155,6 +173,34 @@ async def delete_tool(
|
||||
return success(msg="工具删除成功")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/{tool_id}/active", response_model=ApiResponse)
|
||||
async def set_tool_active(
|
||||
tool_id: str,
|
||||
request: ToolActiveUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""设置工具可用状态(启用/禁用)
|
||||
|
||||
- is_active=true: 启用工具
|
||||
- is_active=false: 禁用工具(等同于删除,但可恢复)
|
||||
"""
|
||||
try:
|
||||
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
|
||||
if not success_flag:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
action = "启用" if request.is_active else "禁用"
|
||||
return success(msg=f"工具已{action}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -187,6 +233,8 @@ async def execute_tool(
|
||||
},
|
||||
msg="工具执行完成"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -203,6 +251,8 @@ async def parse_openapi_schema(
|
||||
if result["success"] is False:
|
||||
raise HTTPException(status_code=400, detail=result["message"])
|
||||
return success(data=result, msg="Schema解析完成")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -216,8 +266,10 @@ async def sync_mcp_tools(
|
||||
try:
|
||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||
if not result.get("success", False):
|
||||
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
||||
raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
|
||||
return success(data=result, msg="MCP工具列表同步完成")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -240,8 +292,10 @@ async def test_tool_connection(
|
||||
# 普通连接测试
|
||||
result = await service.test_connection(tool_id, current_user.tenant_id)
|
||||
if result["success"] is False:
|
||||
raise HTTPException(status_code=400, detail=result["message"])
|
||||
raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
|
||||
return success(data=result, msg="连接测试完成")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.models.user_model import User
|
||||
from app.schemas import user_schema
|
||||
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
|
||||
from app.schemas.user_schema import (
|
||||
ChangePasswordRequest,
|
||||
AdminChangePasswordRequest,
|
||||
SendEmailCodeRequest,
|
||||
VerifyEmailCodeRequest,
|
||||
VerifyPasswordRequest)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import user_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.security import verify_password
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -25,7 +35,8 @@ router = APIRouter(
|
||||
def create_superuser(
|
||||
user: user_schema.UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_superuser: User = Depends(get_current_superuser)
|
||||
current_superuser: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""创建超级管理员(仅超级管理员可访问)"""
|
||||
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
||||
@@ -34,7 +45,7 @@ def create_superuser(
|
||||
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="超级管理员创建成功")
|
||||
return success(data=result_schema, msg=t("users.create.superuser_success"))
|
||||
|
||||
|
||||
@router.delete("/{user_id}", response_model=ApiResponse)
|
||||
@@ -42,6 +53,7 @@ def delete_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""停用用户(软删除)"""
|
||||
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -49,13 +61,14 @@ def delete_user(
|
||||
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
||||
return success(msg="用户停用成功")
|
||||
return success(msg=t("users.delete.deactivate_success"))
|
||||
|
||||
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
||||
def activate_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""激活用户"""
|
||||
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -66,13 +79,14 @@ def activate_user(
|
||||
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户激活成功")
|
||||
return success(data=result_schema, msg=t("users.activate.success"))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_current_user_info(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
||||
@@ -92,12 +106,27 @@ def get_current_user_info(
|
||||
result_schema.current_workspace_name = current_workspace.name
|
||||
|
||||
for ws in result.workspaces:
|
||||
if ws.workspace_id == current_user.current_workspace_id:
|
||||
if ws.workspace_id == current_user.current_workspace_id and ws.is_active:
|
||||
result_schema.role = ws.role
|
||||
break
|
||||
|
||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
|
||||
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
||||
if current_user.external_source:
|
||||
try:
|
||||
from premium.sso.models import SSOSource
|
||||
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||
if source and source.permissions:
|
||||
result_schema.permissions = source.permissions
|
||||
else:
|
||||
result_schema.permissions = []
|
||||
except ModuleNotFoundError:
|
||||
result_schema.permissions = []
|
||||
else:
|
||||
result_schema.permissions = ["all"]
|
||||
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.get("/superusers", response_model=ApiResponse)
|
||||
@@ -105,6 +134,7 @@ def get_tenant_superusers(
|
||||
include_inactive: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
||||
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
||||
@@ -117,7 +147,7 @@ def get_tenant_superusers(
|
||||
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
||||
|
||||
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ApiResponse)
|
||||
@@ -125,6 +155,7 @@ def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""根据用户ID获取用户信息"""
|
||||
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -135,7 +166,7 @@ def get_user_info_by_id(
|
||||
api_logger.info(f"用户信息获取成功: {result.username}")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.put("/change-password", response_model=ApiResponse)
|
||||
@@ -143,6 +174,7 @@ async def change_password(
|
||||
request: ChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""修改当前用户密码"""
|
||||
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
||||
@@ -155,7 +187,7 @@ async def change_password(
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
|
||||
|
||||
@router.put("/admin/change-password", response_model=ApiResponse)
|
||||
@@ -163,6 +195,7 @@ async def admin_change_password(
|
||||
request: AdminChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""超级管理员修改指定用户的密码"""
|
||||
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
||||
@@ -177,7 +210,107 @@ async def admin_change_password(
|
||||
# 根据是否生成了随机密码来构造响应
|
||||
if request.new_password:
|
||||
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
else:
|
||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
return success(data=generated_password, msg=t("auth.password.reset_success"))
|
||||
|
||||
|
||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||
def verify_pwd(
|
||||
request: VerifyPasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证当前用户密码"""
|
||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||
|
||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||
if not is_valid:
|
||||
raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/send-email-code", response_model=ApiResponse)
|
||||
async def send_email_code(
|
||||
request: SendEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""发送邮箱验证码"""
|
||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||
|
||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||
|
||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||
return success(msg=t("users.email.code_sent"))
|
||||
|
||||
|
||||
@router.put("/change-email", response_model=ApiResponse)
|
||||
async def change_email(
|
||||
request: VerifyEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证验证码并修改邮箱"""
|
||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||
|
||||
await user_service.verify_and_change_email(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
new_email=request.new_email,
|
||||
code=request.code
|
||||
)
|
||||
|
||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||
return success(msg=t("users.email.change_success"))
|
||||
|
||||
|
||||
|
||||
@router.get("/me/language", response_model=ApiResponse)
|
||||
def get_current_user_language(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户的语言偏好"""
|
||||
api_logger.info(f"获取用户语言偏好: {current_user.username}")
|
||||
|
||||
language = user_service.get_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=language),
|
||||
msg=t("users.language.get_success")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/language", response_model=ApiResponse)
|
||||
def update_current_user_language(
|
||||
request: user_schema.LanguagePreferenceRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""设置当前用户的语言偏好"""
|
||||
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
|
||||
|
||||
updated_user = user_service.update_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
language=request.language,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
|
||||
msg=t("users.language.update_success")
|
||||
)
|
||||
|
||||
@@ -5,26 +5,29 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends,Header
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.api_key_utils import timestamp_to_datetime
|
||||
from app.services.memory_base_service import Translation_English
|
||||
from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_memory_types,
|
||||
analytics_graph_data,
|
||||
analytics_community_graph_data,
|
||||
)
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.end_user_info_schema import (
|
||||
EndUserInfoResponse,
|
||||
EndUserInfoCreate,
|
||||
EndUserInfoUpdate,
|
||||
)
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.dependencies import get_current_user
|
||||
@@ -44,10 +47,9 @@ router = APIRouter(
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的记忆洞察报告
|
||||
@@ -55,18 +57,10 @@ async def get_memory_insight_report_api(
|
||||
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
||||
如需生成新的洞察报告,请使用专门的生成接口。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id,model_id,language_type)
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||
@@ -81,17 +75,24 @@ async def get_memory_insight_report_api(
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的用户摘要
|
||||
|
||||
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
||||
如需生成新的用户摘要,请使用专门的生成接口。
|
||||
|
||||
语言控制:
|
||||
- 使用 X-Language-Type Header 指定语言
|
||||
- 如果未传 Header,默认使用中文 (zh)
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
@@ -103,7 +104,7 @@ async def get_user_summary_api(
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language_type)
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
@@ -118,16 +119,24 @@ async def get_user_summary_api(
|
||||
|
||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||
async def generate_cache_api(
|
||||
request: GenerateCacheRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
手动触发缓存生成
|
||||
|
||||
- 如果提供 end_user_id,只为该用户生成
|
||||
- 如果不提供,为当前工作空间的所有用户生成
|
||||
|
||||
语言控制:
|
||||
- 使用 X-Language-Type Header 指定语言 ("zh" 中文, "en" 英文)
|
||||
- 如果未传 Header,默认使用中文 (zh)
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -139,7 +148,7 @@ async def generate_cache_api(
|
||||
|
||||
api_logger.info(
|
||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||
f"end_user_id={end_user_id if end_user_id else '全部用户'}"
|
||||
f"end_user_id={end_user_id if end_user_id else '全部用户'}, language={language}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -148,10 +157,12 @@ async def generate_cache_api(
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id)
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id)
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
@@ -185,7 +196,7 @@ async def generate_cache_api(
|
||||
# 为整个工作空间生成
|
||||
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
|
||||
|
||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
|
||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id, language=language)
|
||||
|
||||
# 记录统计信息
|
||||
api_logger.info(
|
||||
@@ -202,9 +213,9 @@ async def generate_cache_api(
|
||||
|
||||
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
||||
async def get_node_statistics_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -213,7 +224,8 @@ async def get_node_statistics_api(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
api_logger.info(
|
||||
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
|
||||
try:
|
||||
# 调用新的记忆类型统计函数
|
||||
@@ -221,21 +233,23 @@ async def get_node_statistics_api(
|
||||
|
||||
# 计算总数用于日志
|
||||
total_count = sum(item["count"] for item in result)
|
||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
api_logger.info(
|
||||
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||
async def get_graph_data_api(
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -289,107 +303,165 @@ async def get_graph_data_api(
|
||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||
async def get_end_user_profile(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||
async def get_community_graph_data_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
other_name=end_user.other_name,
|
||||
position=end_user.position,
|
||||
department=end_user.department,
|
||||
contact=end_user.contact,
|
||||
phone=end_user.phone,
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}")
|
||||
return success(data=result, msg=result.get("message", "查询成功"))
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取社区图谱: end_user_id={end_user_id}, "
|
||||
f"nodes={result['statistics']['total_nodes']}, "
|
||||
f"edges={result['statistics']['total_edges']}"
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
|
||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
|
||||
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
|
||||
|
||||
#=======================终端用户信息接口=======================
|
||||
|
||||
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
||||
async def update_end_user_profile(
|
||||
profile_update: EndUserProfileUpdate,
|
||||
@router.get("/end_user_info", response_model=ApiResponse)
|
||||
async def get_end_user_info(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户的基本信息
|
||||
查询终端用户信息记录
|
||||
|
||||
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
||||
所有字段都是可选的,只更新提供的字段。
|
||||
根据 end_user_id 查询单条终端用户信息记录。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = profile_update.end_user_id
|
||||
|
||||
# 验证工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询终端用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"查询终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
# 调用 Service 层处理业务逻辑
|
||||
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
|
||||
# 校验 end_user 是否属于当前工作空间
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||
if end_user is None:
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||
if str(end_user.workspace_id) != str(workspace_id):
|
||||
api_logger.warning(
|
||||
f"用户 {current_user.username} 尝试查询不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||
)
|
||||
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||
|
||||
result = user_memory_service.get_end_user_info(db, end_user_id)
|
||||
|
||||
if result["success"]:
|
||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
|
||||
api_logger.info(f"成功查询终端用户信息: end_user_id={end_user_id}")
|
||||
return success(data=result["data"], msg="查询成功")
|
||||
else:
|
||||
error_msg = result["error"]
|
||||
api_logger.error(f"查询终端用户信息失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
|
||||
if error_msg == "终端用户信息记录不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
||||
elif error_msg == "无效的终端用户ID格式":
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
||||
else:
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询终端用户信息失败", error_msg)
|
||||
|
||||
|
||||
@router.post("/end_user_info/updated", response_model=ApiResponse)
|
||||
async def update_end_user_info(
|
||||
info_update: EndUserInfoUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户信息记录
|
||||
|
||||
根据 end_user_id 更新终端用户信息记录,支持批量更新多个别名。
|
||||
|
||||
示例请求体:
|
||||
{
|
||||
"end_user_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
||||
"other_name": "张三1",
|
||||
"aliases": ["小张", "张工"],
|
||||
"meta_data": {"position": "工程师", "department": "技术部"}
|
||||
}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = info_update.end_user_id
|
||||
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新终端用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"更新终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
# 校验 end_user 是否属于当前工作空间
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||
if end_user is None:
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||
if str(end_user.workspace_id) != str(workspace_id):
|
||||
api_logger.warning(
|
||||
f"用户 {current_user.username} 尝试更新不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||
)
|
||||
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||
|
||||
# 获取更新数据(排除 end_user_id)
|
||||
update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||
|
||||
result = user_memory_service.update_end_user_info(db, end_user_id, update_data)
|
||||
|
||||
if result["success"]:
|
||||
api_logger.info(f"成功更新终端用户信息: end_user_id={end_user_id}")
|
||||
return success(data=result["data"], msg="更新成功")
|
||||
else:
|
||||
error_msg = result["error"]
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
api_logger.error(f"终端用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
|
||||
# 根据错误类型映射到合适的业务错误码
|
||||
if error_msg == "终端用户不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
||||
elif error_msg == "无效的用户ID格式":
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
|
||||
if error_msg == "终端用户信息记录不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
||||
elif error_msg == "无效的终端用户ID格式":
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
||||
else:
|
||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||
return fail(BizCode.INTERNAL_ERROR, "终端用户信息更新失败", error_msg)
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
workspace_id=current_user.current_workspace_id
|
||||
async def memory_space_timeline_of_shared_memories(
|
||||
id: str, label: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
@@ -398,14 +470,16 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
||||
else:
|
||||
model_id = None
|
||||
MemoryEntity = MemoryEntityService(id, label)
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language_type)
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
|
||||
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
async def memory_space_relationship_evolution(id: str, label: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||
|
||||
|
||||
@@ -1,610 +0,0 @@
|
||||
"""
|
||||
工作流 API 控制器
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.schemas.workflow_schema import (
|
||||
WorkflowConfigCreate,
|
||||
WorkflowConfigUpdate,
|
||||
WorkflowConfig,
|
||||
WorkflowValidationResponse,
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowExecutionRequest,
|
||||
WorkflowExecutionResponse
|
||||
)
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["workflow"])
|
||||
|
||||
|
||||
# ==================== 工作流配置管理 ====================
|
||||
|
||||
@router.post("/{app_id}/workflow")
|
||||
@cur_workspace_access_guard()
|
||||
async def create_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""创建工作流配置
|
||||
|
||||
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 创建工作流配置
|
||||
workflow_config = service.create_workflow_config(
|
||||
app_id=app_id,
|
||||
nodes=[node.model_dump() for node in config.nodes],
|
||||
edges=[edge.model_dump() for edge in config.edges],
|
||||
variables=[var.model_dump() for var in config.variables],
|
||||
execution_config=config.execution_config.model_dump(),
|
||||
triggers=[trigger.model_dump() for trigger in config.triggers],
|
||||
validate=True # 进行基础验证
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowConfig.model_validate(workflow_config),
|
||||
msg="工作流配置创建成功"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"创建工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"创建工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# @router.get("/{app_id}/workflow")
|
||||
# async def get_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)]
|
||||
#
|
||||
# ):
|
||||
# """获取工作流配置
|
||||
#
|
||||
# 获取应用的工作流配置详情。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
#
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
#
|
||||
# # 获取工作流配置
|
||||
# service = WorkflowService(db)
|
||||
# workflow_config = service.get_workflow_config(app_id)
|
||||
#
|
||||
# if not workflow_config:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="工作流配置不存在"
|
||||
# )
|
||||
#
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config)
|
||||
# )
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"获取工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
# @router.put("/{app_id}/workflow")
|
||||
# async def update_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# config: WorkflowConfigUpdate,
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)],
|
||||
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
# ):
|
||||
# """更新工作流配置
|
||||
|
||||
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
|
||||
# # 更新工作流配置
|
||||
# workflow_config = service.update_workflow_config(
|
||||
# app_id=app_id,
|
||||
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
|
||||
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
|
||||
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
|
||||
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
|
||||
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
|
||||
# validate=True
|
||||
# )
|
||||
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config),
|
||||
# msg="工作流配置更新成功"
|
||||
# )
|
||||
|
||||
# except BusinessException as e:
|
||||
# logger.warning(f"更新工作流配置失败: {e.message}")
|
||||
# return fail(code=e.error_code, msg=e.message)
|
||||
# except Exception as e:
|
||||
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"更新工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.delete("/{app_id}/workflow")
|
||||
async def delete_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""删除工作流配置
|
||||
|
||||
删除应用的工作流配置。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 删除工作流配置
|
||||
deleted = service.delete_workflow_config(app_id)
|
||||
|
||||
if not deleted:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
return success(msg="工作流配置删除成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"删除工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{app_id}/workflow/validate")
|
||||
async def validate_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
):
|
||||
"""验证工作流配置
|
||||
|
||||
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证工作流配置
|
||||
|
||||
if for_publish:
|
||||
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
|
||||
else:
|
||||
workflow_config = service.get_workflow_config(app_id)
|
||||
if not workflow_config:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
from app.core.workflow.validator import validate_workflow_config as validate_config
|
||||
config_dict = {
|
||||
"nodes": workflow_config.nodes,
|
||||
"edges": workflow_config.edges,
|
||||
"variables": workflow_config.variables,
|
||||
"execution_config": workflow_config.execution_config,
|
||||
"triggers": workflow_config.triggers
|
||||
}
|
||||
is_valid, errors = validate_config(config_dict, for_publish=False)
|
||||
|
||||
return success(
|
||||
data=WorkflowValidationResponse(
|
||||
is_valid=is_valid,
|
||||
errors=errors,
|
||||
warnings=[]
|
||||
)
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"验证工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"验证工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 工作流执行管理 ====================
|
||||
|
||||
@router.get("/{app_id}/workflow/executions")
|
||||
async def get_workflow_executions(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
):
|
||||
"""获取工作流执行记录列表
|
||||
|
||||
获取应用的工作流执行历史记录。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 获取执行记录
|
||||
executions = service.get_executions_by_app(app_id, limit, offset)
|
||||
|
||||
# 获取统计信息
|
||||
statistics = service.get_execution_statistics(app_id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"executions": [WorkflowExecution.model_validate(e) for e in executions],
|
||||
"statistics": statistics,
|
||||
"pagination": {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"total": statistics["total"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行记录失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/workflow/executions/{execution_id}")
|
||||
async def get_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""获取工作流执行详情
|
||||
|
||||
获取单个工作流执行的详细信息,包括所有节点的执行记录。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 获取节点执行记录
|
||||
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"execution": WorkflowExecution.model_validate(execution),
|
||||
"node_executions": [
|
||||
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行详情失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
@router.post("/{app_id}/workflow/run")
|
||||
async def run_workflow(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""执行工作流
|
||||
|
||||
执行工作流并返回结果。支持流式和非流式两种模式。
|
||||
|
||||
**非流式模式**:等待工作流执行完成后返回完整结果。
|
||||
|
||||
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 准备输入数据
|
||||
input_data = {
|
||||
"message": request.message or "",
|
||||
"variables": request.variables
|
||||
}
|
||||
|
||||
# 执行工作流
|
||||
|
||||
if request.stream:
|
||||
# 流式执行
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
async def event_generator():
|
||||
"""生成 SSE 事件
|
||||
|
||||
SSE 格式:
|
||||
event: <event_type>
|
||||
data: <json_data>
|
||||
|
||||
支持的事件类型:
|
||||
- workflow_start: 工作流开始
|
||||
- workflow_end: 工作流结束
|
||||
- node_start: 节点开始执行
|
||||
- node_end: 节点执行完成
|
||||
- node_chunk: 中间节点的流式输出
|
||||
- message: 最终消息的流式输出(End 节点及其相邻节点)
|
||||
"""
|
||||
try:
|
||||
async for event in await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
# event: <type>
|
||||
# data: <json>
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
# 发送错误事件
|
||||
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
yield sse_error
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
result = await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=False
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowExecutionResponse(
|
||||
execution_id=result["execution_id"],
|
||||
status=result["status"],
|
||||
output=result.get("output"),
|
||||
output_data=result.get("output_data"),
|
||||
error_message=result.get("error_message"),
|
||||
elapsed_time=result.get("elapsed_time"),
|
||||
token_usage=result.get("token_usage")
|
||||
),
|
||||
msg="工作流执行完成"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"执行工作流失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"执行工作流失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||
async def cancel_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""取消工作流执行
|
||||
|
||||
取消正在运行的工作流执行。
|
||||
|
||||
**注意**:当前版本仅更新状态为 cancelled,实际的执行取消功能待实现。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 检查执行状态
|
||||
if execution.status not in ["pending", "running"]:
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"无法取消状态为 {execution.status} 的执行"
|
||||
)
|
||||
|
||||
# 更新状态为 cancelled
|
||||
service.update_execution_status(execution_id, "cancelled")
|
||||
|
||||
return success(msg="工作流执行已取消")
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||
return fail(code=e.code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"取消工作流执行失败: {str(e)}"
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
@@ -14,6 +14,12 @@ from app.dependencies import (
|
||||
get_current_user,
|
||||
workspace_access_guard,
|
||||
)
|
||||
from app.i18n.dependencies import get_current_language, get_translator
|
||||
from app.i18n.serializers import (
|
||||
WorkspaceSerializer,
|
||||
WorkspaceMemberSerializer,
|
||||
WorkspaceInviteSerializer
|
||||
)
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.models.user_model import User
|
||||
from app.models.workspace_model import InviteStatus
|
||||
@@ -29,6 +35,7 @@ from app.schemas.workspace_schema import (
|
||||
WorkspaceUpdate,
|
||||
)
|
||||
from app.services import workspace_service
|
||||
from app.core.quota_stub import check_workspace_quota
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -65,7 +72,9 @@ def get_workspaces(
|
||||
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_tenant: Tenants = Depends(get_current_tenant)
|
||||
current_tenant: Tenants = Depends(get_current_tenant),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下用户参与的所有工作空间
|
||||
|
||||
@@ -88,25 +97,51 @@ def get_workspaces(
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
||||
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
|
||||
return success(data=workspaces_schema, msg="工作空间列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
|
||||
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
|
||||
|
||||
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
@check_workspace_quota
|
||||
def create_workspace(
|
||||
workspace: WorkspaceCreate,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建新的工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
|
||||
# 验证并获取语言参数
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求创建工作空间: {workspace.name}, "
|
||||
f"language={language}"
|
||||
)
|
||||
|
||||
result = workspace_service.create_workspace(
|
||||
db=db, workspace=workspace, user=current_user)
|
||||
db=db, workspace=workspace, user=current_user, language=language
|
||||
)
|
||||
|
||||
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间创建成功")
|
||||
api_logger.info(
|
||||
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
||||
f"创建者: {current_user.username}, language={language}"
|
||||
)
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.created"))
|
||||
|
||||
@router.put("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
@@ -114,6 +149,8 @@ def update_workspace(
|
||||
workspace: WorkspaceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新工作空间"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -126,14 +163,21 @@ def update_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间更新成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.updated"))
|
||||
|
||||
@router.get("/members", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_cur_workspace_members(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间成员列表(关系序列化)"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
||||
@@ -144,8 +188,14 @@ def get_cur_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
||||
|
||||
# 转换为表格项并使用序列化器添加国际化字段
|
||||
table_items = _convert_members_to_table_items(members)
|
||||
return success(data=table_items, msg="工作空间成员列表获取成功")
|
||||
serializer = WorkspaceMemberSerializer()
|
||||
members_data = [item.model_dump() for item in table_items]
|
||||
members_i18n = serializer.serialize_list(members_data, language)
|
||||
|
||||
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
|
||||
|
||||
|
||||
@router.put("/members", response_model=ApiResponse)
|
||||
@@ -155,6 +205,7 @@ def update_workspace_members(
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
||||
@@ -165,27 +216,28 @@ def update_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
||||
return success(msg="成员角色更新成功")
|
||||
return success(msg=t("workspace.members.role_updated"))
|
||||
|
||||
|
||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def delete_workspace_member(
|
||||
async def delete_workspace_member(
|
||||
member_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
|
||||
workspace_service.delete_workspace_member(
|
||||
await workspace_service.delete_workspace_member(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
member_id=member_id,
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
||||
return success(msg="成员删除成功")
|
||||
return success(msg=t("workspace.members.deleted"))
|
||||
|
||||
|
||||
# 创建空间协作邀请
|
||||
@@ -195,6 +247,8 @@ def create_workspace_invite(
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -207,7 +261,12 @@ def create_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
||||
return success(data=result, msg="邀请创建成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.created"))
|
||||
|
||||
|
||||
@router.get("/invites", response_model=ApiResponse)
|
||||
@@ -219,6 +278,8 @@ def get_workspace_invites(
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请列表"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -233,18 +294,30 @@ def get_workspace_invites(
|
||||
offset=offset
|
||||
)
|
||||
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
||||
return success(data=invites, msg="邀请列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
invites_i18n = serializer.serialize_list(invites, language)
|
||||
|
||||
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
|
||||
|
||||
|
||||
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
||||
def get_workspace_invite_info(
|
||||
token: str,
|
||||
db: Session = Depends(get_db),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请用户信息(无需认证)"""
|
||||
result = workspace_service.validate_invite_token(db=db, token=token)
|
||||
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
||||
return success(data=result, msg="邀请验证成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.validated"))
|
||||
|
||||
|
||||
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
||||
@@ -254,6 +327,8 @@ def revoke_workspace_invite(
|
||||
invite_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""撤销工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -266,7 +341,12 @@ def revoke_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
||||
return success(data=result, msg="邀请撤销成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
|
||||
|
||||
# ==================== 公开邀请接口(无需认证) ====================
|
||||
|
||||
@@ -289,6 +369,7 @@ def switch_workspace(
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""切换工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
||||
@@ -299,7 +380,7 @@ def switch_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
||||
return success(msg="工作空间切换成功")
|
||||
return success(msg=t("workspace.switched"))
|
||||
|
||||
|
||||
@router.get("/storage", response_model=ApiResponse)
|
||||
@@ -307,6 +388,7 @@ def switch_workspace(
|
||||
def get_workspace_storage_type(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的存储类型"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -318,7 +400,7 @@ def get_workspace_storage_type(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
||||
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
|
||||
return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
|
||||
|
||||
|
||||
@router.get("/workspace_models", response_model=ApiResponse)
|
||||
@@ -326,6 +408,8 @@ def get_workspace_storage_type(
|
||||
def workspace_models_configs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -341,14 +425,14 @@ def workspace_models_configs(
|
||||
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作空间不存在或无权访问"
|
||||
detail=t("workspace.not_found")
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
|
||||
|
||||
|
||||
@router.put("/workspace_models", response_model=ApiResponse)
|
||||
@@ -357,6 +441,7 @@ def update_workspace_models_configs(
|
||||
models_update: WorkspaceModelsUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -373,5 +458,5 @@ def update_workspace_models_configs(
|
||||
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))
|
||||
|
||||
|
||||
4
api/app/core/__init__.py
Normal file
4
api/app/core/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:24
|
||||
162
api/app/core/agent/agent_middleware.py
Normal file
162
api/app/core/agent/agent_middleware.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Agent Middleware - 动态技能过滤"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
|
||||
from app.services.skill_service import SkillService
|
||||
from app.repositories.skill_repository import SkillRepository
|
||||
|
||||
|
||||
class AgentMiddleware:
|
||||
"""Agent 中间件 - 用于动态过滤和加载技能"""
|
||||
|
||||
def __init__(self, skills: Optional[dict] = None):
|
||||
"""
|
||||
初始化中间件
|
||||
|
||||
Args:
|
||||
skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]}
|
||||
"""
|
||||
self.skills = skills or {}
|
||||
self.enabled = self.skills.get('enabled', False)
|
||||
self.all_skills = self.skills.get('all_skills', False)
|
||||
self.skill_ids = self.skills.get('skill_ids', [])
|
||||
|
||||
@staticmethod
|
||||
def filter_tools(
|
||||
tools: List,
|
||||
message: str = "",
|
||||
skill_configs: Dict[str, Any] = None,
|
||||
tool_to_skill_map: Dict[str, str] = None
|
||||
) -> tuple[List, List[str]]:
|
||||
"""
|
||||
根据消息内容和技能配置动态过滤工具
|
||||
|
||||
Args:
|
||||
tools: 所有可用工具列表
|
||||
message: 用户消息(可用于智能过滤)
|
||||
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
|
||||
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
|
||||
|
||||
Returns:
|
||||
(过滤后的工具列表, 激活的技能ID列表)
|
||||
"""
|
||||
if not tools:
|
||||
return [], []
|
||||
|
||||
# 如果没有技能配置,返回所有工具
|
||||
if not skill_configs:
|
||||
return tools, []
|
||||
|
||||
# 基于关键词匹配激活技能
|
||||
activated_skill_ids = []
|
||||
message_lower = message.lower()
|
||||
|
||||
for skill_id, config in skill_configs.items():
|
||||
if not config.get('enabled', True):
|
||||
continue
|
||||
|
||||
keywords = config.get('keywords', [])
|
||||
# 如果没有关键词限制,或消息包含关键词,则激活该技能
|
||||
if not keywords or any(kw.lower() in message_lower for kw in keywords):
|
||||
activated_skill_ids.append(skill_id)
|
||||
|
||||
# 如果没有工具映射关系,返回所有工具
|
||||
if not tool_to_skill_map:
|
||||
return tools, activated_skill_ids
|
||||
|
||||
# 根据激活的技能过滤工具
|
||||
filtered_tools = []
|
||||
for tool in tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
# 如果工具不属于任何skill(base_tools),或者工具所属的skill被激活,则保留
|
||||
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
|
||||
filtered_tools.append(tool)
|
||||
|
||||
return filtered_tools, activated_skill_ids
|
||||
|
||||
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
加载技能关联的工具
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
tenant_id: 租户id
|
||||
base_tools: 基础工具列表
|
||||
|
||||
Returns:
|
||||
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
|
||||
"""
|
||||
|
||||
tools_dict = {}
|
||||
tool_to_skill_map = {} # 工具名称到技能ID的映射
|
||||
|
||||
if base_tools:
|
||||
for tool in base_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
tools_dict[tool_name] = tool
|
||||
# base_tools 不属于任何 skill,不加入映射
|
||||
|
||||
skill_configs = {}
|
||||
skill_ids_to_load = []
|
||||
|
||||
# 如果启用技能且 all_skills 为 True,加载租户下所有激活的技能
|
||||
if self.enabled and self.all_skills:
|
||||
skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000)
|
||||
skill_ids_to_load = [str(skill.id) for skill in skills]
|
||||
elif self.enabled and self.skill_ids:
|
||||
skill_ids_to_load = self.skill_ids
|
||||
|
||||
if skill_ids_to_load:
|
||||
for skill_id in skill_ids_to_load:
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
||||
if skill and skill.is_active:
|
||||
# 保存技能配置(包含prompt)
|
||||
config = skill.config or {}
|
||||
config['prompt'] = skill.prompt
|
||||
config['name'] = skill.name
|
||||
skill_configs[skill_id] = config
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 加载技能工具并获取映射关系
|
||||
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id)
|
||||
|
||||
# 只添加不冲突的 skill_tools
|
||||
for tool in skill_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
if tool_name not in tools_dict:
|
||||
tools_dict[tool_name] = tool
|
||||
# 复制映射关系
|
||||
if tool_name in skill_tool_map:
|
||||
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
|
||||
|
||||
return list(tools_dict.values()), skill_configs, tool_to_skill_map
|
||||
|
||||
@staticmethod
|
||||
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
根据激活的技能ID获取对应的提示词
|
||||
|
||||
Args:
|
||||
activated_skill_ids: 被激活的技能ID列表
|
||||
skill_configs: 技能配置字典
|
||||
|
||||
Returns:
|
||||
合并后的提示词
|
||||
"""
|
||||
prompts = []
|
||||
for skill_id in activated_skill_ids:
|
||||
config = skill_configs.get(skill_id, {})
|
||||
prompt = config.get('prompt')
|
||||
name = config.get('name', 'Skill')
|
||||
if prompt:
|
||||
prompts.append(f"# {name}\n{prompt}")
|
||||
|
||||
return "\n\n".join(prompts) if prompts else ""
|
||||
|
||||
@staticmethod
|
||||
def create_runnable():
|
||||
"""创建可运行的中间件"""
|
||||
return RunnablePassthrough()
|
||||
@@ -7,26 +7,18 @@ LangChain Agent 封装
|
||||
- 支持流式输出
|
||||
- 使用 RedBearLLM 支持多提供商
|
||||
"""
|
||||
import os
|
||||
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.errors import GraphRecursionError
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -34,16 +26,23 @@ logger = get_business_logger()
|
||||
class LangChainAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
is_omni: bool = False,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
|
||||
deep_thinking: bool = False, # 是否启用深度思考模式
|
||||
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
|
||||
json_output: bool = False, # 是否强制 JSON 输出
|
||||
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -56,28 +55,71 @@ class LangChainAgent:
|
||||
max_tokens: 最大 token 数
|
||||
system_prompt: 系统提示词
|
||||
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
||||
streaming: 是否启用流式输出(默认 True)
|
||||
streaming: 是否启用流式输出
|
||||
max_iterations: 最大迭代次数(None 表示自动计算:基础 5 次 + 每个工具 2 次)
|
||||
max_tool_consecutive_calls: 单个工具最大连续调用次数(默认 3 次)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
self.tools = tools or []
|
||||
self.streaming = streaming
|
||||
self.is_omni = is_omni
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
|
||||
# 创建 RedBearLLM(支持多提供商)
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
self.last_tool_called: Optional[str] = None
|
||||
|
||||
# 根据工具数量动态调整最大迭代次数
|
||||
# 基础值 + 每个工具额外的调用机会
|
||||
if max_iterations is None:
|
||||
# 自动计算:基础 5 次 + 每个工具 2 次额外机会
|
||||
self.max_iterations = 5 + len(self.tools) * 2
|
||||
else:
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format
|
||||
# 在 system prompt 中注入 JSON 要求
|
||||
from app.models.models_model import ModelProvider
|
||||
if json_output and (
|
||||
(provider.lower() == ModelProvider.DASHSCOPE and not is_omni)
|
||||
or provider.lower() == ModelProvider.VOLCANO
|
||||
# 有工具时 response_format 会被移除,所有 provider 都需要 system prompt 注入保证 JSON 输出
|
||||
or bool(tools)
|
||||
):
|
||||
self.system_prompt += "\n请以JSON格式输出。"
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
f"max_tool_consecutive_calls={self.max_tool_consecutive_calls}, "
|
||||
f"auto_calculated={max_iterations is None}"
|
||||
)
|
||||
|
||||
# 创建 RedBearLLM,capability 校验由 RedBearModelConfig 统一处理
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
capability=capability,
|
||||
deep_thinking=deep_thinking,
|
||||
thinking_budget_tokens=thinking_budget_tokens,
|
||||
json_output=json_output,
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"streaming": streaming # 使用参数控制流式
|
||||
"streaming": streaming
|
||||
}
|
||||
)
|
||||
|
||||
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
||||
# 从经过校验的 config 读取实际生效的能力开关
|
||||
self.deep_thinking = model_config.deep_thinking
|
||||
self.json_output = model_config.json_output
|
||||
|
||||
# 获取底层模型用于真正的流式调用
|
||||
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
|
||||
@@ -86,11 +128,14 @@ class LangChainAgent:
|
||||
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
||||
self._underlying_llm.streaming = True
|
||||
|
||||
# 包装工具以跟踪连续调用次数
|
||||
wrapped_tools = self._wrap_tools_with_tracking(self.tools) if self.tools else None
|
||||
|
||||
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
||||
# 无论是否有工具,都使用 agent 统一处理
|
||||
self.agent = create_agent(
|
||||
model=self.llm,
|
||||
tools=self.tools if self.tools else None,
|
||||
tools=wrapped_tools,
|
||||
system_prompt=self.system_prompt
|
||||
)
|
||||
|
||||
@@ -102,17 +147,92 @@ class LangChainAgent:
|
||||
"has_api_base": bool(api_base),
|
||||
"temperature": temperature,
|
||||
"streaming": streaming,
|
||||
"max_iterations": self.max_iterations,
|
||||
"max_tool_consecutive_calls": self.max_tool_consecutive_calls,
|
||||
"tool_count": len(self.tools),
|
||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||
"tool_count": len(self.tools)
|
||||
# "tool_count": len(self.tools)
|
||||
}
|
||||
)
|
||||
|
||||
def _wrap_tools_with_tracking(self, tools: Sequence[BaseTool]) -> List[BaseTool]:
|
||||
"""包装工具以跟踪连续调用次数
|
||||
|
||||
Args:
|
||||
tools: 原始工具列表
|
||||
|
||||
Returns:
|
||||
List[BaseTool]: 包装后的工具列表
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from functools import wraps
|
||||
|
||||
wrapped_tools = []
|
||||
|
||||
for original_tool in tools:
|
||||
tool_name = original_tool.name
|
||||
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||
|
||||
if not original_func:
|
||||
# 如果无法获取原始函数,直接使用原工具
|
||||
wrapped_tools.append(original_tool)
|
||||
continue
|
||||
|
||||
# 创建包装函数
|
||||
def make_wrapped_func(tool_name, original_func):
|
||||
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||
|
||||
@wraps(original_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""包装后的工具函数,跟踪连续调用次数"""
|
||||
# 检查是否是连续调用同一个工具
|
||||
if self.last_tool_called == tool_name:
|
||||
self.tool_call_counter[tool_name] = self.tool_call_counter.get(tool_name, 0) + 1
|
||||
else:
|
||||
# 切换到新工具,重置计数器
|
||||
self.tool_call_counter[tool_name] = 1
|
||||
self.last_tool_called = tool_name
|
||||
|
||||
current_count = self.tool_call_counter[tool_name]
|
||||
|
||||
logger.debug(
|
||||
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||
)
|
||||
|
||||
# 检查是否超过最大连续调用次数
|
||||
if current_count > self.max_tool_consecutive_calls:
|
||||
logger.warning(
|
||||
f"工具 '{tool_name}' 连续调用次数已达上限 ({self.max_tool_consecutive_calls}),"
|
||||
f"返回提示信息"
|
||||
)
|
||||
return (
|
||||
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||
)
|
||||
|
||||
# 调用原始工具函数
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
# 使用 StructuredTool 创建新工具
|
||||
wrapped_tool = StructuredTool(
|
||||
name=original_tool.name,
|
||||
description=original_tool.description,
|
||||
func=make_wrapped_func(tool_name, original_func),
|
||||
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||
)
|
||||
|
||||
wrapped_tools.append(wrapped_tool)
|
||||
|
||||
return wrapped_tools
|
||||
|
||||
def _prepare_messages(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""准备消息列表
|
||||
|
||||
@@ -120,14 +240,12 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表
|
||||
context: 上下文信息
|
||||
files: 多模态文件内容列表(已处理)
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# 添加系统提示词
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
messages: list = []
|
||||
|
||||
# 添加历史消息
|
||||
if history:
|
||||
@@ -142,112 +260,94 @@ class LangChainAgent:
|
||||
if context:
|
||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
# 构建用户消息(支持多模态)
|
||||
if files and len(files) > 0:
|
||||
content_parts = self._build_multimodal_content(user_content, files)
|
||||
messages.append(HumanMessage(content=content_parts))
|
||||
else:
|
||||
# 纯文本消息
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
# '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
# end_user_end=f"Term_{end_user_end}"
|
||||
# print(messages)
|
||||
# print(aimessages)
|
||||
# session_id = store.save_session(
|
||||
# userid=end_user_end,
|
||||
# messages=messages,
|
||||
# apply_id=end_user_end,
|
||||
# end_user_id=end_user_end,
|
||||
# aimessages=aimessages
|
||||
# )
|
||||
# store.delete_duplicate_sessions()
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
# return session_id
|
||||
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_redis_read(self,end_user_end):
|
||||
# end_user_end = f"Term_{end_user_end}"
|
||||
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
# messagss_list=[]
|
||||
# retrieved_content=[]
|
||||
# for messages in history:
|
||||
# query = messages.get("Query")
|
||||
# aimessages = messages.get("Answer")
|
||||
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
# retrieved_content.append({query: aimessages})
|
||||
# return messagss_list,retrieved_content
|
||||
@staticmethod
|
||||
def _extract_tokens_from_message(msg) -> int:
|
||||
"""从 AIMessage 或类似对象中提取 total_tokens,兼容多种 provider 格式
|
||||
|
||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||
支持的格式:
|
||||
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
|
||||
- response_metadata.usage.total_tokens (部分 provider)
|
||||
- usage_metadata.total_tokens (LangChain 新版)
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
total = 0
|
||||
# 1. response_metadata
|
||||
response_meta = getattr(msg, "response_metadata", None)
|
||||
if response_meta and isinstance(response_meta, dict):
|
||||
# 尝试 token_usage 路径
|
||||
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
|
||||
if isinstance(token_usage, dict):
|
||||
total = token_usage.get("total_tokens", 0)
|
||||
# 2. usage_metadata(LangChain 新版 AIMessage 属性)
|
||||
if not total:
|
||||
usage_meta = getattr(msg, "usage_metadata", None)
|
||||
if usage_meta:
|
||||
if isinstance(usage_meta, dict):
|
||||
total = usage_meta.get("total_tokens", 0)
|
||||
else:
|
||||
total = getattr(usage_meta, "total_tokens", 0)
|
||||
return total or 0
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建多模态消息内容
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
text: 文本内容
|
||||
files: 文件列表(已由 MultimodalService 处理为对应 provider 的格式)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 消息内容列表
|
||||
"""
|
||||
if storage_type == "rag":
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
# 根据 provider 使用不同的文本格式
|
||||
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
|
||||
# ModelProvider.GPUSTACK] or (
|
||||
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
|
||||
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
# else:
|
||||
# # 通义千问等: {"text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if user_message:
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
# 添加文件内容
|
||||
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
||||
content_parts.extend(files)
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if ai_message:
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
logger.debug(
|
||||
f"构建多模态消息: provider={self.provider}, "
|
||||
f"parts={len(content_parts)}, "
|
||||
f"files={len(files)}"
|
||||
)
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
return content_parts
|
||||
|
||||
# 调用 Celery 任务,传递结构化消息列表
|
||||
# 数据流:
|
||||
# 1. structured_messages 传递给 write_message_task
|
||||
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
actual_config_id, # config_id: 配置ID
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
@staticmethod
|
||||
def _extract_reasoning_content(msg) -> str:
|
||||
"""从 AIMessage 中提取深度思考内容(reasoning_content)
|
||||
|
||||
所有 provider 统一通过 additional_kwargs.reasoning_content 传递:
|
||||
- DeepSeek-R1 / QwQ: 原生字段
|
||||
- Volcano (Doubao-thinking): 由 VolcanoChatOpenAI 从 delta.reasoning_content 注入
|
||||
"""
|
||||
additional = getattr(msg, "additional_kwargs", None) or {}
|
||||
return additional.get("reasoning_content") or additional.get("reasoning", "")
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -255,59 +355,15 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||
context: 上下文信息(如知识库检索结果)
|
||||
files: 多模态文件
|
||||
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat= message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# db_for_memory = next(get_db())
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# print(retrieved_content)
|
||||
# # 为长期记忆操作获取新的数据库连接
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||
# raise
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
# # 长期记忆写入(
|
||||
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
|
||||
logger.debug(
|
||||
"准备调用 LangChain Agent",
|
||||
@@ -315,27 +371,84 @@ class LangChainAgent:
|
||||
"has_context": bool(context),
|
||||
"has_history": bool(history),
|
||||
"has_tools": bool(self.tools),
|
||||
"message_count": len(messages)
|
||||
"has_files": bool(files),
|
||||
"message_count": len(messages),
|
||||
"max_iterations": self.max_iterations
|
||||
}
|
||||
)
|
||||
|
||||
# 统一使用 agent.invoke 调用
|
||||
result = await self.agent.ainvoke({"messages": messages})
|
||||
# 通过 recursion_limit 限制最大迭代次数,防止工具调用死循环
|
||||
try:
|
||||
result = await self.agent.ainvoke(
|
||||
{"messages": messages},
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
)
|
||||
except (RecursionError, GraphRecursionError) as e:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
||||
extra={"error": str(e)}
|
||||
)
|
||||
# 返回一个友好的错误提示
|
||||
return {
|
||||
"content": f"抱歉,我在处理您的请求时遇到了问题。已达到最大处理步骤限制({self.max_iterations}次)。请尝试简化您的问题或稍后再试。",
|
||||
"model": self.model_name,
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
reasoning_content = ""
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
content = msg.content
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
logger.debug(f"AI 消息内容: {msg.content}")
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
logger.debug(f"提取字符串内容,长度: {len(content)}")
|
||||
elif isinstance(msg.content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
logger.debug(f"多模态响应,列表长度: {len(msg.content)}")
|
||||
text_parts = []
|
||||
for item in msg.content:
|
||||
logger.debug(f"处理项: {item}")
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取文本: {text[:100]}...")
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取文本: {text[:100]}...")
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
logger.debug(f"提取字符串: {item[:100]}...")
|
||||
content = "".join(text_parts)
|
||||
logger.debug(f"合并后内容长度: {len(content)}")
|
||||
else:
|
||||
content = str(msg.content)
|
||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||
total_tokens = self._extract_tokens_from_message(msg)
|
||||
reasoning_content = self._extract_reasoning_content(msg) if self.deep_thinking else ""
|
||||
break
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, content)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -343,9 +456,11 @@ class LangChainAgent:
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
if reasoning_content:
|
||||
response["reasoning_content"] = reasoning_content
|
||||
|
||||
logger.debug(
|
||||
"Agent 调用完成",
|
||||
@@ -362,25 +477,24 @@ class LangChainAgent:
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id:Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
) -> AsyncGenerator[str, None]:
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> AsyncGenerator[str | int | dict[str, str], None]:
|
||||
"""执行流式对话
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
history: 历史消息列表
|
||||
context: 上下文信息
|
||||
files: 多模态文件
|
||||
|
||||
Yields:
|
||||
str: 消息内容块
|
||||
int: token 统计
|
||||
Dict: 深度思考内容 {"type": "reasoning", "content": "..."}
|
||||
"""
|
||||
logger.info("=" * 80)
|
||||
logger.info(" chat_stream 方法开始执行")
|
||||
@@ -388,98 +502,129 @@ class LangChainAgent:
|
||||
logger.info(f" Has tools: {bool(self.tools)}")
|
||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||
logger.info("=" * 80)
|
||||
message_chat = message
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
# # TODO 乐力齐
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# db_for_memory = next(get_db())
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# # 长期记忆写入
|
||||
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to long term memory: {e}")
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
|
||||
logger.debug(
|
||||
f"准备流式调用,has_tools={bool(self.tools)}, message_count={len(messages)}"
|
||||
f"准备流式调用,has_tools={bool(self.tools)}, has_files={bool(files)}, message_count={len(messages)}"
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
yielded_content = False
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
full_content=''
|
||||
full_content = ''
|
||||
full_reasoning = ''
|
||||
try:
|
||||
last_event = {}
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2"
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
last_event = event
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
|
||||
# 处理所有可能的流式事件
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
full_content+=chunk.content
|
||||
if chunk and hasattr(chunk, "content") and chunk.content:
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
# 提取深度思考内容(仅在启用深度思考时)
|
||||
if self.deep_thinking:
|
||||
reasoning_chunk = self._extract_reasoning_content(chunk)
|
||||
if reasoning_chunk:
|
||||
full_reasoning += reasoning_chunk
|
||||
yield {"type": "reasoning", "content": reasoning_chunk}
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk:
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
full_content+=chunk.content
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
if hasattr(chunk, "content"):
|
||||
# 提取深度思考内容(仅在启用深度思考时)
|
||||
if self.deep_thinking:
|
||||
reasoning_chunk = self._extract_reasoning_content(chunk)
|
||||
if reasoning_chunk:
|
||||
full_reasoning += reasoning_chunk
|
||||
yield {"type": "reasoning", "content": reasoning_chunk}
|
||||
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
elif isinstance(chunk, str):
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
logger.debug(f"工具调用开始: {event.get('name')}")
|
||||
elif kind == "on_tool_end":
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
|
||||
# 统计token消耗
|
||||
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
stream_total_tokens = self._extract_tokens_from_message(msg)
|
||||
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
|
||||
yield stream_total_tokens
|
||||
break
|
||||
|
||||
except GraphRecursionError:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),模型可能不支持正确的工具调用停止判断"
|
||||
)
|
||||
if not full_content:
|
||||
yield "抱歉,我在处理您的请求时遇到了问题(已达最大处理步骤限制)。请尝试简化问题或更换模型后重试。"
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -493,5 +638,3 @@ class LangChainAgent:
|
||||
logger.info("=" * 80)
|
||||
logger.info("chat_stream 方法执行结束")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
|
||||
@@ -70,6 +70,8 @@ def require_api_key(
|
||||
})
|
||||
raise BusinessException("API Key 无效或已过期", BizCode.API_KEY_INVALID)
|
||||
|
||||
ApiKeyAuthService.check_app_published(db, api_key_obj)
|
||||
|
||||
if scopes:
|
||||
missing_scopes = []
|
||||
for scope in scopes:
|
||||
@@ -97,7 +99,7 @@ def require_api_key(
|
||||
)
|
||||
|
||||
rate_limiter = RateLimiterService()
|
||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj, db=db)
|
||||
if not is_allowed:
|
||||
logger.warning("API Key 限流触发", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
@@ -106,10 +108,12 @@ def require_api_key(
|
||||
"error_msg": error_msg
|
||||
})
|
||||
# 根据错误消息判断限流类型
|
||||
if "QPS" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||
elif "Daily" in error_msg:
|
||||
if "Daily" in error_msg:
|
||||
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
||||
elif "Tenant" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED # 租户套餐速率超限,同属 QPS 类
|
||||
elif "QPS" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||
else:
|
||||
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
||||
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
"""API Key 工具函数"""
|
||||
import secrets
|
||||
import uuid as _uuid
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session as _Session
|
||||
from app.core.error_codes import BizCode as _BizCode
|
||||
from app.core.exceptions import BusinessException as _BusinessException
|
||||
from app.models.end_user_model import EndUser as _EndUser
|
||||
from app.repositories.end_user_repository import EndUserRepository as _EndUserRepository
|
||||
|
||||
from app.models.api_key_model import ApiKeyType
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -65,3 +72,72 @@ def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
|
||||
return None
|
||||
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
|
||||
def get_current_user_from_api_key(db: _Session, api_key_auth):
|
||||
"""通过 API Key 构造 current_user 对象。
|
||||
|
||||
从 API Key 反查创建者(管理员用户),并设置其 workspace 上下文。
|
||||
与内部接口的 Depends(get_current_user) (JWT) 等价。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
api_key_auth: API Key 认证信息(ApiKeyAuth)
|
||||
|
||||
Returns:
|
||||
User ORM 对象,已设置 current_workspace_id
|
||||
"""
|
||||
from app.services import api_key_service
|
||||
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(
|
||||
db, api_key_auth.api_key_id, api_key_auth.workspace_id
|
||||
)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
return current_user
|
||||
|
||||
|
||||
def validate_end_user_in_workspace(
|
||||
db: _Session,
|
||||
end_user_id: str,
|
||||
workspace_id,
|
||||
) -> _EndUser:
|
||||
"""校验 end_user 是否存在且属于指定 workspace。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户 ID
|
||||
workspace_id: 工作空间 ID(UUID 或字符串均可)
|
||||
|
||||
Returns:
|
||||
EndUser ORM 对象(校验通过时)
|
||||
|
||||
Raises:
|
||||
BusinessException(INVALID_PARAMETER): end_user_id 格式无效
|
||||
BusinessException(USER_NOT_FOUND): end_user 不存在
|
||||
BusinessException(PERMISSION_DENIED): end_user 不属于该 workspace
|
||||
"""
|
||||
try:
|
||||
_uuid.UUID(end_user_id)
|
||||
except (ValueError, AttributeError):
|
||||
raise _BusinessException(
|
||||
f"Invalid end_user_id format: {end_user_id}",
|
||||
_BizCode.INVALID_PARAMETER,
|
||||
)
|
||||
|
||||
end_user_repo = _EndUserRepository(db)
|
||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||
|
||||
if end_user is None:
|
||||
raise _BusinessException(
|
||||
"End user not found",
|
||||
_BizCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
if str(end_user.workspace_id) != str(workspace_id):
|
||||
raise _BusinessException(
|
||||
"End user does not belong to this workspace",
|
||||
_BizCode.PERMISSION_DENIED,
|
||||
)
|
||||
|
||||
return end_user
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -16,18 +16,18 @@ class Settings:
|
||||
# cloud: SaaS 云服务版(全功能,按量计费)
|
||||
# enterprise: 企业私有化版(License 控制)
|
||||
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
|
||||
|
||||
|
||||
# License 配置(企业版)
|
||||
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
|
||||
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
|
||||
|
||||
|
||||
# 计费服务配置(SaaS 版)
|
||||
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
|
||||
|
||||
|
||||
# 基础 URL(用于 SSO 回调等)
|
||||
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
|
||||
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
|
||||
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
@@ -57,7 +57,6 @@ class Settings:
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||
@@ -91,13 +90,14 @@ class Settings:
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
|
||||
# SSO 免登配置
|
||||
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
|
||||
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||
|
||||
@@ -115,6 +115,7 @@ class Settings:
|
||||
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
|
||||
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
|
||||
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
|
||||
S3_ENDPOINT_URL: str = os.getenv("S3_ENDPOINT_URL", "")
|
||||
|
||||
# VOLC ASR settings
|
||||
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
||||
@@ -130,7 +131,7 @@ class Settings:
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
@@ -157,6 +158,49 @@ class Settings:
|
||||
if origin.strip()
|
||||
]
|
||||
|
||||
# Language Configuration
|
||||
# Supported values: "zh" (Chinese), "en" (English)
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# ========================================================================
|
||||
# Internationalization (i18n) Configuration
|
||||
# ========================================================================
|
||||
# Default language for API responses
|
||||
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# Supported languages (comma-separated)
|
||||
I18N_SUPPORTED_LANGUAGES: list[str] = [
|
||||
lang.strip()
|
||||
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
|
||||
if lang.strip()
|
||||
]
|
||||
|
||||
# Core locales directory (community edition)
|
||||
# Use absolute path to work from any working directory
|
||||
I18N_CORE_LOCALES_DIR: str = os.getenv(
|
||||
"I18N_CORE_LOCALES_DIR",
|
||||
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
|
||||
)
|
||||
|
||||
# Premium locales directory (enterprise edition, optional)
|
||||
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
|
||||
|
||||
# Enable translation cache
|
||||
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
|
||||
|
||||
# LRU cache size for hot translations
|
||||
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
|
||||
|
||||
# Enable hot reload of translation files
|
||||
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
|
||||
|
||||
# Fallback language when translation is missing
|
||||
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
|
||||
|
||||
# Log missing translations
|
||||
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
@@ -185,19 +229,47 @@ class Settings:
|
||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||
|
||||
# Celery configuration (internal)
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||
# 详见 docs/celery-env-bug-report.md
|
||||
# 默认使用 Redis 作为 broker 和 backend,与业务缓存隔离
|
||||
# 如需使用 RabbitMQ,在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost
|
||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
||||
|
||||
# SMTP Email Configuration
|
||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
||||
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
||||
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
||||
|
||||
SANDBOX_URL: str = os.getenv("SANDBOX_URL", "")
|
||||
|
||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Celery Beat Schedule Configuration (定时任务执行频率)
|
||||
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
|
||||
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
|
||||
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
|
||||
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
|
||||
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
|
||||
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
|
||||
|
||||
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
|
||||
# implicit_emotions_update: 每天几分执行(分钟,0-59)
|
||||
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
|
||||
# Memory Module Configuration (internal)
|
||||
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
@@ -210,9 +282,35 @@ class Settings:
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
||||
|
||||
# model square loading
|
||||
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800))
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
# ========================================================================
|
||||
# General Ontology Type Configuration
|
||||
# ========================================================================
|
||||
# 通用本体文件路径列表(逗号分隔)
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
|
||||
|
||||
# 是否启用通用本体类型功能
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
||||
|
||||
# Prompt 中最大类型数量
|
||||
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
|
||||
|
||||
# 核心通用类型列表(逗号分隔)
|
||||
CORE_GENERAL_TYPES: str = os.getenv(
|
||||
"CORE_GENERAL_TYPES",
|
||||
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
|
||||
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
|
||||
)
|
||||
|
||||
# 实验模式开关(允许通过 API 动态切换本体配置)
|
||||
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module output files.
|
||||
|
||||
@@ -19,6 +19,7 @@ class BizCode(IntEnum):
|
||||
TENANT_NOT_FOUND = 3002
|
||||
WORKSPACE_NO_ACCESS = 3003
|
||||
WORKSPACE_INVITE_NOT_FOUND = 3004
|
||||
WORKSPACE_ACCESS_DENIED = 3005
|
||||
# API Key 管理(3xxx)
|
||||
API_KEY_NOT_FOUND = 3007
|
||||
API_KEY_DUPLICATE_NAME = 3008
|
||||
@@ -30,6 +31,9 @@ class BizCode(IntEnum):
|
||||
API_KEY_QPS_LIMIT_EXCEEDED = 3014
|
||||
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
|
||||
API_KEY_QUOTA_EXCEEDED = 3016
|
||||
API_KEY_RATE_LIMIT_EXCEEDED = 3017
|
||||
QUOTA_EXCEEDED = 3018
|
||||
RATE_LIMIT_EXCEEDED = 3019
|
||||
# 资源(4xxx)
|
||||
NOT_FOUND = 4000
|
||||
USER_NOT_FOUND = 4001
|
||||
@@ -40,12 +44,14 @@ class BizCode(IntEnum):
|
||||
FILE_NOT_FOUND = 4006
|
||||
APP_NOT_FOUND = 4007
|
||||
RELEASE_NOT_FOUND = 4008
|
||||
USER_NO_ACCESS = 4009
|
||||
|
||||
# 冲突/状态(5xxx)
|
||||
DUPLICATE_NAME = 5001
|
||||
RESOURCE_ALREADY_EXISTS = 5002
|
||||
VERSION_ALREADY_EXISTS = 5003
|
||||
STATE_CONFLICT = 5004
|
||||
RESOURCE_IN_USE = 5005
|
||||
|
||||
# 应用发布(6xxx)
|
||||
PUBLISH_FAILED = 6001
|
||||
@@ -60,6 +66,7 @@ class BizCode(IntEnum):
|
||||
PERMISSION_DENIED = 6010
|
||||
INVALID_CONVERSATION = 6011
|
||||
CONFIG_MISSING = 6012
|
||||
APP_NOT_PUBLISHED = 6013
|
||||
|
||||
# 模型(7xxx)
|
||||
MODEL_CONFIG_INVALID = 7001
|
||||
@@ -112,8 +119,11 @@ HTTP_MAPPING = {
|
||||
BizCode.FORBIDDEN: 403,
|
||||
BizCode.TENANT_NOT_FOUND: 400,
|
||||
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||
BizCode.WORKSPACE_INVITE_NOT_FOUND: 400,
|
||||
BizCode.WORKSPACE_ACCESS_DENIED: 403,
|
||||
BizCode.NOT_FOUND: 400,
|
||||
BizCode.USER_NOT_FOUND: 200,
|
||||
BizCode.USER_NO_ACCESS: 401,
|
||||
BizCode.WORKSPACE_NOT_FOUND: 400,
|
||||
BizCode.MODEL_NOT_FOUND: 400,
|
||||
BizCode.KNOWLEDGE_NOT_FOUND: 400,
|
||||
@@ -125,6 +135,7 @@ HTTP_MAPPING = {
|
||||
BizCode.RESOURCE_ALREADY_EXISTS: 409,
|
||||
BizCode.VERSION_ALREADY_EXISTS: 409,
|
||||
BizCode.STATE_CONFLICT: 409,
|
||||
BizCode.RESOURCE_IN_USE: 409,
|
||||
BizCode.PUBLISH_FAILED: 500,
|
||||
BizCode.NO_DRAFT_TO_PUBLISH: 400,
|
||||
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,
|
||||
@@ -148,7 +159,8 @@ HTTP_MAPPING = {
|
||||
BizCode.API_KEY_QPS_LIMIT_EXCEEDED: 429,
|
||||
BizCode.API_KEY_DAILY_LIMIT_EXCEEDED: 429,
|
||||
BizCode.API_KEY_QUOTA_EXCEEDED: 429,
|
||||
|
||||
BizCode.QUOTA_EXCEEDED: 402,
|
||||
|
||||
BizCode.MODEL_CONFIG_INVALID: 400,
|
||||
BizCode.API_KEY_MISSING: 400,
|
||||
BizCode.PROVIDER_NOT_SUPPORTED: 400,
|
||||
@@ -177,4 +189,21 @@ HTTP_MAPPING = {
|
||||
BizCode.DB_ERROR: 500,
|
||||
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||
BizCode.RATE_LIMITED: 429,
|
||||
BizCode.RATE_LIMIT_EXCEEDED: 429,
|
||||
}
|
||||
|
||||
ERROR_CODE_TO_BIZ_CODE = {
|
||||
"QUOTA_EXCEEDED": BizCode.QUOTA_EXCEEDED,
|
||||
"RATE_LIMIT_EXCEEDED": BizCode.RATE_LIMIT_EXCEEDED,
|
||||
"API_KEY_NOT_FOUND": BizCode.API_KEY_NOT_FOUND,
|
||||
"API_KEY_INVALID": BizCode.API_KEY_INVALID,
|
||||
"API_KEY_EXPIRED": BizCode.API_KEY_EXPIRED,
|
||||
"WORKSPACE_NOT_FOUND": BizCode.WORKSPACE_NOT_FOUND,
|
||||
"WORKSPACE_NO_ACCESS": BizCode.WORKSPACE_NO_ACCESS,
|
||||
"PERMISSION_DENIED": BizCode.PERMISSION_DENIED,
|
||||
"TOKEN_EXPIRED": BizCode.TOKEN_EXPIRED,
|
||||
"TOKEN_INVALID": BizCode.TOKEN_INVALID,
|
||||
"VALIDATION_FAILED": BizCode.VALIDATION_FAILED,
|
||||
"INVALID_PARAMETER": BizCode.INVALID_PARAMETER,
|
||||
"MISSING_PARAMETER": BizCode.MISSING_PARAMETER,
|
||||
}
|
||||
|
||||
82
api/app/core/language_utils.py
Normal file
82
api/app/core/language_utils.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""语言处理工具模块
|
||||
|
||||
本模块提供集中化的语言校验和处理功能,确保整个应用中语言参数的一致性。
|
||||
|
||||
Functions:
|
||||
validate_language: 校验语言参数,确保其为有效值
|
||||
get_language_from_header: 从请求头获取并校验语言参数
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 支持的语言列表
|
||||
SUPPORTED_LANGUAGES = {"zh", "en"}
|
||||
|
||||
# 默认回退语言
|
||||
DEFAULT_LANGUAGE = "zh"
|
||||
|
||||
|
||||
def validate_language(language: Optional[str]) -> str:
|
||||
"""
|
||||
校验语言参数,确保其为有效值。
|
||||
|
||||
Args:
|
||||
language: 待校验的语言代码,可以是 None、"zh"、"en" 或其他值
|
||||
|
||||
Returns:
|
||||
有效的语言代码("zh" 或 "en")
|
||||
|
||||
Examples:
|
||||
>>> validate_language("zh")
|
||||
'zh'
|
||||
>>> validate_language("en")
|
||||
'en'
|
||||
>>> validate_language("EN") # 大小写不敏感
|
||||
'en'
|
||||
>>> validate_language(None) # None 回退到默认值
|
||||
'zh'
|
||||
>>> validate_language("fr") # 不支持的语言回退到默认值
|
||||
'zh'
|
||||
"""
|
||||
if language is None:
|
||||
return DEFAULT_LANGUAGE
|
||||
|
||||
# 标准化:转小写并去除空白
|
||||
lang = str(language).lower().strip()
|
||||
|
||||
if lang in SUPPORTED_LANGUAGES:
|
||||
return lang
|
||||
|
||||
logger.warning(
|
||||
f"无效的语言参数 '{language}',已回退到默认值 '{DEFAULT_LANGUAGE}'。"
|
||||
f"支持的语言: {SUPPORTED_LANGUAGES}"
|
||||
)
|
||||
return DEFAULT_LANGUAGE
|
||||
|
||||
|
||||
def get_language_from_header(language_type: Optional[str]) -> str:
|
||||
"""
|
||||
从请求头获取并校验语言参数。
|
||||
|
||||
这是一个便捷函数,用于在 controller 层统一处理 X-Language-Type Header。
|
||||
|
||||
Args:
|
||||
language_type: 从 X-Language-Type Header 获取的语言值
|
||||
|
||||
Returns:
|
||||
有效的语言代码("zh" 或 "en")
|
||||
|
||||
Examples:
|
||||
>>> get_language_from_header(None) # Header 未传递
|
||||
'zh'
|
||||
>>> get_language_from_header("en")
|
||||
'en'
|
||||
>>> get_language_from_header("invalid") # 无效值回退
|
||||
'zh'
|
||||
"""
|
||||
return validate_language(language_type)
|
||||
@@ -38,6 +38,56 @@ class SensitiveDataLoggingFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
|
||||
class Neo4jSuccessNotificationFilter(logging.Filter):
|
||||
"""Neo4j 日志过滤器:过滤成功/信息性状态的通知,保留真正的警告和错误
|
||||
|
||||
Neo4j 驱动会以 WARNING 级别记录所有数据库通知,包括成功的操作。
|
||||
这个过滤器会过滤掉以下 GQL 状态码的通知,只保留真正的警告和错误:
|
||||
- 00000: 成功完成 (successful completion)
|
||||
- 00N00: 无数据 (no data)
|
||||
- 00NA0: 无数据,信息性通知 (no data, informational notification)
|
||||
|
||||
使用正则表达式进行更严格的匹配,避免误过滤无关的警告。
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
# 编译正则表达式以提高性能
|
||||
# 匹配所有"成功/信息性"的 GQL 状态码:
|
||||
# 00000 = 成功完成, 00N00 = 无数据, 00NA0 = 无数据信息性通知
|
||||
GQL_STATUS_PATTERN = re.compile(r"gql_status=['\"](00000|00N00|00NA0)['\"]")
|
||||
|
||||
# 匹配 status_description 中的成功完成或信息性通知消息
|
||||
SUCCESS_DESC_PATTERN = re.compile(r"status_description=['\"]note:\s*(successful\s+completion|no\s+data)['\"]", re.IGNORECASE)
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
"""
|
||||
过滤 Neo4j 成功通知
|
||||
|
||||
Args:
|
||||
record: 日志记录
|
||||
|
||||
Returns:
|
||||
True表示允许记录,False表示拒绝(过滤掉)
|
||||
"""
|
||||
# 只处理 INFO 和 WARNING 级别的日志
|
||||
# Neo4j 驱动对 severity='INFORMATION' 的通知使用 INFO 级别,
|
||||
# 对 severity='WARNING' 的通知使用 WARNING 级别
|
||||
if record.levelno not in (logging.INFO, logging.WARNING):
|
||||
return True
|
||||
|
||||
# 检查是否是 Neo4j 的成功通知
|
||||
message = str(record.msg)
|
||||
|
||||
# 使用正则表达式进行更严格的匹配
|
||||
# 这样可以避免误过滤包含这些子字符串但不是 Neo4j 通知的日志
|
||||
if self.GQL_STATUS_PATTERN.search(message) or self.SUCCESS_DESC_PATTERN.search(message):
|
||||
return False # 过滤掉这条日志
|
||||
|
||||
# 保留其他所有日志(包括真正的警告和错误)
|
||||
return True
|
||||
|
||||
|
||||
class LoggingConfig:
|
||||
"""全局日志配置类"""
|
||||
|
||||
@@ -65,6 +115,22 @@ class LoggingConfig:
|
||||
# 清除现有处理器
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Neo4j 通知过滤器 - 挂在 handler 上确保所有传播上来的日志都能被过滤
|
||||
neo4j_filter = Neo4jSuccessNotificationFilter()
|
||||
|
||||
# 抑制 Neo4j 通知日志
|
||||
# Neo4j 驱动内部会给 neo4j.notifications logger 配置自己的 handler,
|
||||
# 导致日志绕过根 logger 的 filter 直接输出。
|
||||
# 多管齐下确保过滤生效:
|
||||
# 1. 设置 neo4j.notifications 级别为 WARNING(过滤 INFO 级别的 00NA0 通知)
|
||||
# 2. 在所有 neo4j logger 上添加 filter(过滤 WARNING 级别的成功通知)
|
||||
# 3. 在根 handler 上也添加 filter(兜底)
|
||||
neo4j_notifications_logger = logging.getLogger("neo4j.notifications")
|
||||
neo4j_notifications_logger.setLevel(logging.WARNING)
|
||||
for neo4j_logger_name in ["neo4j", "neo4j.io", "neo4j.pool", "neo4j.notifications"]:
|
||||
neo4j_logger = logging.getLogger(neo4j_logger_name)
|
||||
neo4j_logger.addFilter(neo4j_filter)
|
||||
|
||||
# 创建格式化器
|
||||
formatter = logging.Formatter(
|
||||
fmt=settings.LOG_FORMAT,
|
||||
@@ -80,6 +146,7 @@ class LoggingConfig:
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||
console_handler.addFilter(sensitive_filter)
|
||||
console_handler.addFilter(neo4j_filter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器(带轮转)
|
||||
@@ -93,6 +160,7 @@ class LoggingConfig:
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||
file_handler.addFilter(sensitive_filter)
|
||||
file_handler.addFilter(neo4j_filter)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
cls._initialized = True
|
||||
@@ -461,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
|
||||
# Fallback to console only if file write fails
|
||||
print(f"Warning: Could not write to timing log: {e}")
|
||||
|
||||
# Always print to console (backward compatible behavior)
|
||||
print(f"✓ {step_name}: {duration:.2f}s")
|
||||
# Always log at INFO level (avoids Celery treating stdout as WARNING)
|
||||
_timing_logger = logging.getLogger(__name__)
|
||||
_timing_logger.info(f"✓ {step_name}: {duration:.2f}s")
|
||||
|
||||
|
||||
def get_agent_logger(name: str = "agent_service",
|
||||
|
||||
@@ -1,16 +1,45 @@
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||
from app.schemas.memory_agent_schema import AgentMemoryDataset
|
||||
|
||||
|
||||
def content_input_node(state: ReadState) -> ReadState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information
|
||||
|
||||
Extracts the content from the first message in the state and returns it
|
||||
as the data field while preserving all other state information.
|
||||
|
||||
Args:
|
||||
state: ReadState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
# Return content and maintain all state information
|
||||
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||
|
||||
return {"data": content}
|
||||
|
||||
|
||||
def content_input_write(state: WriteState) -> WriteState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information for write operations
|
||||
|
||||
Extracts the content from the first message in the state for write operations.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
WriteState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
return {"data": content}
|
||||
# Return content and maintain all state information
|
||||
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||
|
||||
return {"data": content}
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Perceptual Memory Retrieval Node & Service
|
||||
|
||||
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
|
||||
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
|
||||
with BM25+embedding fusion reranking.
|
||||
|
||||
Also provides the perceptual_retrieve_node for use as a LangGraph node.
|
||||
"""
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_perceptual_by_fulltext,
|
||||
search_perceptual_by_embedding,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class PerceptualSearchService:
|
||||
"""
|
||||
感知记忆检索服务。
|
||||
|
||||
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
|
||||
调用方只需提供 query / keywords、end_user_id、memory_config,即可获得
|
||||
格式化并排序后的感知记忆列表和拼接文本。
|
||||
|
||||
Usage:
|
||||
service = PerceptualSearchService(end_user_id=..., memory_config=...)
|
||||
results = await service.search(query="...", keywords=[...], limit=10)
|
||||
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
|
||||
"""
|
||||
|
||||
DEFAULT_ALPHA = 0.6
|
||||
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
end_user_id: str,
|
||||
memory_config: Any,
|
||||
alpha: float = DEFAULT_ALPHA,
|
||||
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
|
||||
):
|
||||
self.end_user_id = end_user_id
|
||||
self.memory_config = memory_config
|
||||
self.alpha = alpha
|
||||
self.content_score_threshold = content_score_threshold
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
keywords: Optional[List[str]] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
|
||||
|
||||
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
|
||||
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
|
||||
|
||||
Args:
|
||||
query: 原始用户查询(用于向量检索和 BM25 补查)
|
||||
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
|
||||
limit: 最大返回数量
|
||||
|
||||
Returns:
|
||||
{
|
||||
"memories": [格式化后的记忆 dict, ...],
|
||||
"content": "拼接的纯文本摘要",
|
||||
"keyword_raw": int,
|
||||
"embedding_raw": int,
|
||||
}
|
||||
"""
|
||||
if keywords is None:
|
||||
keywords = [query] if query else []
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
kw_task = self._keyword_search(connector, keywords, limit)
|
||||
emb_task = self._embedding_search(connector, query, limit)
|
||||
|
||||
kw_results, emb_results = await asyncio.gather(
|
||||
kw_task, emb_task, return_exceptions=True
|
||||
)
|
||||
if isinstance(kw_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
|
||||
kw_results = []
|
||||
if isinstance(emb_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
|
||||
emb_results = []
|
||||
|
||||
# 补查 BM25:找出 embedding 命中但 keyword 未命中的 id,
|
||||
# 用原始 query 对这些节点补查全文索引拿 BM25 score
|
||||
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
|
||||
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
|
||||
|
||||
if emb_only_ids and query:
|
||||
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
|
||||
# 把补查到的 BM25 score 注入到 embedding 结果中
|
||||
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
|
||||
for r in emb_results:
|
||||
rid = r.get("id", "")
|
||||
if rid in backfill_map:
|
||||
r["bm25_backfill_score"] = backfill_map[rid]
|
||||
logger.info(
|
||||
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
|
||||
f"{len(backfill_map)} got BM25 scores"
|
||||
)
|
||||
|
||||
reranked = self._rerank(kw_results, emb_results, limit)
|
||||
|
||||
memories = []
|
||||
content_parts = []
|
||||
for record in reranked:
|
||||
fmt = self._format_result(record)
|
||||
fmt["score"] = round(record.get("content_score", 0), 4)
|
||||
memories.append(fmt)
|
||||
content_parts.append(self._build_content_text(fmt))
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] {len(memories)} results after rerank "
|
||||
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
|
||||
)
|
||||
return {
|
||||
"memories": memories,
|
||||
"content": "\n\n".join(content_parts),
|
||||
"keyword_raw": len(kw_results),
|
||||
"embedding_raw": len(emb_results),
|
||||
}
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
async def _bm25_backfill(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
query: str,
|
||||
target_ids: set,
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
对指定 id 集合补查全文索引 BM25 score。
|
||||
|
||||
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
|
||||
"""
|
||||
escaped = escape_lucene_query(query)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
try:
|
||||
r = await search_perceptual_by_fulltext(
|
||||
connector=connector, query=escaped,
|
||||
end_user_id=self.end_user_id,
|
||||
limit=limit * 5, # 多查一些以提高命中率
|
||||
)
|
||||
all_hits = r.get("perceptuals", [])
|
||||
return [h for h in all_hits if h.get("id") in target_ids]
|
||||
except Exception as e:
|
||||
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
|
||||
return []
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
keywords: List[str],
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
|
||||
seen_ids: set = set()
|
||||
all_results: List[dict] = []
|
||||
|
||||
async def _one(kw: str):
|
||||
escaped = escape_lucene_query(kw)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
r = await search_perceptual_by_fulltext(
|
||||
connector=connector, query=escaped,
|
||||
end_user_id=self.end_user_id, limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
|
||||
tasks = [_one(kw) for kw in keywords[:10]]
|
||||
batch = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in batch:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
|
||||
continue
|
||||
for rec in result:
|
||||
rid = rec.get("id", "")
|
||||
if rid and rid not in seen_ids:
|
||||
seen_ids.add(rid)
|
||||
all_results.append(rec)
|
||||
|
||||
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
|
||||
return all_results[:limit]
|
||||
|
||||
async def _embedding_search(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
|
||||
try:
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
with get_db_context() as db:
|
||||
cfg = MemoryConfigService(db).get_embedder_config(
|
||||
str(self.memory_config.embedding_model_id)
|
||||
)
|
||||
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
|
||||
|
||||
r = await search_perceptual_by_embedding(
|
||||
connector=connector, embedder_client=client,
|
||||
query_text=query_text, end_user_id=self.end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
except Exception as e:
|
||||
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
|
||||
return []
|
||||
|
||||
def _rerank(
|
||||
self,
|
||||
keyword_results: List[dict],
|
||||
embedding_results: List[dict],
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""BM25 + embedding 融合排序。
|
||||
|
||||
对 embedding 结果中带有 bm25_backfill_score 的条目,
|
||||
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
|
||||
"""
|
||||
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
|
||||
emb_backfill_items = []
|
||||
for item in embedding_results:
|
||||
backfill_score = item.get("bm25_backfill_score")
|
||||
if backfill_score is not None and item.get("id"):
|
||||
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
|
||||
|
||||
# 合并后统一归一化 BM25 scores
|
||||
all_bm25_items = keyword_results + emb_backfill_items
|
||||
all_bm25_items = self._normalize_scores(all_bm25_items)
|
||||
|
||||
# 建立 id -> normalized BM25 score 的映射
|
||||
bm25_norm_map: Dict[str, float] = {}
|
||||
for item in all_bm25_items:
|
||||
item_id = item.get("id", "")
|
||||
if item_id:
|
||||
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
|
||||
|
||||
# 归一化 embedding scores
|
||||
embedding_results = self._normalize_scores(embedding_results)
|
||||
|
||||
# 合并
|
||||
combined: Dict[str, dict] = {}
|
||||
for item in keyword_results:
|
||||
item_id = item.get("id", "")
|
||||
if not item_id:
|
||||
continue
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = 0.0
|
||||
|
||||
for item in embedding_results:
|
||||
item_id = item.get("id", "")
|
||||
if not item_id:
|
||||
continue
|
||||
if item_id in combined:
|
||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
for item in combined.values():
|
||||
bm25 = float(item.get("bm25_score", 0) or 0)
|
||||
emb = float(item.get("embedding_score", 0) or 0)
|
||||
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
|
||||
|
||||
results = list(combined.values())
|
||||
before = len(results)
|
||||
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
|
||||
results.sort(key=lambda x: x["content_score"], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
|
||||
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
|
||||
)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
|
||||
"""Z-score + sigmoid 归一化。"""
|
||||
if not items:
|
||||
return items
|
||||
scores = [float(it.get(field, 0) or 0) for it in items]
|
||||
if len(scores) <= 1:
|
||||
for it in items:
|
||||
it[f"normalized_{field}"] = 1.0
|
||||
return items
|
||||
mean = sum(scores) / len(scores)
|
||||
var = sum((s - mean) ** 2 for s in scores) / len(scores)
|
||||
std = math.sqrt(var)
|
||||
if std == 0:
|
||||
for it in items:
|
||||
it[f"normalized_{field}"] = 1.0
|
||||
else:
|
||||
for it, s in zip(items, scores):
|
||||
z = (s - mean) / std
|
||||
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def _format_result(record: dict) -> dict:
|
||||
return {
|
||||
"id": record.get("id", ""),
|
||||
"perceptual_type": record.get("perceptual_type", ""),
|
||||
"file_name": record.get("file_name", ""),
|
||||
"file_path": record.get("file_path", ""),
|
||||
"summary": record.get("summary", ""),
|
||||
"topic": record.get("topic", ""),
|
||||
"domain": record.get("domain", ""),
|
||||
"keywords": record.get("keywords", []),
|
||||
"created_at": str(record.get("created_at", "")),
|
||||
"file_type": record.get("file_type", ""),
|
||||
"score": record.get("score", 0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_content_text(formatted: dict) -> str:
|
||||
parts = []
|
||||
if formatted["summary"]:
|
||||
parts.append(formatted["summary"])
|
||||
if formatted["topic"]:
|
||||
parts.append(f"[主题: {formatted['topic']}]")
|
||||
if formatted["keywords"]:
|
||||
kw_list = formatted["keywords"]
|
||||
if isinstance(kw_list, list):
|
||||
parts.append(f"[关键词: {', '.join(kw_list)}]")
|
||||
if formatted["file_name"]:
|
||||
parts.append(f"[文件: {formatted['file_name']}]")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
|
||||
"""Extract search keywords from problem extension results."""
|
||||
keywords = []
|
||||
context = problem_extension.get("context", {})
|
||||
if isinstance(context, dict):
|
||||
for original_q, extended_qs in context.items():
|
||||
keywords.append(original_q)
|
||||
if isinstance(extended_qs, list):
|
||||
keywords.extend(extended_qs)
|
||||
return keywords
|
||||
|
||||
|
||||
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
|
||||
"""
|
||||
LangGraph node: perceptual memory retrieval.
|
||||
|
||||
Uses PerceptualSearchService to run keyword + embedding search with
|
||||
BM25 fusion reranking, then writes results to state['perceptual_data'].
|
||||
"""
|
||||
end_user_id = state.get("end_user_id", "")
|
||||
problem_extension = state.get("problem_extension", {})
|
||||
original_query = state.get("data", "")
|
||||
memory_config = state.get("memory_config", None)
|
||||
|
||||
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
|
||||
|
||||
keywords = _extract_keywords_from_problems(problem_extension)
|
||||
if not keywords:
|
||||
keywords = [original_query] if original_query else []
|
||||
|
||||
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
|
||||
|
||||
service = PerceptualSearchService(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
search_result = await service.search(
|
||||
query=original_query,
|
||||
keywords=keywords,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
result = {
|
||||
"memories": search_result["memories"],
|
||||
"content": search_result["content"],
|
||||
"_intermediate": {
|
||||
"type": "perceptual_retrieve",
|
||||
"title": "感知记忆检索",
|
||||
"data": search_result["memories"],
|
||||
"query": original_query,
|
||||
"result_count": len(search_result["memories"]),
|
||||
},
|
||||
}
|
||||
return {"perceptual_data": result}
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
@@ -12,27 +12,46 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ProblemNodeService(LLMServiceMixin):
|
||||
"""问题处理节点服务类"""
|
||||
"""
|
||||
Problem processing node service class
|
||||
|
||||
Handles problem decomposition and extension operations using LLM services.
|
||||
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
problem_service = ProblemNodeService()
|
||||
|
||||
|
||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
"""
|
||||
Problem decomposition node
|
||||
|
||||
Breaks down complex user queries into smaller, more manageable sub-problems.
|
||||
Uses LLM to analyze the input and generate structured problem decomposition
|
||||
with question types and reasoning.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user input and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with problem decomposition results
|
||||
"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
@@ -53,18 +72,19 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
structured = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
structured = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
# 添加更详细的日志记录
|
||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not structured or not hasattr(structured, 'root'):
|
||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
@@ -106,17 +126,17 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"content_length": len(content),
|
||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||
|
||||
# 创建默认的空结果
|
||||
# Create default empty result
|
||||
result = {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": content,
|
||||
@@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 返回更新后的状态,包含spit_context字段
|
||||
# Return updated state including spit_context field
|
||||
return {"spit_data": result}
|
||||
|
||||
|
||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
"""问题扩展节点"""
|
||||
# 获取原始数据和分解结果
|
||||
"""
|
||||
Problem extension node
|
||||
|
||||
Extends the decomposed problems from Split_The_Problem node by generating
|
||||
additional related questions and organizing them by original question.
|
||||
Uses LLM to create comprehensive question extensions for better memory retrieval.
|
||||
|
||||
Args:
|
||||
state: ReadState containing decomposed problems and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extended problem results
|
||||
"""
|
||||
# Get original data and decomposition results
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
@@ -171,17 +203,18 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
response_content = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
response_content = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not response_content or not hasattr(response_content, 'root'):
|
||||
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
||||
aggregated_dict = {}
|
||||
@@ -215,12 +248,12 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"questions_count": len(databasets),
|
||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Problem_Extension error details: {error_details}")
|
||||
@@ -230,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
print(time.time() - start)
|
||||
result = {
|
||||
"context": aggregated_dict,
|
||||
"original": data,
|
||||
|
||||
@@ -6,34 +6,41 @@ import os
|
||||
# ===== 第三方库 =====
|
||||
from langchain.agents import create_agent
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db, get_db_context
|
||||
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
COUNTState,
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
||||
create_hybrid_retrieval_tool_sync,
|
||||
create_time_retrieval_tool,
|
||||
extract_tool_message_content,
|
||||
)
|
||||
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
db = next(get_db())
|
||||
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
@@ -50,10 +57,25 @@ async def rag_config(state):
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
async def rag_knowledge(state,question):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
@@ -61,22 +83,34 @@ async def rag_knowledge(state,question):
|
||||
cleaned_query = question
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except Exception :
|
||||
retrieval_knowledge=[]
|
||||
except Exception:
|
||||
retrieval_knowledge = []
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||
|
||||
|
||||
async def llm_infomation(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Get LLM configuration information from state
|
||||
|
||||
Retrieves model configuration details including model ID and tenant ID
|
||||
from the memory configuration in the current state.
|
||||
|
||||
Args:
|
||||
state: ReadState containing memory configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Model configuration as Pydantic model
|
||||
"""
|
||||
memory_config = state.get('memory_config', None)
|
||||
model_id = memory_config.llm_model_id
|
||||
tenant_id = memory_config.tenant_id
|
||||
|
||||
# 使用现有的 memory_config 而不是重新查询数据库
|
||||
# 或者使用线程安全的数据库访问
|
||||
# Use existing memory_config instead of re-querying database
|
||||
# or use thread-safe database access
|
||||
with get_db_context() as db:
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||
@@ -85,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState:
|
||||
|
||||
async def clean_databases(data) -> str:
|
||||
"""
|
||||
简化的数据库搜索结果清理函数
|
||||
Simplified database search result cleaning function
|
||||
|
||||
Processes and cleans search results from various sources including
|
||||
reranked results and time-based search results. Extracts text content
|
||||
from structured data and returns as formatted string.
|
||||
|
||||
Args:
|
||||
data: 搜索结果数据
|
||||
data: Search result data (can be string, dict, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的内容字符串
|
||||
str: Cleaned content string
|
||||
"""
|
||||
try:
|
||||
# 解析JSON字符串
|
||||
# Parse JSON string
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = json.loads(data)
|
||||
@@ -104,24 +142,24 @@ async def clean_databases(data) -> str:
|
||||
if not isinstance(data, dict):
|
||||
return str(data)
|
||||
|
||||
# 获取结果数据
|
||||
# Get result data
|
||||
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
||||
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
||||
results = data.get('results', data)
|
||||
if not isinstance(results, dict):
|
||||
return str(results)
|
||||
|
||||
# 收集所有内容
|
||||
# Collect all content
|
||||
content_list = []
|
||||
|
||||
# 处理重排序结果
|
||||
|
||||
# Process reranked results
|
||||
reranked = results.get('reranked_results', {})
|
||||
if reranked:
|
||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
||||
for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']:
|
||||
items = reranked.get(category, [])
|
||||
if isinstance(items, list):
|
||||
content_list.extend(items)
|
||||
# 处理时间搜索结果
|
||||
# Process time search results
|
||||
time_search = results.get('time_search', {})
|
||||
if time_search:
|
||||
if isinstance(time_search, dict):
|
||||
@@ -131,17 +169,23 @@ async def clean_databases(data) -> str:
|
||||
elif isinstance(time_search, list):
|
||||
content_list.extend(time_search)
|
||||
|
||||
# 提取文本内容
|
||||
# Extract text content,对 community 按 name 去重(多次 tool 调用会产生重复)
|
||||
text_parts = []
|
||||
seen_community_names = set()
|
||||
for item in content_list:
|
||||
if isinstance(item, dict):
|
||||
text = item.get('statement') or item.get('content', '')
|
||||
# community 节点用 name 去重
|
||||
if 'member_count' in item or 'core_entities' in item:
|
||||
community_name = item.get('name') or item.get('id', '')
|
||||
if community_name in seen_community_names:
|
||||
continue
|
||||
seen_community_names.add(community_name)
|
||||
text = item.get('statement') or item.get('content') or item.get('summary', '')
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
|
||||
|
||||
return '\n'.join(text_parts).strip()
|
||||
|
||||
except Exception as e:
|
||||
@@ -150,24 +194,33 @@ async def clean_databases(data) -> str:
|
||||
|
||||
|
||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Retrieve information using simplified search approach
|
||||
|
||||
Processes extended problems from previous nodes and performs retrieval
|
||||
using either RAG or hybrid search based on storage type. Handles concurrent
|
||||
processing of multiple questions and deduplicates results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
|
||||
'''
|
||||
|
||||
模型信息
|
||||
'''
|
||||
|
||||
problem_extension=state.get('problem_extension', '')['context']
|
||||
storage_type=state.get('storage_type', '')
|
||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
||||
end_user_id=state.get('end_user_id', '')
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original=state.get('data', '')
|
||||
problem_list=[]
|
||||
for key,values in problem_extension.items():
|
||||
original = state.get('data', '')
|
||||
problem_list = []
|
||||
for key, values in problem_extension.items():
|
||||
for data in values:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
# 创建异步任务处理单个问题
|
||||
|
||||
# Create async task to process individual questions
|
||||
async def process_question_nodes(idx, question):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
@@ -213,7 +266,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
databases_data = {
|
||||
@@ -244,7 +297,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val, strict=False):
|
||||
if j!=['']:
|
||||
if j != ['']:
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
@@ -257,15 +310,26 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
}
|
||||
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve':dup_databases}
|
||||
|
||||
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取end_user_id
|
||||
"""
|
||||
Advanced retrieve function using LangChain agents and tools
|
||||
|
||||
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
|
||||
to perform sophisticated information retrieval. Supports both RAG and traditional
|
||||
memory storage approaches with concurrent processing and result deduplication.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
# Get end_user_id from state
|
||||
import time
|
||||
start=time.time()
|
||||
start = time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
@@ -283,6 +347,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
with get_db_context() as db: # 使用同步数据库上下文管理器
|
||||
config_service = MemoryConfigService(db)
|
||||
return await llm_infomation(state)
|
||||
|
||||
llm_config = await get_llm_info()
|
||||
api_key_obj = llm_config.api_keys[0]
|
||||
api_key = api_key_obj.api_key
|
||||
@@ -296,28 +361,33 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
search_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries", "statements", "chunks", "entities", "communities"],
|
||||
}
|
||||
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
||||
tools=[time_retrieval_tool, hybrid_retrieval],
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
# Create async task to process individual questions
|
||||
import asyncio
|
||||
|
||||
# 在模块级别定义信号量,限制最大并发数
|
||||
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
|
||||
# Define semaphore at module level to limit maximum concurrency
|
||||
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
|
||||
|
||||
async def process_question(idx, question):
|
||||
async with SEMAPHORE: # 限制并发
|
||||
async with SEMAPHORE: # Limit concurrency
|
||||
try:
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||
question)
|
||||
else:
|
||||
cleaned_query = question
|
||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||
# Use asyncio to run synchronous agent.invoke in thread pool
|
||||
import asyncio
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
@@ -331,8 +401,32 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
raw_results = tool_results['content']
|
||||
clean_content = await clean_databases(raw_results)
|
||||
|
||||
# 社区展开:从 tool 返回结果中提取命中的 community,
|
||||
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
|
||||
_expanded_stmts_to_write = []
|
||||
try:
|
||||
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
|
||||
reranked = results_dict.get('reranked_results', {})
|
||||
community_hits = reranked.get('communities', [])
|
||||
if not community_hits:
|
||||
community_hits = results_dict.get('communities', [])
|
||||
if community_hits:
|
||||
from app.core.memory.agent.services.search_service import expand_communities_to_statements
|
||||
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
|
||||
community_results=community_hits,
|
||||
end_user_id=end_user_id,
|
||||
existing_content=clean_content,
|
||||
)
|
||||
if new_texts:
|
||||
clean_content = clean_content + '\n' + '\n'.join(new_texts)
|
||||
except Exception as parse_err:
|
||||
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
|
||||
|
||||
try:
|
||||
raw_results = raw_results['results']
|
||||
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段)
|
||||
if _expanded_stmts_to_write and isinstance(raw_results, dict):
|
||||
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
|
||||
except Exception:
|
||||
raw_results = []
|
||||
|
||||
@@ -366,7 +460,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
import asyncio
|
||||
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
@@ -413,5 +507,3 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
# json.dump(dup_databases, f, indent=4)
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||
PerceptualSearchService,
|
||||
)
|
||||
from app.core.memory.agent.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
@@ -17,34 +19,144 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.db import get_db
|
||||
from app.core.memory.enums import Neo4jNodeType
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
db_session = next(get_db())
|
||||
|
||||
|
||||
class SummaryNodeService(LLMServiceMixin):
|
||||
"""总结节点服务类"""
|
||||
"""
|
||||
Summary node service class
|
||||
|
||||
Handles summary generation operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
generating summaries from retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
# 创建全局服务实例
|
||||
|
||||
# Create global service instance
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings specifically for summary generation.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary with knowledge base settings
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id": os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach for summary generation
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results for summary processing.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for in knowledge base
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
- retrieval_knowledge: List of retrieved knowledge chunks
|
||||
- clean_content: Formatted content string
|
||||
- cleaned_query: Processed query string
|
||||
- raw_results: Raw retrieval results
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query = question
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except Exception:
|
||||
retrieval_knowledge = []
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Retrieve conversation history for summary context
|
||||
|
||||
Gets the conversation history for the current user to provide context
|
||||
for summary generation operations.
|
||||
|
||||
Args:
|
||||
state: ReadState containing end_user_id
|
||||
|
||||
Returns:
|
||||
ReadState: Conversation history data
|
||||
"""
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||
search_mode) -> str:
|
||||
"""
|
||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||
Enhanced summary_llm function with better error handling and data validation
|
||||
|
||||
Generates summaries using LLM with structured output. Includes fallback mechanisms
|
||||
for handling LLM failures and provides robust error recovery.
|
||||
|
||||
Args:
|
||||
state: ReadState containing current context
|
||||
history: Conversation history for context
|
||||
retrieve_info: Retrieved information to summarize
|
||||
template_name: Jinja2 template name for prompt generation
|
||||
operation_name: Type of operation (summary, input_summary, retrieve_summary)
|
||||
response_model: Pydantic model for structured output
|
||||
search_mode: Search mode flag ("0" for simple, "1" for complex)
|
||||
|
||||
Returns:
|
||||
str: Generated summary text or fallback message
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
|
||||
# 构建系统提示词
|
||||
|
||||
# Build system prompt
|
||||
if str(search_mode) == "0":
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
template_name=template_name,
|
||||
@@ -61,40 +173,41 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
try:
|
||||
# 使用优化的LLM服务进行结构化输出
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# 验证结构化响应
|
||||
# Use optimized LLM service for structured output
|
||||
with get_db_context() as db_session:
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# Validate structured response
|
||||
if structured is None:
|
||||
logger.warning(f"LLM返回None,使用默认回答")
|
||||
logger.warning("LLM返回None,使用默认回答")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
# 根据操作类型提取答案
|
||||
|
||||
# Extract answer based on operation type
|
||||
if operation_name == "summary":
|
||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
# 处理RetrieveSummaryResponse
|
||||
# Handle RetrieveSummaryResponse
|
||||
if hasattr(structured, 'data') and structured.data:
|
||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
logger.warning(f"结构化响应缺少data字段")
|
||||
logger.warning("结构化响应缺少data字段")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
# 验证答案不为空
|
||||
|
||||
# Validate answer is not empty
|
||||
if not aimessages or aimessages.strip() == "":
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
return aimessages
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||
|
||||
# 尝试非结构化输出作为fallback
|
||||
|
||||
# Try unstructured output as fallback
|
||||
try:
|
||||
logger.info("尝试非结构化输出作为fallback")
|
||||
response = await summary_service.call_llm_simple(
|
||||
@@ -103,24 +216,38 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
system_prompt=system_prompt,
|
||||
fallback_message="信息不足,无法回答"
|
||||
)
|
||||
|
||||
|
||||
if response and response.strip():
|
||||
# 简单清理响应
|
||||
# Simple response cleaning
|
||||
cleaned_response = response.strip()
|
||||
# 移除可能的JSON标记
|
||||
# Remove possible JSON markers
|
||||
if cleaned_response.startswith('```'):
|
||||
lines = cleaned_response.split('\n')
|
||||
cleaned_response = '\n'.join(lines[1:-1])
|
||||
|
||||
|
||||
return cleaned_response
|
||||
else:
|
||||
return "信息不足,无法回答"
|
||||
|
||||
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback也失败: {fallback_error}")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
|
||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
"""
|
||||
Save summary results to Redis session storage
|
||||
|
||||
Stores the generated summary and user query in Redis for session management
|
||||
and conversation history tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user and query information
|
||||
aimessages: Generated summary message to save
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state after saving to Redis
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
@@ -132,10 +259,26 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
)
|
||||
await SessionService(store).cleanup_duplicates()
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
storage_type=state.get("storage_type",'')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
|
||||
|
||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||
"""
|
||||
Format summary results for different output types
|
||||
|
||||
Creates structured output formats for both input summary and retrieval summary
|
||||
operations, including metadata and intermediate results for frontend display.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user information
|
||||
aimessages: Generated summary message
|
||||
raw_results: Raw search/retrieval results
|
||||
|
||||
Returns:
|
||||
tuple: (input_summary, retrieve_summary) formatted result dictionaries
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
input_summary = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
@@ -152,14 +295,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
retrieve={
|
||||
retrieve = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "retrieval_summary",
|
||||
"title":"快速检索",
|
||||
"title": "快速检索",
|
||||
"summary": aimessages,
|
||||
"query": data,
|
||||
"storage_type": storage_type,
|
||||
@@ -167,31 +310,90 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
return input_summary,retrieve
|
||||
return input_summary, retrieve
|
||||
|
||||
|
||||
async def Input_Summary(state: ReadState) -> ReadState:
|
||||
start=time.time()
|
||||
storage_type=state.get("storage_type",'')
|
||||
"""
|
||||
Generate quick input summary from retrieved information
|
||||
|
||||
Performs fast retrieval and generates a quick summary response for user queries.
|
||||
This function prioritizes speed by only searching summary nodes and provides
|
||||
immediate feedback to users.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user query, storage configuration, and context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing summary results with status and metadata
|
||||
"""
|
||||
start = time.time()
|
||||
storage_type = state.get("storage_type", '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
end_user_id=state.get("end_user_id", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
history = await summary_history( state)
|
||||
history = await summary_history(state)
|
||||
search_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
"include": [Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY] # MemorySummary 和 Community 同为高维度概括节点
|
||||
}
|
||||
|
||||
try:
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
||||
if storage_type != "rag":
|
||||
|
||||
async def _perceptual_search():
|
||||
service = PerceptualSearchService(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
return await service.search(query=data, limit=5)
|
||||
|
||||
hybrid_task = SearchService().execute_hybrid_search(
|
||||
**search_params,
|
||||
memory_config=memory_config,
|
||||
expand_communities=False,
|
||||
)
|
||||
perceptual_task = _perceptual_search()
|
||||
|
||||
gather_results = await asyncio.gather(
|
||||
hybrid_task, perceptual_task, return_exceptions=True
|
||||
)
|
||||
hybrid_result = gather_results[0]
|
||||
perceptual_results = gather_results[1]
|
||||
|
||||
# 处理 hybrid search 异常
|
||||
if isinstance(hybrid_result, Exception):
|
||||
raise hybrid_result
|
||||
retrieve_info, question, raw_results = hybrid_result
|
||||
|
||||
# 处理感知记忆结果
|
||||
if isinstance(perceptual_results, Exception):
|
||||
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
|
||||
perceptual_results = []
|
||||
|
||||
# 拼接感知记忆内容到 retrieve_info
|
||||
if perceptual_results and isinstance(perceptual_results, dict):
|
||||
perceptual_content = perceptual_results.get("content", "")
|
||||
if perceptual_content:
|
||||
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
|
||||
count = len(perceptual_results.get("memories", []))
|
||||
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
|
||||
|
||||
# 调试:打印 community 检索结果数量
|
||||
if raw_results and isinstance(raw_results, dict):
|
||||
reranked = raw_results.get('reranked_results', {})
|
||||
community_hits = reranked.get('communities', [])
|
||||
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
|
||||
f"summary 命中数: {len(reranked.get('summaries', []))}")
|
||||
else:
|
||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||
except Exception as e:
|
||||
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
||||
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
|
||||
retrieve_info, question, raw_results = "", data, []
|
||||
|
||||
|
||||
try:
|
||||
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
||||
# 'input_summary',RetrieveSummaryResponse)
|
||||
@@ -199,8 +401,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
||||
summary = summary_result[0]
|
||||
except Exception as e:
|
||||
logger.error( f"Input_Summary failed: {e}", exc_info=True )
|
||||
summary= {
|
||||
logger.error(f"Input_Summary failed: {e}", exc_info=True)
|
||||
summary = {
|
||||
"status": "fail",
|
||||
"summary_result": "信息不足,无法回答",
|
||||
"storage_type": storage_type,
|
||||
@@ -208,35 +410,58 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"error": str(e)
|
||||
}
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
duration = end - start
|
||||
log_time('检索', duration)
|
||||
return {"summary":summary}
|
||||
return {"summary": summary}
|
||||
|
||||
async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
retrieve=state.get("retrieve", '')
|
||||
history = await summary_history( state)
|
||||
|
||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate comprehensive summary from retrieved expansion issues
|
||||
|
||||
Processes retrieved expansion issues and generates a detailed summary using LLM.
|
||||
This function handles complex retrieval results and provides comprehensive answers
|
||||
based on expanded query results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing retrieve data with expansion issues
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing comprehensive summary results
|
||||
"""
|
||||
retrieve = state.get("retrieve", '')
|
||||
history = await summary_history(state)
|
||||
import json
|
||||
with open("检索.json","w",encoding='utf-8') as f:
|
||||
with open("检索.json", "w", encoding='utf-8') as f:
|
||||
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
||||
retrieve=retrieve.get("Expansion_issue", [])
|
||||
start=time.time()
|
||||
retrieve_info_str=[]
|
||||
retrieve = retrieve.get("Expansion_issue", [])
|
||||
start = time.time()
|
||||
retrieve_info_str = []
|
||||
for data in retrieve:
|
||||
if data=='':
|
||||
retrieve_info_str=''
|
||||
if data == '':
|
||||
retrieve_info_str = ''
|
||||
else:
|
||||
for key, value in data.items():
|
||||
if key=='Answer_Small':
|
||||
if key == 'Answer_Small':
|
||||
for i in value:
|
||||
retrieve_info_str.append(i)
|
||||
retrieve_info_str=list(set(retrieve_info_str))
|
||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
||||
retrieve_info_str = list(set(retrieve_info_str))
|
||||
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
aimessages = await summary_llm(
|
||||
state,
|
||||
history,
|
||||
retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2',
|
||||
'retrieve_summary', RetrieveSummaryResponse,
|
||||
"1"
|
||||
)
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -248,33 +473,52 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary":summary}
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary(state: ReadState)-> ReadState:
|
||||
start=time.time()
|
||||
async def Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate final comprehensive summary from verified data
|
||||
|
||||
Creates the final summary using verified expansion issues and conversation history.
|
||||
This function processes verified data to generate the most comprehensive and
|
||||
accurate response to user queries.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and query information
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing final summary results
|
||||
"""
|
||||
start = time.time()
|
||||
query = state.get("data", '')
|
||||
verify=state.get("verify", '')
|
||||
verify_expansion_issue=verify.get("verified_data", '')
|
||||
retrieve_info_str=''
|
||||
verify = state.get("verify", '')
|
||||
verify_expansion_issue = verify.get("verified_data", '')
|
||||
retrieve_info_str = ''
|
||||
for data in verify_expansion_issue:
|
||||
for key, value in data.items():
|
||||
if key=='answer_small':
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str+=i+'\n'
|
||||
history=await summary_history(state)
|
||||
retrieve_info_str += i + '\n'
|
||||
history = await summary_history(state)
|
||||
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages=await summary_llm(state,history,data,
|
||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
@@ -286,14 +530,28 @@ async def Summary(state: ReadState)-> ReadState:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary":summary}
|
||||
return {"summary": summary}
|
||||
|
||||
async def Summary_fails(state: ReadState)-> ReadState:
|
||||
storage_type=state.get("storage_type", '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
||||
|
||||
async def Summary_fails(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate fallback summary when normal summary process fails
|
||||
|
||||
Provides a fallback summary generation mechanism when the standard summary
|
||||
process encounters errors or fails to produce satisfactory results. Uses
|
||||
a specialized failure template to handle edge cases.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and failure context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing fallback summary results
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
@@ -304,17 +562,24 @@ async def Summary_fails(state: ReadState)-> ReadState:
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str += i + '\n'
|
||||
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result= {
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
return {"summary":result}
|
||||
return {"summary": result}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.verification_models import VerificationResult
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
@@ -10,29 +11,53 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class VerificationNodeService(LLMServiceMixin):
|
||||
"""验证节点服务类"""
|
||||
"""
|
||||
Verification node service class
|
||||
|
||||
Handles data verification operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
verifying and validating retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
# 创建全局服务实例
|
||||
|
||||
# Create global service instance
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
"""
|
||||
Process verification results and generate output format
|
||||
|
||||
Transforms VerificationResult objects into structured output format suitable
|
||||
for frontend consumption. Handles conversion of VerificationItem objects to
|
||||
dictionary format and adds metadata for tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user configuration
|
||||
messages_deal: VerificationResult containing verification outcomes
|
||||
|
||||
Returns:
|
||||
dict: Formatted verification result with status and metadata
|
||||
"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
data = state.get('data', '')
|
||||
|
||||
# 将 VerificationItem 对象转换为字典列表
|
||||
|
||||
# Convert VerificationItem objects to dictionary list
|
||||
verified_data = []
|
||||
if messages_deal.expansion_issue:
|
||||
for item in messages_deal.expansion_issue:
|
||||
@@ -40,7 +65,7 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
verified_data.append(item.model_dump())
|
||||
elif isinstance(item, dict):
|
||||
verified_data.append(item)
|
||||
|
||||
|
||||
Verify_result = {
|
||||
"status": messages_deal.split_result,
|
||||
"verified_data": verified_data,
|
||||
@@ -58,34 +83,37 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
}
|
||||
}
|
||||
return Verify_result
|
||||
|
||||
|
||||
async def Verify(state: ReadState):
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", {})
|
||||
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
logger.info(
|
||||
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||
|
||||
|
||||
messages = {
|
||||
"Query": content,
|
||||
"Expansion_issue": retrieve_expansion
|
||||
}
|
||||
|
||||
logger.info("Verify: 开始渲染模板")
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
|
||||
# Generate JSON schema to guide LLM output format
|
||||
json_schema = VerificationResult.model_json_schema()
|
||||
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
template_name='split_verify_prompt.jinja2',
|
||||
operation_name='split_verify_prompt',
|
||||
@@ -94,29 +122,30 @@ async def Verify(state: ReadState):
|
||||
json_schema=json_schema
|
||||
)
|
||||
logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}")
|
||||
|
||||
|
||||
# 使用优化的LLM服务,添加超时保护
|
||||
logger.info("Verify: 开始调用 LLM")
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
import asyncio
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"query": content,
|
||||
"history": history if isinstance(history, list) else [],
|
||||
"expansion_issue": [],
|
||||
"split_result": "failed",
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
)
|
||||
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
|
||||
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
|
||||
|
||||
with get_db_context() as db_session:
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"query": content,
|
||||
"history": history if isinstance(history, list) else [],
|
||||
"expansion_issue": [],
|
||||
"split_result": "failed",
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150 second timeout
|
||||
)
|
||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
||||
@@ -127,11 +156,11 @@ async def Verify(state: ReadState):
|
||||
split_result="failed",
|
||||
reason="LLM调用超时"
|
||||
)
|
||||
|
||||
|
||||
result = await Verify_prompt(state, structured)
|
||||
logger.info("=== Verify 节点执行完成 ===")
|
||||
return {"verify": result}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
|
||||
# 返回失败的验证结果
|
||||
@@ -152,4 +181,4 @@ async def Verify(state: ReadState):
|
||||
"user_rag_memory_id": state.get('user_rag_memory_id', '')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.logging_config import get_agent_logger
|
||||
@@ -10,7 +11,7 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages, end_user_id, and memory_config
|
||||
state: WriteState containing messages, end_user_id, memory_config, and language
|
||||
|
||||
Returns:
|
||||
dict: Contains 'write_result' with status and data fields
|
||||
@@ -18,6 +19,7 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
messages = state.get('messages', [])
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', '')
|
||||
language = state.get('language', 'zh') # 默认中文
|
||||
|
||||
# Convert LangChain messages to structured format expected by write()
|
||||
structured_messages = []
|
||||
@@ -35,9 +37,19 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
messages=structured_messages,
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
language=language,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
# 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成
|
||||
for lang in ["zh", "en"]:
|
||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=lang,
|
||||
)
|
||||
if deleted:
|
||||
logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||
|
||||
write_result = {
|
||||
"status": "success",
|
||||
"data": structured_messages,
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
|
||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||
perceptual_retrieve_node,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
||||
Split_The_Problem,
|
||||
Problem_Extension,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
||||
retrieve,
|
||||
retrieve_nodes,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||
Input_Summary,
|
||||
@@ -30,12 +28,26 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
Retrieve_continue,
|
||||
Verify_continue,
|
||||
)
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph():
|
||||
"""创建并返回 LangGraph 工作流"""
|
||||
"""
|
||||
Create and return a LangGraph workflow for memory reading operations
|
||||
|
||||
Builds a state graph workflow that handles memory retrieval, problem analysis,
|
||||
verification, and summarization. The workflow includes nodes for content input,
|
||||
problem splitting, retrieval, verification, and various summary operations.
|
||||
|
||||
Yields:
|
||||
StateGraph: Compiled LangGraph workflow for memory reading
|
||||
|
||||
Raises:
|
||||
Exception: If workflow creation fails
|
||||
"""
|
||||
try:
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
@@ -43,135 +55,34 @@ async def make_read_graph():
|
||||
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
||||
workflow.add_node("Problem_Extension", Problem_Extension)
|
||||
workflow.add_node("Input_Summary", Input_Summary)
|
||||
# workflow.add_node("Retrieve", retrieve_nodes)
|
||||
workflow.add_node("Retrieve", retrieve)
|
||||
workflow.add_node("Retrieve", retrieve_nodes)
|
||||
# workflow.add_node("Retrieve", retrieve)
|
||||
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
|
||||
workflow.add_node("Verify", Verify)
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||
workflow.add_node("Summary", Summary)
|
||||
workflow.add_node("Summary_fails", Summary_fails)
|
||||
|
||||
# 添加边
|
||||
|
||||
# Add edges to define workflow flow
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
|
||||
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
|
||||
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
# 编译工作流
|
||||
|
||||
# Compile workflow
|
||||
graph = workflow.compile()
|
||||
yield graph
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建工作流失败: {e}")
|
||||
logger.error(f"创建工作流失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
end_user_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
import time
|
||||
start=time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
print(f"处理节点: {node_name}")
|
||||
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||
summary = node_data['InputSummary']['summary_result']
|
||||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||||
summary = node_data['RetrieveSummary']['summary_result']
|
||||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||||
summary = node_data['summary']['summary_result']
|
||||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||||
summary = node_data['SummaryFails']['summary_result']
|
||||
|
||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||
if spit_data and spit_data != [] and spit_data != {}:
|
||||
_intermediate_outputs.append(spit_data)
|
||||
|
||||
# Problem_Extension 节点
|
||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||
_intermediate_outputs.append(problem_extension)
|
||||
|
||||
# Retrieve 节点
|
||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
_intermediate_outputs.append(summary_n)
|
||||
|
||||
# # 过滤掉空值
|
||||
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||
#
|
||||
# # 优化搜索结果
|
||||
# print("=== 开始优化搜索结果 ===")
|
||||
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
# result=reorder_output_results(optimized_outputs)
|
||||
# # 保存优化后的结果到文件
|
||||
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
|
||||
# import json
|
||||
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
|
||||
#
|
||||
print(f"=== 最终摘要 ===")
|
||||
print(summary)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
end=time.time()
|
||||
print(100*'y')
|
||||
print(f"总耗时: {end-start}s")
|
||||
print(100*'y')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
counter = COUNTState(limit=3)
|
||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
|
||||
|
||||
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
status=state.get('verify', '')['status']
|
||||
status = state.get('verify', '')['status']
|
||||
# loop_count = counter.get_total()
|
||||
if "success" in status:
|
||||
# counter.reset()
|
||||
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
|
||||
# if loop_count < 2: # Maximum loop count is 3
|
||||
# return "content_input"
|
||||
# else:
|
||||
# counter.reset()
|
||||
# counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Add default return value to avoid returning None
|
||||
|
||||
@@ -0,0 +1,295 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from app.celery_task_scheduler import scheduler
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
|
||||
async def write(
|
||||
storage_type,
|
||||
end_user_id,
|
||||
user_message,
|
||||
ai_message,
|
||||
user_rag_memory_id,
|
||||
actual_end_user_id,
|
||||
actual_config_id,
|
||||
long_term_messages=None
|
||||
):
|
||||
"""
|
||||
Write memory with structured message support
|
||||
|
||||
Handles memory writing operations for different storage types (Neo4j/RAG).
|
||||
Supports both individual message pairs and batch long-term message processing.
|
||||
|
||||
Args:
|
||||
storage_type: Storage type identifier ("neo4j" or "rag")
|
||||
end_user_id: Terminal user identifier
|
||||
user_message: User message content
|
||||
ai_message: AI response content
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_end_user_id: Actual user identifier for storage
|
||||
actual_config_id: Configuration identifier
|
||||
long_term_messages: Optional list of structured messages for batch processing
|
||||
|
||||
Logic explanation:
|
||||
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
|
||||
- Neo4j mode: Uses structured message lists
|
||||
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
|
||||
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
|
||||
3. Each message is converted to independent Chunk, preserving speaker field
|
||||
"""
|
||||
|
||||
if long_term_messages is None:
|
||||
long_term_messages = []
|
||||
with get_db_context() as db:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j mode: Use structured message lists
|
||||
structured_messages = []
|
||||
|
||||
# Always add user message (if not empty)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# Only add assistant message when AI reply is not empty
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# If long_term_messages provided, use it to replace structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# If it's a JSON string, parse it first
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# If no messages, return directly
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
# write_id = write_message_task.delay(
|
||||
# actual_end_user_id, # end_user_id: User ID
|
||||
# structured_messages, # message: JSON string format message list
|
||||
# str(actual_config_id), # config_id: Configuration ID string
|
||||
# storage_type, # storage_type: "neo4j"
|
||||
# user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
# )
|
||||
scheduler.push_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
str(actual_end_user_id),
|
||||
{
|
||||
"end_user_id": str(actual_end_user_id),
|
||||
"message": structured_messages,
|
||||
"config_id": str(actual_config_id),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id or ""
|
||||
}
|
||||
)
|
||||
|
||||
# logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
# write_status = get_task_memory_write_result(str(write_id))
|
||||
# logger.info(f'[WRITE] Task result - user={actual_end_user_id}')
|
||||
|
||||
|
||||
async def term_memory_save(end_user_id, strategy_type, scope):
|
||||
"""
|
||||
Save long-term memory data to database
|
||||
|
||||
Handles the storage of long-term memory data based on different strategies
|
||||
(chunk-based or aggregate-based) and manages the transition from short-term
|
||||
to long-term memory storage.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for memory association
|
||||
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
scope: Scope/window size for memory processing
|
||||
"""
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if not result:
|
||||
logger.warning(f"No write data found for user {end_user_id}")
|
||||
return
|
||||
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data) == scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||
long_messages = await messages_parse(long_time_data)
|
||||
repo.upsert(end_user_id, long_messages)
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||
"""
|
||||
Process dialogue based on window size and write to Neo4j
|
||||
|
||||
Manages conversation data based on a sliding window approach. When the window
|
||||
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
langchain_messages: Original message data list
|
||||
scope: Window size determining when to trigger long-term storage
|
||||
"""
|
||||
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_has_history:
|
||||
end_user_visit_count, redis_messages = is_end_user_has_history
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
return
|
||||
end_user_visit_count += 1
|
||||
if end_user_visit_count < scope:
|
||||
redis_messages.extend(langchain_messages)
|
||||
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
|
||||
else:
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
redis_messages.extend(langchain_messages)
|
||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
scheduler.push_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
str(end_user_id),
|
||||
{
|
||||
"end_user_id": str(end_user_id),
|
||||
"message": redis_messages,
|
||||
"config_id": str(config_id),
|
||||
"storage_type": AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||
"user_rag_memory_id": ""
|
||||
}
|
||||
)
|
||||
# write_message_task.delay(
|
||||
# end_user_id, # end_user_id: User ID
|
||||
# redis_messages, # message: JSON string format message list
|
||||
# config_id, # config_id: Configuration ID string
|
||||
# AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
||||
# "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
# )
|
||||
count_store.update_sessions_count(end_user_id, 0, [])
|
||||
|
||||
|
||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||
"""
|
||||
Process memory storage based on time intervals and write to Neo4j
|
||||
|
||||
Retrieves Redis data based on time intervals and writes it to Neo4j for
|
||||
long-term storage. This function handles time-based memory consolidation.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
time: Time interval for data retrieval
|
||||
"""
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = long_time_data
|
||||
messages = []
|
||||
memory_config = memory_config.config_id
|
||||
for i in format_messages:
|
||||
message = json.loads(i['Query'])
|
||||
messages += message
|
||||
if format_messages:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
|
||||
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
Aggregation judgment function: determine if input sentence and historical messages describe the same event
|
||||
|
||||
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
|
||||
historical data or stored as separate events. This helps optimize memory storage and retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: Memory configuration object containing LLM settings
|
||||
|
||||
Returns:
|
||||
dict: Aggregation judgment result containing is_same_event flag and processed output
|
||||
"""
|
||||
history = None
|
||||
try:
|
||||
# 1. Get historical session data (using new method)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
history = []
|
||||
else:
|
||||
history = await format_parsing(result)
|
||||
json_schema = WriteAggregateModel.model_json_schema()
|
||||
template_service = TemplateService(template_root)
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='write_aggregate_judgment.jinja2',
|
||||
operation_name='aggregate_judgment',
|
||||
history=history,
|
||||
sentence=ori_messages,
|
||||
json_schema=json_schema
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
factory = MemoryClientFactory(db_session)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": system_prompt
|
||||
}
|
||||
]
|
||||
structured = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=WriteAggregateModel
|
||||
)
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
result_dict = {
|
||||
"is_same_event": structured.is_same_event,
|
||||
"output": output_value
|
||||
}
|
||||
if not structured.is_same_event:
|
||||
logger.info(result_dict)
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -2,41 +2,53 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from app.core.memory.src.search import (
|
||||
search_by_temporal,
|
||||
search_by_keyword_temporal,
|
||||
)
|
||||
|
||||
|
||||
def extract_tool_message_content(response):
|
||||
"""从agent响应中提取ToolMessage内容和工具名称"""
|
||||
"""
|
||||
Extract ToolMessage content and tool names from agent response
|
||||
|
||||
Parses agent response messages to extract tool execution results and metadata.
|
||||
Handles JSON parsing and provides structured access to tool output data.
|
||||
|
||||
Args:
|
||||
response: Agent response dictionary containing messages
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
|
||||
- tool_name: Name of the executed tool
|
||||
- content: Parsed tool execution result (JSON or raw text)
|
||||
"""
|
||||
messages = response.get('messages', [])
|
||||
|
||||
for message in messages:
|
||||
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
||||
# 这是一个ToolMessage
|
||||
# This is a ToolMessage
|
||||
tool_content = message.content
|
||||
tool_name = None
|
||||
|
||||
# 尝试获取工具名称
|
||||
# Try to get tool name
|
||||
if hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
elif hasattr(message, 'tool_name'):
|
||||
tool_name = message.tool_name
|
||||
|
||||
try:
|
||||
# 解析JSON内容
|
||||
# Parse JSON content
|
||||
parsed_content = json.loads(tool_content)
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': parsed_content
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# 如果不是JSON格式,直接返回内容
|
||||
# If not JSON format, return content directly
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': tool_content
|
||||
@@ -46,38 +58,61 @@ def extract_tool_message_content(response):
|
||||
|
||||
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
"""
|
||||
Input schema for time retrieval tool
|
||||
|
||||
Defines the expected input parameters for time-based retrieval operations.
|
||||
Used for validation and documentation of tool parameters.
|
||||
|
||||
Attributes:
|
||||
context: User input query content for search
|
||||
end_user_id: Group ID for filtering search results, defaults to test user
|
||||
"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
|
||||
|
||||
Creates a specialized time-based retrieval tool that searches for statements within
|
||||
specified time ranges. Includes field cleaning functionality to remove unnecessary
|
||||
metadata from search results.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for scoping search results
|
||||
|
||||
Returns:
|
||||
function: Configured TimeRetrievalWithGroupId tool function
|
||||
"""
|
||||
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
"""
|
||||
清理时间搜索结果中不需要的字段,并修改结构
|
||||
Clean unnecessary fields from temporal search results and modify structure
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption and
|
||||
restructures the response format for better usability.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据
|
||||
data: Data to be cleaned (dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# List of fields to filter out
|
||||
fields_to_remove = {
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'valid_at', 'invalid_at', 'statement_ids'
|
||||
}
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
||||
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
|
||||
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
|
||||
cleaned_value = clean_temporal_result_fields(value)
|
||||
# 进一步将内部的 statements 改为 time_search
|
||||
# Further change internal statements to time_search
|
||||
if 'statements' in cleaned_value:
|
||||
cleaned['results'] = {
|
||||
'time_search': cleaned_value['statements']
|
||||
@@ -91,26 +126,35 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return [clean_temporal_result_fields(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
|
||||
end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs time-based search operations with automatic metadata filtering. Supports
|
||||
flexible date range specification and provides clean, user-friendly output.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query context content
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- end_user_id_param: Group ID (optional, overrides default group ID)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results with temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
# Use passed parameters or default values
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
# Basic time search
|
||||
results = await search_by_temporal(
|
||||
end_user_id=actual_end_user_id,
|
||||
start_date=actual_start_date,
|
||||
@@ -118,33 +162,43 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=10
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
cleaned_results = results
|
||||
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
@tool
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
|
||||
clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询内容
|
||||
- days_back: 向前搜索的天数,默认7天
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
- end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs combined keyword and temporal search operations with automatic metadata
|
||||
filtering. Provides more targeted search results by combining content relevance
|
||||
with time-based filtering.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query content for keyword matching
|
||||
- days_back: Number of days to search backwards, default 7 days
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results combining keyword and temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||
|
||||
# 关键词时间搜索
|
||||
# Keyword time search
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
end_user_id=end_user_id,
|
||||
@@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=15
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
@@ -162,50 +216,61 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
|
||||
return TimeRetrievalWithGroupId
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"""
|
||||
创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段
|
||||
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates an advanced hybrid search tool that combines multiple search strategies
|
||||
(keyword, vector, hybrid) with automatic result cleaning and formatting.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||
memory_config: Memory configuration object containing LLM and search settings
|
||||
**search_params: Search parameters including end_user_id, limit, include, etc.
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearch tool function with async capabilities
|
||||
"""
|
||||
|
||||
|
||||
def clean_result_fields(data):
|
||||
"""
|
||||
递归清理结果中不需要的字段
|
||||
Recursively clean unnecessary fields from results
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption,
|
||||
improving readability and reducing response size.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据(可能是字典、列表或其他类型)
|
||||
data: Data to be cleaned (can be dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# List of fields to filter out
|
||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||
}
|
||||
|
||||
# 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements
|
||||
|
||||
if isinstance(data, dict):
|
||||
# 对字典进行清理
|
||||
# Clean dictionary
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key not in fields_to_remove:
|
||||
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
|
||||
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
|
||||
return cleaned
|
||||
elif isinstance(data, list):
|
||||
# 对列表中的每个元素进行清理
|
||||
# Clean each element in list
|
||||
return [clean_result_fields(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
# Return other types directly
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
async def HybridSearch(
|
||||
context: str,
|
||||
@@ -215,57 +280,63 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
clean_output: bool = True # 新增:是否清理输出字段
|
||||
clean_output: bool = True # New: whether to clean output fields
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
|
||||
|
||||
Provides comprehensive search capabilities combining multiple search strategies
|
||||
with intelligent result ranking and automatic metadata filtering for clean output.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
rerank_alpha: Reranking weight parameter for result scoring
|
||||
use_forgetting_rerank: Whether to use forgetting-based reranking
|
||||
use_llm_rerank: Whether to use LLM-based reranking
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted comprehensive search results
|
||||
"""
|
||||
try:
|
||||
# 导入run_hybrid_search函数
|
||||
# Import run_hybrid_search function
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
# 合并参数,优先使用传入的参数
|
||||
# Merge parameters, prioritize passed parameters
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]),
|
||||
"output_path": None, # Don't save to file
|
||||
"memory_config": memory_config,
|
||||
"rerank_alpha": rerank_alpha,
|
||||
"use_forgetting_rerank": use_forgetting_rerank,
|
||||
"use_llm_rerank": use_llm_rerank
|
||||
}
|
||||
|
||||
# 执行混合检索
|
||||
# Execute hybrid retrieval
|
||||
raw_results = await run_hybrid_search(**final_params)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_result_fields(raw_results)
|
||||
else:
|
||||
cleaned_results = raw_results
|
||||
|
||||
# 格式化返回结果
|
||||
# Format return results
|
||||
formatted_results = {
|
||||
"search_query": context,
|
||||
"search_type": search_type,
|
||||
"results": cleaned_results
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"error": f"混合检索失败: {str(e)}",
|
||||
@@ -274,38 +345,52 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return HybridSearch
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"""
|
||||
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
|
||||
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates a synchronous wrapper around the async hybrid search functionality,
|
||||
making it compatible with synchronous tool execution environments.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数
|
||||
memory_config: Memory configuration object containing search settings
|
||||
**search_params: Search parameters for configuration
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearchSync tool function
|
||||
"""
|
||||
|
||||
@tool
|
||||
def HybridSearchSync(
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Provides the same hybrid search capabilities as the async version but in a
|
||||
synchronous execution context. Automatically handles async-to-sync conversion.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 创建异步工具并执行
|
||||
# Create async tool and execute
|
||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||
return await async_tool.ainvoke({
|
||||
"context": context,
|
||||
@@ -314,7 +399,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
106
api/app/core/memory/agent/langgraph_graph/tools/write_tool.py
Normal file
106
api/app/core/memory/agent/langgraph_graph/tools/write_tool.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
|
||||
|
||||
async def format_parsing(messages: list, type: str = 'string'):
|
||||
"""
|
||||
Format and parse message lists into different output types
|
||||
|
||||
Processes message lists from storage and converts them into either string format
|
||||
or dictionary format based on the specified type parameter. Handles JSON parsing
|
||||
and role-based message organization.
|
||||
|
||||
Args:
|
||||
messages: List of message objects from storage containing message data
|
||||
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
|
||||
|
||||
Returns:
|
||||
list: Formatted message list in the specified format
|
||||
- 'string': List of formatted text messages with role prefixes
|
||||
- 'dict': List of dictionaries mapping user messages to AI responses
|
||||
"""
|
||||
result = []
|
||||
user = []
|
||||
ai = []
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
for history_messag in hstory_messages.strip().splitlines():
|
||||
history_messag = json.loads(history_messag)
|
||||
for content in history_messag:
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role == "user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict":
|
||||
if role == 'human' or role == "user":
|
||||
user.append(content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key, values in zip(user, ai):
|
||||
result.append({key: values})
|
||||
return result
|
||||
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
"""
|
||||
Parse messages from storage format into user-AI conversation pairs
|
||||
|
||||
Extracts and organizes conversation data from stored message format,
|
||||
separating user and AI messages and pairing them for database storage.
|
||||
|
||||
Args:
|
||||
messages: List or dictionary containing stored message data with Query fields
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing user-AI message pairs for database storage
|
||||
"""
|
||||
user = []
|
||||
ai = []
|
||||
database = []
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
for data in Query:
|
||||
role = data['role']
|
||||
if role == "human":
|
||||
user.append(data['content'])
|
||||
if role == "ai":
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content, ai_content):
|
||||
"""
|
||||
Create structured chat message format for agent conversations
|
||||
|
||||
Formats user and AI content into a standardized message structure suitable
|
||||
for agent processing and storage. Creates role-based message objects.
|
||||
|
||||
Args:
|
||||
user_content: User's message content string
|
||||
ai_content: AI's response content string
|
||||
|
||||
Returns:
|
||||
list: List of structured message dictionaries with role and content fields
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{user_content}"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"{ai_content}"
|
||||
}
|
||||
|
||||
]
|
||||
return messages
|
||||
@@ -1,93 +1,94 @@
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
@asynccontextmanager
|
||||
async def make_write_graph():
|
||||
|
||||
async def long_term_storage(
|
||||
long_term_type: str,
|
||||
langchain_messages: list,
|
||||
memory_config_id: str,
|
||||
end_user_id: str,
|
||||
scope: int = 6
|
||||
):
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
Handle long-term memory storage with different strategies
|
||||
|
||||
Supports multiple storage strategies including chunk-based, time-based,
|
||||
and aggregate judgment approaches for long-term memory persistence.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||
langchain_messages: List of messages to store
|
||||
memory_config_id: Memory configuration identifier
|
||||
end_user_id: User group identifier
|
||||
scope: Scope parameter for chunk-based storage (default: 6)
|
||||
"""
|
||||
# workflow = StateGraph(WriteState)
|
||||
# workflow.add_node("content_input", content_input_write)
|
||||
# workflow.add_node("save_neo4j", write_node)
|
||||
# workflow.add_edge(START, "content_input")
|
||||
# workflow.add_edge("content_input", "save_neo4j")
|
||||
# workflow.add_edge("save_neo4j", END)
|
||||
#
|
||||
# graph = workflow.compile()
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
workflow.add_edge("save_neo4j", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
yield graph
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "今天周一"
|
||||
end_user_id = 'new_2025test1103' # 组ID
|
||||
|
||||
if langchain_messages is None:
|
||||
langchain_messages = []
|
||||
|
||||
write_store.save_session_write(end_user_id, langchain_messages)
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
try:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j'==node_name:
|
||||
massages=node_data
|
||||
massages=massages.get('write_result')['status']
|
||||
print(massages) # | 更新数据: {node_data}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config_id, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
# Dialogue window with 6 rounds of conversation
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
# Time-based strategy
|
||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
# Aggregate judgment
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
async def write_long_term(
|
||||
storage_type: str,
|
||||
end_user_id: str,
|
||||
messages: list[dict],
|
||||
user_rag_memory_id: str,
|
||||
actual_config_id: str
|
||||
):
|
||||
"""
|
||||
Write long-term memory with different storage types
|
||||
|
||||
Handles both RAG-based storage and traditional memory storage approaches.
|
||||
For traditional storage, uses chunk-based strategy with paired user-AI messages.
|
||||
|
||||
Args:
|
||||
storage_type: Type of storage (RAG or traditional)
|
||||
end_user_id: User group identifier
|
||||
messages: message list
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_config_id: Actual configuration ID
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
message_content = []
|
||||
for message in messages:
|
||||
message_content.append(f'{message.get("role")}:{message.get("content")}')
|
||||
messages_string = "\n".join(message_content)
|
||||
await write_rag(end_user_id, messages_string, user_rag_memory_id)
|
||||
else:
|
||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
await long_term_storage(long_term_type=CHUNK,
|
||||
langchain_messages=messages,
|
||||
memory_config_id=actual_config_id,
|
||||
end_user_id=end_user_id,
|
||||
scope=SCOPE)
|
||||
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Pydantic models for write aggregate judgment operations."""
|
||||
|
||||
from typing import List, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MessageItem(BaseModel):
|
||||
"""Individual message item in conversation."""
|
||||
|
||||
role: str = Field(..., description="角色:user 或 assistant")
|
||||
content: str = Field(..., description="消息内容")
|
||||
|
||||
|
||||
class WriteAggregateResponse(BaseModel):
|
||||
"""Response model for aggregate judgment containing judgment result and output."""
|
||||
|
||||
is_same_event: bool = Field(
|
||||
...,
|
||||
description="是否是同一事件。True表示是同一事件,False表示不同事件"
|
||||
)
|
||||
output: Union[List[MessageItem], bool] = Field(
|
||||
...,
|
||||
description="如果is_same_event为True,返回False;如果is_same_event为False,返回消息列表"
|
||||
)
|
||||
|
||||
|
||||
# 为了保持向后兼容,保留旧的类名作为别名
|
||||
WriteAggregateModel = WriteAggregateResponse
|
||||
@@ -7,21 +7,88 @@ and deduplication.
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.enums import Neo4jNodeType
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||
_EXPAND_FIELDS_TO_REMOVE = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
||||
}
|
||||
|
||||
|
||||
def _clean_expand_fields(obj):
|
||||
"""递归过滤展开结果中不可序列化的字段(DateTime 等)。"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE}
|
||||
if isinstance(obj, list):
|
||||
return [_clean_expand_fields(i) for i in obj]
|
||||
return obj
|
||||
|
||||
|
||||
async def expand_communities_to_statements(
|
||||
community_results: List[dict],
|
||||
end_user_id: str,
|
||||
existing_content: str = "",
|
||||
limit: int = 10,
|
||||
) -> Tuple[List[dict], List[str]]:
|
||||
"""
|
||||
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
||||
|
||||
- 对展开结果去重(过滤已在 existing_content 中出现的文本)
|
||||
- 过滤不可序列化字段
|
||||
- 返回 (cleaned_expanded_stmts, new_texts)
|
||||
- cleaned_expanded_stmts: 可直接写回 raw_results 的列表
|
||||
- new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content
|
||||
"""
|
||||
community_ids = [r.get("id") for r in community_results if r.get("id")]
|
||||
if not community_ids or not end_user_id:
|
||||
return [], []
|
||||
|
||||
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
result = await search_graph_community_expand(
|
||||
connector=connector,
|
||||
community_ids=community_ids,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}")
|
||||
return [], []
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
expanded_stmts = result.get("expanded_statements", [])
|
||||
if not expanded_stmts:
|
||||
return [], []
|
||||
|
||||
existing_lines = set(existing_content.splitlines())
|
||||
new_texts = [
|
||||
s["statement"] for s in expanded_stmts
|
||||
if s.get("statement") and s["statement"] not in existing_lines
|
||||
]
|
||||
cleaned = _clean_expand_fields(expanded_stmts)
|
||||
logger.info(
|
||||
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||
return cleaned, new_texts
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the search service."""
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
def extract_content_from_result(self, result: dict) -> str:
|
||||
|
||||
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
|
||||
"""
|
||||
Extract only meaningful content from search results, dropping all metadata.
|
||||
|
||||
@@ -30,35 +97,50 @@ class SearchService:
|
||||
- Entities: extract 'name' and 'fact_summary' fields
|
||||
- Summaries: extract 'content' field
|
||||
- Chunks: extract 'content' field
|
||||
- Communities: extract 'content' field (c.summary), prefixed with community name
|
||||
|
||||
Args:
|
||||
result: Search result dictionary
|
||||
node_type: Hint for node type ("community", "summary", etc.)
|
||||
|
||||
Returns:
|
||||
Clean content string without metadata
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return str(result)
|
||||
|
||||
|
||||
content_parts = []
|
||||
|
||||
|
||||
# Statements: extract statement field
|
||||
if 'statement' in result and result['statement']:
|
||||
content_parts.append(result['statement'])
|
||||
|
||||
# Summaries/Chunks: extract content field
|
||||
if 'content' in result and result['content']:
|
||||
if Neo4jNodeType.STATEMENT in result and result[Neo4jNodeType.STATEMENT]:
|
||||
content_parts.append(result[Neo4jNodeType.STATEMENT])
|
||||
|
||||
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||
is_community = (
|
||||
node_type == Neo4jNodeType.COMMUNITY
|
||||
or 'member_count' in result
|
||||
or 'core_entities' in result
|
||||
)
|
||||
if is_community:
|
||||
name = result.get('name', '')
|
||||
content = result.get('content', '')
|
||||
if content:
|
||||
prefix = f"[主题:{name}] " if name else ""
|
||||
content_parts.append(f"{prefix}{content}")
|
||||
elif 'content' in result and result['content']:
|
||||
# Summaries / Chunks
|
||||
content_parts.append(result['content'])
|
||||
|
||||
|
||||
# Entities: extract name and fact_summary (commented out in original)
|
||||
# if 'name' in result and result['name']:
|
||||
# content_parts.append(result['name'])
|
||||
# if result.get('fact_summary'):
|
||||
# content_parts.append(result['fact_summary'])
|
||||
|
||||
|
||||
# Return concatenated content or empty string
|
||||
return '\n'.join(content_parts) if content_parts else ""
|
||||
|
||||
|
||||
def clean_query(self, query: str) -> str:
|
||||
"""
|
||||
Clean and escape query text for Lucene.
|
||||
@@ -74,32 +156,33 @@ class SearchService:
|
||||
Cleaned and escaped query string
|
||||
"""
|
||||
q = str(query).strip()
|
||||
|
||||
|
||||
# Remove wrapping quotes
|
||||
if (q.startswith("'") and q.endswith("'")) or (
|
||||
q.startswith('"') and q.endswith('"')
|
||||
q.startswith('"') and q.endswith('"')
|
||||
):
|
||||
q = q[1:-1]
|
||||
|
||||
|
||||
# Remove newlines and carriage returns
|
||||
q = q.replace('\r', ' ').replace('\n', ' ').strip()
|
||||
|
||||
|
||||
# Apply Lucene escaping
|
||||
q = escape_lucene_query(q)
|
||||
|
||||
|
||||
return q
|
||||
|
||||
|
||||
async def execute_hybrid_search(
|
||||
self,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config = None
|
||||
self,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config=None,
|
||||
expand_communities: bool = True,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
@@ -114,17 +197,19 @@ class SearchService:
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
memory_config: Memory configuration object (required)
|
||||
expand_communities: If True, expand community hits to member statements (default: True).
|
||||
Set to False for quick-summary paths that only need community-level text.
|
||||
|
||||
Returns:
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
raw_results is None if return_raw_results=False
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
include = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
|
||||
try:
|
||||
# Execute search
|
||||
answer = await run_hybrid_search(
|
||||
@@ -137,18 +222,18 @@ class SearchService:
|
||||
memory_config=memory_config,
|
||||
rerank_alpha=rerank_alpha
|
||||
)
|
||||
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
# Prioritize summaries as they contain synthesized contextual information
|
||||
answer_list = []
|
||||
|
||||
|
||||
# For hybrid search, use reranked_results
|
||||
if search_type == "hybrid":
|
||||
reranked_results = answer.get('reranked_results', {})
|
||||
|
||||
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
|
||||
|
||||
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||
priority_order = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in reranked_results:
|
||||
category_results = reranked_results[category]
|
||||
@@ -157,33 +242,46 @@ class SearchService:
|
||||
else:
|
||||
# For keyword or embedding search, results are directly in answer dict
|
||||
# Apply same priority order
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
|
||||
priority_order = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in answer:
|
||||
category_results = answer[category]
|
||||
if isinstance(category_results, list):
|
||||
answer_list.extend(category_results)
|
||||
|
||||
# Extract clean content from all results
|
||||
content_list = [
|
||||
self.extract_content_from_result(ans)
|
||||
for ans in answer_list
|
||||
]
|
||||
|
||||
|
||||
# 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要)
|
||||
if expand_communities and Neo4jNodeType.COMMUNITY in include:
|
||||
community_results = (
|
||||
answer.get('reranked_results', {}).get(Neo4jNodeType.COMMUNITY.value, [])
|
||||
if search_type == "hybrid"
|
||||
else answer.get(Neo4jNodeType.COMMUNITY.value, [])
|
||||
)
|
||||
cleaned_stmts, new_texts = await expand_communities_to_statements(
|
||||
community_results=community_results,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
answer_list.extend(cleaned_stmts)
|
||||
|
||||
# Extract clean content from all results,按类型传入 node_type 区分 community
|
||||
content_list = []
|
||||
for ans in answer_list:
|
||||
# community 节点有 member_count 或 core_entities 字段
|
||||
ntype = Neo4jNodeType.COMMUNITY if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||
|
||||
# Filter out empty strings and join with newlines
|
||||
clean_content = '\n'.join([c for c in content_list if c])
|
||||
|
||||
|
||||
# Log first 200 chars
|
||||
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
|
||||
|
||||
|
||||
# Return raw results if requested
|
||||
if return_raw_results:
|
||||
return clean_content, cleaned_query, answer
|
||||
else:
|
||||
return clean_content, cleaned_query, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
|
||||
|
||||
@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
ref_id: str = "",
|
||||
config_id: str = None
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||
@@ -21,7 +21,7 @@ async def get_chunked_dialogs(
|
||||
end_user_id: Group identifier
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
config_id: Configuration ID for processing (used to load pruning config)
|
||||
|
||||
Returns:
|
||||
List of DialogData objects with generated chunks
|
||||
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
files = msg.get("file_content", [])
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
@@ -57,6 +58,63 @@ async def get_chunked_dialogs(
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 语义剪枝步骤(在分块之前)
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# 加载剪枝配置
|
||||
pruning_config = None
|
||||
if config_id:
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
# 使用 MemoryConfigService 加载完整的 MemoryConfig 对象
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="semantic_pruning"
|
||||
)
|
||||
|
||||
if memory_config:
|
||||
pruning_config = PruningConfig(
|
||||
pruning_switch=memory_config.pruning_enabled,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=memory_config.pruning_threshold,
|
||||
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||
ontology_class_infos=memory_config.ontology_class_infos,
|
||||
)
|
||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
||||
|
||||
# 获取LLM客户端用于剪枝
|
||||
if pruning_config.pruning_switch:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
|
||||
original_msg_count = len(dialog_data.context.msgs)
|
||||
|
||||
# 使用 prune_dataset 而不是 prune_dialog
|
||||
# prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息
|
||||
pruned_dialogs = await pruner.prune_dataset([dialog_data])
|
||||
|
||||
if pruned_dialogs:
|
||||
dialog_data = pruned_dialogs[0]
|
||||
remaining_msg_count = len(dialog_data.context.msgs)
|
||||
deleted_count = original_msg_count - remaining_msg_count
|
||||
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
|
||||
else:
|
||||
logger.warning("[剪枝] prune_dataset 返回空列表")
|
||||
else:
|
||||
logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝")
|
||||
except Exception as e:
|
||||
logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
|
||||
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
class LLMClientPool:
|
||||
"""LLM客户端连接池"""
|
||||
|
||||
def __init__(self, max_size: int = 5):
|
||||
self.max_size = max_size
|
||||
self.pools: Dict[str, asyncio.Queue] = {}
|
||||
self.active_clients: Dict[str, int] = {}
|
||||
|
||||
async def get_client(self, llm_model_id: str):
|
||||
"""获取LLM客户端"""
|
||||
if llm_model_id not in self.pools:
|
||||
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
|
||||
self.active_clients[llm_model_id] = 0
|
||||
|
||||
pool = self.pools[llm_model_id]
|
||||
|
||||
try:
|
||||
# 尝试从池中获取客户端
|
||||
client = pool.get_nowait()
|
||||
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
except asyncio.QueueEmpty:
|
||||
# 池为空,创建新客户端
|
||||
if self.active_clients[llm_model_id] < self.max_size:
|
||||
db_session = next(get_db())
|
||||
client = get_llm_client_fast(llm_model_id, db_session)
|
||||
self.active_clients[llm_model_id] += 1
|
||||
logger.debug(f"创建新LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
else:
|
||||
# 等待可用客户端
|
||||
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
|
||||
return await pool.get()
|
||||
|
||||
async def return_client(self, llm_model_id: str, client):
|
||||
"""归还LLM客户端到池中"""
|
||||
if llm_model_id in self.pools:
|
||||
try:
|
||||
self.pools[llm_model_id].put_nowait(client)
|
||||
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
|
||||
except asyncio.QueueFull:
|
||||
# 池已满,丢弃客户端
|
||||
self.active_clients[llm_model_id] -= 1
|
||||
logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}")
|
||||
|
||||
# 全局客户端池
|
||||
llm_client_pool = LLMClientPool()
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Annotated, TypedDict
|
||||
@@ -8,16 +7,19 @@ from langgraph.graph import add_messages
|
||||
|
||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||
|
||||
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
"""
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
"""
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
end_user_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
memory_config: object
|
||||
write_result: dict
|
||||
data: str
|
||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
||||
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
@@ -42,18 +44,21 @@ class ReadState(TypedDict):
|
||||
config_id: str
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
problem_extension:dict
|
||||
problem_extension: dict
|
||||
storage_type: str
|
||||
user_rag_memory_id: str
|
||||
llm_id: str
|
||||
embedding_id: str
|
||||
memory_config: object # 新增字段用于传递内存配置对象
|
||||
retrieve:dict
|
||||
retrieve: dict
|
||||
perceptual_data: dict
|
||||
RetrieveSummary: dict
|
||||
InputSummary: dict
|
||||
verify: dict
|
||||
SummaryFails: dict
|
||||
summary: dict
|
||||
|
||||
|
||||
class COUNTState:
|
||||
"""
|
||||
工作流对话检索内容计数器
|
||||
@@ -98,6 +103,7 @@ class COUNTState:
|
||||
self.total = 0
|
||||
print("[COUNTState] 已重置为 0")
|
||||
|
||||
|
||||
def deduplicate_entries(entries):
|
||||
seen = set()
|
||||
deduped = []
|
||||
@@ -108,6 +114,7 @@ def deduplicate_entries(entries):
|
||||
deduped.append(entry)
|
||||
return deduped
|
||||
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
@@ -141,4 +148,4 @@ def convert_extended_question_to_question(data):
|
||||
return [convert_extended_question_to_question(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
return data
|
||||
|
||||
@@ -39,6 +39,30 @@
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
## 指代消歧规则(Coreference Resolution):
|
||||
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
||||
|
||||
1. **"用户"的消歧**:
|
||||
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
||||
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物,则"用户"指的就是这个人
|
||||
- 示例:历史中有"老李的原名叫李建国",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
||||
|
||||
2. **"我"的消歧**:
|
||||
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
||||
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
||||
|
||||
3. **"他/她/它"的消歧**:
|
||||
- 从上下文或历史中找出最近提到的同类实体
|
||||
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
||||
|
||||
4. **"那个人/这个人"的消歧**:
|
||||
- 从历史中找出最近提到的人物
|
||||
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
||||
|
||||
5. **优先级**:
|
||||
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
||||
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
||||
|
||||
|
||||
|
||||
输出要求:
|
||||
@@ -71,6 +95,34 @@
|
||||
"reason": "输出原问题的关键要素"
|
||||
}
|
||||
]
|
||||
|
||||
## 指代消歧示例(重要):
|
||||
示例1 - "用户"的消歧:
|
||||
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
||||
输入问题:"用户是谁?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"original_question": "用户是谁?",
|
||||
"extended_question": "李建国是谁?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
||||
}
|
||||
]
|
||||
|
||||
示例2 - "我"的消歧:
|
||||
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
||||
输入问题:"我推荐的书是什么?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"original_question": "我推荐的书是什么?",
|
||||
"extended_question": "张曼玉推荐的书是什么?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
||||
}
|
||||
]
|
||||
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
|
||||
@@ -27,6 +27,30 @@
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
## 指代消歧规则(Coreference Resolution):
|
||||
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
||||
|
||||
1. **"用户"的消歧**:
|
||||
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
||||
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物(如"老李"、"李建国"),则"用户"指的就是这个人
|
||||
- 示例:历史中反复出现"老李/李建国/建国哥",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
||||
|
||||
2. **"我"的消歧**:
|
||||
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
||||
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
||||
|
||||
3. **"他/她/它"的消歧**:
|
||||
- 从上下文或历史中找出最近提到的同类实体
|
||||
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
||||
|
||||
4. **"那个人/这个人"的消歧**:
|
||||
- 从历史中找出最近提到的人物
|
||||
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
||||
|
||||
5. **优先级**:
|
||||
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
||||
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
||||
|
||||
## 指令:
|
||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||
单跳(Single-hop)
|
||||
@@ -151,6 +175,34 @@
|
||||
]
|
||||
- 必须通过json.loads()的格式支持的形式输出
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
|
||||
## 指代消歧示例(重要):
|
||||
示例1 - "用户"的消歧:
|
||||
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
||||
输入问题:"用户是谁?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": "李建国是谁?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
||||
}
|
||||
]
|
||||
|
||||
示例2 - "我"的消歧:
|
||||
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
||||
输入问题:"我推荐的书是什么?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": "张曼玉推荐的书是什么?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
||||
}
|
||||
]
|
||||
|
||||
- 关键的JSON格式要求
|
||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
输入句子:{{sentence}}
|
||||
历史消息:{{history}}
|
||||
|
||||
# 你的角色
|
||||
你是一个擅长事件聚合与语义判断的专家。
|
||||
|
||||
# 你的任务
|
||||
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
|
||||
|
||||
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False):
|
||||
- 描述的是同一个具体事件或事实
|
||||
- 存在明显的因果关系、前后发展关系
|
||||
- 是对同一事件的补充、解释、追问或延展
|
||||
- 逻辑上属于同一语境下的连续讨论
|
||||
|
||||
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
|
||||
- 话题不同,事件主体不同
|
||||
- 时间、地点、对象明显不同
|
||||
- 只是语义相似,但并非同一具体事件
|
||||
- 无直接事件、因果或逻辑关联
|
||||
|
||||
# 输出规则(非常重要)
|
||||
你必须按照以下JSON格式输出:
|
||||
|
||||
**如果是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": true,
|
||||
"output": false
|
||||
}
|
||||
```
|
||||
|
||||
**如果不是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": false,
|
||||
"output": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "输入句子的内容"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "对应的回复内容"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
# JSON Schema
|
||||
{{json_schema}}
|
||||
|
||||
# 注意事项
|
||||
- 必须严格按照上述格式输出
|
||||
- output 字段:如果是同一事件返回 false,如果不是同一事件返回完整的消息列表
|
||||
- 消息列表必须包含 role 和 content 字段
|
||||
- 不要输出任何解释、分析或多余内容
|
||||
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
from typing import Any, List, Dict, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def serialize_messages(messages: Any) -> str:
|
||||
"""
|
||||
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
|
||||
|
||||
Args:
|
||||
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
|
||||
|
||||
Returns:
|
||||
str: JSON 字符串
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return messages
|
||||
|
||||
if isinstance(messages, (list, tuple)):
|
||||
# 检查是否是 LangChain 消息对象列表
|
||||
serialized_list = []
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
||||
# LangChain 消息对象
|
||||
serialized_list.append({
|
||||
'type': msg.type,
|
||||
'content': msg.content,
|
||||
'role': getattr(msg, 'role', msg.type)
|
||||
})
|
||||
elif isinstance(msg, dict):
|
||||
serialized_list.append(msg)
|
||||
else:
|
||||
serialized_list.append(str(msg))
|
||||
return json.dumps(serialized_list, ensure_ascii=False)
|
||||
|
||||
if isinstance(messages, dict):
|
||||
return json.dumps(messages, ensure_ascii=False)
|
||||
|
||||
# 其他类型转为字符串
|
||||
return str(messages)
|
||||
|
||||
|
||||
def deserialize_messages(messages_str: str) -> Any:
|
||||
"""
|
||||
将 JSON 字符串反序列化为原始格式
|
||||
|
||||
Args:
|
||||
messages_str: JSON 字符串
|
||||
|
||||
Returns:
|
||||
反序列化后的对象(list、dict 或 string)
|
||||
"""
|
||||
if not messages_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
return json.loads(messages_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return messages_str
|
||||
|
||||
|
||||
def fix_encoding(text: str) -> str:
|
||||
"""
|
||||
修复错误编码的文本
|
||||
|
||||
Args:
|
||||
text: 需要修复的文本
|
||||
|
||||
Returns:
|
||||
str: 修复后的文本
|
||||
"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
|
||||
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
格式化会话数据为统一的输出格式
|
||||
|
||||
Args:
|
||||
data: 原始会话数据
|
||||
include_time: 是否包含时间字段
|
||||
|
||||
Returns:
|
||||
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
|
||||
"""
|
||||
result = {
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": fix_encoding(data.get('aimessages', ''))
|
||||
}
|
||||
|
||||
if include_time:
|
||||
result["starttime"] = data.get('starttime', '')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
|
||||
"""
|
||||
根据时间范围过滤数据
|
||||
|
||||
Args:
|
||||
items: 包含 starttime 字段的数据列表
|
||||
minutes: 时间范围(分钟)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 过滤后的数据列表
|
||||
"""
|
||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
filtered_items = []
|
||||
for item in items:
|
||||
starttime = item.get('starttime', '')
|
||||
if starttime and starttime >= time_threshold_str:
|
||||
filtered_items.append(item)
|
||||
|
||||
return filtered_items
|
||||
|
||||
|
||||
def sort_and_limit_results(items: List[Dict], limit: int = 6,
|
||||
remove_time: bool = True) -> List[Dict]:
|
||||
"""
|
||||
对结果进行排序、限制数量并移除时间字段
|
||||
|
||||
Args:
|
||||
items: 数据列表
|
||||
limit: 最大返回数量
|
||||
remove_time: 是否移除 starttime 字段
|
||||
|
||||
Returns:
|
||||
List[Dict]: 处理后的数据列表
|
||||
"""
|
||||
# 按时间降序排序(最新的在前)
|
||||
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
# 限制数量
|
||||
result_items = items[:limit]
|
||||
|
||||
# 移除 starttime 字段
|
||||
if remove_time:
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于1条,返回空列表
|
||||
if len(result_items) < 1:
|
||||
return []
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
def generate_session_key(session_id: str, key_type: str = "session") -> str:
|
||||
"""
|
||||
生成 Redis key
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
key_type: key 类型 ("session", "read", "write", "count")
|
||||
|
||||
Returns:
|
||||
str: Redis key
|
||||
"""
|
||||
if key_type == "count":
|
||||
return f"session:count:{session_id}"
|
||||
elif key_type == "write":
|
||||
return f"session:write:{session_id}"
|
||||
elif key_type == "session" or key_type == "read":
|
||||
return f"session:{session_id}"
|
||||
else:
|
||||
return f"session:{session_id}"
|
||||
|
||||
|
||||
def get_current_timestamp() -> str:
|
||||
"""
|
||||
获取当前时间戳字符串
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
|
||||
"""
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -1,11 +1,37 @@
|
||||
import redis
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.core.config import settings
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.agent.utils.redis_base import (
|
||||
serialize_messages,
|
||||
deserialize_messages,
|
||||
fix_encoding,
|
||||
format_session_data,
|
||||
filter_by_time_range,
|
||||
sort_and_limit_results,
|
||||
generate_session_key,
|
||||
get_current_timestamp
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
class RedisWriteStore:
|
||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
@@ -16,32 +42,439 @@ class RedisSessionStore:
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def _fix_encoding(self, text):
|
||||
"""修复错误编码的文本"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
# 修改后的 save_session 方法
|
||||
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
|
||||
def save_session_write(self, userid: str, messages: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
优化版本:确保写入时间不超过1秒
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
||||
messages = serialize_messages(messages)
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="write")
|
||||
|
||||
# 使用 pipeline 批量写入,减少网络往返
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"messages": messages,
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
# 直接写入数据,decode_responses=True 已经处理了编码
|
||||
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
logger.error(f"[save_session_write] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||
"""
|
||||
通过 save_session_write 的 userid 获取 sessionid 和 messages
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
|
||||
"""
|
||||
try:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 userid 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
results.append({
|
||||
"sessionid": session_id,
|
||||
"messages": fix_encoding(data.get('messages', ''))
|
||||
})
|
||||
|
||||
if not results:
|
||||
return False
|
||||
|
||||
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"[get_session_by_userid] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
通过 end_user_id 获取所有 write 类型的会话数据
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{
|
||||
"session_id": "uuid",
|
||||
"id": "...",
|
||||
"sessionid": "end_user_id",
|
||||
"messages": "...",
|
||||
"starttime": "timestamp"
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
try:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 end_user_id 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == end_user_id:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
|
||||
# 构建完整的会话信息
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"id": data.get('id', ''),
|
||||
"sessionid": data.get('sessionid', ''),
|
||||
"messages": fix_encoding(data.get('messages', '')),
|
||||
"starttime": data.get('starttime', '')
|
||||
}
|
||||
results.append(session_info)
|
||||
|
||||
if not results:
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
return False
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
minutes: int = 5) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
minutes: 查询最近几分钟的数据,默认5分钟
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 userid 的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid and data.get('starttime'):
|
||||
# write 类型没有 aimessages,所以 Answer 为空
|
||||
matched_items.append({
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": "",
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
|
||||
# 根据时间范围过滤
|
||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||
# 排序并移除时间字段
|
||||
result_items = sort_and_limit_results(filtered_items)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
def delete_all_write_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 write 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:write:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisCountStore:
|
||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uuid = session_id
|
||||
|
||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||
"""
|
||||
保存用户访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
count: 访问次数
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uuid,
|
||||
"end_user_id": end_user_id,
|
||||
"count": int(count),
|
||||
"messages": serialize_messages(messages),
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||
|
||||
# 创建索引:end_user_id -> session_id 映射
|
||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||
|
||||
result = pipe.execute()
|
||||
|
||||
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
|
||||
"""
|
||||
通过 end_user_id 查询访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
list 或 False: 如果找到返回 [count, messages],否则返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
except Exception as type_error:
|
||||
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
# 直接获取数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
# 索引存在但数据不存在,清理索引
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
|
||||
count = data.get('count')
|
||||
messages_str = data.get('messages')
|
||||
|
||||
if count is not None:
|
||||
messages: list[dict] = deserialize_messages(messages_str)
|
||||
return int(count), messages
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[get_sessions_count] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
def update_sessions_count(
|
||||
self,
|
||||
end_user_id: str,
|
||||
new_count: int,
|
||||
messages: Any
|
||||
) -> bool:
|
||||
"""
|
||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
new_count: 新的 count 值
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回 True,未找到记录返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
# 索引键类型错误,删除并返回 False
|
||||
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
self.r.delete(index_key)
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
except Exception as type_error:
|
||||
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
|
||||
# 直接更新数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
messages_str = serialize_messages(messages)
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, 'count', str(new_count))
|
||||
pipe.hset(key, 'messages', messages_str)
|
||||
result = pipe.execute()
|
||||
|
||||
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[update_sessions_count] 更新失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_count_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 count 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:count:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis 会话存储类,用于管理会话数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
# ==================== 写入操作 ====================
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
aimessages: AI回复消息
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="read")
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
@@ -49,177 +482,195 @@ class RedisSessionStore:
|
||||
"end_user_id": end_user_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": starttime
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
|
||||
# 可选:设置过期时间(例如30天),避免数据无限增长
|
||||
# pipe.expire(key, 30 * 24 * 60 * 60)
|
||||
|
||||
# 执行批量操作
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id # 返回新生成的 session_id
|
||||
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"保存会话失败: {e}")
|
||||
logger.error(f"[save_session] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def save_sessions_batch(self, sessions_data):
|
||||
"""
|
||||
批量写入多条会话数据,返回 session_id 列表
|
||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
|
||||
优化版本:批量操作,大幅提升性能
|
||||
"""
|
||||
try:
|
||||
session_ids = []
|
||||
pipe = self.r.pipeline()
|
||||
# ==================== 读取操作 ====================
|
||||
|
||||
for session in sessions_data:
|
||||
session_id = str(uuid.uuid4())
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}"
|
||||
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": session.get('userid'),
|
||||
"apply_id": session.get('apply_id'),
|
||||
"end_user_id": session.get('end_user_id'),
|
||||
"messages": session.get('messages'),
|
||||
"aimessages": session.get('aimessages'),
|
||||
"starttime": starttime
|
||||
})
|
||||
|
||||
session_ids.append(session_id)
|
||||
|
||||
# 一次性执行所有写入操作
|
||||
results = pipe.execute()
|
||||
print(f"批量保存完成: {len(session_ids)} 条记录")
|
||||
return session_ids
|
||||
except Exception as e:
|
||||
print(f"批量保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ---------------- 读取 ----------------
|
||||
def get_session(self, session_id):
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取一条会话数据
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
Dict 或 None: 会话数据
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
key = generate_session_key(session_id)
|
||||
data = self.r.hgetall(key)
|
||||
return data if data else None
|
||||
|
||||
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key in self.r.keys('session:*'):
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (data.get('sessionid') == sessionid and
|
||||
data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
result_items.append(data)
|
||||
|
||||
return result_items
|
||||
|
||||
def get_all_sessions(self):
|
||||
"""
|
||||
获取所有会话数据
|
||||
获取所有会话数据(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
Dict: 所有会话数据,key 为 session_id
|
||||
"""
|
||||
sessions = {}
|
||||
for key in self.r.keys('session:*'):
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
# 排除 count 和 write 类型的 key
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
# ---------------- 更新 ----------------
|
||||
def update_session(self, session_id, field, value):
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
更新单个字段
|
||||
优化版本:使用 pipeline 减少网络往返
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
pipe = self.r.pipeline()
|
||||
pipe.exists(key)
|
||||
pipe.hset(key, field, value)
|
||||
results = pipe.execute()
|
||||
return bool(results[0]) # 返回 key 是否存在
|
||||
|
||||
# ---------------- 删除 ----------------
|
||||
def delete_session(self, session_id):
|
||||
"""
|
||||
删除单条会话
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
return self.r.delete(key)
|
||||
|
||||
def delete_all_sessions(self):
|
||||
"""
|
||||
删除所有会话
|
||||
"""
|
||||
keys = self.r.keys('session:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
def delete_duplicate_sessions(self):
|
||||
"""
|
||||
删除重复会话数据,条件:
|
||||
"sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||
|
||||
Args:
|
||||
sessionid: 会话ID(支持模糊匹配)
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 第一步:使用 pipeline 批量获取所有 key
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 第二步:使用 pipeline 批量获取所有数据
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 第三步:在内存中识别重复数据
|
||||
seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key)
|
||||
keys_to_delete = [] # 需要删除的 key 列表
|
||||
|
||||
for key, data in zip(keys, all_data, strict=False):
|
||||
# 筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 获取五个字段的值
|
||||
sessionid = data.get('sessionid', '')
|
||||
user_id = data.get('id', '')
|
||||
end_user_id = data.get('end_user_id', '')
|
||||
messages = data.get('messages', '')
|
||||
aimessages = data.get('aimessages', '')
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配或完全匹配 sessionid
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append(format_session_data(data, include_time=True))
|
||||
|
||||
# 排序、限制数量并移除时间字段
|
||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
# ==================== 更新操作 ====================
|
||||
|
||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||
"""
|
||||
更新单个字段
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
field: 字段名
|
||||
value: 字段值
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
key = generate_session_key(session_id)
|
||||
pipe = self.r.pipeline()
|
||||
pipe.exists(key)
|
||||
pipe.hset(key, field, value)
|
||||
results = pipe.execute()
|
||||
return bool(results[0])
|
||||
|
||||
# ==================== 删除操作 ====================
|
||||
|
||||
def delete_session(self, session_id: str) -> int:
|
||||
"""
|
||||
删除单条会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
key = generate_session_key(session_id)
|
||||
return self.r.delete(key)
|
||||
|
||||
def delete_all_sessions(self) -> int:
|
||||
"""
|
||||
删除所有会话(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:*')
|
||||
# 过滤掉 count 和 write 类型
|
||||
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
|
||||
if keys_to_delete:
|
||||
return self.r.delete(*keys_to_delete)
|
||||
return 0
|
||||
|
||||
def delete_duplicate_sessions(self) -> int:
|
||||
"""
|
||||
删除重复会话数据(不包括 count 和 write 类型)
|
||||
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
logger.debug("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 批量获取所有数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 识别重复数据
|
||||
seen = {}
|
||||
keys_to_delete = []
|
||||
|
||||
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 用五元组作为唯一标识
|
||||
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
|
||||
identifier = (
|
||||
data.get('sessionid', ''),
|
||||
data.get('id', ''),
|
||||
data.get('end_user_id', ''),
|
||||
data.get('messages', ''),
|
||||
data.get('aimessages', '')
|
||||
)
|
||||
|
||||
if identifier in seen:
|
||||
# 重复,标记为待删除
|
||||
keys_to_delete.append(key)
|
||||
else:
|
||||
# 第一次出现,记录
|
||||
seen[identifier] = key
|
||||
|
||||
# 第四步:使用 pipeline 批量删除重复的 key
|
||||
# 批量删除重复的 key
|
||||
deleted_count = 0
|
||||
if keys_to_delete:
|
||||
# 分批删除,避免单次操作过大
|
||||
batch_size = 1000
|
||||
for i in range(0, len(keys_to_delete), batch_size):
|
||||
batch = keys_to_delete[i:i + batch_size]
|
||||
@@ -230,82 +681,31 @@ class RedisSessionStore:
|
||||
deleted_count += len(batch)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
def find_user_session(self, sessionid):
|
||||
user_id = sessionid
|
||||
|
||||
result_items = []
|
||||
for key, values in store.get_all_sessions().items():
|
||||
history = {}
|
||||
if user_id == str(values['sessionid']):
|
||||
history["Query"] = values['messages']
|
||||
history["Answer"] = values['aimessages']
|
||||
result_items.append(history)
|
||||
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
return (result_items)
|
||||
|
||||
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
# 使用 pipeline 批量获取数据,提高性能
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
if not keys:
|
||||
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 使用 pipeline 批量获取所有 hash 数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 解析并筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查是否符合三个条件
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配 sessionid 或者完全匹配
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append({
|
||||
"Query": self._fix_encoding(data.get('messages')),
|
||||
"Answer": self._fix_encoding(data.get('aimessages')),
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
# 按时间降序排序(最新的在前)
|
||||
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
# 只保留最新的6条
|
||||
result_items = matched_items[:6]
|
||||
# # 移除 starttime 字段
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于等于1条,返回空列表
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
# 全局实例
|
||||
store = RedisSessionStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
)
|
||||
|
||||
write_store = RedisWriteStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
count_store = RedisCountStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
@@ -4,15 +4,20 @@ Write Tools for Memory Knowledge Extraction Pipeline
|
||||
This module provides the main write function for executing the knowledge extraction
|
||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||
memory_summary_generation
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
@@ -22,29 +27,30 @@ from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "",
|
||||
language: str = "zh",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
ref_id: Reference ID, defaults to ""
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
if not ref_id:
|
||||
ref_id = uuid.uuid4().hex
|
||||
# Extract config values
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
@@ -93,12 +99,39 @@ async def write(
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
pipeline_config = get_pipeline_config(memory_config)
|
||||
|
||||
# Fetch ontology types if scene_id is configured
|
||||
ontology_types = None
|
||||
if memory_config.scene_id:
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
||||
|
||||
with get_db_context() as db:
|
||||
ontology_types = load_ontology_types_for_scene(
|
||||
scene_id=memory_config.scene_id,
|
||||
workspace_id=memory_config.workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if ontology_types:
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=pipeline_config,
|
||||
embedding_id=embedding_model_id,
|
||||
language=language,
|
||||
ontology_types=ontology_types,
|
||||
)
|
||||
|
||||
# Run the complete extraction pipeline
|
||||
@@ -107,9 +140,11 @@ async def write(
|
||||
all_chunk_nodes,
|
||||
all_statement_nodes,
|
||||
all_entity_nodes,
|
||||
all_perceptual_nodes,
|
||||
all_statement_chunk_edges,
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
all_perceptual_edges,
|
||||
all_dedup_details,
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
@@ -117,29 +152,117 @@ async def write(
|
||||
|
||||
# Step 3: Save all data to Neo4j database
|
||||
step_start = time.time()
|
||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||
|
||||
# Neo4j 写入前:清洗用户/AI助手实体之间的别名交叉污染
|
||||
# 从 Neo4j 查询已有的 AI 助手别名,与本轮实体中的 AI 助手别名合并,
|
||||
# 确保用户实体的 aliases 不包含 AI 助手的名字
|
||||
try:
|
||||
await create_fulltext_indexes()
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
clean_cross_role_aliases,
|
||||
fetch_neo4j_assistant_aliases,
|
||||
)
|
||||
neo4j_assistant_aliases = set()
|
||||
if all_entity_nodes:
|
||||
_eu_id = all_entity_nodes[0].end_user_id
|
||||
if _eu_id:
|
||||
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(neo4j_connector, _eu_id)
|
||||
clean_cross_role_aliases(all_entity_nodes, external_assistant_aliases=neo4j_assistant_aliases)
|
||||
logger.info(f"Neo4j 写入前别名清洗完成,AI助手别名排除集大小: {len(neo4j_assistant_aliases)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
logger.warning(f"Neo4j 写入前别名清洗失败(不影响主流程): {e}")
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 秒
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
perceptual_nodes=all_perceptual_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
perceptual_edges=all_perceptual_edges,
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
|
||||
if all_entity_nodes:
|
||||
end_user_id = all_entity_nodes[0].end_user_id
|
||||
|
||||
# Neo4j 写入完成后,用 PgSQL 权威 aliases 覆盖 Neo4j 用户实体
|
||||
try:
|
||||
from app.repositories.end_user_info_repository import EndUserInfoRepository
|
||||
if end_user_id:
|
||||
with get_db_context() as db_session:
|
||||
info = EndUserInfoRepository(db_session).get_by_end_user_id(uuid.UUID(end_user_id))
|
||||
pg_aliases = info.aliases if info and info.aliases else []
|
||||
if info is not None:
|
||||
# 将 Python 侧占位名集合作为参数传入,避免 Cypher 硬编码
|
||||
placeholder_names = list(_USER_PLACEHOLDER_NAMES)
|
||||
await neo4j_connector.execute_query(
|
||||
"""
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $placeholder_names
|
||||
SET e.aliases = $aliases
|
||||
""",
|
||||
end_user_id=end_user_id, aliases=pg_aliases,
|
||||
placeholder_names=placeholder_names,
|
||||
)
|
||||
logger.info(f"[AliasSync] Neo4j 用户实体 aliases 已用 PgSQL 权威源覆盖: {pg_aliases}")
|
||||
except Exception as sync_err:
|
||||
logger.warning(f"[AliasSync] PgSQL→Neo4j aliases 同步失败(不影响主流程): {sync_err}")
|
||||
|
||||
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||
try:
|
||||
from app.tasks import run_incremental_clustering
|
||||
|
||||
new_entity_ids = [e.id for e in all_entity_nodes]
|
||||
task = run_incremental_clustering.apply_async(
|
||||
kwargs={
|
||||
"end_user_id": end_user_id,
|
||||
"new_entity_ids": new_entity_ids,
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
},
|
||||
priority=3,
|
||||
)
|
||||
logger.info(
|
||||
f"[Clustering] 增量聚类任务已提交到 Celery - "
|
||||
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
|
||||
|
||||
break
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# 检查是否是死锁错误
|
||||
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
else:
|
||||
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
|
||||
raise
|
||||
else:
|
||||
# 非死锁错误,直接抛出
|
||||
raise
|
||||
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
finally:
|
||||
await neo4j_connector.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing Neo4j connector: {e}")
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
@@ -147,11 +270,10 @@ async def write(
|
||||
step_start = time.time()
|
||||
try:
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||
)
|
||||
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
ms_connector = Neo4jConnector()
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
@@ -172,5 +294,40 @@ async def write(
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
# 将提取统计写入 Redis,按 workspace_id 存储
|
||||
try:
|
||||
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||
|
||||
stats_to_cache = {
|
||||
"chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0,
|
||||
"statements_count": len(all_statement_nodes) if all_statement_nodes else 0,
|
||||
"triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0,
|
||||
"triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0,
|
||||
"temporal_count": 0,
|
||||
}
|
||||
await ActivityStatsCache.set_activity_stats(
|
||||
workspace_id=str(memory_config.workspace_id),
|
||||
stats=stats_to_cache,
|
||||
)
|
||||
logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
# Close LLM/Embedder underlying httpx clients to prevent
|
||||
# 'RuntimeError: Event loop is closed' during garbage collection
|
||||
for client_obj in (llm_client, embedder_client):
|
||||
try:
|
||||
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
|
||||
if underlying is None:
|
||||
continue
|
||||
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
|
||||
inner = getattr(underlying, '_model', underlying)
|
||||
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
|
||||
http_client = getattr(inner, 'async_client', None)
|
||||
if http_client is not None and hasattr(http_client, 'aclose'):
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user