Compare commits
1591 Commits
v0.1.1
...
release/v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c8dca6379 | ||
|
|
819d205166 | ||
|
|
9e17f65eda | ||
|
|
7373f68172 | ||
|
|
0999bd30d7 | ||
|
|
f01185a7fc | ||
|
|
7cd7303754 | ||
|
|
d19fec2155 | ||
|
|
df18868888 | ||
|
|
4438b08560 | ||
|
|
1029f94669 | ||
|
|
0a3acf446d | ||
|
|
5a7723553c | ||
|
|
975844eccf | ||
|
|
865ad31f2f | ||
|
|
b756f0c86c | ||
|
|
3e5f6176af | ||
|
|
ab5b165dc2 | ||
|
|
f9393c2f63 | ||
|
|
aa6638424c | ||
|
|
834387e254 | ||
|
|
9caa986c80 | ||
|
|
72b84dfc8f | ||
|
|
af10195025 | ||
|
|
22382423ad | ||
|
|
0f80c67cbd | ||
|
|
aa6473c1c7 | ||
|
|
cde61cb6ac | ||
|
|
b1368997c2 | ||
|
|
ec7dc448c1 | ||
|
|
254147265e | ||
|
|
479bba9a4e | ||
|
|
cfb39a6baa | ||
|
|
05c9ed1450 | ||
|
|
f53633a8b8 | ||
|
|
63882e9391 | ||
|
|
3c4dfb868f | ||
|
|
cae9105b8d | ||
|
|
2c9401ccfb | ||
|
|
2b0dedc81c | ||
|
|
16b87de0df | ||
|
|
8c3af7f4ff | ||
|
|
5f56cc8056 | ||
|
|
827ab27bef | ||
|
|
ccc67df8df | ||
|
|
82538c469f | ||
|
|
076ceee29d | ||
|
|
822b73b015 | ||
|
|
862bff51cb | ||
|
|
bccbeaabe4 | ||
|
|
03676b7adc | ||
|
|
af6fde414f | ||
|
|
d069809001 | ||
|
|
fc240849cf | ||
|
|
61d2a328fe | ||
|
|
fed0ae8e9c | ||
|
|
eaf0de453b | ||
|
|
e833db954a | ||
|
|
0b2651f4ed | ||
|
|
10c677a6fd | ||
|
|
3398c4737a | ||
|
|
a008f5fbef | ||
|
|
6a42e73667 | ||
|
|
7611db19f3 | ||
|
|
d3399dfaf5 | ||
|
|
248f0d95ac | ||
|
|
5c39d841ee | ||
|
|
87be67cb9a | ||
|
|
1a08bea864 | ||
|
|
bc4406cec6 | ||
|
|
4206c849c3 | ||
|
|
3f052b7798 | ||
|
|
f1c5f24f6b | ||
|
|
e981c95225 | ||
|
|
4ce4f53835 | ||
|
|
f16e369540 | ||
|
|
47bf93d65e | ||
|
|
5c2e0af33e | ||
|
|
aaa0410781 | ||
|
|
366b148f3d | ||
|
|
6a265de31c | ||
|
|
c3707f543c | ||
|
|
8de368348b | ||
|
|
d052c31ac5 | ||
|
|
31320afed6 | ||
|
|
7afe507296 | ||
|
|
4188443101 | ||
|
|
a1fc0fd394 | ||
|
|
71fe35533d | ||
|
|
a2ed335e59 | ||
|
|
8422a05d74 | ||
|
|
139ae3bcb4 | ||
|
|
a0a57d5fbb | ||
|
|
80fa88ac37 | ||
|
|
0fda1c752d | ||
|
|
6c2fc75199 | ||
|
|
2cb6aeb022 | ||
|
|
e0174f75b3 | ||
|
|
51d04746a3 | ||
|
|
3b08d6c320 | ||
|
|
495c5802a0 | ||
|
|
621b074b3d | ||
|
|
6df32983b5 | ||
|
|
9c9fe9dde7 | ||
|
|
128c1a6178 | ||
|
|
f90e102854 | ||
|
|
2e1eb9a5a6 | ||
|
|
60a95f6556 | ||
|
|
218637e81d | ||
|
|
404f78af0f | ||
|
|
130f15665c | ||
|
|
6301528301 | ||
|
|
6feea968e0 | ||
|
|
b5199b2eb9 | ||
|
|
78ce2a9a8b | ||
|
|
6ed542b007 | ||
|
|
5322b0c4a3 | ||
|
|
a72d5d2c77 | ||
|
|
16c1cbe24f | ||
|
|
0d8f4c76e7 | ||
|
|
e511b14933 | ||
|
|
b5ba53208e | ||
|
|
b8bfb4d0c5 | ||
|
|
1b666638bc | ||
|
|
2bd364eca3 | ||
|
|
f27fc51801 | ||
|
|
0f85eff76b | ||
|
|
0def474cc2 | ||
|
|
026e4376d4 | ||
|
|
cf571cf02b | ||
|
|
590ec3a446 | ||
|
|
23bfdcefef | ||
|
|
647a978865 | ||
|
|
86f72100f0 | ||
|
|
8b255259ba | ||
|
|
8aad8faae9 | ||
|
|
420f391f3c | ||
|
|
817221347f | ||
|
|
13dce5e265 | ||
|
|
850d9ee70b | ||
|
|
ba36ccb21f | ||
|
|
f712754927 | ||
|
|
efe3865aa4 | ||
|
|
53dbe2f436 | ||
|
|
720498084b | ||
|
|
f5eda38dc9 | ||
|
|
8ada221777 | ||
|
|
4ee198813a | ||
|
|
440e8acd99 | ||
|
|
218671ef06 | ||
|
|
34de0bb9c5 | ||
|
|
8e6cf09056 | ||
|
|
5929072b76 | ||
|
|
37325e9802 | ||
|
|
778bc4bd70 | ||
|
|
f78f59ec42 | ||
|
|
d4c4160215 | ||
|
|
85aea97c21 | ||
|
|
b075cad4de | ||
|
|
f326febc8a | ||
|
|
1738e45090 | ||
|
|
6e758faa37 | ||
|
|
32e79c5df0 | ||
|
|
aa69cd3a0c | ||
|
|
da4a1f536d | ||
|
|
b3af757167 | ||
|
|
82794f051a | ||
|
|
a726a81224 | ||
|
|
9aae6163f0 | ||
|
|
941527e7ee | ||
|
|
a3f05220d3 | ||
|
|
7446241735 | ||
|
|
6033d37537 | ||
|
|
1524d7b5ce | ||
|
|
e00341a4cc | ||
|
|
f5185d2e95 | ||
|
|
c041d24989 | ||
|
|
dc9003f9db | ||
|
|
07e0c70629 | ||
|
|
37f77e0990 | ||
|
|
aef1a57ea8 | ||
|
|
69af479224 | ||
|
|
f38223c97f | ||
|
|
1ac6702eb0 | ||
|
|
2510f60dce | ||
|
|
b9d7fb2598 | ||
|
|
a39ba564fa | ||
|
|
34310bfabe | ||
|
|
78fd189510 | ||
|
|
94836ed9af | ||
|
|
1d662fb63e | ||
|
|
d1933d2aef | ||
|
|
163872be6e | ||
|
|
14fcb66a9c | ||
|
|
c488eb0cd0 | ||
|
|
91d20f7272 | ||
|
|
c3d7963fe0 | ||
|
|
c31a92bf01 | ||
|
|
b5703c1b82 | ||
|
|
df34735a9b | ||
|
|
31bee889d7 | ||
|
|
b3ba0a6ed6 | ||
|
|
ce3b7897d7 | ||
|
|
9115ad6950 | ||
|
|
c6b76438f4 | ||
|
|
68c4c7429c | ||
|
|
8466c8e019 | ||
|
|
d899b27448 | ||
|
|
229eb5cc86 | ||
|
|
66c153f1ad | ||
|
|
bbb2c6c903 | ||
|
|
5edf3f2b8a | ||
|
|
006c6cd159 | ||
|
|
9675982555 | ||
|
|
c6c7a1827c | ||
|
|
3ac8a9431b | ||
|
|
5c42a84c3e | ||
|
|
8fdaebbe6e | ||
|
|
9a98ccff2c | ||
|
|
ee4027c561 | ||
|
|
7f36a06f26 | ||
|
|
0826a34d8b | ||
|
|
1792cb4d93 | ||
|
|
304ccef101 | ||
|
|
bdc22c892d | ||
|
|
a5034e84ba | ||
|
|
6e2de96fed | ||
|
|
2b6d86e591 | ||
|
|
8c6f4cb117 | ||
|
|
16d4b32eb7 | ||
|
|
45a64dbbac | ||
|
|
537668b463 | ||
|
|
07fea23dd0 | ||
|
|
cef14291f0 | ||
|
|
bbde0588af | ||
|
|
aa7d52568b | ||
|
|
f39c77ac70 | ||
|
|
aa733354e8 | ||
|
|
7cec966979 | ||
|
|
74865d2cf2 | ||
|
|
c9a8753473 | ||
|
|
ce8a2cbe34 | ||
|
|
c0fdd0c6d3 | ||
|
|
88bfcfe6cd | ||
|
|
c4dcf1fd65 | ||
|
|
6cebddf893 | ||
|
|
1738ed3664 | ||
|
|
37ddcb91ac | ||
|
|
574ab4506b | ||
|
|
81353538e5 | ||
|
|
5abfcdfbe8 | ||
|
|
9962a61c21 | ||
|
|
5cf2b08777 | ||
|
|
9be1c01b70 | ||
|
|
62b2ecdfc2 | ||
|
|
2ff9000d25 | ||
|
|
5829148ce4 | ||
|
|
8e15a340f6 | ||
|
|
1270b7cdd8 | ||
|
|
7c02fe8148 | ||
|
|
4ac63e1c23 | ||
|
|
4aeb653ed2 | ||
|
|
2d5c2de613 | ||
|
|
96590941cf | ||
|
|
0655ff4a91 | ||
|
|
0ba370052e | ||
|
|
4d59e04aba | ||
|
|
6db6c33564 | ||
|
|
ed0d963aec | ||
|
|
3a36d038ee | ||
|
|
3d068a9c96 | ||
|
|
87df352adc | ||
|
|
8b546b7366 | ||
|
|
77ea0680fb | ||
|
|
4c592bf7e3 | ||
|
|
6718553bf4 | ||
|
|
79dc6f3f69 | ||
|
|
8df72d2822 | ||
|
|
b9578bd08a | ||
|
|
035e56e42f | ||
|
|
3ce5926689 | ||
|
|
035464c0ac | ||
|
|
f1fcffbfc0 | ||
|
|
b79fe07052 | ||
|
|
e6aa0e0e10 | ||
|
|
54700e6fbe | ||
|
|
5a90d4776d | ||
|
|
f81fdca62a | ||
|
|
3a0671c661 | ||
|
|
1037729fb3 | ||
|
|
5f211620c5 | ||
|
|
cb6a3aae9e | ||
|
|
5e512df3d4 | ||
|
|
9916cf3265 | ||
|
|
729c283c63 | ||
|
|
c99f04314c | ||
|
|
dd9be2ed90 | ||
|
|
f7aed9dd98 | ||
|
|
5253cf3899 | ||
|
|
f7d92be5ea | ||
|
|
97d8168824 | ||
|
|
550bd4da23 | ||
|
|
2327be7557 | ||
|
|
a7ffc19ba1 | ||
|
|
bbaa39c569 | ||
|
|
d1de0250e7 | ||
|
|
2d731c6412 | ||
|
|
6a6e64f487 | ||
|
|
b9201c918a | ||
|
|
7dedad898a | ||
|
|
d497189352 | ||
|
|
fa4da8f467 | ||
|
|
e9ff742162 | ||
|
|
3849cfb835 | ||
|
|
c453af23c6 | ||
|
|
bcf2376f5a | ||
|
|
4f0b653a82 | ||
|
|
be2f56ae6a | ||
|
|
cbc9602495 | ||
|
|
616709acbb | ||
|
|
c72ce381c0 | ||
|
|
67053ab8ae | ||
|
|
33238d34c9 | ||
|
|
2ef54168fc | ||
|
|
b33ccf00f9 | ||
|
|
829eb4b3be | ||
|
|
6c49456c13 | ||
|
|
fc8f06ee14 | ||
|
|
120a524b7e | ||
|
|
bd037ac3a3 | ||
|
|
b8ea427029 | ||
|
|
275be47224 | ||
|
|
4ea9c7e660 | ||
|
|
92d78d9a52 | ||
|
|
a820001eea | ||
|
|
8273f6d217 | ||
|
|
bd63e0fce8 | ||
|
|
12ba3d473e | ||
|
|
0b9cc0f068 | ||
|
|
5ca397befa | ||
|
|
da735fe776 | ||
|
|
b4f69f2cff | ||
|
|
1885c00cbc | ||
|
|
1e4fdeb1a6 | ||
|
|
cb7dbb0ed4 | ||
|
|
44083aec79 | ||
|
|
4a9b743153 | ||
|
|
b462e17a5b | ||
|
|
b272a52b57 | ||
|
|
3f87c64e83 | ||
|
|
1795364f5f | ||
|
|
e69fbb2f97 | ||
|
|
32b40fc6bf | ||
|
|
f039ea7f56 | ||
|
|
41334f5f1e | ||
|
|
79b19b744e | ||
|
|
2103410694 | ||
|
|
2143d94e83 | ||
|
|
9ae2612945 | ||
|
|
3a09b26b6d | ||
|
|
e381449aec | ||
|
|
bacffc94d9 | ||
|
|
7044f705e7 | ||
|
|
6db4fe28a7 | ||
|
|
f966176694 | ||
|
|
bd24de4577 | ||
|
|
dc2ea5c007 | ||
|
|
4fb673077a | ||
|
|
b3a136ac03 | ||
|
|
22f1bfa3fa | ||
|
|
f6ad0aab94 | ||
|
|
371fdeb948 | ||
|
|
f7a0af75c4 | ||
|
|
b31e526e4d | ||
|
|
26abf7b586 | ||
|
|
d477e24e34 | ||
|
|
3ca3e8e023 | ||
|
|
3bd374495b | ||
|
|
b26f60ee8d | ||
|
|
df681eaf22 | ||
|
|
01458ac111 | ||
|
|
6c7a68802b | ||
|
|
e3074b833f | ||
|
|
1097d699f8 | ||
|
|
55b4e0ebd3 | ||
|
|
0011a8ce9f | ||
|
|
100bf4fa49 | ||
|
|
6da5b81311 | ||
|
|
787adf5423 | ||
|
|
01b500e7d1 | ||
|
|
e64603ea27 | ||
|
|
4219e12cc0 | ||
|
|
c86ccf0931 | ||
|
|
d4571fb75b | ||
|
|
ec2369c397 | ||
|
|
6ebd48408b | ||
|
|
7e7b54593c | ||
|
|
f93c9f5cd2 | ||
|
|
a810fbe008 | ||
|
|
600a914bd9 | ||
|
|
b1688950c4 | ||
|
|
d8e3f9b7b8 | ||
|
|
08d55e4463 | ||
|
|
55e2baa865 | ||
|
|
55174dc707 | ||
|
|
d57e3b3f64 | ||
|
|
aa42cd0aec | ||
|
|
ac6d9a39ec | ||
|
|
9b07775395 | ||
|
|
936fb8b8a1 | ||
|
|
6c8318b696 | ||
|
|
d554079e2b | ||
|
|
37464a101e | ||
|
|
c5674246b0 | ||
|
|
f076199e3f | ||
|
|
8326db1143 | ||
|
|
992e41e0a0 | ||
|
|
076e95d5c2 | ||
|
|
dfd79e5972 | ||
|
|
b16c9d53ef | ||
|
|
5fe85fb457 | ||
|
|
b45f470310 | ||
|
|
0ecda33ab8 | ||
|
|
7fcfca455a | ||
|
|
6a32154b8f | ||
|
|
132206677f | ||
|
|
30a8775548 | ||
|
|
045bc9aefc | ||
|
|
d5c46574cc | ||
|
|
37fea09403 | ||
|
|
063e8fae43 | ||
|
|
184c4fbf7f | ||
|
|
e19d27f640 | ||
|
|
ea96830758 | ||
|
|
d2edbc738d | ||
|
|
03bc8c8280 | ||
|
|
68908213da | ||
|
|
b3d5add89a | ||
|
|
7fe2d8fbe1 | ||
|
|
de545a69ca | ||
|
|
dc48ba540d | ||
|
|
81e92b4fa6 | ||
|
|
ebad5e00a3 | ||
|
|
bca03f1365 | ||
|
|
c89f55f0bd | ||
|
|
4d98bace87 | ||
|
|
dcdc899528 | ||
|
|
b57aa55001 | ||
|
|
d0c0168c20 | ||
|
|
af596a09cf | ||
|
|
6849c620b8 | ||
|
|
12598f0dca | ||
|
|
3f4ce4f16f | ||
|
|
4aaf0d8d5c | ||
|
|
65db056e09 | ||
|
|
232cef7cb9 | ||
|
|
73a432879a | ||
|
|
09afec17f9 | ||
|
|
ac47ab3deb | ||
|
|
8b3d7c168a | ||
|
|
60e8eb63ac | ||
|
|
4f29cd24b8 | ||
|
|
ba73ade2a0 | ||
|
|
7559305fc9 | ||
|
|
6985f553f9 | ||
|
|
8fc15df6d0 | ||
|
|
eb8160a5af | ||
|
|
16cf6eee9b | ||
|
|
320f684354 | ||
|
|
12062a5440 | ||
|
|
4423a9d979 | ||
|
|
1eb44defb6 | ||
|
|
e253fba2e9 | ||
|
|
c05d95924f | ||
|
|
2db583d62d | ||
|
|
59d8e1bf9f | ||
|
|
1001344c27 | ||
|
|
8a0e2da03f | ||
|
|
f58886be6f | ||
|
|
3c1d3b4d6a | ||
|
|
bbba995ff7 | ||
|
|
0033b5be80 | ||
|
|
87d53fb9b7 | ||
|
|
157031f23e | ||
|
|
8a37869489 | ||
|
|
5c10f11681 | ||
|
|
7b72bf0cd0 | ||
|
|
be29666916 | ||
|
|
8d4c5b5b33 | ||
|
|
52260f469a | ||
|
|
c566d22836 | ||
|
|
75f59a86c8 | ||
|
|
1eaf12446f | ||
|
|
efdd42426e | ||
|
|
62c557deae | ||
|
|
db1da4a61a | ||
|
|
db46c186aa | ||
|
|
677a603835 | ||
|
|
447d8790ad | ||
|
|
7a78f15a90 | ||
|
|
c1941809e9 | ||
|
|
623aaf8a0e | ||
|
|
7b3bf41120 | ||
|
|
0c3960eb0b | ||
|
|
fe3c31c08c | ||
|
|
94600cdbfc | ||
|
|
4e7ab3d7e3 | ||
|
|
47b25d7a26 | ||
|
|
0249666fa4 | ||
|
|
2e8504ce2f | ||
|
|
aca7d25001 | ||
|
|
2444309bc2 | ||
|
|
97c5a78d48 | ||
|
|
effdb88455 | ||
|
|
2f0ce3852e | ||
|
|
5475496399 | ||
|
|
b569d77a23 | ||
|
|
dfa7a2d4cf | ||
|
|
169e01276d | ||
|
|
07e698265e | ||
|
|
0632d7611f | ||
|
|
b3f39eedac | ||
|
|
46ed7e38bf | ||
|
|
8c5199d32d | ||
|
|
36ed833d64 | ||
|
|
47969ce61e | ||
|
|
06731e2026 | ||
|
|
123347169d | ||
|
|
f9101a744c | ||
|
|
97eb33000f | ||
|
|
60231ec88d | ||
|
|
3364374dc6 | ||
|
|
a3cf773e75 | ||
|
|
4092d5fbaf | ||
|
|
07e9fde9e8 | ||
|
|
9b4613630b | ||
|
|
f125d11b6d | ||
|
|
657d48a5f9 | ||
|
|
3735bdde19 | ||
|
|
3f906d81cb | ||
|
|
7c1f622797 | ||
|
|
cfe696ae8d | ||
|
|
021c50a8f2 | ||
|
|
95745ba869 | ||
|
|
adfae54816 | ||
|
|
10ed093eb8 | ||
|
|
c96df6bfa5 | ||
|
|
0126d18525 | ||
|
|
9e6e8f50f8 | ||
|
|
7e0b31626f | ||
|
|
1d9e249a77 | ||
|
|
88b89ef315 | ||
|
|
62b7925cb0 | ||
|
|
cc1528f550 | ||
|
|
1c8a83140b | ||
|
|
34276e2066 | ||
|
|
71abd16ae7 | ||
|
|
918e7285c4 | ||
|
|
056d422c71 | ||
|
|
5ee54f4e0e | ||
|
|
260c75e70c | ||
|
|
2d7401922f | ||
|
|
8c7a1348cf | ||
|
|
24fbdbd716 | ||
|
|
aad8f0e36b | ||
|
|
15cad44f08 | ||
|
|
0271454671 | ||
|
|
d0ddf288ca | ||
|
|
bc250ac377 | ||
|
|
7922fc3b0e | ||
|
|
161da723b9 | ||
|
|
514c19a247 | ||
|
|
41550d4a41 | ||
|
|
33cc3c1c3f | ||
|
|
7d15182202 | ||
|
|
8f0a1d9c6e | ||
|
|
72b5e5cf8e | ||
|
|
62aba2dd38 | ||
|
|
cdd6b80089 | ||
|
|
333836f5e7 | ||
|
|
a2dfda3471 | ||
|
|
2d28b4b05c | ||
|
|
87f9bcc6a3 | ||
|
|
48aca996ff | ||
|
|
c8c7e9b304 | ||
|
|
97ff023995 | ||
|
|
e273a336f8 | ||
|
|
34f0c3b90c | ||
|
|
7c2902d2b8 | ||
|
|
8e41afdffc | ||
|
|
7268886294 | ||
|
|
cbae900866 | ||
|
|
ffff138a6f | ||
|
|
88c95db8d0 | ||
|
|
56e657a0bb | ||
|
|
bc36b79105 | ||
|
|
5694bc0230 | ||
|
|
36130031f9 | ||
|
|
b8f1095f53 | ||
|
|
442fa09533 | ||
|
|
42ef2efbc8 | ||
|
|
ead3080b2b | ||
|
|
c6ea31c296 | ||
|
|
21eae29bb7 | ||
|
|
406740b524 | ||
|
|
9d30bc4062 | ||
|
|
fad91b64ab | ||
|
|
2132e71a81 | ||
|
|
bd8a451879 | ||
|
|
24dafa7359 | ||
|
|
3b5df793fb | ||
|
|
da835b6138 | ||
|
|
7e650d86a5 | ||
|
|
308e28cecc | ||
|
|
9a3c74fb64 | ||
|
|
f571f0688a | ||
|
|
1e9c32a102 | ||
|
|
8c69199689 | ||
|
|
3efb3e8a35 | ||
|
|
cfcb278406 | ||
|
|
9e195ea63b | ||
|
|
dc0d34c281 | ||
|
|
72076c218f | ||
|
|
151fd3b950 | ||
|
|
2d484fcb30 | ||
|
|
6e0407f404 | ||
|
|
8670aaba1e | ||
|
|
f27de7df35 | ||
|
|
63fa4dc8ec | ||
|
|
a191e32f71 | ||
|
|
9a38e8a4a0 | ||
|
|
6194222289 | ||
|
|
0d077eaeb7 | ||
|
|
b2c7a9a005 | ||
|
|
be01f1869e | ||
|
|
9f2b6390b0 | ||
|
|
e196f86e30 | ||
|
|
ec41d45234 | ||
|
|
567d1ba18b | ||
|
|
df8706983b | ||
|
|
8697498b32 | ||
|
|
af917c538a | ||
|
|
034e97dfa6 | ||
|
|
5e1e5f68e1 | ||
|
|
fb76f765cc | ||
|
|
7a3f57261d | ||
|
|
a1a460625d | ||
|
|
3f42ea2c61 | ||
|
|
940c594066 | ||
|
|
5e47fc45ab | ||
|
|
b471d56a86 | ||
|
|
61f8029205 | ||
|
|
e2f047d035 | ||
|
|
1aff4eda67 | ||
|
|
a6c5c44ed8 | ||
|
|
3f389d685a | ||
|
|
5d5351f0bc | ||
|
|
1224802ac6 | ||
|
|
e919f89caf | ||
|
|
bb8e7a68ea | ||
|
|
48f95e0ea4 | ||
|
|
931e9bcf0d | ||
|
|
67a3351c4c | ||
|
|
dfe5eeed7b | ||
|
|
3464573f17 | ||
|
|
9cf49c9c75 | ||
|
|
4e837cb90c | ||
|
|
e4fb58496b | ||
|
|
15a254c0cd | ||
|
|
d62746fc8c | ||
|
|
4b8b6fe407 | ||
|
|
6754834eb3 | ||
|
|
be98db561d | ||
|
|
574d0afc72 | ||
|
|
31c8ad611c | ||
|
|
b23730388d | ||
|
|
1b853aa893 | ||
|
|
36cb0a12ad | ||
|
|
5439eacf2d | ||
|
|
2687c3b80e | ||
|
|
fa009327ad | ||
|
|
838bd46e83 | ||
|
|
ccc2009aa8 | ||
|
|
d9aba92314 | ||
|
|
696b0475a8 | ||
|
|
e7370489e8 | ||
|
|
f1503b2238 | ||
|
|
cd4661e878 | ||
|
|
364e01ec7a | ||
|
|
ffb7b0ba38 | ||
|
|
22151eb49b | ||
|
|
d0354345f6 | ||
|
|
b1e61eb1e4 | ||
|
|
36e0ed15b6 | ||
|
|
095dfc2879 | ||
|
|
17dea9433e | ||
|
|
c285444e2f | ||
|
|
8ba402d080 | ||
|
|
88ab86734d | ||
|
|
504d87b0b0 | ||
|
|
b0d5818351 | ||
|
|
8826a01d32 | ||
|
|
cfb7a40841 | ||
|
|
8267761890 | ||
|
|
a651ae6ed4 | ||
|
|
a01911ba5f | ||
|
|
ee50b25d06 | ||
|
|
a67be85858 | ||
|
|
59c5a3973a | ||
|
|
d76d7343ff | ||
|
|
2b9638e7d3 | ||
|
|
3459a73705 | ||
|
|
bd480a466b | ||
|
|
4c34cb55b6 | ||
|
|
7347f9104c | ||
|
|
e137e4a38a | ||
|
|
b5989bbc25 | ||
|
|
c31ff7ceef | ||
|
|
9206c7642a | ||
|
|
d1b4f2b6c2 | ||
|
|
75066f2827 | ||
|
|
303f3aefef | ||
|
|
44fb5e0fd5 | ||
|
|
17a695120a | ||
|
|
6dc716eaf8 | ||
|
|
194be086d4 | ||
|
|
cca3900678 | ||
|
|
4fe32b7dbc | ||
|
|
c49603c25b | ||
|
|
8de85a4041 | ||
|
|
58a2135fa4 | ||
|
|
ab9a97db22 | ||
|
|
d291c241d5 | ||
|
|
24d4cb9b94 | ||
|
|
5b9adb799f | ||
|
|
38b41df36b | ||
|
|
34a9befe5c | ||
|
|
67fd579074 | ||
|
|
e2714b942d | ||
|
|
6b2556f870 | ||
|
|
43e6e9d201 | ||
|
|
131e0cc4c7 | ||
|
|
537be81b8f | ||
|
|
765168db7f | ||
|
|
1e16b06a24 | ||
|
|
42b59a644d | ||
|
|
d9fa9039bb | ||
|
|
cd4c93a5cb | ||
|
|
808961243d | ||
|
|
4d80e119f7 | ||
|
|
10c87edae1 | ||
|
|
0eb335d112 | ||
|
|
b8b26ccfe5 | ||
|
|
e89c23da4d | ||
|
|
f3da8956d9 | ||
|
|
b1147d77af | ||
|
|
66bc2fb41f | ||
|
|
4e538a6df8 | ||
|
|
ced087f8ae | ||
|
|
0f1eed0b1e | ||
|
|
95f15b77a3 | ||
|
|
f9ccfd5ca0 | ||
|
|
7207d7c847 | ||
|
|
00c4a524b7 | ||
|
|
9c3e0b5541 | ||
|
|
33bfe33eb3 | ||
|
|
3127c382a4 | ||
|
|
1748a390ec | ||
|
|
a7c0837049 | ||
|
|
44bf1eeae2 | ||
|
|
762b7a8ef1 | ||
|
|
102712a16e | ||
|
|
40810c59d7 | ||
|
|
35a10e86b5 | ||
|
|
c0c985494d | ||
|
|
8984ba7aef | ||
|
|
179869d481 | ||
|
|
5f29956f2b | ||
|
|
7e56c09620 | ||
|
|
dbc4ba84c2 | ||
|
|
9e4a527675 | ||
|
|
2e7f6afe3f | ||
|
|
45833542a7 | ||
|
|
1be6de30d7 | ||
|
|
981d78c8ba | ||
|
|
fbc7bedb6c | ||
|
|
9a4b1f0937 | ||
|
|
4786b0c5d4 | ||
|
|
17bed26096 | ||
|
|
511e16f1d3 | ||
|
|
18204bc1f7 | ||
|
|
e5e914903c | ||
|
|
7ba443afa5 | ||
|
|
b58d97fad3 | ||
|
|
d2a67a53b5 | ||
|
|
c0b556000c | ||
|
|
462c3b0696 | ||
|
|
d34ad73439 | ||
|
|
2c21712d58 | ||
|
|
2862db3534 | ||
|
|
bf3e30dac0 | ||
|
|
ce01e588c9 | ||
|
|
2a23082203 | ||
|
|
d373f924f6 | ||
|
|
eaf46ee006 | ||
|
|
d51355a0ad | ||
|
|
1e481a311a | ||
|
|
375660f232 | ||
|
|
46abb23ee8 | ||
|
|
8555bb697c | ||
|
|
f821893653 | ||
|
|
f6031baee4 | ||
|
|
75b3ea1f05 | ||
|
|
c818ba7bc7 | ||
|
|
74f0018962 | ||
|
|
3a0f07d36f | ||
|
|
8fb9e779a6 | ||
|
|
c5a794f1b5 | ||
|
|
3aa2cdd754 | ||
|
|
d93d52cf10 | ||
|
|
2abbd5a7fb | ||
|
|
2a10e9f7ee | ||
|
|
166d05afe9 | ||
|
|
2eff8d1962 | ||
|
|
93c9e76c4b | ||
|
|
021cb09b82 | ||
|
|
28e6939884 | ||
|
|
8847039d76 | ||
|
|
a047cf2e91 | ||
|
|
a8ae16e321 | ||
|
|
2694576a32 | ||
|
|
e4f10670f6 | ||
|
|
1324ba3a49 | ||
|
|
73c7810310 | ||
|
|
d160076267 | ||
|
|
a53be31765 | ||
|
|
ed8c1c7c19 | ||
|
|
159c8d1ff9 | ||
|
|
8932d455d8 | ||
|
|
3af183f6c3 | ||
|
|
4475be51cc | ||
|
|
c3ea3b751b | ||
|
|
e2c67d0c5b | ||
|
|
87731090ca | ||
|
|
80ca247435 | ||
|
|
a5b8d3afa5 | ||
|
|
1f615a06ad | ||
|
|
4123560a98 | ||
|
|
5267bd60a5 | ||
|
|
f76bffb482 | ||
|
|
51185c83c9 | ||
|
|
f1f887faae | ||
|
|
d53cbe7868 | ||
|
|
722746c78b | ||
|
|
46f0f3cee9 | ||
|
|
e1f5607836 | ||
|
|
ebc41b2eec | ||
|
|
7cd0d78424 | ||
|
|
d740559749 | ||
|
|
399357f752 | ||
|
|
3b4b474ce8 | ||
|
|
4534e46811 | ||
|
|
7bfa7b3f02 | ||
|
|
1cc34d8e62 | ||
|
|
2eff6b2e9d | ||
|
|
b046411302 | ||
|
|
6ab65b3626 | ||
|
|
cf321f9b09 | ||
|
|
8228d38859 | ||
|
|
c2e3110fa2 | ||
|
|
85681db7b7 | ||
|
|
1fc04c37d3 | ||
|
|
0fd8a122fb | ||
|
|
e3b6ede992 | ||
|
|
3601737869 | ||
|
|
9de6b4f151 | ||
|
|
4f4f55d67f | ||
|
|
714c624dc6 | ||
|
|
988a41f5e4 | ||
|
|
14946d9a1d | ||
|
|
94cced8323 | ||
|
|
9b8ed16e37 | ||
|
|
a5e44cd229 | ||
|
|
eccc208229 | ||
|
|
79cfabb45d | ||
|
|
af6e1e2b99 | ||
|
|
4ad51c1b24 | ||
|
|
1919580759 | ||
|
|
b27ffe57e6 | ||
|
|
c115bcde54 | ||
|
|
c44712167f | ||
|
|
1aabaff1f2 | ||
|
|
21c0383efb | ||
|
|
313f19eba4 | ||
|
|
c8591d7bca | ||
|
|
c6bcf53fea | ||
|
|
86812b34d1 | ||
|
|
27d1174dbb | ||
|
|
15f9c49418 | ||
|
|
6e18c92a13 | ||
|
|
c5e0df12ad | ||
|
|
7870c6c33f | ||
|
|
ebe018347b | ||
|
|
86fe6fe5ab | ||
|
|
9e828b1750 | ||
|
|
45adb9627a | ||
|
|
d56e168df9 | ||
|
|
940d3d4567 | ||
|
|
6bd7b2b8bb | ||
|
|
f2d6fd7b08 | ||
|
|
7219274d94 | ||
|
|
5dcc815240 | ||
|
|
b84c82880c | ||
|
|
fcc418b4a0 | ||
|
|
15c0bb4c9e | ||
|
|
8db4f914d8 | ||
|
|
f3f9211c9c | ||
|
|
ac160b6b41 | ||
|
|
51680b7077 | ||
|
|
acecdcc041 | ||
|
|
a2a69840f7 | ||
|
|
3a4a7590c2 | ||
|
|
5ced11999e | ||
|
|
bcc8b7ce3c | ||
|
|
4923708515 | ||
|
|
2cbbb829f7 | ||
|
|
1eacd3abe6 | ||
|
|
c5c2f84356 | ||
|
|
742e2f037b | ||
|
|
e3110d2f48 | ||
|
|
1c7fe6d134 | ||
|
|
29718b1c03 | ||
|
|
cd3b4d8dde | ||
|
|
5a3cddab0f | ||
|
|
15221005d1 | ||
|
|
da75abb223 | ||
|
|
8b32f80e27 | ||
|
|
ab9c2d81b0 | ||
|
|
c4039f52bd | ||
|
|
bd851d5e86 | ||
|
|
00e448c5d6 | ||
|
|
5ff8cdb13a | ||
|
|
44783574c0 | ||
|
|
1e7c53d944 | ||
|
|
655ae796fd | ||
|
|
93686dbc1e | ||
|
|
0356add7e0 | ||
|
|
9bea74fcef | ||
|
|
c08b10c20f | ||
|
|
16c0d9bb6c | ||
|
|
9f0d1616a8 | ||
|
|
fafab973ee | ||
|
|
4648ec04c7 | ||
|
|
64e4411048 | ||
|
|
4aeec8afbf | ||
|
|
f10432bf3f | ||
|
|
f0efed8aa1 | ||
|
|
4a4931bee2 | ||
|
|
afcf12ebc9 | ||
|
|
e901d3c9d6 | ||
|
|
fb25495f1b | ||
|
|
b6e6dbf27f | ||
|
|
bd5b97e69b | ||
|
|
1e5acd85ff | ||
|
|
6e1f6d886d | ||
|
|
940af67a87 | ||
|
|
c24fb73147 | ||
|
|
4e96c12634 | ||
|
|
37ef497f4c | ||
|
|
2e504f9c48 | ||
|
|
8f86d3417d | ||
|
|
92dfc54c4c | ||
|
|
3be3604125 | ||
|
|
6920deef63 | ||
|
|
6c30347219 | ||
|
|
d6b08b3c5c | ||
|
|
c93bcb8678 | ||
|
|
21ec923f24 | ||
|
|
3a0eab068c | ||
|
|
98b2da9123 | ||
|
|
8aa496f588 | ||
|
|
cd5f1a1b28 | ||
|
|
0e2e495d09 | ||
|
|
84c6c7e2a6 | ||
|
|
c8ebf9c75a | ||
|
|
29852ff0a5 | ||
|
|
f06ca62589 | ||
|
|
3f39a2be12 | ||
|
|
af7b9ee41c | ||
|
|
575190a96d | ||
|
|
78559d98eb | ||
|
|
398964c747 | ||
|
|
a634565296 | ||
|
|
a5ecbec9a6 | ||
|
|
fe79978f88 | ||
|
|
978ec8bc75 | ||
|
|
9e64cb574a | ||
|
|
783593a79d | ||
|
|
afed5e10fc | ||
|
|
a7c0789e36 | ||
|
|
b5b1a98bc4 | ||
|
|
91d3758691 | ||
|
|
c6030bbec8 | ||
|
|
cb62608dbd | ||
|
|
83fe793e72 | ||
|
|
9d36ec70bc | ||
|
|
6e77f5b068 | ||
|
|
c9dbb64269 | ||
|
|
546d32e3eb | ||
|
|
6b95cd05c8 | ||
|
|
804d87bca2 | ||
|
|
e518b57dea | ||
|
|
642587fc97 | ||
|
|
cd1a50a1d1 | ||
|
|
8881daf592 | ||
|
|
3ced895c9c | ||
|
|
75c1892611 | ||
|
|
9f0c4410f7 | ||
|
|
4976fccf7d | ||
|
|
ee2d3fd53a | ||
|
|
63baf3bd40 | ||
|
|
b37ad0e145 | ||
|
|
c255be8d09 | ||
|
|
616f6401b4 | ||
|
|
d047190453 | ||
|
|
17504b1b9c | ||
|
|
12a27dbcf7 | ||
|
|
547ce858e7 | ||
|
|
995b896b9d | ||
|
|
2d90b0c752 | ||
|
|
9d25b08641 | ||
|
|
5a0d3df689 | ||
|
|
004ec0da6d | ||
|
|
3da990ec77 | ||
|
|
ff6bdc1bed | ||
|
|
2891f2c068 | ||
|
|
9353053a23 | ||
|
|
de058e3b1d | ||
|
|
16fb9f59fe | ||
|
|
eb58e0ea63 | ||
|
|
6ba4b9e7bd | ||
|
|
26dd15ef83 | ||
|
|
46752420da | ||
|
|
49f6f27ffc | ||
|
|
3670674e6b | ||
|
|
3606000740 | ||
|
|
622e67e952 | ||
|
|
546d52149d | ||
|
|
825f257cf4 | ||
|
|
0489013ddd | ||
|
|
07760d55b7 | ||
|
|
2aca4ed67e | ||
|
|
21ae3cdd15 | ||
|
|
5b3bad17e2 | ||
|
|
79dc93664b | ||
|
|
c824ac2b72 | ||
|
|
c2c2b306a2 | ||
|
|
2b017139ef | ||
|
|
034559aac7 | ||
|
|
a6a18b7304 | ||
|
|
67d0b196b8 | ||
|
|
ade72bc949 | ||
|
|
88abdc49fe | ||
|
|
4365c8e95c | ||
|
|
d29321c1f2 | ||
|
|
f2b0d6243f | ||
|
|
cdbf8f64a2 | ||
|
|
605a5d27e7 | ||
|
|
935f3d54b3 | ||
|
|
7c1f040b7c | ||
|
|
c8613e8954 | ||
|
|
339f6280e1 | ||
|
|
437dc27586 | ||
|
|
62f33bba18 | ||
|
|
c2998154e0 | ||
|
|
871304c89b | ||
|
|
8155150e45 | ||
|
|
fcf9a92f11 | ||
|
|
73dc01dcee | ||
|
|
8c92b616bf | ||
|
|
f9a35d0cdc | ||
|
|
d6ce2b447f | ||
|
|
677f6f2cb4 | ||
|
|
26e4824d2a | ||
|
|
281746031c | ||
|
|
c4e6f5113b | ||
|
|
752f4a84e5 | ||
|
|
1fb18cc11c | ||
|
|
99d7061a4f | ||
|
|
fd3016122d | ||
|
|
d8fd585631 | ||
|
|
c9e64489b2 | ||
|
|
49b96e2ae7 | ||
|
|
9f0adee8b2 | ||
|
|
d1f44ef650 | ||
|
|
0ed78f7a62 | ||
|
|
000fbf6e98 | ||
|
|
d9fb8edaa9 | ||
|
|
dda61679bd | ||
|
|
cdfe43ce2c | ||
|
|
61f3a1805c | ||
|
|
3edca01dc9 | ||
|
|
d03a1a9a55 | ||
|
|
925d539174 | ||
|
|
973a0b2d47 | ||
|
|
ba30161559 | ||
|
|
89860e490e | ||
|
|
6dee1659bf | ||
|
|
ed38a4eb93 | ||
|
|
fe4a519d40 | ||
|
|
4c8da85050 | ||
|
|
2e1744c66b | ||
|
|
6010e9e4ff | ||
|
|
a3f053ed02 | ||
|
|
ec25efcb75 | ||
|
|
04eaf35567 | ||
|
|
59f24fb5b4 | ||
|
|
6ac10a8297 | ||
|
|
ea9a36a60a | ||
|
|
be285c85ec | ||
|
|
575f0ae334 | ||
|
|
3b032be694 | ||
|
|
6e5e708a36 | ||
|
|
8d2a3b7c9d | ||
|
|
d2eb10123b | ||
|
|
58d82df327 | ||
|
|
0d9cdd5039 | ||
|
|
7b6619b8de | ||
|
|
8262045b1e | ||
|
|
85e3d5a392 | ||
|
|
0b685b136f | ||
|
|
0695c11739 | ||
|
|
7a4297c4f1 | ||
|
|
92b144d7f5 | ||
|
|
a9901e0495 | ||
|
|
a84d23f69f | ||
|
|
167e6a0d11 | ||
|
|
7bbe56d20a | ||
|
|
1afab9248b | ||
|
|
8914e2caf4 | ||
|
|
5904ac80db | ||
|
|
9576a9a55e | ||
|
|
617ff706bc | ||
|
|
a6e7565919 | ||
|
|
b712325399 | ||
|
|
0fb6ab9ebd | ||
|
|
830e9dd6f9 | ||
|
|
567624c323 | ||
|
|
3ed6f9fad0 | ||
|
|
6452733c4e | ||
|
|
271d6b5d8d | ||
|
|
93ff64f130 | ||
|
|
2c9e5df27d | ||
|
|
6db37d35ed | ||
|
|
ceee4fe5cf | ||
|
|
130b4a57de | ||
|
|
fee22f83c9 | ||
|
|
1cee27e830 | ||
|
|
ba2ff053f9 | ||
|
|
8ed2d12da1 | ||
|
|
227665439f | ||
|
|
1a2e043ec2 | ||
|
|
5ec9ac1fba | ||
|
|
c166615ec8 | ||
|
|
b5a366ef5e | ||
|
|
cdcac262a3 | ||
|
|
89500df0ac | ||
|
|
cb4e80f1bc | ||
|
|
fbc1906fa2 | ||
|
|
a16c099f02 | ||
|
|
a6e1898e1b | ||
|
|
2e3e0f4ce9 | ||
|
|
d5e788ed6b | ||
|
|
78bb9315b7 | ||
|
|
9eb3e1329f | ||
|
|
255fb07615 | ||
|
|
a51eb7c7a0 | ||
|
|
938a99a59c | ||
|
|
95b61e9972 | ||
|
|
84e24ede04 | ||
|
|
7438fedd6b | ||
|
|
4448296e7b | ||
|
|
b5c8741803 | ||
|
|
e72ecfcb0a | ||
|
|
641e75bfd4 | ||
|
|
954d754c09 | ||
|
|
1159da111a | ||
|
|
b71f67f7df | ||
|
|
70cbda27eb | ||
|
|
99790551f9 | ||
|
|
bc88379139 | ||
|
|
734ef3c713 | ||
|
|
02de9a03ca | ||
|
|
e0ca2f5725 | ||
|
|
f320d6e002 | ||
|
|
eef0ee5f5c | ||
|
|
da5e8a3d59 | ||
|
|
f9898607ce | ||
|
|
c780d4be14 | ||
|
|
e60bc37fbf | ||
|
|
9a0c403c51 | ||
|
|
e187c01dc9 | ||
|
|
dec9fca8c2 | ||
|
|
7a5792ba01 | ||
|
|
ada63d9f5c | ||
|
|
8f114b0dfa | ||
|
|
2ba8bb58e0 | ||
|
|
7042f4fc1b | ||
|
|
9427584825 | ||
|
|
592c2ac217 | ||
|
|
dd7abc0d27 | ||
|
|
fe4a53563e | ||
|
|
ab02f610e5 | ||
|
|
0a73b18823 | ||
|
|
1ebab759b1 | ||
|
|
2f13cb4cbc | ||
|
|
f5e71f56e9 | ||
|
|
042a34d22f | ||
|
|
49fa6906ac | ||
|
|
8956939ae7 | ||
|
|
171aad78da | ||
|
|
e4f8ddca9a | ||
|
|
213e89c627 | ||
|
|
7741cffa03 | ||
|
|
65a86a9261 | ||
|
|
0433e17b34 | ||
|
|
df7298ee8c | ||
|
|
c01bddf5be | ||
|
|
bca4b22453 | ||
|
|
c4addc7e54 | ||
|
|
6a0cbd7d8e | ||
|
|
38253fa49a | ||
|
|
c1fba39496 | ||
|
|
d3b093c09d | ||
|
|
1ce5020c73 | ||
|
|
89b7d94bd1 | ||
|
|
4b7908f4fc | ||
|
|
6a78ed7c8a | ||
|
|
a006bc94b9 | ||
|
|
59dad5f3fc | ||
|
|
4e1ff99cc5 | ||
|
|
d477520b67 | ||
|
|
17318d8205 | ||
|
|
7eb0d84947 | ||
|
|
ea944d0ee2 | ||
|
|
18d4a5e865 | ||
|
|
84aefc9a79 | ||
|
|
a52e61fb87 | ||
|
|
5b05009123 | ||
|
|
6dbc411adb | ||
|
|
63f965fafe | ||
|
|
a10e07bd60 | ||
|
|
5422b32ad6 | ||
|
|
d957e27501 | ||
|
|
ffa81e66e7 | ||
|
|
b00049e94e | ||
|
|
f2390412d2 | ||
|
|
7a0746cf4e | ||
|
|
dd3ba59c1d | ||
|
|
617bb43274 | ||
|
|
d56cbed0bf | ||
|
|
eb7374cedc | ||
|
|
f6ca6a547f | ||
|
|
9722601bae | ||
|
|
2a12be310d | ||
|
|
d9f03a7e94 | ||
|
|
37cf9f208e | ||
|
|
24ace52e27 | ||
|
|
177d514d13 | ||
|
|
539821454a | ||
|
|
7d28717030 | ||
|
|
c54e471dc9 | ||
|
|
81508a25a8 | ||
|
|
5cb7962593 | ||
|
|
c5dd09cf50 | ||
|
|
009ceefa30 | ||
|
|
7167c2002f | ||
|
|
e05f33b286 | ||
|
|
50480dc506 | ||
|
|
a4af0f7432 | ||
|
|
7871663cae | ||
|
|
4ac010eda7 | ||
|
|
7165d53982 | ||
|
|
a1e8d858a2 | ||
|
|
2360ef64de | ||
|
|
e6a1643bea | ||
|
|
350fe04495 | ||
|
|
25ce86ae93 | ||
|
|
99b4a17f43 | ||
|
|
621bddd270 | ||
|
|
28eeb0c6f7 | ||
|
|
b3f8de3062 | ||
|
|
28eccd6ce9 | ||
|
|
7632b03d50 | ||
|
|
6e075d3fd8 | ||
|
|
89ac968b9f | ||
|
|
701d506c1f | ||
|
|
569a211810 | ||
|
|
75d5121234 | ||
|
|
04d79ac70f | ||
|
|
fc56e1b624 | ||
|
|
030a141c64 | ||
|
|
72c27273e4 | ||
|
|
bcb3d587a1 | ||
|
|
5fe8043ff8 | ||
|
|
c52b360068 | ||
|
|
cd76ccadc5 | ||
|
|
ba2220d7c8 | ||
|
|
957f8f83ff | ||
|
|
55e97e5588 | ||
|
|
18f0b86ce2 | ||
|
|
1083698a1f | ||
|
|
5040c603ff | ||
|
|
7a1131d8af | ||
|
|
1a3b85c2fc | ||
|
|
9409ec0b6c | ||
|
|
020d7445ec | ||
|
|
26947d85ae | ||
|
|
477404554e | ||
|
|
ef36123ebf | ||
|
|
d6b1c2effb | ||
|
|
eb51e04a18 | ||
|
|
fa6ee2ba2b | ||
|
|
c38f3b1691 | ||
|
|
6cc54a2576 | ||
|
|
070d9036b7 | ||
|
|
499d549e41 | ||
|
|
eabaae4a8f | ||
|
|
a716c607d7 | ||
|
|
3183f39535 | ||
|
|
6783375a14 | ||
|
|
2f825b02bf | ||
|
|
a940717ed0 | ||
|
|
9572924e64 | ||
|
|
9d0622b6cc | ||
|
|
492401f9b7 | ||
|
|
35a06c3cbe | ||
|
|
190155f438 | ||
|
|
f59f508c4d | ||
|
|
79d035ac02 | ||
|
|
a0f19ace92 | ||
|
|
85c7e531e4 | ||
|
|
ed62e92da6 | ||
|
|
888312fd7a | ||
|
|
8d3ec8c047 | ||
|
|
b2a1f6bc13 | ||
|
|
962b74a68a | ||
|
|
2fadf88a93 | ||
|
|
411525687e | ||
|
|
26fbb95454 | ||
|
|
5a5c5e5bf4 | ||
|
|
049642ae48 | ||
|
|
0300abc454 | ||
|
|
ebdf4e4c5e | ||
|
|
71c5b54532 | ||
|
|
e1e77f70f9 | ||
|
|
d4a87187cb | ||
|
|
05ec76f940 | ||
|
|
679e518574 | ||
|
|
29ccf956ec | ||
|
|
35db38c2de | ||
|
|
3400bea9ef | ||
|
|
43dd31d0c9 | ||
|
|
9587dfd905 | ||
|
|
af038bc63e | ||
|
|
3d4c807a87 | ||
|
|
1b9271f652 | ||
|
|
3e71e4d15e | ||
|
|
78207aca34 | ||
|
|
ab0e465760 | ||
|
|
c8b6e22143 | ||
|
|
a07e151c95 | ||
|
|
1fc81d1347 | ||
|
|
e8a5cfe7e3 | ||
|
|
d299c39c55 | ||
|
|
dd5e21a653 | ||
|
|
8625c0f266 | ||
|
|
50bdf2cc75 | ||
|
|
393fbee551 | ||
|
|
fc4cf418e0 | ||
|
|
63f0fa5da2 | ||
|
|
4685fd14ad | ||
|
|
5957eb9c1a | ||
|
|
1f6835a8e0 | ||
|
|
b56994b999 | ||
|
|
eaf2437633 | ||
|
|
fc831e04c1 | ||
|
|
bf6ede64bd | ||
|
|
55dac533d9 | ||
|
|
373b91143d | ||
|
|
2e00eec704 | ||
|
|
e1bc6b8597 | ||
|
|
f31341151f | ||
|
|
3fe2ef6611 | ||
|
|
3a3cd59d8e | ||
|
|
a66fb9eade | ||
|
|
c0b29dd938 | ||
|
|
6babd0b531 | ||
|
|
4e3b8870c5 | ||
|
|
4674d4d291 | ||
|
|
57da24220f | ||
|
|
a8c5368d49 | ||
|
|
b731389c81 | ||
|
|
aca5be2d1b | ||
|
|
c6cd2e5839 | ||
|
|
9dd3fc8d08 | ||
|
|
5f0e1694ce | ||
|
|
351be8aaf3 | ||
|
|
02c8fd0e3f | ||
|
|
5f6ae3a0ef | ||
|
|
742d54342b | ||
|
|
ce2346f709 | ||
|
|
b6cfd55aad | ||
|
|
21a33b84e5 | ||
|
|
24d68de98c | ||
|
|
3560038894 | ||
|
|
184150810b | ||
|
|
964f1c4fae | ||
|
|
4c706048de | ||
|
|
e08e761319 | ||
|
|
8081c15d11 | ||
|
|
d476d92b7d | ||
|
|
f5a057ddc5 | ||
|
|
f5afe36d60 | ||
|
|
f529525fbd | ||
|
|
8985d13635 | ||
|
|
4f9b090b34 | ||
|
|
9935459b32 | ||
|
|
b96a63fb22 | ||
|
|
8d6e773a10 | ||
|
|
68683e5d01 | ||
|
|
7f05a9c5c3 | ||
|
|
6b0ee1b74a | ||
|
|
7ce21afbb3 | ||
|
|
c78dc1fd47 | ||
|
|
306efb50ce | ||
|
|
07bcb54ed3 | ||
|
|
0475d80472 | ||
|
|
e6c35e5f5a | ||
|
|
73dad6f017 | ||
|
|
fd8466e002 | ||
|
|
76f760644a | ||
|
|
eabe12eebe | ||
|
|
3e6adb43a0 | ||
|
|
61926c29e5 | ||
|
|
a09d3be310 | ||
|
|
72afe68de9 | ||
|
|
37f72f919f | ||
|
|
3aa52cc676 | ||
|
|
93665180c8 | ||
|
|
ca029730a1 | ||
|
|
775d36b16b | ||
|
|
4e73820271 | ||
|
|
ca8d5f5cc3 | ||
|
|
61e6cc9e42 | ||
|
|
b0c58ec313 | ||
|
|
909c536b47 | ||
|
|
67d6286274 | ||
|
|
7377abe884 | ||
|
|
9108b713de | ||
|
|
82b9925448 | ||
|
|
8ea243c572 | ||
|
|
d1757796ad | ||
|
|
0e5397bcf4 | ||
|
|
eabf6bb1a9 | ||
|
|
724eb4f801 | ||
|
|
571baa19a5 | ||
|
|
ebd2abbfa0 | ||
|
|
a404f06366 | ||
|
|
262952c022 | ||
|
|
c2eff4f359 | ||
|
|
023cf5aa27 | ||
|
|
0078028992 | ||
|
|
9bedcadca4 | ||
|
|
1383f4abcf | ||
|
|
b09df4d009 | ||
|
|
f1a1d4afff | ||
|
|
0386d57f05 | ||
|
|
c9b02d0c83 | ||
|
|
8e893662f3 | ||
|
|
f93890f9aa | ||
|
|
8222a630e5 | ||
|
|
733b349df1 | ||
|
|
0b3fe0e799 | ||
|
|
7f823ee72e | ||
|
|
0fce86f76b | ||
|
|
624b79aa11 | ||
|
|
06e5f4f8ff | ||
|
|
e11c1bb233 | ||
|
|
7dd4db52df | ||
|
|
00456d5ed0 | ||
|
|
bf5ae25c9c | ||
|
|
38dd19b08a | ||
|
|
30a1f2afe9 | ||
|
|
89ae99078a | ||
|
|
ea4fc49c5a | ||
|
|
0a5beeb053 | ||
|
|
66f395a314 | ||
|
|
52bc67d91d | ||
|
|
a0a3997af2 | ||
|
|
00e061023e | ||
|
|
44aac44a05 | ||
|
|
d591e27c9f | ||
|
|
6887d7dc74 | ||
|
|
433d7b0c49 | ||
|
|
9f647e8357 | ||
|
|
dea328e42a | ||
|
|
c00e164567 | ||
|
|
88ec0f72de | ||
|
|
99c501f188 | ||
|
|
b8bb14966d | ||
|
|
ad2f52c037 | ||
|
|
8408f7d9b7 | ||
|
|
d5a7afb750 | ||
|
|
29ffd0d810 | ||
|
|
c9d89b94b3 | ||
|
|
bfed5404b4 | ||
|
|
ea411c13af | ||
|
|
e608d8f9d0 | ||
|
|
5edaeaf620 | ||
|
|
2b30a69b94 | ||
|
|
c5b15b7351 | ||
|
|
07668ee4c5 | ||
|
|
a1def533af | ||
|
|
578957f389 | ||
|
|
34dde34e61 | ||
|
|
00b2539e4f | ||
|
|
4c1ea155b0 | ||
|
|
35e84bb872 | ||
|
|
6879326429 | ||
|
|
42610f9cb0 | ||
|
|
aa9a82382e | ||
|
|
d1b51b9653 | ||
|
|
c4a5fff954 | ||
|
|
72acea990a | ||
|
|
480f721888 | ||
|
|
26263bdcf0 | ||
|
|
7d40d06b69 | ||
|
|
c6f588fc8c | ||
|
|
5736a70ccb | ||
|
|
d733047f87 | ||
|
|
cd325fe198 | ||
|
|
9f7bafe7fb | ||
|
|
773e785ce9 | ||
|
|
f78fc241a8 | ||
|
|
54ff151ed8 | ||
|
|
5a6f9dfc11 | ||
|
|
b1e69e154b | ||
|
|
497db0bea9 | ||
|
|
ad0a7ebcb9 | ||
|
|
1d97660a20 | ||
|
|
281aec23e3 | ||
|
|
aab8043c8d | ||
|
|
ad2f47029d | ||
|
|
9964f44fc8 | ||
|
|
e1bccff79b | ||
|
|
7b0ed80377 | ||
|
|
6d91f84e33 | ||
|
|
54c2c8f74f | ||
|
|
eaa47ad9f1 | ||
|
|
1027213fc7 | ||
|
|
62071ff96f | ||
|
|
bcec0ae401 | ||
|
|
7da3c5a8e8 | ||
|
|
6cd436d9b8 | ||
|
|
50a244af4d | ||
|
|
d53663fc78 | ||
|
|
3ab68e126f | ||
|
|
83a831d27a | ||
|
|
fa53ab1d6a | ||
|
|
46fe6fff89 | ||
|
|
29cc708f4f | ||
|
|
03e328ae65 | ||
|
|
c4423d609b | ||
|
|
7d56c14e42 | ||
|
|
e92a2f814b | ||
|
|
0a9c01cf33 | ||
|
|
9799ba510a | ||
|
|
44ceee3f42 | ||
|
|
ac2173743b | ||
|
|
1bc06e8204 | ||
|
|
36cab874fa | ||
|
|
d27ed6c419 | ||
|
|
af2b8531e9 | ||
|
|
9b2f603454 | ||
|
|
fb01d185e4 | ||
|
|
437b26ecfc | ||
|
|
1eb36f1aef | ||
|
|
694a0eb4e3 | ||
|
|
9cf22bfae2 | ||
|
|
db3d3dee85 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -21,6 +21,7 @@ examples/
|
||||
|
||||
# Temporary outputs
|
||||
.DS_Store
|
||||
.hypothesis/
|
||||
time.log
|
||||
celerybeat-schedule.db
|
||||
search_results.json
|
||||
@@ -28,6 +29,7 @@ search_results.json
|
||||
api/migrations/versions
|
||||
tmp
|
||||
files
|
||||
powers/
|
||||
|
||||
# Exclude dep files
|
||||
huggingface.co/
|
||||
@@ -35,3 +37,5 @@ nltk_data/
|
||||
tika-server*.jar*
|
||||
cl100k_base.tiktoken
|
||||
libssl*.deb
|
||||
|
||||
sandbox/lib/seccomp_redbear/target
|
||||
|
||||
18
README.md
18
README.md
@@ -226,8 +226,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (Using Redis as broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
@@ -334,7 +334,13 @@ step6: Log In to the Frontend Interface.
|
||||
## License
|
||||
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
|
||||
|
||||
## Acknowledgements & Community
|
||||
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
|
||||
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines.
|
||||
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
|
||||
## Community & Support
|
||||
|
||||
Join our community to ask questions, share your work, and connect with fellow developers.
|
||||
|
||||
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/SuanmoSuanyangTechnology/MemoryBear/issues).
|
||||
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/SuanmoSuanyangTechnology/MemoryBear/pulls).
|
||||
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/SuanmoSuanyangTechnology/MemoryBear/discussions).
|
||||
- **WeChat**: Scan the QR code below to join our WeChat community group.
|
||||
- 
|
||||
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
|
||||
|
||||
@@ -201,8 +201,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (使用Redis作为broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT密钥 (生成方式: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
|
||||
@@ -10,7 +10,6 @@ from app.core.config import settings
|
||||
# 设置日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 创建连接池
|
||||
pool = ConnectionPool.from_url(
|
||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
|
||||
@@ -21,6 +20,7 @@ pool = ConnectionPool.from_url(
|
||||
)
|
||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||
|
||||
|
||||
async def get_redis_connection():
|
||||
"""获取Redis连接"""
|
||||
try:
|
||||
@@ -29,7 +29,8 @@ async def get_redis_connection():
|
||||
logger.error(f"Redis连接失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
|
||||
async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
||||
"""设置Redis键值
|
||||
|
||||
Args:
|
||||
@@ -40,7 +41,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
val = json.dumps(val, ensure_ascii=False)
|
||||
|
||||
|
||||
if expire is not None:
|
||||
# 设置带过期时间的键值
|
||||
await aio_redis.set(key, val, ex=expire)
|
||||
@@ -50,6 +51,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set错误: {str(e)}")
|
||||
|
||||
|
||||
async def aio_redis_get(key: str):
|
||||
"""获取Redis键值"""
|
||||
try:
|
||||
@@ -58,6 +60,7 @@ async def aio_redis_get(key: str):
|
||||
logger.error(f"Redis get错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def aio_redis_delete(key: str):
|
||||
"""删除Redis键"""
|
||||
try:
|
||||
@@ -66,6 +69,7 @@ async def aio_redis_delete(key: str):
|
||||
logger.error(f"Redis delete错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||
"""发布消息到Redis频道"""
|
||||
try:
|
||||
@@ -78,9 +82,10 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||
logger.error(f"Redis发布错误: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
class RedisSubscriber:
|
||||
"""Redis订阅器"""
|
||||
|
||||
|
||||
def __init__(self, channel: str):
|
||||
self.channel = channel
|
||||
self.conn = None
|
||||
@@ -88,25 +93,25 @@ class RedisSubscriber:
|
||||
self.is_closed = False
|
||||
self._queue = asyncio.Queue()
|
||||
self._task = None
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""开始订阅"""
|
||||
if self.is_closed or self._task:
|
||||
return
|
||||
|
||||
|
||||
self._task = asyncio.create_task(self._receive_messages())
|
||||
logger.info(f"开始订阅: {self.channel}")
|
||||
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""接收消息"""
|
||||
try:
|
||||
self.conn = await get_redis_connection()
|
||||
if not self.conn:
|
||||
return
|
||||
|
||||
|
||||
self.pubsub = self.conn.pubsub()
|
||||
await self.pubsub.subscribe(self.channel)
|
||||
|
||||
|
||||
while not self.is_closed:
|
||||
try:
|
||||
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
|
||||
@@ -127,7 +132,7 @@ class RedisSubscriber:
|
||||
finally:
|
||||
await self._queue.put(None)
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
async def _cleanup(self):
|
||||
"""清理资源"""
|
||||
if self.pubsub:
|
||||
@@ -141,7 +146,7 @@ class RedisSubscriber:
|
||||
await self.conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def get_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取消息"""
|
||||
if self.is_closed:
|
||||
@@ -153,7 +158,7 @@ class RedisSubscriber:
|
||||
except Exception as e:
|
||||
logger.error(f"获取消息错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def close(self):
|
||||
"""关闭订阅器"""
|
||||
if self.is_closed:
|
||||
@@ -163,32 +168,33 @@ class RedisSubscriber:
|
||||
self._task.cancel()
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
class RedisPubSubManager:
|
||||
"""Redis发布订阅管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.subscribers = {}
|
||||
|
||||
|
||||
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
|
||||
return await aio_redis_publish(channel, message)
|
||||
|
||||
|
||||
def get_subscriber(self, channel: str) -> RedisSubscriber:
|
||||
if channel in self.subscribers:
|
||||
subscriber = self.subscribers[channel]
|
||||
if not subscriber.is_closed:
|
||||
return subscriber
|
||||
|
||||
|
||||
subscriber = RedisSubscriber(channel)
|
||||
self.subscribers[channel] = subscriber
|
||||
return subscriber
|
||||
|
||||
|
||||
def cancel_subscription(self, channel: str) -> bool:
|
||||
if channel in self.subscribers:
|
||||
asyncio.create_task(self.subscribers[channel].close())
|
||||
del self.subscribers[channel]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cancel_all_subscriptions(self) -> int:
|
||||
count = len(self.subscribers)
|
||||
for subscriber in self.subscribers.values():
|
||||
@@ -196,6 +202,6 @@ class RedisPubSubManager:
|
||||
self.subscribers.clear()
|
||||
return count
|
||||
|
||||
|
||||
# 全局实例
|
||||
pubsub_manager = RedisPubSubManager()
|
||||
|
||||
|
||||
25
api/app/base/type.py
Normal file
25
api/app/base/type.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import TypeDecorator, JSON
|
||||
|
||||
|
||||
class PydanticType(TypeDecorator):
|
||||
impl = JSON
|
||||
|
||||
def __init__(self, pydantic_model: type[BaseModel]):
|
||||
super().__init__()
|
||||
self.model = pydantic_model
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
# 入库:Model -> dict
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, self.model):
|
||||
return value.dict()
|
||||
return value # 已经是 dict 也放行
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
# 出库:dict -> Model
|
||||
if value is None:
|
||||
return None
|
||||
# return self.model.parse_obj(value) # pydantic v1
|
||||
return self.model.model_validate(value) # pydantic v2
|
||||
10
api/app/cache/__init__.py
vendored
Normal file
10
api/app/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Cache 缓存模块
|
||||
|
||||
提供各种缓存功能的统一入口
|
||||
"""
|
||||
from .memory import InterestMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"InterestMemoryCache",
|
||||
]
|
||||
12
api/app/cache/memory/__init__.py
vendored
Normal file
12
api/app/cache/memory/__init__.py
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Memory 缓存模块
|
||||
|
||||
提供记忆系统相关的缓存功能
|
||||
"""
|
||||
from .interest_memory import InterestMemoryCache
|
||||
from .activity_stats_cache import ActivityStatsCache
|
||||
|
||||
__all__ = [
|
||||
"InterestMemoryCache",
|
||||
"ActivityStatsCache",
|
||||
]
|
||||
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Recent Activity Stats Cache
|
||||
|
||||
记忆提取活动统计缓存模块
|
||||
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放
|
||||
查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期时间:24小时
|
||||
ACTIVITY_STATS_CACHE_EXPIRE = 86400
|
||||
|
||||
|
||||
class ActivityStatsCache:
|
||||
"""记忆提取活动统计缓存类"""
|
||||
|
||||
PREFIX = "cache:memory:activity_stats"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, workspace_id: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
|
||||
|
||||
@classmethod
|
||||
async def set_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
stats: Dict[str, Any],
|
||||
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
|
||||
) -> bool:
|
||||
"""设置记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
stats: 统计数据,格式:
|
||||
{
|
||||
"chunk_count": int,
|
||||
"statements_count": int,
|
||||
"triplet_entities_count": int,
|
||||
"triplet_relations_count": int,
|
||||
"temporal_count": int,
|
||||
}
|
||||
expire: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
payload = {
|
||||
"stats": stats,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"workspace_id": workspace_id,
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""获取记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
统计数据字典,缓存不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
value = await aio_redis.get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中活动统计缓存: {key}")
|
||||
return payload
|
||||
logger.info(f"活动统计缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
) -> bool:
|
||||
"""删除记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Interest Distribution Cache
|
||||
|
||||
兴趣分布缓存模块
|
||||
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期时间:24小时
|
||||
INTEREST_CACHE_EXPIRE = 86400
|
||||
|
||||
|
||||
class InterestMemoryCache:
|
||||
"""兴趣分布缓存类"""
|
||||
|
||||
PREFIX = "cache:memory:interest_distribution"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, end_user_id: str, language: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
|
||||
|
||||
@classmethod
|
||||
async def set_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
data: List[Dict[str, Any]],
|
||||
expire: int = INTEREST_CACHE_EXPIRE,
|
||||
) -> bool:
|
||||
"""设置用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
|
||||
expire: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
payload = {
|
||||
"data": data,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""获取用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
兴趣分布列表,缓存不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
value = await aio_redis.get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中兴趣分布缓存: {key}")
|
||||
return payload.get("data")
|
||||
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> bool:
|
||||
"""删除用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -1,40 +1,60 @@
|
||||
import os
|
||||
import platform
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.config import settings
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
# backend: 结果存储(使用 Redis DB 10)
|
||||
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定)
|
||||
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定)
|
||||
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||
|
||||
# Build canonical broker/backend URLs and force them into os.environ so that
|
||||
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
|
||||
# cannot be overridden by stray env vars.
|
||||
# See: https://github.com/celery/celery/issues/4284
|
||||
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
|
||||
# integration and accidentally override our canonical URLs.
|
||||
os.environ.pop("BROKER_URL", None)
|
||||
os.environ.pop("RESULT_BACKEND", None)
|
||||
os.environ.pop("CELERY_BROKER", None)
|
||||
os.environ.pop("CELERY_BACKEND", None)
|
||||
|
||||
celery_app = Celery(
|
||||
"redbear_tasks",
|
||||
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||
broker=_broker_url,
|
||||
backend=_backend_url,
|
||||
)
|
||||
|
||||
# 配置使用本地队列,避免与远程 worker 冲突
|
||||
celery_app.conf.task_default_queue = 'localhost_test_wyl'
|
||||
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
|
||||
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
|
||||
logger.info(
|
||||
"Celery app initialized",
|
||||
extra={
|
||||
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
},
|
||||
)
|
||||
# Default queue for unrouted tasks
|
||||
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||
|
||||
# macOS 兼容性配置
|
||||
import platform
|
||||
|
||||
if platform.system() == 'Darwin': # macOS
|
||||
# 设置环境变量解决 fork 问题
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 使用 solo 池避免多进程问题
|
||||
celery_app.conf.worker_pool = 'solo'
|
||||
|
||||
# 设置唯一的节点名称
|
||||
import socket
|
||||
import time
|
||||
hostname = socket.gethostname()
|
||||
timestamp = int(time.time())
|
||||
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
|
||||
|
||||
# Celery 配置
|
||||
celery_app.conf.update(
|
||||
@@ -52,47 +72,65 @@ celery_app.conf.update(
|
||||
task_ignore_result=False,
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=30 * 60, # 30 分钟硬超时
|
||||
task_soft_time_limit=25 * 60, # 25 分钟软超时
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
# Worker 设置 - 针对 macOS 优化
|
||||
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
|
||||
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
|
||||
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存 1 小时
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
# 任务确认设置
|
||||
task_acks_late=True, # 任务完成后才确认,避免任务丢失
|
||||
worker_disable_rate_limits=True, # 禁用速率限制
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_disable_rate_limits=True,
|
||||
|
||||
# 任务路由(可选,用于不同队列)
|
||||
# task_routes={
|
||||
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
|
||||
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
|
||||
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
|
||||
# 'tasks.process_item': {'queue': 'default'},
|
||||
# },
|
||||
# FLower setting
|
||||
worker_send_task_events=True,
|
||||
task_send_sent_event=True,
|
||||
|
||||
# task routing
|
||||
task_routes={
|
||||
# Memory tasks → memory_tasks queue (threads worker)
|
||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
# 自动发现任务模块
|
||||
celery_app.autodiscover_tasks(['app'])
|
||||
|
||||
# Celery Beat schedule for periodic tasks
|
||||
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
|
||||
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
|
||||
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
|
||||
implicit_emotions_update_schedule = crontab(
|
||||
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
# "check-read-service": {
|
||||
# "task": "app.core.memory.agent.health.check_read_service",
|
||||
# "schedule": health_schedule,
|
||||
# "args": (),
|
||||
# },
|
||||
#构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
"schedule": workspace_reflection_schedule,
|
||||
@@ -103,16 +141,23 @@ beat_schedule_config = {
|
||||
"schedule": memory_cache_regeneration_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"run-forgetting-cycle": {
|
||||
"task": "app.tasks.run_forgetting_cycle_task",
|
||||
"schedule": forgetting_cycle_schedule,
|
||||
"kwargs": {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
"write-all-workspaces-memory": {
|
||||
"task": "app.tasks.write_all_workspaces_memory_task",
|
||||
"schedule": memory_increment_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"update-implicit-emotions-storage": {
|
||||
"task": "app.tasks.update_implicit_emotions_storage",
|
||||
"schedule": implicit_emotions_update_schedule,
|
||||
"args": (),
|
||||
},
|
||||
}
|
||||
|
||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
if settings.DEFAULT_WORKSPACE_ID:
|
||||
beat_schedule_config["write-total-memory"] = {
|
||||
"task": "app.controllers.memory_storage_controller.search_all",
|
||||
"schedule": memory_increment_schedule,
|
||||
"kwargs": {
|
||||
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
},
|
||||
}
|
||||
|
||||
celery_app.conf.beat_schedule = beat_schedule_config
|
||||
|
||||
@@ -3,6 +3,12 @@ Celery Worker 入口点
|
||||
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
|
||||
"""
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import LoggingConfig, get_logger
|
||||
|
||||
# Initialize logging system for Celery worker
|
||||
LoggingConfig.setup_logging()
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Celery worker logging initialized")
|
||||
|
||||
# 导入任务模块以注册任务
|
||||
import app.tasks
|
||||
|
||||
1
api/app/config/__init__.py
Normal file
1
api/app/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Configuration module for application settings."""
|
||||
239
api/app/config/default_ontology_config.py
Normal file
239
api/app/config/default_ontology_config.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""默认本体场景配置
|
||||
|
||||
本模块定义系统预设的本体场景和实体类型配置。
|
||||
这些配置用于在工作空间创建时自动初始化默认场景。
|
||||
支持中英文双语配置,根据用户语言偏好创建对应语言的场景。
|
||||
"""
|
||||
|
||||
# 在线教育场景配置
|
||||
ONLINE_EDUCATION_SCENE = {
|
||||
"name_chinese": "在线教育",
|
||||
"name_english": "Online Education",
|
||||
"description_chinese": "适用于在线教育平台的本体建模,包含学生、教师、课程等核心实体类型",
|
||||
"description_english": "Ontology modeling for online education platforms, including core entity types such as students, teachers, and courses",
|
||||
"types": [
|
||||
{
|
||||
"name_chinese": "学生",
|
||||
"name_english": "Student",
|
||||
"description_chinese": "在教育系统中接受教育的个体,包含姓名、学号、年级、班级等属性",
|
||||
"description_english": "Individuals receiving education in the education system, including attributes such as name, student ID, grade, and class"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教师",
|
||||
"name_english": "Teacher",
|
||||
"description_chinese": "在教育系统中提供教学服务的个体,包含姓名、工号、任教学科、职称等属性",
|
||||
"description_english": "Individuals providing teaching services in the education system, including attributes such as name, employee ID, teaching subject, and title"
|
||||
},
|
||||
{
|
||||
"name_chinese": "课程",
|
||||
"name_english": "Course",
|
||||
"description_chinese": "教育系统中的教学内容单元,包含课程名称、课程代码、学分、学时等属性",
|
||||
"description_english": "Teaching content units in the education system, including attributes such as course name, course code, credits, and class hours"
|
||||
},
|
||||
{
|
||||
"name_chinese": "作业",
|
||||
"name_english": "Assignment",
|
||||
"description_chinese": "课程中布置的学习任务,包含作业标题、截止日期、所属课程、提交状态等属性",
|
||||
"description_english": "Learning tasks assigned in courses, including attributes such as assignment title, deadline, course, and submission status"
|
||||
},
|
||||
{
|
||||
"name_chinese": "成绩",
|
||||
"name_english": "Grade",
|
||||
"description_chinese": "学生学习成果的评价结果,包含分数、评级、考试类型、所属课程等属性",
|
||||
"description_english": "Evaluation results of student learning outcomes, including attributes such as score, rating, exam type, and course"
|
||||
},
|
||||
{
|
||||
"name_chinese": "考试",
|
||||
"name_english": "Exam",
|
||||
"description_chinese": "评估学生学习成果的测试活动,包含考试名称、时间、地点、科目等属性",
|
||||
"description_english": "Test activities to assess student learning outcomes, including attributes such as exam name, time, location, and subject"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教室",
|
||||
"name_english": "Classroom",
|
||||
"description_chinese": "进行教学活动的物理或虚拟空间,包含教室编号、容量、设备等属性",
|
||||
"description_english": "Physical or virtual spaces for teaching activities, including attributes such as classroom number, capacity, and equipment"
|
||||
},
|
||||
{
|
||||
"name_chinese": "学科",
|
||||
"name_english": "Subject",
|
||||
"description_chinese": "知识的分类领域,包含学科名称、代码、所属院系等属性",
|
||||
"description_english": "Classification domains of knowledge, including attributes such as subject name, code, and department"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教材",
|
||||
"name_english": "Textbook",
|
||||
"description_chinese": "教学使用的书籍或资料,包含书名、作者、出版社、ISBN等属性",
|
||||
"description_english": "Books or materials used for teaching, including attributes such as title, author, publisher, and ISBN"
|
||||
},
|
||||
{
|
||||
"name_chinese": "班级",
|
||||
"name_english": "Class",
|
||||
"description_chinese": "学生的组织单位,包含班级名称、年级、人数、班主任等属性",
|
||||
"description_english": "Organizational units of students, including attributes such as class name, grade, number of students, and class teacher"
|
||||
},
|
||||
{
|
||||
"name_chinese": "学期",
|
||||
"name_english": "Semester",
|
||||
"description_chinese": "教学时间的划分单位,包含学期名称、开始时间、结束时间等属性",
|
||||
"description_english": "Time division units for teaching, including attributes such as semester name, start time, and end time"
|
||||
},
|
||||
{
|
||||
"name_chinese": "课时",
|
||||
"name_english": "Class Hour",
|
||||
"description_chinese": "课程的时间单位,包含上课时间、地点、教师、课程等属性",
|
||||
"description_english": "Time units of courses, including attributes such as class time, location, teacher, and course"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教学计划",
|
||||
"name_english": "Teaching Plan",
|
||||
"description_chinese": "课程的教学安排,包含教学目标、内容安排、进度计划等属性",
|
||||
"description_english": "Teaching arrangements for courses, including attributes such as teaching objectives, content arrangement, and progress plan"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 情感陪伴场景配置
|
||||
EMOTIONAL_COMPANION_SCENE = {
|
||||
"name_chinese": "情感陪伴",
|
||||
"name_english": "Emotional Companion",
|
||||
"description_chinese": "适用于情感陪伴应用的本体建模,包含用户、情绪、活动等核心实体类型",
|
||||
"description_english": "Ontology modeling for emotional companion applications, including core entity types such as users, emotions, and activities",
|
||||
"types": [
|
||||
{
|
||||
"name_chinese": "用户",
|
||||
"name_english": "User",
|
||||
"description_chinese": "使用情感陪伴服务的个体,包含姓名、昵称、性格特征、偏好等属性",
|
||||
"description_english": "Individuals using emotional companion services, including attributes such as name, nickname, personality traits, and preferences"
|
||||
},
|
||||
{
|
||||
"name_chinese": "情绪",
|
||||
"name_english": "Emotion",
|
||||
"description_chinese": "用户的情感状态,包含情绪类型、强度、触发原因、持续时间等属性",
|
||||
"description_english": "Emotional states of users, including attributes such as emotion type, intensity, trigger cause, and duration"
|
||||
},
|
||||
{
|
||||
"name_chinese": "活动",
|
||||
"name_english": "Activity",
|
||||
"description_chinese": "用户参与的各类活动,包含活动名称、类型、参与者、时间地点等属性",
|
||||
"description_english": "Various activities users participate in, including attributes such as activity name, type, participants, time, and location"
|
||||
},
|
||||
{
|
||||
"name_chinese": "对话",
|
||||
"name_english": "Conversation",
|
||||
"description_chinese": "用户之间的交流记录,包含对话主题、参与者、时间、关键内容等属性",
|
||||
"description_english": "Communication records between users, including attributes such as conversation topic, participants, time, and key content"
|
||||
},
|
||||
{
|
||||
"name_chinese": "兴趣爱好",
|
||||
"name_english": "Hobby",
|
||||
"description_chinese": "用户的兴趣和爱好,包含爱好名称、类别、熟练程度、相关活动等属性",
|
||||
"description_english": "User interests and hobbies, including attributes such as hobby name, category, proficiency level, and related activities"
|
||||
},
|
||||
{
|
||||
"name_chinese": "日常事件",
|
||||
"name_english": "Daily Event",
|
||||
"description_chinese": "用户日常生活中的事件,包含事件描述、时间、地点、相关人物等属性",
|
||||
"description_english": "Events in users' daily lives, including attributes such as event description, time, location, and related people"
|
||||
},
|
||||
{
|
||||
"name_chinese": "关系",
|
||||
"name_english": "Relationship",
|
||||
"description_chinese": "用户之间的社会关系,包含关系类型、亲密度、建立时间等属性",
|
||||
"description_english": "Social relationships between users, including attributes such as relationship type, intimacy, and establishment time"
|
||||
},
|
||||
{
|
||||
"name_chinese": "回忆",
|
||||
"name_english": "Memory",
|
||||
"description_chinese": "用户的重要记忆片段,包含回忆内容、时间、地点、相关人物等属性",
|
||||
"description_english": "Important memory fragments of users, including attributes such as memory content, time, location, and related people"
|
||||
},
|
||||
{
|
||||
"name_chinese": "地点",
|
||||
"name_english": "Location",
|
||||
"description_chinese": "用户活动的地理位置,包含地点名称、地址、类型、相关事件等属性",
|
||||
"description_english": "Geographic locations of user activities, including attributes such as location name, address, type, and related events"
|
||||
},
|
||||
{
|
||||
"name_chinese": "时间节点",
|
||||
"name_english": "Time Point",
|
||||
"description_chinese": "重要的时间标记,包含日期、事件、意义等属性",
|
||||
"description_english": "Important time markers, including attributes such as date, event, and significance"
|
||||
},
|
||||
{
|
||||
"name_chinese": "目标",
|
||||
"name_english": "Goal",
|
||||
"description_chinese": "用户设定的目标,包含目标描述、截止时间、完成状态、相关活动等属性",
|
||||
"description_english": "Goals set by users, including attributes such as goal description, deadline, completion status, and related activities"
|
||||
},
|
||||
{
|
||||
"name_chinese": "成就",
|
||||
"name_english": "Achievement",
|
||||
"description_chinese": "用户获得的成就,包含成就名称、获得时间、描述、相关目标等属性",
|
||||
"description_english": "Achievements obtained by users, including attributes such as achievement name, acquisition time, description, and related goals"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 导出默认场景列表
|
||||
DEFAULT_SCENES = [ONLINE_EDUCATION_SCENE, EMOTIONAL_COMPANION_SCENE]
|
||||
|
||||
|
||||
def get_scene_name(scene_config: dict, language: str = "zh") -> str:
|
||||
"""获取场景名称(根据语言)
|
||||
|
||||
Args:
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的场景名称
|
||||
"""
|
||||
if language == "en":
|
||||
return scene_config.get("name_english", scene_config.get("name_chinese"))
|
||||
return scene_config.get("name_chinese")
|
||||
|
||||
|
||||
def get_scene_description(scene_config: dict, language: str = "zh") -> str:
|
||||
"""获取场景描述(根据语言)
|
||||
|
||||
Args:
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的场景描述
|
||||
"""
|
||||
if language == "en":
|
||||
return scene_config.get("description_english", scene_config.get("description_chinese"))
|
||||
return scene_config.get("description_chinese")
|
||||
|
||||
|
||||
def get_type_name(type_config: dict, language: str = "zh") -> str:
|
||||
"""获取类型名称(根据语言)
|
||||
|
||||
Args:
|
||||
type_config: 类型配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的类型名称
|
||||
"""
|
||||
if language == "en":
|
||||
return type_config.get("name_english", type_config.get("name_chinese"))
|
||||
return type_config.get("name_chinese")
|
||||
|
||||
|
||||
def get_type_description(type_config: dict, language: str = "zh") -> str:
|
||||
"""获取类型描述(根据语言)
|
||||
|
||||
Args:
|
||||
type_config: 类型配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的类型描述
|
||||
"""
|
||||
if language == "en":
|
||||
return type_config.get("description_english", type_config.get("description_chinese"))
|
||||
return type_config.get("description_chinese")
|
||||
249
api/app/config/default_ontology_initializer.py
Normal file
249
api/app/config/default_ontology_initializer.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""默认本体场景初始化器
|
||||
|
||||
本模块提供默认本体场景和类型的自动初始化功能。
|
||||
在工作空间创建时,自动添加预设的本体场景和实体类型。
|
||||
|
||||
Classes:
|
||||
DefaultOntologyInitializer: 默认本体场景初始化器
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config.default_ontology_config import (
|
||||
DEFAULT_SCENES,
|
||||
get_scene_name,
|
||||
get_scene_description,
|
||||
get_type_name,
|
||||
get_type_description,
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
|
||||
class DefaultOntologyInitializer:
|
||||
"""默认本体场景初始化器
|
||||
|
||||
负责在工作空间创建时自动初始化默认的本体场景和类型。
|
||||
遵循最小侵入原则,确保初始化失败不阻止工作空间创建。
|
||||
|
||||
Attributes:
|
||||
db: 数据库会话
|
||||
scene_repo: 场景Repository
|
||||
class_repo: 类型Repository
|
||||
logger: 业务日志记录器
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self.scene_repo = OntologySceneRepository(db)
|
||||
self.class_repo = OntologyClassRepository(db)
|
||||
self.logger = get_business_logger()
|
||||
|
||||
def initialize_default_scenes(
|
||||
self,
|
||||
workspace_id: UUID,
|
||||
language: str = "zh"
|
||||
) -> Tuple[bool, str]:
|
||||
"""为工作空间初始化默认场景
|
||||
|
||||
创建两个默认场景(在线教育、情感陪伴)及其对应的实体类型。
|
||||
如果创建失败,记录错误日志但不抛出异常。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
language: 语言类型 ("zh" 或 "en"),默认为 "zh"
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
self.logger.info(
|
||||
f"开始初始化默认本体场景 - workspace_id={workspace_id}, language={language}"
|
||||
)
|
||||
|
||||
scenes_created = 0
|
||||
total_types_created = 0
|
||||
|
||||
# 遍历默认场景配置
|
||||
for scene_config in DEFAULT_SCENES:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
|
||||
# 创建场景及其类型
|
||||
scene_id = self._create_scene_with_types(workspace_id, scene_config, language)
|
||||
|
||||
if scene_id:
|
||||
scenes_created += 1
|
||||
# 统计类型数量
|
||||
types_count = len(scene_config.get("types", []))
|
||||
total_types_created += types_count
|
||||
|
||||
self.logger.info(
|
||||
f"场景创建成功 - scene_name={scene_name}, "
|
||||
f"scene_id={scene_id}, types_count={types_count}, language={language}"
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"场景创建失败 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, language={language}"
|
||||
)
|
||||
|
||||
# 记录总体结果
|
||||
self.logger.info(
|
||||
f"默认场景初始化完成 - workspace_id={workspace_id}, "
|
||||
f"language={language}, scenes_created={scenes_created}, "
|
||||
f"total_types_created={total_types_created}"
|
||||
)
|
||||
|
||||
# 如果至少创建了一个场景,视为成功
|
||||
if scenes_created > 0:
|
||||
return True, ""
|
||||
else:
|
||||
error_msg = "所有默认场景创建失败"
|
||||
self.logger.error(
|
||||
f"默认场景初始化失败 - workspace_id={workspace_id}, "
|
||||
f"language={language}, error={error_msg}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"默认场景初始化异常: {str(e)}"
|
||||
self.logger.error(
|
||||
f"默认场景初始化异常 - workspace_id={workspace_id}, "
|
||||
f"language={language}, error={str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
def _create_scene_with_types(
|
||||
self,
|
||||
workspace_id: UUID,
|
||||
scene_config: dict,
|
||||
language: str = "zh"
|
||||
) -> Optional[UUID]:
|
||||
"""创建场景及其类型
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
Optional[UUID]: 创建的场景ID,失败返回None
|
||||
"""
|
||||
try:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
scene_description = get_scene_description(scene_config, language)
|
||||
|
||||
# 检查是否已存在同名场景(支持向后兼容)
|
||||
existing_scene = self.scene_repo.get_by_name(scene_name, workspace_id)
|
||||
if existing_scene:
|
||||
self.logger.info(
|
||||
f"场景已存在,跳过创建 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, scene_id={existing_scene.scene_id}, "
|
||||
f"language={language}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 创建场景记录,设置 is_system_default=true
|
||||
scene_data = {
|
||||
"scene_name": scene_name,
|
||||
"scene_description": scene_description
|
||||
}
|
||||
|
||||
scene = self.scene_repo.create(scene_data, workspace_id)
|
||||
|
||||
# 设置系统默认标识
|
||||
scene.is_system_default = True
|
||||
self.db.flush()
|
||||
|
||||
self.logger.info(
|
||||
f"场景创建成功 - scene_name={scene_name}, "
|
||||
f"scene_id={scene.scene_id}, is_system_default=True, language={language}"
|
||||
)
|
||||
|
||||
# 批量创建类型
|
||||
types_config = scene_config.get("types", [])
|
||||
types_created = self._batch_create_types(scene.scene_id, types_config, language)
|
||||
|
||||
self.logger.info(
|
||||
f"场景类型创建完成 - scene_id={scene.scene_id}, "
|
||||
f"types_created={types_created}/{len(types_config)}, language={language}"
|
||||
)
|
||||
|
||||
return scene.scene_id
|
||||
|
||||
except Exception as e:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
self.logger.error(
|
||||
f"场景创建失败 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, language={language}, error={str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
def _batch_create_types(
|
||||
self,
|
||||
scene_id: UUID,
|
||||
types_config: List[dict],
|
||||
language: str = "zh"
|
||||
) -> int:
|
||||
"""批量创建实体类型
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID
|
||||
types_config: 类型配置列表
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
int: 成功创建的类型数量
|
||||
"""
|
||||
created_count = 0
|
||||
|
||||
for type_config in types_config:
|
||||
try:
|
||||
type_name = get_type_name(type_config, language)
|
||||
type_description = get_type_description(type_config, language)
|
||||
|
||||
# 创建类型数据
|
||||
class_data = {
|
||||
"class_name": type_name,
|
||||
"class_description": type_description
|
||||
}
|
||||
|
||||
# 创建类型
|
||||
ontology_class = self.class_repo.create(class_data, scene_id)
|
||||
|
||||
# 设置系统默认标识
|
||||
ontology_class.is_system_default = True
|
||||
self.db.flush()
|
||||
|
||||
created_count += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"类型创建成功 - class_name={type_name}, "
|
||||
f"class_id={ontology_class.class_id}, "
|
||||
f"scene_id={scene_id}, is_system_default=True, language={language}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
type_name = get_type_name(type_config, language)
|
||||
self.logger.warning(
|
||||
f"单个类型创建失败,继续创建其他类型 - "
|
||||
f"class_name={type_name}, scene_id={scene_id}, "
|
||||
f"language={language}, error={str(e)}"
|
||||
)
|
||||
# 继续创建其他类型
|
||||
continue
|
||||
|
||||
return created_count
|
||||
@@ -4,37 +4,49 @@
|
||||
认证方式: JWT Token
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
from . import (
|
||||
model_controller,
|
||||
task_controller,
|
||||
test_controller,
|
||||
user_controller,
|
||||
auth_controller,
|
||||
workspace_controller,
|
||||
setup_controller,
|
||||
file_controller,
|
||||
document_controller,
|
||||
knowledge_controller,
|
||||
chunk_controller,
|
||||
knowledgeshare_controller,
|
||||
api_key_controller,
|
||||
app_controller,
|
||||
upload_controller,
|
||||
auth_controller,
|
||||
chunk_controller,
|
||||
document_controller,
|
||||
emotion_config_controller,
|
||||
emotion_controller,
|
||||
file_controller,
|
||||
file_storage_controller,
|
||||
home_page_controller,
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
mcp_market_controller,
|
||||
mcp_market_config_controller,
|
||||
memory_agent_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_storage_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_episodic_controller,
|
||||
memory_explicit_controller,
|
||||
memory_forget_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_reflection_controller,
|
||||
api_key_controller,
|
||||
release_share_controller,
|
||||
public_share_controller,
|
||||
memory_short_term_controller,
|
||||
memory_storage_controller,
|
||||
memory_working_controller,
|
||||
model_controller,
|
||||
multi_agent_controller,
|
||||
workflow_controller,
|
||||
emotion_controller,
|
||||
emotion_config_controller,
|
||||
prompt_optimizer_controller,
|
||||
public_share_controller,
|
||||
release_share_controller,
|
||||
setup_controller,
|
||||
task_controller,
|
||||
test_controller,
|
||||
tool_controller,
|
||||
upload_controller,
|
||||
user_controller,
|
||||
user_memory_controllers,
|
||||
workspace_controller,
|
||||
ontology_controller,
|
||||
skill_controller
|
||||
)
|
||||
from . import user_memory_controllers
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
manager_router = APIRouter()
|
||||
@@ -50,6 +62,8 @@ manager_router.include_router(model_controller.router)
|
||||
manager_router.include_router(file_controller.router)
|
||||
manager_router.include_router(document_controller.router)
|
||||
manager_router.include_router(knowledge_controller.router)
|
||||
manager_router.include_router(mcp_market_controller.router)
|
||||
manager_router.include_router(mcp_market_config_controller.router)
|
||||
manager_router.include_router(chunk_controller.router)
|
||||
manager_router.include_router(test_controller.router)
|
||||
manager_router.include_router(knowledgeshare_controller.router)
|
||||
@@ -59,16 +73,26 @@ manager_router.include_router(memory_agent_controller.router)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(memory_storage_controller.router)
|
||||
manager_router.include_router(user_memory_controllers.router)
|
||||
manager_router.include_router(memory_episodic_controller.router)
|
||||
manager_router.include_router(memory_explicit_controller.router)
|
||||
manager_router.include_router(api_key_controller.router)
|
||||
manager_router.include_router(release_share_controller.router)
|
||||
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(multi_agent_controller.router)
|
||||
manager_router.include_router(workflow_controller.router)
|
||||
manager_router.include_router(emotion_controller.router)
|
||||
manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
manager_router.include_router(memory_reflection_controller.router)
|
||||
manager_router.include_router(memory_short_term_controller.router)
|
||||
manager_router.include_router(tool_controller.router)
|
||||
manager_router.include_router(memory_forget_controller.router)
|
||||
manager_router.include_router(home_page_controller.router)
|
||||
manager_router.include_router(implicit_memory_controller.router)
|
||||
manager_router.include_router(memory_perceptual_controller.router)
|
||||
manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
import uuid
|
||||
from typing import Optional, Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models import User
|
||||
from app.models.app_model import AppType, App
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave
|
||||
from app.services import app_service, workspace_service
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.services.app_service import AppService
|
||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.services.workflow_import_service import WorkflowImportService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
@@ -29,9 +33,9 @@ logger = get_business_logger()
|
||||
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
||||
@cur_workspace_access_guard()
|
||||
def create_app(
|
||||
payload: app_schema.AppCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
payload: app_schema.AppCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
app = app_service.create_app(db, user_id=current_user.id, workspace_id=workspace_id, data=payload)
|
||||
@@ -41,22 +45,34 @@ def create_app(
|
||||
@router.get("", summary="应用列表(分页)")
|
||||
@cur_workspace_access_guard()
|
||||
def list_apps(
|
||||
type: str | None = None,
|
||||
visibility: str | None = None,
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
include_shared: bool = True,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
type: str | None = None,
|
||||
visibility: str | None = None,
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
include_shared: bool = True,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出应用
|
||||
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
if ids is not None:
|
||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
|
||||
# 正常分页查询
|
||||
items_orm, total = app_service.list_apps(
|
||||
db,
|
||||
workspace_id=workspace_id,
|
||||
@@ -69,18 +85,17 @@ def list_apps(
|
||||
pagesize=pagesize,
|
||||
)
|
||||
|
||||
# 使用 AppService 的转换方法来设置 is_shared 字段
|
||||
service = app_service.AppService(db)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get("/{app_id}", summary="获取应用详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用详细信息
|
||||
|
||||
@@ -99,10 +114,10 @@ def get_app(
|
||||
@router.put("/{app_id}", summary="更新应用基本信息")
|
||||
@cur_workspace_access_guard()
|
||||
def update_app(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AppUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AppUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
app = app_service.update_app(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
@@ -112,9 +127,9 @@ def update_app(
|
||||
@router.delete("/{app_id}", summary="删除应用")
|
||||
@cur_workspace_access_guard()
|
||||
def delete_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""删除应用
|
||||
|
||||
@@ -141,10 +156,10 @@ def delete_app(
|
||||
@router.post("/{app_id}/copy", summary="复制应用")
|
||||
@cur_workspace_access_guard()
|
||||
def copy_app(
|
||||
app_id: uuid.UUID,
|
||||
new_name: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
new_name: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""复制应用(包括基础信息和配置)
|
||||
|
||||
@@ -178,10 +193,10 @@ def copy_app(
|
||||
@router.put("/{app_id}/config", summary="更新 Agent 配置")
|
||||
@cur_workspace_access_guard()
|
||||
def update_agent_config(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AgentConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AgentConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.update_agent_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
@@ -192,9 +207,9 @@ def update_agent_config(
|
||||
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_agent_config(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
@@ -206,10 +221,10 @@ def get_agent_config(
|
||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||
@cur_workspace_access_guard()
|
||||
def publish_app(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.PublishRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.PublishRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
release = app_service.publish(
|
||||
@@ -217,7 +232,7 @@ def publish_app(
|
||||
app_id=app_id,
|
||||
publisher_id=current_user.id,
|
||||
workspace_id=workspace_id,
|
||||
version_name = payload.version_name,
|
||||
version_name=payload.version_name,
|
||||
release_notes=payload.release_notes
|
||||
)
|
||||
return success(data=app_schema.AppRelease.model_validate(release))
|
||||
@@ -226,9 +241,9 @@ def publish_app(
|
||||
@router.get("/{app_id}/release", summary="获取当前发布版本")
|
||||
@cur_workspace_access_guard()
|
||||
def get_current_release(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
release = app_service.get_current_release(db, app_id=app_id, workspace_id=workspace_id)
|
||||
@@ -240,9 +255,9 @@ def get_current_release(
|
||||
@router.get("/{app_id}/releases", summary="列出历史发布版本(倒序)")
|
||||
@cur_workspace_access_guard()
|
||||
def list_releases(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
releases = app_service.list_releases(db, app_id=app_id, workspace_id=workspace_id)
|
||||
@@ -253,10 +268,10 @@ def list_releases(
|
||||
@router.post("/{app_id}/rollback/{version}", summary="回滚到指定版本")
|
||||
@cur_workspace_access_guard()
|
||||
def rollback(
|
||||
app_id: uuid.UUID,
|
||||
version: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
version: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
release = app_service.rollback(db, app_id=app_id, version=version, workspace_id=workspace_id)
|
||||
@@ -266,10 +281,10 @@ def rollback(
|
||||
@router.post("/{app_id}/share", summary="分享应用到其他工作空间")
|
||||
@cur_workspace_access_guard()
|
||||
def share_app(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AppShareCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AppShareCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""分享应用到其他工作空间
|
||||
|
||||
@@ -294,10 +309,10 @@ def share_app(
|
||||
@router.delete("/{app_id}/share/{target_workspace_id}", summary="取消应用分享")
|
||||
@cur_workspace_access_guard()
|
||||
def unshare_app(
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""取消应用分享
|
||||
|
||||
@@ -318,9 +333,9 @@ def unshare_app(
|
||||
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_shares(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出应用的所有分享记录
|
||||
|
||||
@@ -337,14 +352,15 @@ def list_app_shares(
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||
@cur_workspace_access_guard()
|
||||
async def draft_run(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.DraftRunRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.DraftRunRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None
|
||||
):
|
||||
"""
|
||||
试运行 Agent,使用当前的草稿配置(未发布的配置)
|
||||
@@ -361,7 +377,7 @@ async def draft_run(
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None:
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
@@ -371,20 +387,19 @@ async def draft_run(
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
from app.services.app_service import AppService
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from sqlalchemy import select
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
service = AppService(db)
|
||||
draft_service = DraftRunService(db)
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
# 1. 验证应用
|
||||
app = service._get_app_or_404(app_id)
|
||||
@@ -394,13 +409,22 @@ async def draft_run(
|
||||
# 只读操作,允许访问共享应用
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
# 处理会话ID(创建或验证)
|
||||
conversation_id = await draft_service._ensure_conversation(
|
||||
conversation_id=payload.conversation_id,
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=payload.user_id
|
||||
)
|
||||
conversation_id=payload.conversation_id,
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=payload.user_id
|
||||
)
|
||||
payload.conversation_id = conversation_id
|
||||
|
||||
if app.type == AppType.AGENT:
|
||||
@@ -424,17 +448,17 @@ async def draft_run(
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
|
||||
|
||||
async for event in draft_service.run_stream(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -455,12 +479,13 @@ async def draft_run(
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
"has_conversation_id": bool(payload.conversation_id),
|
||||
"has_variables": bool(payload.variables)
|
||||
"has_variables": bool(payload.variables),
|
||||
"has_files": bool(payload.files)
|
||||
}
|
||||
)
|
||||
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
@@ -470,7 +495,8 @@ async def draft_run(
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -506,7 +532,7 @@ async def draft_run(
|
||||
multi_agent_request = MultiAgentRunRequest(
|
||||
message=payload.message,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables or {},
|
||||
use_llm_routing=True # 默认启用 LLM 路由
|
||||
)
|
||||
@@ -528,10 +554,10 @@ async def draft_run(
|
||||
|
||||
# 调用多智能体服务的流式方法
|
||||
async for event in multiservice.run_stream(
|
||||
app_id=app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
app_id=app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
|
||||
):
|
||||
yield event
|
||||
@@ -571,7 +597,7 @@ async def draft_run(
|
||||
data=result,
|
||||
msg="多 Agent 任务执行成功"
|
||||
)
|
||||
elif app.type == AppType.WORKFLOW: #工作流
|
||||
elif app.type == AppType.WORKFLOW: # 工作流
|
||||
config = workflow_service.check_config(app_id)
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
@@ -592,17 +618,18 @@ async def draft_run(
|
||||
data: <json_data>
|
||||
"""
|
||||
import json
|
||||
|
||||
|
||||
# 调用工作流服务的流式方法
|
||||
async for event in workflow_service.run_stream(
|
||||
app_id=app_id,
|
||||
payload=payload,
|
||||
config=config
|
||||
config=config,
|
||||
workspace_id=current_user.current_workspace_id
|
||||
):
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
@@ -627,7 +654,7 @@ async def draft_run(
|
||||
}
|
||||
)
|
||||
|
||||
result = await workflow_service.run(app_id, payload,config)
|
||||
result = await workflow_service.run(app_id, payload, config, current_user.current_workspace_id)
|
||||
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
@@ -640,16 +667,20 @@ async def draft_run(
|
||||
data=result,
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
|
||||
else:
|
||||
return fail(
|
||||
msg="未知应用类型",
|
||||
code=422
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
|
||||
@cur_workspace_access_guard()
|
||||
async def draft_run_compare(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.DraftRunCompareRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.DraftRunCompareRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
多模型对比试运行
|
||||
@@ -674,7 +705,7 @@ async def draft_run_compare(
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None:
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
@@ -683,7 +714,7 @@ async def draft_run_compare(
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
logger.info(
|
||||
@@ -728,9 +759,23 @@ async def draft_run_compare(
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
||||
|
||||
# 获取 agent_cfg.model_parameters,如果是 ModelParameters 对象则转为字典
|
||||
agent_model_params = agent_cfg.model_parameters
|
||||
if hasattr(agent_model_params, 'model_dump'):
|
||||
agent_model_params = agent_model_params.model_dump()
|
||||
elif not isinstance(agent_model_params, dict):
|
||||
agent_model_params = {}
|
||||
|
||||
# 获取 model_item.model_parameters,如果是 ModelParameters 对象则转为字典
|
||||
item_model_params = model_item.model_parameters
|
||||
if hasattr(item_model_params, 'model_dump'):
|
||||
item_model_params = item_model_params.model_dump()
|
||||
elif not isinstance(item_model_params, dict):
|
||||
item_model_params = {}
|
||||
|
||||
merged_parameters = {
|
||||
**(agent_cfg.model_parameters or {}),
|
||||
**(model_item.model_parameters or {})
|
||||
**(agent_model_params or {}),
|
||||
**(item_model_params or {})
|
||||
}
|
||||
|
||||
model_configs.append({
|
||||
@@ -744,22 +789,23 @@ async def draft_run_compare(
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
async for event in draft_service.run_compare_stream(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60,
|
||||
files=payload.files
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -774,8 +820,8 @@ async def draft_run_compare(
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run_compare(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
@@ -789,7 +835,8 @@ async def draft_run_compare(
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
timeout=payload.timeout or 60,
|
||||
files=payload.files
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -821,15 +868,145 @@ async def get_workflow_config(
|
||||
# 配置总是存在(不存在时返回默认模板)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@router.put("/{app_id}/workflow", summary="更新 Workflow 配置")
|
||||
@cur_workspace_access_guard()
|
||||
async def update_workflow_config(
|
||||
app_id: uuid.UUID,
|
||||
payload: WorkflowConfigUpdate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
app_id: uuid.UUID,
|
||||
payload: WorkflowConfigUpdate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/workflow/export")
|
||||
@cur_workspace_access_guard()
|
||||
async def export_workflow_config(
|
||||
app_id: uuid.UUID,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
"""导出工作流配置为YAML文件"""
|
||||
workflow_service = WorkflowService(db)
|
||||
|
||||
return success(data={
|
||||
"content": workflow_service.export_workflow_dsl(app_id=app_id),
|
||||
})
|
||||
|
||||
|
||||
@router.post("/workflow/import")
|
||||
@cur_workspace_access_guard()
|
||||
async def import_workflow_config(
|
||||
file: UploadFile = File(...),
|
||||
platform: str = Form(...),
|
||||
app_id: str = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
|
||||
):
|
||||
"""从YAML内容导入工作流配置"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST)
|
||||
|
||||
raw_text = (await file.read()).decode("utf-8")
|
||||
import_service = WorkflowImportService(db)
|
||||
config = yaml.safe_load(raw_text)
|
||||
result = await import_service.upload_config(platform, config)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.post("/workflow/import/save")
|
||||
@cur_workspace_access_guard()
|
||||
async def save_workflow_import(
|
||||
data: WorkflowImportSave,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
import_service = WorkflowImportService(db)
|
||||
app = await import_service.save_workflow(
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
temp_id=data.temp_id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
)
|
||||
return success(data=app_schema.App.model_validate(app))
|
||||
|
||||
|
||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_statistics(
|
||||
app_id: uuid.UUID,
|
||||
start_date: int,
|
||||
end_date: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用统计数据
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
db: 数据库连接
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
- daily_conversations: 每日会话数统计
|
||||
- total_conversations: 总会话数
|
||||
- daily_new_users: 每日新增用户数
|
||||
- total_new_users: 总新增用户数
|
||||
- daily_api_calls: 每日API调用次数
|
||||
- total_api_calls: 总API调用次数
|
||||
- daily_tokens: 每日token消耗
|
||||
- total_tokens: 总token消耗
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
result = stats_service.get_app_statistics(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/workspace/api-statistics", summary="工作空间API调用统计")
|
||||
@cur_workspace_access_guard()
|
||||
def get_workspace_api_statistics(
|
||||
start_date: int,
|
||||
end_date: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取工作空间API调用统计
|
||||
|
||||
Args:
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
db: 数据库连接
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
每日统计数据列表,每项包含:
|
||||
- date: 日期
|
||||
- total_calls: 当日总调用次数
|
||||
- app_calls: 当日应用调用次数
|
||||
- service_calls: 当日服务调用次数
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
result = stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
@@ -61,6 +61,7 @@ async def login_for_access_token(
|
||||
user = auth_service.register_user_with_invite(
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
|
||||
@@ -1,24 +1,28 @@
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.models.document_model import Document
|
||||
from app.models import knowledge_model, knowledgeshare_model
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import knowledge_model, knowledgeshare_model
|
||||
from app.models.document_model import Document
|
||||
from app.models.user_model import User
|
||||
from app.schemas import chunk_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -141,7 +145,7 @@ async def get_preview_chunks(
|
||||
}
|
||||
}
|
||||
api_logger.info(f"Querying the document block preview list successful: total={total}, returned={len(chunks)} records")
|
||||
return success(data=result, msg="Querying the document block preview list succeeded")
|
||||
return success(data=jsonable_encoder(result), msg="Querying the document block preview list succeeded")
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse)
|
||||
@@ -199,7 +203,7 @@ async def get_chunks(
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of document chunk list succeeded")
|
||||
return success(data=jsonable_encoder(result), msg="Query of document chunk list succeeded")
|
||||
|
||||
|
||||
@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse)
|
||||
@@ -260,7 +264,7 @@ async def create_chunk(
|
||||
db_document.chunk_num += 1
|
||||
db.commit()
|
||||
|
||||
return success(data=chunk, msg="Document chunk creation successful")
|
||||
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
@@ -287,7 +291,7 @@ async def get_chunk(
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
total, items = vector_service.get_by_segment(doc_id=doc_id)
|
||||
if total:
|
||||
return success(data=items[0], msg="Document chunk query successful")
|
||||
return success(data=jsonable_encoder(items[0]), msg="Document chunk query successful")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -324,7 +328,7 @@ async def update_chunk(
|
||||
chunk = items[0]
|
||||
chunk.page_content = content
|
||||
vector_service.update_by_segment(chunk)
|
||||
return success(data=chunk, msg="The document chunk has been successfully updated")
|
||||
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -389,36 +393,41 @@ async def retrieve_chunks(
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
existing_ids = knowledge_service.get_chunded_knowledgeids(
|
||||
private_items = knowledge_service.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
private_kb_ids = [item[0] for item in private_items]
|
||||
private_workspace_ids = [item[1] for item in private_items]
|
||||
filters = [
|
||||
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
|
||||
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
share_ids = knowledge_service.get_chunded_knowledgeids(
|
||||
items = knowledge_service.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
if share_ids:
|
||||
if items:
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
|
||||
]
|
||||
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
|
||||
share_items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
existing_ids.extend(items)
|
||||
if not existing_ids:
|
||||
share_kb_ids = [item[0] for item in share_items]
|
||||
share_workspace_ids = [item[1] for item in share_items]
|
||||
private_kb_ids.extend(share_kb_ids)
|
||||
private_workspace_ids.extend(share_workspace_ids)
|
||||
if not private_kb_ids:
|
||||
return success(data=[], msg="retrieval successful")
|
||||
kb_id = existing_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||
kb_id = private_kb_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in private_kb_ids]
|
||||
indices = ",".join(uuid_strs)
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
@@ -432,14 +441,14 @@ async def retrieve_chunks(
|
||||
# 1 participle search, 2 semantic search, 3 hybrid search
|
||||
match retrieve_data.retrieve_type:
|
||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case chunk_schema.RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
# Efficient deduplication
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
@@ -448,4 +457,21 @@ async def retrieve_chunks(
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
||||
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
||||
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
rs.insert(0, doc)
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
@@ -1,23 +1,26 @@
|
||||
import datetime
|
||||
import os
|
||||
from typing import Optional
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.controllers import file_controller
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.models import document_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import document_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import document_service, file_service, knowledge_service
|
||||
from app.controllers import file_controller
|
||||
from app.celery_app import celery_app
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -106,7 +109,7 @@ async def get_documents(
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of document list succeeded")
|
||||
return success(data=jsonable_encoder(result), msg="Query of document list succeeded")
|
||||
|
||||
|
||||
@router.post("/document", response_model=ApiResponse)
|
||||
@@ -124,7 +127,7 @@ async def create_document(
|
||||
api_logger.debug(f"Start creating a document: {create_data.file_name}")
|
||||
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||
api_logger.info(f"Document created successfully: {db_document.file_name} (ID: {db_document.id})")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="Document creation successful")
|
||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="Document creation successful")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Document creation failed: {create_data.file_name} - {str(e)}")
|
||||
raise
|
||||
@@ -153,7 +156,7 @@ async def get_document(
|
||||
)
|
||||
|
||||
api_logger.info(f"Document query successful: {db_document.file_name} (ID: {db_document.id})")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="Successfully obtained document information")
|
||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="Successfully obtained document information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -221,7 +224,7 @@ async def update_document(
|
||||
)
|
||||
|
||||
# 5. Return the updated document
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="Document information updated successfully")
|
||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="Document information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{document_id}", response_model=ApiResponse)
|
||||
|
||||
@@ -7,11 +7,13 @@ Routes:
|
||||
GET /memory/config/emotion - 获取情绪引擎配置
|
||||
POST /memory/config/emotion - 更新情绪引擎配置
|
||||
"""
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from sqlalchemy.orm import Session
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user
|
||||
@@ -20,6 +22,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_config_service import EmotionConfigService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.db import get_db
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -32,11 +35,11 @@ router = APIRouter(
|
||||
|
||||
class EmotionConfigQuery(BaseModel):
|
||||
"""情绪配置查询请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: UUID = Field(..., description="配置ID")
|
||||
|
||||
class EmotionConfigUpdate(BaseModel):
|
||||
"""情绪配置更新请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: Union[uuid.UUID, int, str]= Field(..., description="配置ID")
|
||||
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
||||
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
||||
@@ -45,7 +48,7 @@ class EmotionConfigUpdate(BaseModel):
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
def get_emotion_config(
|
||||
config_id: int = Query(..., description="配置ID"),
|
||||
config_id: UUID|int = Query(..., description="配置ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -78,7 +81,7 @@ def get_emotion_config(
|
||||
f"用户 {current_user.username} 请求获取情绪配置",
|
||||
extra={"config_id": config_id}
|
||||
)
|
||||
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 初始化服务
|
||||
config_service = EmotionConfigService(db)
|
||||
|
||||
@@ -157,6 +160,7 @@ def update_emotion_config(
|
||||
}
|
||||
}
|
||||
"""
|
||||
config.config_id=resolve_config_id(config.config_id, db)
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求更新情绪配置",
|
||||
|
||||
@@ -11,6 +11,7 @@ Routes:
|
||||
"""
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
@@ -18,19 +19,20 @@ from app.models.user_model import User
|
||||
from app.schemas.emotion_schema import (
|
||||
EmotionHealthRequest,
|
||||
EmotionSuggestionsRequest,
|
||||
EmotionGenerateSuggestionsRequest,
|
||||
EmotionTagsRequest,
|
||||
EmotionWordcloudRequest,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/emotion",
|
||||
prefix="/memory/emotion-memory",
|
||||
tags=["Emotion Analysis"],
|
||||
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
|
||||
)
|
||||
@@ -44,45 +46,51 @@ emotion_service = EmotionAnalyticsService()
|
||||
@router.post("/tags", response_model=ApiResponse)
|
||||
async def get_emotion_tags(
|
||||
request: EmotionTagsRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪标签统计",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"start_date": request.start_date,
|
||||
"end_date": request.end_date,
|
||||
"limit": request.limit
|
||||
"limit": request.limit,
|
||||
"language_type": language
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_tags(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
emotion_type=request.emotion_type,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
limit=request.limit
|
||||
limit=request.limit,
|
||||
language=language
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
"情绪标签统计获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"total_count": data.get("total_count", 0),
|
||||
"tags_count": len(data.get("tags", []))
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success(data=data, msg="情绪标签获取成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪标签统计失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -95,40 +103,44 @@ async def get_emotion_tags(
|
||||
@router.post("/wordcloud", response_model=ApiResponse)
|
||||
async def get_emotion_wordcloud(
|
||||
request: EmotionWordcloudRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪词云数据",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"limit": request.limit
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_wordcloud(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
emotion_type=request.emotion_type,
|
||||
limit=request.limit
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
"情绪词云数据获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"total_keywords": data.get("total_keywords", 0)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success(data=data, msg="情绪词云获取成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪词云数据失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -141,48 +153,52 @@ async def get_emotion_wordcloud(
|
||||
@router.post("/health", response_model=ApiResponse)
|
||||
async def get_emotion_health(
|
||||
request: EmotionHealthRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 验证时间范围参数
|
||||
if request.time_range not in ["7d", "30d", "90d"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="时间范围参数无效,必须是 7d、30d 或 90d"
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪健康指数",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"time_range": request.time_range
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.calculate_emotion_health_index(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
time_range=request.time_range
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
"情绪健康指数获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"health_score": data.get("health_score", 0),
|
||||
"end_user_id": request.end_user_id,
|
||||
"health_score": data.get("health_score") or 0,
|
||||
"level": data.get("level", "未知")
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success(data=data, msg="情绪健康指数获取成功")
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪健康指数失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -192,78 +208,179 @@ async def get_emotion_health(
|
||||
|
||||
|
||||
|
||||
# @router.post("/check-data", response_model=ApiResponse)
|
||||
# async def check_emotion_data_exists(
|
||||
# request: EmotionSuggestionsRequest,
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user),
|
||||
# ):
|
||||
# """检查用户情绪建议数据是否存在
|
||||
|
||||
# Args:
|
||||
# request: 包含 end_user_id
|
||||
# db: 数据库会话
|
||||
# current_user: 当前用户
|
||||
|
||||
# Returns:
|
||||
# 数据存在状态
|
||||
# """
|
||||
# try:
|
||||
# api_logger.info(
|
||||
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
|
||||
# extra={"end_user_id": request.end_user_id}
|
||||
# )
|
||||
|
||||
# # 从数据库获取建议
|
||||
# data = await emotion_service.get_cached_suggestions(
|
||||
# end_user_id=request.end_user_id,
|
||||
# db=db
|
||||
# )
|
||||
|
||||
# if data is None:
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
|
||||
# return fail(
|
||||
# BizCode.NOT_FOUND,
|
||||
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
|
||||
# {"exists": False}
|
||||
# )
|
||||
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
|
||||
# return success(data={"exists": True}, msg="情绪建议数据已存在")
|
||||
|
||||
# except Exception as e:
|
||||
# api_logger.error(
|
||||
# f"检查情绪建议数据失败: {str(e)}",
|
||||
# extra={"end_user_id": request.end_user_id},
|
||||
# exc_info=True
|
||||
# )
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
# detail=f"检查情绪建议数据失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/suggestions", response_model=ApiResponse)
|
||||
async def get_emotion_suggestions(
|
||||
request: EmotionSuggestionsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取个性化情绪建议
|
||||
|
||||
"""获取个性化情绪建议(从数据库读取)
|
||||
|
||||
Args:
|
||||
request: 包含 group_id 和可选的 config_id
|
||||
request: 包含 end_user_id 和可选的 config_id
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
|
||||
Returns:
|
||||
个性化情绪建议响应
|
||||
存储的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 验证 config_id(如果提供)
|
||||
# 获取终端用户关联的配置
|
||||
config_id = request.config_id
|
||||
if config_id is None:
|
||||
# 如果没有提供 config_id,尝试获取用户关联的配置
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(request.group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "无法获取用户关联的配置", str(e))
|
||||
else:
|
||||
# 如果提供了 config_id,验证其有效性
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
try:
|
||||
config_service = MemoryConfigService(db)
|
||||
config = config_service.get_config_by_id(config_id)
|
||||
if not config:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", f"配置 {config_id} 不存在")
|
||||
except Exception as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID验证失败", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"config_id": config_id
|
||||
"end_user_id": request.end_user_id,
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
|
||||
# 从数据库获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
if data is None:
|
||||
api_logger.info(
|
||||
f"用户 {request.end_user_id} 的建议数据不存在",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success(data=data, msg="个性化建议获取成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取个性化建议失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取个性化建议失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generate_suggestions", response_model=ApiResponse)
|
||||
async def generate_emotion_suggestions(
|
||||
request: EmotionGenerateSuggestionsRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""生成个性化情绪建议(调用LLM并保存到数据库)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
新生成的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求生成个性化情绪建议",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层生成建议
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议生成成功",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="个性化建议生成成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"生成个性化建议失败: {str(e)}",
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成个性化建议失败: {str(e)}"
|
||||
)
|
||||
@@ -1,22 +1,25 @@
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.models import file_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import file_service, document_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -93,11 +96,11 @@ async def get_files(
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of file list succeeded")
|
||||
return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
|
||||
|
||||
|
||||
@router.post("/folder", response_model=ApiResponse)
|
||||
def create_folder(
|
||||
async def create_folder(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
folder_name: str = '/',
|
||||
@@ -121,7 +124,7 @@ def create_folder(
|
||||
)
|
||||
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
|
||||
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
|
||||
return success(data=file_schema.File.model_validate(db_file), msg="Folder creation successful")
|
||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
|
||||
raise
|
||||
@@ -207,7 +210,7 @@ async def upload_file(
|
||||
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||
|
||||
api_logger.info(f"File upload successfully: {file.filename} (file_id: {db_file.id}, document_id: {db_document.id})")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="File upload successful")
|
||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="File upload successful")
|
||||
|
||||
|
||||
@router.post("/customtext", response_model=ApiResponse)
|
||||
@@ -288,7 +291,7 @@ async def custom_text(
|
||||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||
|
||||
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
|
||||
|
||||
|
||||
@router.get("/{file_id}", response_model=Any)
|
||||
@@ -362,7 +365,7 @@ async def update_file(
|
||||
# 2. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the file fields: {file_id}")
|
||||
updated_fields = []
|
||||
for field, value in update_data.items():
|
||||
for field, value in update_data.dict(exclude_unset=True).items():
|
||||
if hasattr(db_file, field):
|
||||
old_value = getattr(db_file, field)
|
||||
if old_value != value:
|
||||
@@ -387,7 +390,7 @@ async def update_file(
|
||||
)
|
||||
|
||||
# 4. Return the updated file
|
||||
return success(data=file_schema.File.model_validate(db_file), msg="File information updated successfully")
|
||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{file_id}", response_model=ApiResponse)
|
||||
|
||||
634
api/app/controllers/file_storage_controller.py
Normal file
634
api/app/controllers/file_storage_controller.py
Normal file
@@ -0,0 +1,634 @@
|
||||
"""
|
||||
File storage controller module.
|
||||
|
||||
This module provides API endpoints for file storage operations using the
|
||||
configurable storage backend. It is a new controller that does not modify
|
||||
the existing file_controller.py.
|
||||
|
||||
Routes:
|
||||
POST /storage/files - Upload a file
|
||||
GET /storage/files/{file_id} - Download a file
|
||||
DELETE /storage/files/{file_id} - Delete a file
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.storage import LocalStorage
|
||||
from app.core.storage.url_signer import generate_signed_url, verify_signed_url
|
||||
from app.core.storage_exceptions import (
|
||||
StorageDeleteError,
|
||||
StorageUploadError,
|
||||
)
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_share_user_id, ShareTokenData
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.file_storage_service import (
|
||||
FileStorageService,
|
||||
generate_file_key,
|
||||
get_file_storage_service,
|
||||
)
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/storage",
|
||||
tags=["storage"]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/files", response_model=ApiResponse)
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Upload a file to the configured storage backend.
|
||||
"""
|
||||
tenant_id = current_user.tenant_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_logger.info(
|
||||
f"Storage upload request: tenant_id={tenant_id}, workspace_id={workspace_id}, "
|
||||
f"filename={file.filename}, username={current_user.username}"
|
||||
)
|
||||
|
||||
# Read file contents
|
||||
contents = await file.read()
|
||||
file_size = len(contents)
|
||||
|
||||
# Validate file size
|
||||
if file_size == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The file is empty."
|
||||
)
|
||||
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||
)
|
||||
|
||||
# Extract file extension
|
||||
_, file_extension = os.path.splitext(file.filename)
|
||||
file_ext = file_extension.lower()
|
||||
|
||||
# Generate file_id and file_key
|
||||
file_id = uuid.uuid4()
|
||||
file_key = generate_file_key(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
)
|
||||
|
||||
# Create file metadata record with pending status
|
||||
file_metadata = FileMetadata(
|
||||
id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_key=file_key,
|
||||
file_name=file.filename,
|
||||
file_ext=file_ext,
|
||||
file_size=file_size,
|
||||
content_type=file.content_type,
|
||||
status="pending",
|
||||
)
|
||||
db.add(file_metadata)
|
||||
db.commit()
|
||||
db.refresh(file_metadata)
|
||||
|
||||
# Upload file to storage backend
|
||||
try:
|
||||
await storage_service.upload_file(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
content=contents,
|
||||
content_type=file.content_type,
|
||||
)
|
||||
# Update status to completed
|
||||
file_metadata.status = "completed"
|
||||
db.commit()
|
||||
api_logger.info(f"File uploaded to storage: file_key={file_key}")
|
||||
except StorageUploadError as e:
|
||||
# Update status to failed
|
||||
file_metadata.status = "failed"
|
||||
db.commit()
|
||||
api_logger.error(f"Storage upload failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"File storage failed: {str(e)}"
|
||||
)
|
||||
|
||||
api_logger.info(f"File upload successful: {file.filename} (file_id: {file_id})")
|
||||
|
||||
return success(
|
||||
data={"file_id": str(file_id), "file_key": file_key},
|
||||
msg="File upload successful"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/share/files", response_model=ApiResponse)
|
||||
async def upload_file_with_share_token(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Upload a file to the configured storage backend using share_token authentication.
|
||||
"""
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.models.app_model import App
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
# Get share and release info from share_token
|
||||
service = ReleaseShareService(db)
|
||||
share_info = service.get_shared_release_info(share_token=share_data.share_token)
|
||||
|
||||
# Get share object to access app_id
|
||||
share = service.repo.get_by_share_token(share_data.share_token)
|
||||
if not share:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Shared app not found"
|
||||
)
|
||||
|
||||
# Get app to access workspace_id
|
||||
app = db.query(App).filter(
|
||||
App.id == share.app_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="App not found"
|
||||
)
|
||||
|
||||
# Get workspace to access tenant_id
|
||||
workspace = db.query(Workspace).filter(
|
||||
Workspace.id == app.workspace_id
|
||||
).first()
|
||||
|
||||
if not workspace:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Workspace not found"
|
||||
)
|
||||
|
||||
tenant_id = workspace.tenant_id
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
api_logger.info(
|
||||
f"Storage upload request (share): tenant_id={tenant_id}, workspace_id={workspace_id}, "
|
||||
f"filename={file.filename}, share_token={share_data.share_token}"
|
||||
)
|
||||
|
||||
# Read file contents
|
||||
contents = await file.read()
|
||||
file_size = len(contents)
|
||||
|
||||
# Validate file size
|
||||
if file_size == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The file is empty."
|
||||
)
|
||||
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||
)
|
||||
|
||||
# Extract file extension
|
||||
_, file_extension = os.path.splitext(file.filename)
|
||||
file_ext = file_extension.lower()
|
||||
|
||||
# Generate file_id and file_key
|
||||
file_id = uuid.uuid4()
|
||||
file_key = generate_file_key(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
)
|
||||
|
||||
# Create file metadata record with pending status
|
||||
file_metadata = FileMetadata(
|
||||
id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_key=file_key,
|
||||
file_name=file.filename,
|
||||
file_ext=file_ext,
|
||||
file_size=file_size,
|
||||
content_type=file.content_type,
|
||||
status="pending",
|
||||
)
|
||||
db.add(file_metadata)
|
||||
db.commit()
|
||||
db.refresh(file_metadata)
|
||||
|
||||
# Upload file to storage backend
|
||||
try:
|
||||
await storage_service.upload_file(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
content=contents,
|
||||
content_type=file.content_type,
|
||||
)
|
||||
# Update status to completed
|
||||
file_metadata.status = "completed"
|
||||
db.commit()
|
||||
api_logger.info(f"File uploaded to storage (share): file_key={file_key}")
|
||||
except StorageUploadError as e:
|
||||
# Update status to failed
|
||||
file_metadata.status = "failed"
|
||||
db.commit()
|
||||
api_logger.error(f"Storage upload failed (share): {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"File storage failed: {str(e)}"
|
||||
)
|
||||
|
||||
api_logger.info(f"File upload successful (share): {file.filename} (file_id: {file_id})")
|
||||
|
||||
return success(
|
||||
data={"file_id": str(file_id), "file_key": file_key},
|
||||
msg="File upload successful"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}", response_model=Any)
|
||||
async def download_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
) -> Any:
|
||||
"""
|
||||
Download a file from the configured storage backend.
|
||||
"""
|
||||
api_logger.info(f"Storage download request: file_id={file_id}")
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}"
|
||||
)
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
if isinstance(storage, LocalStorage):
|
||||
full_path = storage._get_full_path(file_key)
|
||||
|
||||
if not full_path.exists():
|
||||
api_logger.warning(f"File not found on disk: file_key={file_key}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
)
|
||||
|
||||
api_logger.info(f"Serving local file: file_key={file_key}")
|
||||
return FileResponse(
|
||||
path=str(full_path),
|
||||
filename=file_metadata.file_name,
|
||||
media_type=file_metadata.content_type or "application/octet-stream"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except FileNotFoundError:
|
||||
api_logger.warning(f"File not found in remote storage: file_key={file_key}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found in storage"
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/files/{file_id}", response_model=ApiResponse)
|
||||
async def delete_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Delete a file from the configured storage backend.
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Storage delete request: file_id={file_id}, username={current_user.username}"
|
||||
)
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
|
||||
# Delete file from storage
|
||||
try:
|
||||
deleted = await storage_service.delete_file(file_key)
|
||||
if deleted:
|
||||
api_logger.info(f"File deleted from storage: file_key={file_key}")
|
||||
else:
|
||||
api_logger.info(f"File did not exist in storage: file_key={file_key}")
|
||||
except StorageDeleteError as e:
|
||||
api_logger.error(f"Storage delete failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete file from storage: {str(e)}"
|
||||
)
|
||||
|
||||
# Delete database record
|
||||
try:
|
||||
db.delete(file_metadata)
|
||||
db.commit()
|
||||
api_logger.info(f"File record deleted from database: file_id={file_id}")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Database delete failed: {e}")
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete file record: {str(e)}"
|
||||
)
|
||||
|
||||
return success(msg="File deleted successfully")
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
||||
async def get_file_url(
|
||||
file_id: uuid.UUID,
|
||||
expires: int = None,
|
||||
permanent: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Get an access URL for a file (no authentication required).
|
||||
|
||||
Args:
|
||||
file_id: The UUID of the file.
|
||||
expires: URL validity period in seconds (default from FILE_URL_EXPIRES env).
|
||||
permanent: If True, return a permanent URL without expiration.
|
||||
db: Database session.
|
||||
storage_service: The file storage service.
|
||||
|
||||
Returns:
|
||||
ApiResponse with the access URL.
|
||||
"""
|
||||
if expires is None:
|
||||
expires = settings.FILE_URL_EXPIRES
|
||||
|
||||
api_logger.info(f"Get file URL request: file_id={file_id}, expires={expires}, permanent={permanent}")
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}"
|
||||
)
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
try:
|
||||
if permanent:
|
||||
# Generate permanent URL (no expiration check)
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
url = f"{server_url}/storage/permanent/{file_id}"
|
||||
return success(
|
||||
data={
|
||||
"url": url,
|
||||
"expires_in": None,
|
||||
"permanent": True,
|
||||
"file_name": file_metadata.file_name,
|
||||
},
|
||||
msg="Permanent file URL generated successfully"
|
||||
)
|
||||
|
||||
if isinstance(storage, LocalStorage):
|
||||
# For local storage, generate signed URL with expiration
|
||||
url = generate_signed_url(str(file_id), expires)
|
||||
else:
|
||||
# For remote storage (OSS/S3), get presigned URL
|
||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
||||
|
||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||
return success(
|
||||
data={
|
||||
"url": url,
|
||||
"expires_in": expires,
|
||||
"permanent": False,
|
||||
"file_name": file_metadata.file_name,
|
||||
},
|
||||
msg="File URL generated successfully"
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to generate file URL: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate file URL: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/public/{file_id}", response_model=Any)
|
||||
async def public_download_file(
|
||||
file_id: uuid.UUID,
|
||||
expires: int = 0,
|
||||
signature: str = "",
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
) -> Any:
|
||||
"""
|
||||
Public file download endpoint with signature verification.
|
||||
|
||||
This endpoint allows downloading files without authentication,
|
||||
but requires a valid signature and non-expired timestamp.
|
||||
|
||||
Args:
|
||||
file_id: The UUID of the file.
|
||||
expires: Expiration timestamp.
|
||||
signature: HMAC signature for verification.
|
||||
db: Database session.
|
||||
storage_service: The file storage service.
|
||||
|
||||
Returns:
|
||||
FileResponse for the requested file.
|
||||
"""
|
||||
api_logger.info(f"Public download request: file_id={file_id}")
|
||||
|
||||
# Verify signature
|
||||
is_valid, error_msg = verify_signed_url(str(file_id), expires, signature)
|
||||
if not is_valid:
|
||||
api_logger.warning(f"Invalid signed URL: file_id={file_id}, error={error_msg}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=error_msg
|
||||
)
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}"
|
||||
)
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
if isinstance(storage, LocalStorage):
|
||||
full_path = storage._get_full_path(file_key)
|
||||
|
||||
if not full_path.exists():
|
||||
api_logger.warning(f"File not found on disk: file_key={file_key}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found"
|
||||
)
|
||||
|
||||
api_logger.info(f"Serving public file: file_key={file_key}")
|
||||
return FileResponse(
|
||||
path=str(full_path),
|
||||
filename=file_metadata.file_name,
|
||||
media_type=file_metadata.content_type or "application/octet-stream"
|
||||
)
|
||||
else:
|
||||
# For remote storage, redirect to presigned URL
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/permanent/{file_id}", response_model=Any)
|
||||
async def permanent_download_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
) -> Any:
|
||||
"""
|
||||
Permanent file download endpoint (no expiration, no signature required).
|
||||
|
||||
This endpoint allows downloading files without authentication or expiration.
|
||||
Use with caution as URLs are permanently accessible.
|
||||
|
||||
Args:
|
||||
file_id: The UUID of the file.
|
||||
db: Database session.
|
||||
storage_service: The file storage service.
|
||||
|
||||
Returns:
|
||||
FileResponse for the requested file.
|
||||
"""
|
||||
api_logger.info(f"Permanent download request: file_id={file_id}")
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}"
|
||||
)
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
if isinstance(storage, LocalStorage):
|
||||
full_path = storage._get_full_path(file_key)
|
||||
|
||||
if not full_path.exists():
|
||||
api_logger.warning(f"File not found on disk: file_key={file_key}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found"
|
||||
)
|
||||
|
||||
api_logger.info(f"Serving permanent file: file_key={file_key}")
|
||||
return FileResponse(
|
||||
path=str(full_path),
|
||||
filename=file_metadata.file_name,
|
||||
media_type=file_metadata.content_type or "application/octet-stream"
|
||||
)
|
||||
else:
|
||||
# For remote storage, redirect to presigned URL with long expiration
|
||||
try:
|
||||
# Use a very long expiration (7 days max for most cloud providers)
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file: {str(e)}"
|
||||
)
|
||||
44
api/app/controllers/home_page_controller.py
Normal file
44
api/app/controllers/home_page_controller.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.home_page_service import HomePageService
|
||||
|
||||
router = APIRouter(prefix="/home-page", tags=["Home Page"])
|
||||
|
||||
@router.get("/statistics", response_model=ApiResponse)
|
||||
def get_home_statistics(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取首页统计数据"""
|
||||
statistics = HomePageService.get_home_statistics(db, current_user.tenant_id)
|
||||
return success(data=statistics, msg="统计数据获取成功")
|
||||
|
||||
@router.get("/workspaces", response_model=ApiResponse)
|
||||
def get_workspace_list(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取工作空间列表"""
|
||||
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
|
||||
return success(data=workspace_list, msg="工作空间列表获取成功")
|
||||
|
||||
@router.get("/version", response_model=ApiResponse)
|
||||
def get_system_version():
|
||||
"""获取系统版本号+说明"""
|
||||
current_version = settings.SYSTEM_VERSION
|
||||
version_info = HomePageService.load_version_introduction(current_version)
|
||||
return success(
|
||||
data={
|
||||
"version": current_version,
|
||||
"introduction": version_info.get("introduction"),
|
||||
"introduction_en": version_info.get("introduction_en")
|
||||
},
|
||||
msg="系统版本获取成功"
|
||||
)
|
||||
457
api/app/controllers/implicit_memory_controller.py
Normal file
457
api/app/controllers/implicit_memory_controller.py
Normal file
@@ -0,0 +1,457 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import (
|
||||
cur_workspace_access_guard,
|
||||
get_current_user,
|
||||
)
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.implicit_memory_schema import GenerateProfileRequest
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/implicit-memory",
|
||||
tags=["Implicit Memory"],
|
||||
)
|
||||
|
||||
|
||||
def handle_implicit_memory_error(e: Exception, operation: str, user_id: str = None) -> dict:
|
||||
"""
|
||||
Centralized error handling for implicit memory operations.
|
||||
|
||||
Args:
|
||||
e: The exception that occurred
|
||||
operation: Description of the operation that failed
|
||||
user_id: Optional user ID for logging context
|
||||
|
||||
Returns:
|
||||
Standardized error response
|
||||
"""
|
||||
error_context = f"user_id={user_id}" if user_id else "unknown user"
|
||||
|
||||
if isinstance(e, ValueError):
|
||||
if "user" in str(e).lower() and "not found" in str(e).lower():
|
||||
api_logger.warning(f"Invalid user ID for {operation}: {error_context}")
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的用户ID", str(e))
|
||||
elif "insufficient" in str(e).lower() or "no data" in str(e).lower():
|
||||
api_logger.warning(f"Insufficient data for {operation}: {error_context}")
|
||||
return fail(BizCode.INSUFFICIENT_DATA, "数据不足,无法进行分析", str(e))
|
||||
else:
|
||||
api_logger.warning(f"Invalid parameters for {operation}: {error_context}")
|
||||
return fail(BizCode.INVALID_FILTER_PARAMS, "无效的参数", str(e))
|
||||
|
||||
elif isinstance(e, KeyError):
|
||||
api_logger.warning(f"Missing required data for {operation}: {error_context}")
|
||||
return fail(BizCode.INSUFFICIENT_DATA, "缺少必要的数据", str(e))
|
||||
|
||||
elif isinstance(e, (ConnectionError, TimeoutError)):
|
||||
api_logger.error(f"Service unavailable for {operation}: {error_context}")
|
||||
return fail(BizCode.SERVICE_UNAVAILABLE, "服务暂时不可用", str(e))
|
||||
|
||||
elif "analysis" in str(e).lower() or "llm" in str(e).lower():
|
||||
api_logger.error(f"Analysis failed for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.ANALYSIS_FAILED, "分析处理失败", str(e))
|
||||
|
||||
elif "storage" in str(e).lower() or "database" in str(e).lower():
|
||||
api_logger.error(f"Storage error for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.PROFILE_STORAGE_ERROR, "数据存储失败", str(e))
|
||||
|
||||
else:
|
||||
api_logger.error(f"Unexpected error for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, f"{operation}失败", str(e))
|
||||
|
||||
|
||||
def validate_user_id(user_id: str) -> None:
|
||||
"""
|
||||
Validate user ID format and constraints.
|
||||
|
||||
Args:
|
||||
user_id: User ID to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If user ID is invalid
|
||||
"""
|
||||
if not user_id or not user_id.strip():
|
||||
raise ValueError("User ID cannot be empty")
|
||||
|
||||
if len(user_id.strip()) < 1:
|
||||
raise ValueError("User ID is too short")
|
||||
|
||||
|
||||
def validate_date_range(start_date: Optional[datetime], end_date: Optional[datetime]) -> None:
|
||||
"""
|
||||
Validate date range parameters.
|
||||
|
||||
Args:
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Raises:
|
||||
ValueError: If date range is invalid
|
||||
"""
|
||||
if (start_date and not end_date) or (end_date and not start_date):
|
||||
raise ValueError("Both start_date and end_date must be provided together")
|
||||
|
||||
if start_date and end_date and start_date >= end_date:
|
||||
raise ValueError("start_date must be before end_date")
|
||||
|
||||
if start_date and start_date > datetime.now():
|
||||
raise ValueError("start_date cannot be in the future")
|
||||
|
||||
|
||||
def validate_confidence_threshold(threshold: float) -> None:
|
||||
"""
|
||||
Validate confidence threshold parameter.
|
||||
|
||||
Args:
|
||||
threshold: Confidence threshold to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If threshold is invalid
|
||||
"""
|
||||
if not 0.0 <= threshold <= 1.0:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def check_user_data_exists(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
检查用户画像数据是否存在
|
||||
|
||||
Args:
|
||||
end_user_id: 目标用户ID
|
||||
|
||||
Returns:
|
||||
数据存在状态
|
||||
"""
|
||||
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="画像数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
|
||||
return success(data={"exists": True}, msg="画像数据已存在")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
|
||||
|
||||
|
||||
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
end_user_id: str,
|
||||
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
|
||||
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter start date"),
|
||||
end_date: Optional[datetime] = Query(None, description="Filter end date"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user preference tags from cache.
|
||||
|
||||
Args:
|
||||
end_user_id: Target end user ID
|
||||
confidence_threshold: Minimum confidence score (0.0-1.0)
|
||||
tag_category: Optional category filter
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
|
||||
Returns:
|
||||
List of preference tags from cache
|
||||
"""
|
||||
api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract preferences from cache
|
||||
preferences = cached_profile.get("preferences", [])
|
||||
|
||||
# Apply filters (client-side filtering on cached data)
|
||||
filtered_preferences = []
|
||||
for pref in preferences:
|
||||
# Filter by confidence threshold
|
||||
if confidence_threshold is not None and pref.get("confidence_score", 0) < confidence_threshold:
|
||||
continue
|
||||
|
||||
# Filter by category if specified
|
||||
if tag_category and pref.get("category") != tag_category:
|
||||
continue
|
||||
|
||||
# Filter by date range if specified
|
||||
if start_date or end_date:
|
||||
created_at_ts = pref.get("created_at")
|
||||
if created_at_ts:
|
||||
created_at = datetime.fromtimestamp(created_at_ts / 1000)
|
||||
if start_date and created_at < start_date:
|
||||
continue
|
||||
if end_date and created_at > end_date:
|
||||
continue
|
||||
|
||||
filtered_preferences.append(pref)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
|
||||
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/portrait/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_dimension_portrait(
|
||||
end_user_id: str,
|
||||
include_history: bool = Query(False, description="Include historical trends"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's four-dimension personality portrait from cache.
|
||||
|
||||
Args:
|
||||
end_user_id: Target end user ID
|
||||
include_history: Whether to include historical trend data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Four-dimension personality portrait from cache
|
||||
"""
|
||||
api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract portrait from cache
|
||||
portrait = cached_profile.get("portrait", {})
|
||||
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
|
||||
return success(data=portrait, msg="四维画像获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_interest_area_distribution(
|
||||
end_user_id: str,
|
||||
include_trends: bool = Query(False, description="Include trend analysis"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's interest area distribution from cache.
|
||||
|
||||
Args:
|
||||
end_user_id: Target end user ID
|
||||
include_trends: Whether to include trend analysis data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Interest area distribution from cache
|
||||
"""
|
||||
api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract interest areas from cache
|
||||
interest_areas = cached_profile.get("interest_areas", {})
|
||||
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
|
||||
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/habits/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_behavior_habits(
|
||||
end_user_id: str,
|
||||
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
|
||||
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
|
||||
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's behavioral habits from cache.
|
||||
|
||||
Args:
|
||||
end_user_id: Target end user ID
|
||||
confidence_level: Filter by confidence level (high, medium, low)
|
||||
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
|
||||
time_period: Filter by time period (current, past)
|
||||
|
||||
Returns:
|
||||
List of behavioral habits from cache
|
||||
"""
|
||||
api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract habits from cache
|
||||
habits = cached_profile.get("habits", [])
|
||||
|
||||
# Apply filters (client-side filtering on cached data)
|
||||
filtered_habits = []
|
||||
for habit in habits:
|
||||
# Filter by confidence level
|
||||
if confidence_level:
|
||||
confidence_mapping = {
|
||||
"high": 85,
|
||||
"medium": 50,
|
||||
"low": 20
|
||||
}
|
||||
numerical_confidence = confidence_mapping.get(confidence_level.lower())
|
||||
if habit.get("confidence_level", 0) < numerical_confidence:
|
||||
continue
|
||||
|
||||
# Filter by frequency pattern
|
||||
if frequency_pattern and habit.get("frequency_pattern") != frequency_pattern:
|
||||
continue
|
||||
|
||||
# Filter by time period
|
||||
if time_period:
|
||||
is_current = habit.get("is_current", True)
|
||||
if time_period.lower() == "current" and not is_current:
|
||||
continue
|
||||
elif time_period.lower() == "past" and is_current:
|
||||
continue
|
||||
|
||||
filtered_habits.append(habit)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
|
||||
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/generate_profile", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def generate_implicit_memory_profile(
|
||||
request: GenerateProfileRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Generate complete user profile (all 4 modules) and cache it.
|
||||
|
||||
Args:
|
||||
request: Generate profile request with end_user_id
|
||||
db: Database session
|
||||
current_user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
Complete user profile with all modules
|
||||
"""
|
||||
end_user_id = request.end_user_id
|
||||
api_logger.info(f"Generate profile requested for user: {end_user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Generate complete profile (calls LLM for all 4 modules)
|
||||
api_logger.info(f"开始生成完整用户画像: user={end_user_id}")
|
||||
profile_data = await service.generate_complete_profile(user_id=end_user_id)
|
||||
|
||||
# Save to cache
|
||||
await service.save_profile_cache(
|
||||
end_user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
db=db,
|
||||
expires_hours=168 # 7 days
|
||||
)
|
||||
|
||||
api_logger.info(f"用户画像生成并缓存成功: user={end_user_id}")
|
||||
|
||||
# Add metadata
|
||||
profile_data["end_user_id"] = end_user_id
|
||||
profile_data["cached"] = False
|
||||
|
||||
return success(data=profile_data, msg="用户画像生成成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"生成用户画像失败: user={end_user_id}, error={str(e)}", exc_info=True)
|
||||
return handle_implicit_memory_error(e, "用户画像生成", end_user_id)
|
||||
@@ -1,26 +1,32 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.models import knowledge_model, document_model, file_model
|
||||
from app.schemas import knowledge_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import knowledge_service, document_service
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.common import settings
|
||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.nlp import rag_tokenizer, search
|
||||
from app.core.rag.prompts.generator import graph_entity_types
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.nlp import rag_tokenizer, search
|
||||
from app.core.rag.common import settings
|
||||
from app.celery_app import celery_app
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import knowledge_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import knowledge_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import knowledge_service, document_service
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -47,6 +53,45 @@ def get_parser_types():
|
||||
return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType))
|
||||
|
||||
|
||||
@router.get("/knowledge_graph_entity_types", response_model=ApiResponse)
|
||||
async def get_knowledge_graph_entity_types(
|
||||
llm_id: uuid.UUID,
|
||||
scenario: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
get knowledge graph entity types based on llm_id
|
||||
"""
|
||||
api_logger.info(f"Obtain details of the knowledge graph: llm_id={llm_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the model exists
|
||||
api_logger.debug(f"Check whether the model exists: {llm_id}")
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=llm_id)
|
||||
|
||||
if not config:
|
||||
api_logger.warning(
|
||||
f"The model does not exist or you do not have permission to access it: llm_id={llm_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The model does not exist or you do not have permission to access it"
|
||||
)
|
||||
# 2. Prepare to configure chat_mdl information
|
||||
chat_model = Base(
|
||||
key=config.api_keys[0].api_key,
|
||||
model_name=config.api_keys[0].model_name,
|
||||
base_url=config.api_keys[0].api_base
|
||||
)
|
||||
response = graph_entity_types(chat_model, scenario)
|
||||
return success(data=response, msg="Successfully obtained knowledge graph entity types")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"get knowledge graph entity types failed: llm_id={llm_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/knowledges", response_model=ApiResponse)
|
||||
async def get_knowledges(
|
||||
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"),
|
||||
@@ -130,7 +175,7 @@ async def get_knowledges(
|
||||
"has_next": True if page*pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of knowledge base list successful")
|
||||
return success(data=jsonable_encoder(result), msg="Query of knowledge base list successful")
|
||||
|
||||
|
||||
@router.post("/knowledge", response_model=ApiResponse)
|
||||
@@ -156,7 +201,7 @@ async def create_knowledge(
|
||||
)
|
||||
db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=create_data, current_user=current_user)
|
||||
api_logger.info(f"The knowledge base has been successfully created: {db_knowledge.name} (ID: {db_knowledge.id})")
|
||||
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base has been successfully created")
|
||||
return success(data=jsonable_encoder(knowledge_schema.Knowledge.model_validate(db_knowledge)), msg="The knowledge base has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the knowledge base failed: {create_data.name} - {str(e)}")
|
||||
raise
|
||||
@@ -185,7 +230,7 @@ async def get_knowledge(
|
||||
)
|
||||
|
||||
api_logger.info(f"Knowledge base query successful: {db_knowledge.name} (ID: {db_knowledge.id})")
|
||||
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="Successfully obtained knowledge base information")
|
||||
return success(data=jsonable_encoder(knowledge_schema.Knowledge.model_validate(db_knowledge)), msg="Successfully obtained knowledge base information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -202,7 +247,7 @@ async def update_knowledge(
|
||||
):
|
||||
api_logger.info(f"Update knowledge base request: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
db_knowledge = await _update_knowledge(knowledge_id=knowledge_id, update_data=update_data, db=db, current_user=current_user)
|
||||
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base information has been successfully updated")
|
||||
return success(data=jsonable_encoder(knowledge_schema.Knowledge.model_validate(db_knowledge)), msg="The knowledge base information has been successfully updated")
|
||||
|
||||
|
||||
async def _update_knowledge(
|
||||
@@ -379,7 +424,7 @@ async def delete_knowledge_graph(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Soft-delete knowledge graph
|
||||
delete knowledge graph
|
||||
"""
|
||||
api_logger.info(f"Request to delete knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
|
||||
@@ -444,40 +489,97 @@ async def rebuild_knowledge_graph(
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{knowledge_id}/knowledge_graph_entity_types", response_model=ApiResponse)
|
||||
async def get_knowledge_graph_entity_types(
|
||||
knowledge_id: uuid.UUID,
|
||||
scenario: str,
|
||||
@router.get("/check/yuque/auth", response_model=ApiResponse)
|
||||
async def check_yuque_auth(
|
||||
yuque_user_id: str,
|
||||
yuque_token: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
get knowledge graph entity types based on knowledge_id
|
||||
check yuque auth info
|
||||
"""
|
||||
api_logger.info(f"Obtain details of the knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
api_logger.info(f"check yuque auth info, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the knowledge base exists
|
||||
api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||
|
||||
if not db_knowledge:
|
||||
api_logger.warning(
|
||||
f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or you do not have permission to access it"
|
||||
)
|
||||
# 2. Prepare to configure chat_mdl information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
api_client = YuqueAPIClient(
|
||||
user_id=yuque_user_id,
|
||||
token=yuque_token
|
||||
)
|
||||
response = graph_entity_types(chat_model, scenario)
|
||||
return success(data=response, msg="Successfully obtained knowledge graph entity types")
|
||||
async with api_client as client:
|
||||
repos = await client.get_user_repos()
|
||||
if repos:
|
||||
return success(msg="Successfully auth yuque info")
|
||||
return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"get knowledge graph entity types failed: knowledge_id={knowledge_id} - {str(e)}")
|
||||
api_logger.error(f"auth yuque info failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/check/feishu/auth", response_model=ApiResponse)
|
||||
async def check_feishu_auth(
|
||||
feishu_app_id: str,
|
||||
feishu_app_secret: str,
|
||||
feishu_folder_token: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
check feishu auth info
|
||||
"""
|
||||
api_logger.info(f"check feishu auth info, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_client = FeishuAPIClient(
|
||||
app_id=feishu_app_id,
|
||||
app_secret=feishu_app_secret
|
||||
)
|
||||
async with api_client as client:
|
||||
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
|
||||
if files:
|
||||
return success(msg="Successfully auth feishu info")
|
||||
return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"auth feishu info failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
|
||||
async def sync_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
sync knowledge base information based on knowledge_id
|
||||
"""
|
||||
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query knowledge base information from the database
|
||||
api_logger.debug(f"Query knowledge base: {knowledge_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. sync knowledge
|
||||
# from app.tasks import sync_knowledge_for_kb
|
||||
# sync_knowledge_for_kb(kb_id)
|
||||
task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id])
|
||||
result = {
|
||||
"task_id": task.id
|
||||
}
|
||||
return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
395
api/app/controllers/mcp_market_config_controller.py
Normal file
395
api/app/controllers/mcp_market_config_controller.py
Normal file
@@ -0,0 +1,395 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
import requests
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
from modelscope.hub.errors import raise_for_http_status
|
||||
from modelscope.hub.mcp_api import MCPApi
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import mcp_market_config_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_config_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_config_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/mcp_market_configs",
|
||||
tags=["mcp_market_configs"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mcp_servers", response_model=ApiResponse)
|
||||
async def get_mcp_servers(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (Optional search query string,e.g. Chinese service name, English service name, author/owner username)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the mcp servers list in pages
|
||||
- Support keyword search for name,author,owner
|
||||
- Return paging metadata + mcp server list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp server list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 3. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
'filter': {},
|
||||
'page_number': page,
|
||||
'page_size': pagesize,
|
||||
'search': keywords
|
||||
}
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(token)
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=api.builder_headers(api.headers),
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get MCP servers: {str(e)}"
|
||||
)
|
||||
|
||||
data = api._handle_response(r)
|
||||
total = data.get('total_count', 0)
|
||||
mcp_server_list = data.get('mcp_server_list', [])
|
||||
# items = [{
|
||||
# 'name': item.get('name', ''),
|
||||
# 'id': item.get('id', ''),
|
||||
# 'description': item.get('description', '')
|
||||
# } for item in mcp_server_list]
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": mcp_server_list,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/operational_mcp_servers", response_model=ApiResponse)
|
||||
async def get_operational_mcp_servers(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the operational mcp servers list in pages
|
||||
- Support keyword search for name,author,owner
|
||||
- Return paging metadata + operational mcp server list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
url = f'{api.mcp_base_url}/operational'
|
||||
headers = api.builder_headers(api.headers)
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
||||
r = api.session.get(url, headers=headers, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get operational MCP servers: {str(e)}"
|
||||
)
|
||||
|
||||
data = api._handle_response(r)
|
||||
total = data.get('total_count', 0)
|
||||
mcp_server_list = data.get('mcp_server_list', [])
|
||||
# items = [{
|
||||
# 'name': item.get('name', ''),
|
||||
# 'id': item.get('id', ''),
|
||||
# 'description': item.get('description', '')
|
||||
# } for item in mcp_server_list]
|
||||
|
||||
# 3. Return structured response
|
||||
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/mcp_server", response_model=ApiResponse)
|
||||
async def get_mcp_server(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
server_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get detailed information for a specific MCP Server
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp server: tenant_id={current_user.tenant_id}, mcp_market_config_id={mcp_market_config_id}, server_id={server_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. Get detailed information for a specific MCP Server
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
result = api.get_mcp_server(server_id=server_id)
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.post("/mcp_market_config", response_model=ApiResponse)
|
||||
async def create_mcp_market_config(
|
||||
create_data: mcp_market_config_schema.McpMarketConfigCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create mcp market config
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market config: mcp_market_id={create_data.mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
|
||||
if db_mcp_market_config_exist:
|
||||
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
||||
)
|
||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the mcp market config failed: {create_data.mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market config information based on mcp_market_config_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Obtain details of the mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="Successfully obtained mcp market config information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market config query failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/mcp_market_id/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market_config_by_mcp_market_id(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market config information based on mcp_market_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market config: mcp_market_id={mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: mcp_market_id={mcp_market_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="Successfully obtained mcp market config information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market config query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def update_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
update_data: mcp_market_config_schema.McpMarketConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# 1. Check if the mcp market config exists
|
||||
api_logger.debug(f"Query the mcp market config to be updated: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
updated_fields = []
|
||||
for field, value in update_dict.items():
|
||||
if hasattr(db_mcp_market_config, field):
|
||||
old_value = getattr(db_mcp_market_config, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_mcp_market_config, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 3. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market_config)
|
||||
api_logger.info(f"The mcp market config has been successfully updated: (ID: {db_mcp_market_config.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"The mcp market config update failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"The mcp market config update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return the updated mcp market config
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def delete_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete mcp market config
|
||||
"""
|
||||
api_logger.info(f"Request to delete mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the mcp market config exists
|
||||
api_logger.debug(f"Check whether the mcp market config exists: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Deleting mcp market config
|
||||
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
api_logger.info(f"The mcp market config has been successfully deleted: (ID: {mcp_market_config_id})")
|
||||
return success(msg="The mcp market config has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
262
api/app/controllers/mcp_market_controller.py
Normal file
262
api/app/controllers/mcp_market_controller.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import mcp_market_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/mcp_markets",
|
||||
tags=["mcp_markets"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mcp_markets", response_model=ApiResponse)
|
||||
async def get_mcp_markets(
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: category, created_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (mcp_market base name)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the mcp markets list in pages
|
||||
- Support keyword search for name,description
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + mcp_market list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp market list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Construct query conditions
|
||||
filters = []
|
||||
|
||||
# Keyword search (fuzzy matching of mcp market name,description)
|
||||
if keywords:
|
||||
api_logger.debug(f"Add keyword search criteria: {keywords}")
|
||||
filters.append(
|
||||
or_(
|
||||
mcp_market_model.McpMarket.name.ilike(f"%{keywords}%"),
|
||||
mcp_market_model.McpMarket.description.ilike(f"%{keywords}%")
|
||||
)
|
||||
)
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug("Start executing mcp market paging query")
|
||||
total, items = mcp_market_service.get_mcp_markets_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"mcp market query successful: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market query failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Query failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=jsonable_encoder(result), msg="Query of mcp market list successful")
|
||||
|
||||
|
||||
@router.post("/mcp_market", response_model=ApiResponse)
|
||||
async def create_mcp_market(
|
||||
create_data: mcp_market_schema.McpMarketCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create mcp market
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market: name={create_data.name}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market: {create_data.name}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=create_data.name, current_user=current_user)
|
||||
if db_mcp_market_exist:
|
||||
api_logger.warning(f"The mcp market name already exists: {create_data.name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market name already exists: {create_data.name}"
|
||||
)
|
||||
db_mcp_market = mcp_market_service.create_mcp_market(db=db, mcp_market=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market has been successfully created: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="The mcp market has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the mcp market failed: {create_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market information based on mcp_market_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Obtain details of the mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market information from the database
|
||||
api_logger.debug(f"Query mcp market: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market query successful: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="Successfully obtained mcp market information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def update_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
update_data: mcp_market_schema.McpMarketUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# 1. Check if the mcp market exists
|
||||
api_logger.debug(f"Query the mcp market to be updated: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(
|
||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. not updating the name (name already exists)
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
if "name" in update_dict:
|
||||
name = update_dict["name"]
|
||||
if name != db_mcp_market.name:
|
||||
# Check if the mcp market name already exists
|
||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=name, current_user=current_user)
|
||||
if db_mcp_market_exist:
|
||||
api_logger.warning(f"The mcp market name already exists: {name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market name already exists: {name}"
|
||||
)
|
||||
# 3. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market fields: {mcp_market_id}")
|
||||
updated_fields = []
|
||||
for field, value in update_dict.items():
|
||||
if hasattr(db_mcp_market, field):
|
||||
old_value = getattr(db_mcp_market, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_mcp_market, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 4. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market)
|
||||
api_logger.info(f"The mcp market has been successfully updated: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"The mcp market update failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"The mcp market update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 5. Return the updated mcp market
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="The mcp market information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def delete_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete mcp market
|
||||
"""
|
||||
api_logger.info(f"Request to delete mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the mcp market exists
|
||||
api_logger.debug(f"Check whether the mcp market exists: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(
|
||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Deleting mcp market
|
||||
mcp_market_service.delete_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
api_logger.info(f"The mcp market has been successfully deleted: (ID: {mcp_market_id})")
|
||||
return success(msg="The mcp market has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the mcp market: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
@@ -1,8 +1,17 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
@@ -15,10 +24,6 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.services import task_service, workspace_service
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
@@ -33,7 +38,7 @@ router = APIRouter(
|
||||
|
||||
@router.get("/health/status", response_model=ApiResponse)
|
||||
async def get_health_status(
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get latest health status written by Celery periodic task
|
||||
@@ -51,8 +56,9 @@ async def get_health_status(
|
||||
|
||||
@router.get("/download_log")
|
||||
async def download_log(
|
||||
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
log_type: str = Query("file", regex="^(file|transmission)$",
|
||||
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Download or stream agent service log file
|
||||
@@ -71,16 +77,16 @@ async def download_log(
|
||||
- transmission mode: StreamingResponse with SSE
|
||||
"""
|
||||
api_logger.info(f"Log download requested with log_type={log_type}")
|
||||
|
||||
|
||||
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
|
||||
if log_type not in ["file", "transmission"]:
|
||||
api_logger.warning(f"Invalid log_type parameter: {log_type}")
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"无效的log_type参数",
|
||||
BizCode.BAD_REQUEST,
|
||||
"无效的log_type参数",
|
||||
"log_type必须是'file'或'transmission'"
|
||||
)
|
||||
|
||||
|
||||
# Route to appropriate mode
|
||||
if log_type == "file":
|
||||
# File mode: Return complete log file content
|
||||
@@ -115,23 +121,28 @@ async def download_log(
|
||||
@router.post("/writer_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Write service endpoint - processes write operations synchronously
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
@@ -140,7 +151,7 @@ async def write_server(
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
@@ -152,22 +163,27 @@ async def write_server(
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
api_logger.warning(
|
||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
|
||||
api_logger.info(
|
||||
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.group_id,
|
||||
user_input.message,
|
||||
user_input.end_user_id,
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
storage_type,
|
||||
user_rag_memory_id,
|
||||
language
|
||||
)
|
||||
|
||||
return success(data=result, msg="写入成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -183,23 +199,29 @@ async def write_server(
|
||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server_async(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
api_logger.info(
|
||||
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
@@ -219,12 +241,15 @@ async def write_server_async(
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
|
||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
@@ -234,9 +259,9 @@ async def write_server_async(
|
||||
@router.post("/read_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def read_server(
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Read service endpoint - processes read operations synchronously
|
||||
@@ -247,16 +272,14 @@ async def read_server(
|
||||
- "2": Direct answer based on context
|
||||
|
||||
Args:
|
||||
user_input: Read request with message, history, search_switch, and group_id
|
||||
user_input: Read request with message, history, search_switch, and end_user_id
|
||||
|
||||
Returns:
|
||||
Response with query answer
|
||||
"""
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
@@ -271,12 +294,14 @@ async def read_server(
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(
|
||||
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await memory_agent_service.read_memory(
|
||||
user_input.group_id,
|
||||
user_input.end_user_id,
|
||||
user_input.message,
|
||||
user_input.history,
|
||||
user_input.search_switch,
|
||||
@@ -285,6 +310,23 @@ async def read_server(
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
||||
user_input.end_user_id)
|
||||
query = user_input.message
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||
end_user_id=user_input.end_user_id,
|
||||
retrieve_info=retrieve_info,
|
||||
history=history,
|
||||
query=query,
|
||||
config_id=config_id,
|
||||
db=db
|
||||
)
|
||||
if "信息不足,无法回答" in result['answer']:
|
||||
result['answer'] = retrieve_info
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -300,9 +342,10 @@ async def read_server(
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
async def file_update(
|
||||
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||
model_id:str = Form(..., description="模型ID"),
|
||||
model_id: str = Form(..., description="模型ID"),
|
||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
文件上传接口 - 支持图片识别
|
||||
@@ -315,9 +358,6 @@ async def file_update(
|
||||
Returns:
|
||||
文件处理结果
|
||||
"""
|
||||
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen)
|
||||
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
@@ -326,7 +366,7 @@ async def file_update(
|
||||
for file in files:
|
||||
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
|
||||
content = await file.read()
|
||||
|
||||
|
||||
if file.content_type and file.content_type.startswith("image/"):
|
||||
vision_model = QWenCV(
|
||||
key=apiConfig.api_key,
|
||||
@@ -340,12 +380,12 @@ async def file_update(
|
||||
else:
|
||||
api_logger.warning(f"Unsupported file type: {file.content_type}")
|
||||
file_content.append(f"[不支持的文件类型: {file.content_type}]")
|
||||
|
||||
|
||||
result_text = ';'.join(file_content)
|
||||
api_logger.info(f"File processing completed, result length: {len(result_text)}")
|
||||
|
||||
|
||||
return success(data=result_text, msg="转换文本成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
|
||||
@@ -382,7 +422,7 @@ async def read_server_async(
|
||||
try:
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.read_message",
|
||||
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
|
||||
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
|
||||
config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Read task queued: {task.id}")
|
||||
@@ -395,8 +435,8 @@ async def read_server_async(
|
||||
|
||||
@router.get("/read_result/", response_model=ApiResponse)
|
||||
async def get_read_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async read task
|
||||
@@ -417,7 +457,7 @@ async def get_read_task_result(
|
||||
try:
|
||||
result = task_service.get_task_memory_read_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
@@ -426,7 +466,7 @@ async def get_read_task_result(
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"end_user_id": task_result.get("end_user_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
@@ -435,7 +475,7 @@ async def get_read_task_result(
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="查询任务已完成")
|
||||
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
@@ -444,7 +484,7 @@ async def get_read_task_result(
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
|
||||
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
@@ -464,7 +504,7 @@ async def get_read_task_result(
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
@@ -472,8 +512,8 @@ async def get_read_task_result(
|
||||
|
||||
@router.get("/write_result/", response_model=ApiResponse)
|
||||
async def get_write_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async write task
|
||||
@@ -494,7 +534,7 @@ async def get_write_task_result(
|
||||
try:
|
||||
result = task_service.get_task_memory_write_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
@@ -503,7 +543,7 @@ async def get_write_task_result(
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"end_user_id": task_result.get("end_user_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
@@ -512,7 +552,7 @@ async def get_write_task_result(
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="写入任务已完成")
|
||||
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
@@ -521,7 +561,7 @@ async def get_write_task_result(
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
|
||||
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
@@ -541,7 +581,7 @@ async def get_write_task_result(
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
@@ -549,23 +589,38 @@ async def get_write_task_result(
|
||||
|
||||
@router.post("/status_type", response_model=ApiResponse)
|
||||
async def status_type(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
|
||||
Args:
|
||||
user_input: Request containing user message and group_id
|
||||
user_input: Request containing user message and end_user_id
|
||||
|
||||
Returns:
|
||||
Type classification result
|
||||
"""
|
||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
||||
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
# 将消息列表转换为字符串用于分类
|
||||
# 只取最后一条用户消息进行分类
|
||||
last_user_message = ""
|
||||
for msg in reversed(messages_list):
|
||||
if msg.get('role') == 'user':
|
||||
last_user_message = msg.get('content', '')
|
||||
break
|
||||
|
||||
if not last_user_message:
|
||||
# 如果没有用户消息,使用所有消息的内容
|
||||
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||
|
||||
result = await memory_agent_service.classify_message_type(
|
||||
user_input.message,
|
||||
last_user_message,
|
||||
user_input.config_id,
|
||||
db
|
||||
)
|
||||
@@ -579,26 +634,21 @@ async def status_type(
|
||||
|
||||
@router.get("/stats/types", response_model=ApiResponse)
|
||||
async def get_knowledge_type_stats_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||
会对缺失类型补 0,返回字典形式。
|
||||
可选按状态过滤。
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
|
||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
api_logger.info(
|
||||
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
try:
|
||||
from app.db import get_db
|
||||
|
||||
# 获取数据库会话
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 调用service层函数
|
||||
result = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
@@ -606,45 +656,70 @@ async def get_knowledge_type_stats_api(
|
||||
current_workspace_id=current_user.current_workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
return success(data=result, msg="获取知识库类型统计成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Knowledge type stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_by_user_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
limit: int = Query(20, description="返回标签数量限制"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
||||
async def get_interest_distribution_by_user_api(
|
||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
获取指定用户的兴趣分布标签
|
||||
|
||||
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{"name": "标签名", "frequency": 频次},
|
||||
{"name": "兴趣活动名", "frequency": 频次},
|
||||
...
|
||||
]
|
||||
"""
|
||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||
language = get_language_from_header(language_type)
|
||||
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
|
||||
try:
|
||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||
# 优先读取缓存
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit
|
||||
language=language,
|
||||
)
|
||||
return success(data=result, msg="获取热门记忆标签成功")
|
||||
if cached is not None:
|
||||
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
|
||||
return success(data=cached, msg="获取兴趣分布标签成功")
|
||||
|
||||
# 缓存未命中,调用模型生成
|
||||
result = await memory_agent_service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 写入缓存,24小时过期
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
)
|
||||
|
||||
return success(data=result, msg="获取兴趣分布标签成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
|
||||
api_logger.error(f"Interest distribution by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||
async def get_user_profile_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -682,17 +757,17 @@ async def get_user_profile_api(
|
||||
# ):
|
||||
# """
|
||||
# Get parsed API documentation (Public endpoint - no authentication required)
|
||||
|
||||
|
||||
# Args:
|
||||
# file_path: Optional path to API docs file. If None, uses default path.
|
||||
|
||||
|
||||
# Returns:
|
||||
# Parsed API documentation including title, meta info, and sections
|
||||
# """
|
||||
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
|
||||
# try:
|
||||
# result = await memory_agent_service.get_api_docs(file_path)
|
||||
|
||||
|
||||
# if result.get("success"):
|
||||
# return success(msg=result["msg"], data=result["data"])
|
||||
# else:
|
||||
@@ -708,9 +783,9 @@ async def get_user_profile_api(
|
||||
|
||||
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
||||
async def get_end_user_connected_config(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
@@ -729,9 +804,9 @@ async def get_end_user_connected_config(
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config as get_config,
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
|
||||
try:
|
||||
result = get_config(end_user_id, db)
|
||||
return success(data=result, msg="获取终端用户关联配置成功")
|
||||
@@ -740,4 +815,4 @@ async def get_end_user_connected_config(
|
||||
return fail(BizCode.NOT_FOUND, str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from app.repositories.end_user_repository import update_end_user_other_name
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.memory_agent_schema import End_User_Information
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.app_schema import App as AppSchema
|
||||
|
||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -43,54 +40,7 @@ def get_workspace_total_end_users(
|
||||
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
|
||||
return success(data=total_end_users, msg="用户数量获取成功")
|
||||
|
||||
@router.post("/update/end_users", response_model=ApiResponse)
|
||||
async def update_workspace_end_users(
|
||||
user_input: End_User_Information,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
更新工作空间的宿主信息
|
||||
"""
|
||||
username = user_input.end_user_name # 要更新的用户名
|
||||
end_user_input_id = user_input.id # 宿主ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息")
|
||||
api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}")
|
||||
|
||||
try:
|
||||
# 导入更新函数
|
||||
from app.repositories.end_user_repository import update_end_user_other_name
|
||||
import uuid
|
||||
|
||||
# 转换 end_user_id 为 UUID 类型
|
||||
end_user_uuid = uuid.UUID(end_user_input_id)
|
||||
|
||||
# 直接更新数据库中的 other_name 字段
|
||||
updated_count = update_end_user_other_name(
|
||||
db=db,
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=username
|
||||
)
|
||||
|
||||
api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}")
|
||||
|
||||
return success(
|
||||
data={
|
||||
"updated_count": updated_count,
|
||||
"end_user_id": end_user_input_id,
|
||||
"updated_other_name": username
|
||||
},
|
||||
msg=f"成功更新 {updated_count} 个宿主的信息"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新宿主信息失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"更新宿主信息失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -100,36 +50,134 @@ async def get_workspace_end_users(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取工作空间的宿主列表
|
||||
获取工作空间的宿主列表(高性能优化版本 v2)
|
||||
|
||||
返回格式与原 memory_list 接口中的 end_users 字段相同
|
||||
优化策略:
|
||||
1. 批量查询 end_users(一次查询而非循环)
|
||||
2. 并发查询所有用户的记忆数量(Neo4j)
|
||||
3. RAG 模式使用批量查询(一次 SQL)
|
||||
4. 只返回必要字段减少数据传输
|
||||
5. 添加短期缓存减少重复查询
|
||||
6. 并发执行配置查询和记忆数量查询
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"end_user": {"id": "uuid", "other_name": "名称"},
|
||||
"memory_num": {"total": 数量},
|
||||
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
||||
}
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 尝试从缓存获取(30秒缓存)
|
||||
cache_key = f"end_users:workspace:{workspace_id}"
|
||||
try:
|
||||
cached_data = await aio_redis_get(cache_key)
|
||||
if cached_data:
|
||||
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
||||
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
||||
|
||||
# 获取当前空间类型
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
|
||||
# 获取 end_users(已优化为批量查询)
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
)
|
||||
if not end_users:
|
||||
api_logger.info("工作空间下没有宿主")
|
||||
# 缓存空结果,避免重复查询
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
return success(data=[], msg="宿主列表获取成功")
|
||||
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
|
||||
# 并发执行两个独立的查询任务
|
||||
async def get_memory_configs():
|
||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||
try:
|
||||
return await asyncio.to_thread(
|
||||
get_end_users_connected_configs_batch,
|
||||
end_user_ids, db
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
return {}
|
||||
|
||||
async def get_memory_nums():
|
||||
"""获取记忆数量"""
|
||||
if current_workspace_type == "rag":
|
||||
# RAG 模式:批量查询
|
||||
try:
|
||||
chunk_map = await asyncio.to_thread(
|
||||
memory_dashboard_service.get_users_total_chunk_batch,
|
||||
end_user_ids, db, current_user
|
||||
)
|
||||
return {uid: {"total": count} for uid, count in chunk_map.items()}
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
elif current_workspace_type == "neo4j":
|
||||
# Neo4j 模式:并发查询(带并发限制)
|
||||
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
||||
MAX_CONCURRENT_QUERIES = 10
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
||||
|
||||
async def get_neo4j_memory_num(end_user_id: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
return await memory_storage_service.search_all(end_user_id)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
||||
return {"total": 0}
|
||||
|
||||
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
||||
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
# 并发执行配置查询和记忆数量查询
|
||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||
get_memory_configs(),
|
||||
get_memory_nums()
|
||||
)
|
||||
|
||||
# 构建结果(优化:使用列表推导式)
|
||||
result = []
|
||||
for end_user in end_users:
|
||||
memory_num = {}
|
||||
if current_workspace_type == "neo4j":
|
||||
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
|
||||
memory_num = await memory_storage_service.search_all(str(end_user.id))
|
||||
elif current_workspace_type == "rag":
|
||||
memory_num = {
|
||||
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
|
||||
user_id = str(end_user.id)
|
||||
config_info = memory_configs_map.get(user_id, {})
|
||||
result.append({
|
||||
'end_user': {
|
||||
'id': user_id,
|
||||
'other_name': end_user.other_name
|
||||
},
|
||||
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
|
||||
'memory_config': {
|
||||
"memory_config_id": config_info.get("memory_config_id"),
|
||||
"memory_config_name": config_info.get("memory_config_name")
|
||||
}
|
||||
result.append(
|
||||
{
|
||||
'end_user':end_user,
|
||||
'memory_num':memory_num
|
||||
}
|
||||
)
|
||||
|
||||
})
|
||||
|
||||
# 写入缓存(30秒过期)
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
@@ -422,6 +470,8 @@ async def get_chunk_insight(
|
||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||
async def dashboard_data(
|
||||
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
|
||||
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -456,6 +506,15 @@ async def dashboard_data(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
|
||||
|
||||
# 如果没有提供时间范围,默认使用最近30天
|
||||
if start_date is None or end_date is None:
|
||||
from datetime import datetime, timedelta
|
||||
end_dt = datetime.now()
|
||||
start_dt = end_dt - timedelta(days=30)
|
||||
end_date = int(end_dt.timestamp() * 1000)
|
||||
start_date = int(start_dt.timestamp() * 1000)
|
||||
api_logger.info(f"使用默认时间范围: {start_dt} 到 {end_dt}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
@@ -465,7 +524,6 @@ async def dashboard_data(
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
|
||||
user_rag_memory_id = None
|
||||
|
||||
# 根据 storage_type 决定返回哪个数据对象
|
||||
# 如果是 'rag',neo4j_data 为 null;否则 rag_data 为 null
|
||||
@@ -517,17 +575,22 @@ async def dashboard_data(
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||
|
||||
# 3. 获取API调用增量(total_api_call,转换为整数)
|
||||
# 3. 获取API调用统计(total_api_call)
|
||||
try:
|
||||
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||
db=db,
|
||||
# 使用 AppStatisticsService 获取真实的API调用统计
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
neo4j_data["total_api_call"] = api_increment
|
||||
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
neo4j_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||
api_logger.error(f"获取API调用统计失败: {str(e)}")
|
||||
neo4j_data["total_api_call"] = 0
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
@@ -543,8 +606,8 @@ async def dashboard_data(
|
||||
|
||||
# 获取RAG相关数据
|
||||
try:
|
||||
# total_memory: 使用 total_chunk(总chunk数)
|
||||
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||
# total_memory: 只统计用户知识库(permission_id='Memory')的chunk数
|
||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
@@ -556,10 +619,23 @@ async def dashboard_data(
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
rag_data["total_knowledge"] = total_kb
|
||||
|
||||
# total_api_call: 固定值
|
||||
rag_data["total_api_call"] = 1024
|
||||
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
|
||||
try:
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
rag_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
||||
rag_data["total_api_call"] = 0
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
|
||||
|
||||
133
api/app/controllers/memory_episodic_controller.py
Normal file
133
api/app/controllers/memory_episodic_controller.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
情景记忆相关的控制器
|
||||
包含情景记忆总览和详情查询接口
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_episodic_schema import (
|
||||
EpisodicMemoryOverviewRequest,
|
||||
EpisodicMemoryDetailsRequest,
|
||||
translate_episodic_type,
|
||||
)
|
||||
from app.services.memory_episodic_service import memory_episodic_service
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/episodic-memory",
|
||||
tags=["Episodic Memory"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/overview", response_model=ApiResponse)
|
||||
async def get_episodic_memory_overview_api(
|
||||
request: EpisodicMemoryOverviewRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取情景记忆总览
|
||||
|
||||
返回指定用户的所有情景记忆列表,包括标题和创建时间。
|
||||
支持通过时间范围、情景类型和标题关键词进行筛选。
|
||||
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆总览但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 验证参数
|
||||
valid_time_ranges = ["all", "today", "this_week", "this_month"]
|
||||
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
|
||||
|
||||
if request.time_range not in valid_time_ranges:
|
||||
return fail(BizCode.INVALID_PARAMETER, f"无效的时间范围参数,可选值:{', '.join(valid_time_ranges)}")
|
||||
|
||||
if request.episodic_type not in valid_episodic_types:
|
||||
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
|
||||
|
||||
# 处理 title_keyword(去除首尾空格)
|
||||
title_keyword = request.title_keyword.strip() if request.title_keyword else None
|
||||
|
||||
api_logger.info(
|
||||
f"情景记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}, time_range={request.time_range}, episodic_type={request.episodic_type}, "
|
||||
f"title_keyword={title_keyword}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await memory_episodic_service.get_episodic_memory_overview(
|
||||
request.end_user_id, request.time_range, request.episodic_type, title_keyword
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取情景记忆总览: end_user_id={request.end_user_id}, "
|
||||
f"total={result['total']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"情景记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "情景记忆总览查询失败", str(e))
|
||||
|
||||
|
||||
@router.post("/details", response_model=ApiResponse)
|
||||
async def get_episodic_memory_details_api(
|
||||
request: EpisodicMemoryDetailsRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取情景记忆详情
|
||||
|
||||
返回指定情景记忆的详细信息,包括涉及对象、情景类型、内容记录和情绪。
|
||||
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆详情但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"情景记忆详情查询请求: end_user_id={request.end_user_id}, summary_id={request.summary_id}, "
|
||||
f"user={current_user.username}, workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await memory_episodic_service.get_episodic_memory_details(
|
||||
end_user_id=request.end_user_id,
|
||||
summary_id=request.summary_id
|
||||
)
|
||||
|
||||
# 根据语言参数翻译 episodic_type
|
||||
language = get_language_from_header(language_type)
|
||||
if "episodic_type" in result:
|
||||
result["episodic_type"] = translate_episodic_type(result["episodic_type"], language)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 处理情景记忆不存在的情况
|
||||
api_logger.warning(f"情景记忆不存在: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "情景记忆不存在", str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"情景记忆详情查询失败: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "情景记忆详情查询失败", str(e))
|
||||
115
api/app/controllers/memory_explicit_controller.py
Normal file
115
api/app/controllers/memory_explicit_controller.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
显性记忆控制器
|
||||
|
||||
处理显性记忆相关的API接口,包括情景记忆和语义记忆的查询。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.memory_explicit_service import MemoryExplicitService
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_explicit_schema import (
|
||||
ExplicitMemoryOverviewRequest,
|
||||
ExplicitMemoryDetailsRequest,
|
||||
)
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Initialize service
|
||||
memory_explicit_service = MemoryExplicitService()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/explicit-memory",
|
||||
tags=["Explicit Memory"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/overview", response_model=ApiResponse)
|
||||
async def get_explicit_memory_overview_api(
|
||||
request: ExplicitMemoryOverviewRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取显性记忆总览
|
||||
|
||||
返回指定用户的所有显性记忆列表,包括标题、完整内容、创建时间和情绪信息。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆总览但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"显性记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await memory_explicit_service.get_explicit_memory_overview(
|
||||
request.end_user_id
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取显性记忆总览: end_user_id={request.end_user_id}, "
|
||||
f"total={result['total']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"显性记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
||||
|
||||
|
||||
@router.post("/details", response_model=ApiResponse)
|
||||
async def get_explicit_memory_details_api(
|
||||
request: ExplicitMemoryDetailsRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取显性记忆详情
|
||||
|
||||
根据 memory_id 返回情景记忆或语义记忆的详细信息。
|
||||
- 情景记忆:包括标题、内容、情绪、创建时间
|
||||
- 语义记忆:包括名称、核心定义、详细笔记、创建时间
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆详情但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"显性记忆详情查询请求: end_user_id={request.end_user_id}, memory_id={request.memory_id}, "
|
||||
f"user={current_user.username}, workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await memory_explicit_service.get_explicit_memory_details(
|
||||
end_user_id=request.end_user_id,
|
||||
memory_id=request.memory_id
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取显性记忆详情: end_user_id={request.end_user_id}, memory_id={request.memory_id}, "
|
||||
f"memory_type={result.get('memory_type')}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 处理记忆不存在的情况
|
||||
api_logger.warning(f"显性记忆不存在: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "显性记忆不存在", str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"显性记忆详情查询失败: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "显性记忆详情查询失败", str(e))
|
||||
367
api/app/controllers/memory_forget_controller.py
Normal file
367
api/app/controllers/memory_forget_controller.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
遗忘引擎控制器模块
|
||||
|
||||
本模块提供遗忘引擎的 REST API 接口,包括:
|
||||
1. 手动触发遗忘周期
|
||||
2. 获取和更新配置
|
||||
3. 获取统计信息
|
||||
4. 获取遗忘曲线数据
|
||||
|
||||
所有接口都需要用户认证,并自动关联到当前工作空间。
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ForgettingTriggerRequest,
|
||||
ForgettingConfigResponse,
|
||||
ForgettingConfigUpdateRequest,
|
||||
ForgettingStatsResponse,
|
||||
ForgettingReportResponse,
|
||||
ForgettingCurveRequest,
|
||||
ForgettingCurveResponse,
|
||||
ForgettingCurvePoint,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/forget-memory",
|
||||
tags=["Memory Forgetting Engine"],
|
||||
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
|
||||
)
|
||||
|
||||
# 初始化服务
|
||||
forget_service = MemoryForgetService()
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
@router.post("/trigger", response_model=ApiResponse)
|
||||
async def trigger_forgetting_cycle(
|
||||
payload: ForgettingTriggerRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
手动触发遗忘周期
|
||||
|
||||
执行一次完整的遗忘周期,识别并融合低激活值节点。
|
||||
|
||||
Args:
|
||||
payload: 触发请求参数
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含遗忘报告的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = payload.end_user_id # 从 payload 中获取 end_user_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试触发遗忘周期但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 通过 end_user_id 获取关联的 config_id
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id((config_id), db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求触发遗忘周期: "
|
||||
f"end_user_id={end_user_id}, config_id={config_id}, max_batch={payload.max_merge_batch_size}, "
|
||||
f"min_days={payload.min_days_since_access}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层执行遗忘周期
|
||||
report = await forget_service.trigger_forgetting_cycle(
|
||||
db=db,
|
||||
end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
|
||||
max_merge_batch_size=payload.max_merge_batch_size,
|
||||
min_days_since_access=payload.min_days_since_access,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingReportResponse(**report)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="遗忘周期执行成功")
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.warning(f"遗忘周期执行被拒绝: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "RuntimeError")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"触发遗忘周期失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "触发遗忘周期失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
async def read_forgetting_config(
|
||||
config_id: UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取遗忘引擎配置
|
||||
|
||||
读取指定配置ID的遗忘引擎参数。
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含配置信息的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 调用服务层读取配置
|
||||
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingConfigResponse(**config)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"配置不存在: config_id={config_id}, 错误: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"配置不存在: {config_id}", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"读取遗忘引擎配置失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse)
|
||||
async def update_forgetting_config(
|
||||
payload: ForgettingConfigUpdateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新遗忘引擎配置
|
||||
|
||||
更新指定配置ID的遗忘引擎参数。
|
||||
|
||||
Args:
|
||||
payload: 配置更新请求
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含更新结果的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id=resolve_config_id((payload.config_id), db)
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 构建更新字段字典(排除 None 值和 config_id)
|
||||
update_data = {
|
||||
key: value
|
||||
for key, value in payload.model_dump(exclude_none=True).items()
|
||||
if key != 'config_id'
|
||||
}
|
||||
|
||||
# 调用服务层更新配置
|
||||
config = forget_service.update_forgetting_config(
|
||||
db=db,
|
||||
config_id=payload.config_id,
|
||||
update_fields=update_data
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingConfigResponse(**config)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="更新成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"配置不存在: config_id={payload.config_id}, 错误: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"更新遗忘引擎配置失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ApiResponse)
|
||||
async def get_forgetting_stats(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取遗忘引擎统计信息
|
||||
|
||||
返回知识层节点统计、激活值分布等信息。
|
||||
|
||||
Args:
|
||||
end_user_id: 组ID(即 end_user_id,可选)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含统计信息的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
# 如果提供了 end_user_id,通过它获取 config_id
|
||||
config_id = None
|
||||
if end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
||||
f"end_user_id={end_user_id}, config_id={config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取统计信息
|
||||
stats = await forget_service.get_forgetting_stats(
|
||||
db=db,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingStatsResponse(**stats)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取遗忘引擎统计失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||
|
||||
|
||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||
async def get_forgetting_curve(
|
||||
request: ForgettingCurveRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取遗忘曲线数据
|
||||
|
||||
生成遗忘曲线数据用于可视化,模拟记忆激活值随时间的衰减。
|
||||
|
||||
Args:
|
||||
request: 遗忘曲线请求参数
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含遗忘曲线数据的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
request.config_id = resolve_config_id((request.config_id), db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘曲线: "
|
||||
f"importance_score={request.importance_score}, days={request.days}, config_id={request.config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层生成遗忘曲线
|
||||
result = await forget_service.get_forgetting_curve(
|
||||
db=db,
|
||||
importance_score=request.importance_score,
|
||||
days=request.days,
|
||||
config_id=request.config_id
|
||||
)
|
||||
|
||||
# 转换为响应格式
|
||||
curve_points = [
|
||||
ForgettingCurvePoint(**point)
|
||||
for point in result['curve_data']
|
||||
]
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingCurveResponse(
|
||||
curve_data=curve_points,
|
||||
config=result['config']
|
||||
)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取遗忘曲线失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘曲线失败", str(e))
|
||||
255
api/app/controllers/memory_perceptual_controller.py
Normal file
255
api/app/controllers/memory_perceptual_controller.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.models.memory_perceptual_model import PerceptualType
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
PerceptualFilter
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/perceptual",
|
||||
tags=["Perceptual Memory System"],
|
||||
dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve perceptual memory statistics for a user group.
|
||||
|
||||
Args:
|
||||
end_user_id: ID of the user group (usually end_user_id in this context)
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Response containing memory count statistics
|
||||
"""
|
||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
count_stats = service.get_memory_count(end_user_id)
|
||||
|
||||
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
|
||||
|
||||
return success(
|
||||
data=count_stats,
|
||||
msg="Memory statistics retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch memory statistics",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
|
||||
def get_last_visual_memory(
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent VISION-type memory for a user.
|
||||
|
||||
Args:
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest visual memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
visual_memory = service.get_latest_visual_memory(end_user_id)
|
||||
|
||||
if visual_memory is None:
|
||||
api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No visual memory available"
|
||||
)
|
||||
|
||||
api_logger.info(f"Latest visual memory retrieved successfully: file={visual_memory.get('file_name')}")
|
||||
|
||||
return success(
|
||||
data=visual_memory,
|
||||
msg="Latest visual memory retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest visual memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
|
||||
def get_last_memory_listen(
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent AUDIO-type memory for a user.
|
||||
|
||||
Args:
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest audio memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
audio_memory = service.get_latest_audio_memory(end_user_id)
|
||||
|
||||
if audio_memory is None:
|
||||
api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No audio memory available"
|
||||
)
|
||||
|
||||
api_logger.info(f"Latest audio memory retrieved successfully: file={audio_memory.get('file_name')}")
|
||||
|
||||
return success(
|
||||
data=audio_memory,
|
||||
msg="Latest audio memory retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest audio memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/last_text", response_model=ApiResponse)
|
||||
def get_last_text_memory(
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent TEXT-type memory for a user.
|
||||
|
||||
Args:
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest text memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
# 调用服务层获取最近的文本记忆
|
||||
service = MemoryPerceptualService(db)
|
||||
text_memory = service.get_latest_text_memory(end_user_id)
|
||||
|
||||
if text_memory is None:
|
||||
api_logger.info(f"No text memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No text memory available"
|
||||
)
|
||||
|
||||
api_logger.info(f"Latest text memory retrieved successfully: file={text_memory.get('file_name')}")
|
||||
|
||||
return success(
|
||||
data=text_memory,
|
||||
msg="Latest text memory retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest text memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/timeline", response_model=ApiResponse)
|
||||
def get_memory_time_line(
|
||||
end_user_id: uuid.UUID,
|
||||
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve a timeline of perceptual memories for a user group.
|
||||
|
||||
Args:
|
||||
end_user_id: ID of the user group
|
||||
perceptual_type: Optional filter for perceptual type
|
||||
page: Page number for pagination
|
||||
page_size: Number of items per page
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Timeline data of perceptual memories
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Fetching perceptual memory timeline: user={current_user.username}, "
|
||||
f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
query = PerceptualQuerySchema(
|
||||
filter=PerceptualFilter(type=perceptual_type),
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
service = MemoryPerceptualService(db)
|
||||
timeline_data = service.get_time_line(end_user_id, query)
|
||||
|
||||
api_logger.info(
|
||||
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
|
||||
f"returned={len(timeline_data.memories)}"
|
||||
)
|
||||
|
||||
return success(
|
||||
data=timeline_data.model_dump(),
|
||||
msg="Perceptual memory timeline retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch perceptual memory timeline",
|
||||
)
|
||||
@@ -1,16 +1,19 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||
ReflectionConfig,
|
||||
ReflectionEngine,
|
||||
ReflectionEngine, ReflectionRange, ReflectionBaseline,
|
||||
)
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.services.memory_reflection_service import (
|
||||
@@ -19,10 +22,12 @@ from app.services.memory_reflection_service import (
|
||||
)
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -39,64 +44,40 @@ async def save_reflection_config(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
|
||||
|
||||
|
||||
try:
|
||||
config_id = request.config_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
if not config_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="缺少必需参数: config_id"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
update_params = {
|
||||
"enable_self_reflexion": request.reflection_enabled,
|
||||
"iteration_period": request.reflection_period_in_hours,
|
||||
"reflexion_range": request.reflexion_range,
|
||||
"baseline": request.baseline,
|
||||
"reflection_model_id": request.reflection_model_id,
|
||||
"memory_verify": request.memory_verify,
|
||||
"quality_assessment": request.quality_assessment,
|
||||
}
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
enable_self_reflexion=request.reflection_enabled,
|
||||
iteration_period=request.reflection_period_in_hours,
|
||||
reflexion_range=request.reflexion_range,
|
||||
baseline=request.baseline,
|
||||
reflection_model_id=request.reflection_model_id,
|
||||
memory_verify=request.memory_verify,
|
||||
quality_assessment=request.quality_assessment
|
||||
)
|
||||
|
||||
|
||||
|
||||
query, params = DataConfigRepository.build_update_reflection(config_id, **update_params)
|
||||
|
||||
result = db.execute(text(query), params)
|
||||
if result.rowcount == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 查询更新后的配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"更新后未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}")
|
||||
db.refresh(memory_config)
|
||||
|
||||
reflection_result={
|
||||
"config_id": result.config_id,
|
||||
"enable_self_reflexion": result.enable_self_reflexion,
|
||||
"iteration_period": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
"baseline": result.baseline,
|
||||
"reflection_model_id": result.reflection_model_id,
|
||||
"memory_verify": result.memory_verify,
|
||||
"quality_assessment": result.quality_assessment,
|
||||
"user_id": result.user_id}
|
||||
"config_id": memory_config.config_id,
|
||||
"enable_self_reflexion": memory_config.enable_self_reflexion,
|
||||
"iteration_period": memory_config.iteration_period,
|
||||
"reflexion_range": memory_config.reflexion_range,
|
||||
"baseline": memory_config.baseline,
|
||||
"reflection_model_id": memory_config.reflection_model_id,
|
||||
"memory_verify": memory_config.memory_verify,
|
||||
"quality_assessment": memory_config.quality_assessment}
|
||||
|
||||
return success(data=reflection_result, msg="反思配置成功")
|
||||
|
||||
@@ -116,48 +97,76 @@ async def save_reflection_config(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reflection")
|
||||
@router.get("/reflection")
|
||||
async def start_workspace_reflection(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
"""启动工作空间中所有匹配应用的反思功能"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as query_db:
|
||||
service = WorkspaceAppService(query_db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
|
||||
reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['data_configs'] == []:
|
||||
# 跳过没有配置的应用
|
||||
if not data['memory_configs']:
|
||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||
continue
|
||||
|
||||
|
||||
releases = data['releases']
|
||||
data_configs = data['data_configs']
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_reflection_from_data(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
# 为每个配置和用户组合执行反思
|
||||
for config in memory_configs:
|
||||
config_id_str = str(config['config_id'])
|
||||
|
||||
# 找到匹配此配置的所有release
|
||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||
|
||||
if not matching_releases:
|
||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||
continue
|
||||
|
||||
# 为每个用户执行反思 - 使用独立的数据库会话
|
||||
for user in end_users:
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||
|
||||
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
|
||||
with get_db_context() as user_db:
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(user_db)
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": {
|
||||
"status": "错误",
|
||||
"message": f"反思失败: {str(e)}"
|
||||
}
|
||||
})
|
||||
|
||||
return success(data=reflection_results, msg="反思配置成功")
|
||||
|
||||
@@ -171,35 +180,27 @@ async def start_workspace_reflection(
|
||||
|
||||
@router.get("/reflection/configs")
|
||||
async def start_reflection_configs(
|
||||
config_id: int,
|
||||
config_id: uuid.UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询data_config表中的反思配置信息"""
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
try:
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
memory_config_id = resolve_config_id(result.config_id, db)
|
||||
# 构建返回数据
|
||||
reflection_config = {
|
||||
"config_id": result.config_id,
|
||||
"config_id": memory_config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
"reflection_period_in_hours": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
"baseline": result.baseline,
|
||||
"reflection_model_id": result.reflection_model_id,
|
||||
"memory_verify": result.memory_verify,
|
||||
"quality_assessment": result.quality_assessment,
|
||||
"user_id": result.user_id
|
||||
"quality_assessment": result.quality_assessment
|
||||
}
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="反思配置查询成功")
|
||||
@@ -217,19 +218,19 @@ async def start_reflection_configs(
|
||||
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: int,
|
||||
language_type: str = "zh",
|
||||
config_id: UUID|int,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -242,7 +243,7 @@ async def reflection_run(
|
||||
model_id = result.reflection_model_id
|
||||
if model_id:
|
||||
try:
|
||||
ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
ModelConfigService.get_model_by_id(db=db, model_id=uuid.UUID(model_id))
|
||||
api_logger.info(f"模型ID验证成功: {model_id}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}")
|
||||
@@ -252,8 +253,8 @@ async def reflection_run(
|
||||
config = ReflectionConfig(
|
||||
enabled=result.enable_self_reflexion,
|
||||
iteration_period=result.iteration_period,
|
||||
reflexion_range=result.reflexion_range,
|
||||
baseline=result.baseline,
|
||||
reflexion_range=ReflectionRange(result.reflexion_range),
|
||||
baseline=ReflectionBaseline(result.baseline),
|
||||
output_example='',
|
||||
memory_verify=result.memory_verify,
|
||||
quality_assessment=result.quality_assessment,
|
||||
|
||||
50
api/app/controllers/memory_short_term_controller.py
Normal file
50
api/app/controllers/memory_short_term_controller.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.services.memory_short_service import LongService, ShortService
|
||||
from app.services.memory_storage_service import search_entity
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/short",
|
||||
tags=["Memory"],
|
||||
)
|
||||
@router.get("/short_term")
|
||||
async def short_term_configs(
|
||||
end_user_id: str,
|
||||
language_type:str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id, db)
|
||||
short_result=short_term.get_short_databasets()
|
||||
short_count=short_term.get_short_count()
|
||||
|
||||
long_term=LongService(end_user_id, db)
|
||||
long_result=long_term.get_long_databasets()
|
||||
|
||||
entity_result = await search_entity(end_user_id)
|
||||
result = {
|
||||
'short_term': short_result,
|
||||
'long_term': long_result,
|
||||
'entity': entity_result.get('num', 0),
|
||||
"retrieval_number":short_count,
|
||||
"long_term_number":len(long_result)
|
||||
}
|
||||
|
||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||
@@ -1,29 +1,23 @@
|
||||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.user_model import User
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
)
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
GenerateCacheRequest,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_storage_service import (
|
||||
@@ -38,13 +32,14 @@ from app.services.memory_storage_service import (
|
||||
search_dialogue,
|
||||
search_edges,
|
||||
search_entity,
|
||||
search_entity_graph,
|
||||
search_statement,
|
||||
)
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -80,68 +75,9 @@ async def get_storage_info(
|
||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||
|
||||
|
||||
# --- DB connection dependency ---
|
||||
_CONN: Optional[object] = None
|
||||
|
||||
|
||||
"""PostgreSQL 连接生成与管理(使用 psycopg2)。"""
|
||||
# 这个可以转移,可能是已经有的
|
||||
# PostgreSQL 数据库连接
|
||||
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
|
||||
host = os.getenv("DB_HOST")
|
||||
user = os.getenv("DB_USER")
|
||||
password = os.getenv("DB_PASSWORD")
|
||||
database = os.getenv("DB_NAME")
|
||||
port_str = os.getenv("DB_PORT")
|
||||
try:
|
||||
import psycopg2 # type: ignore
|
||||
port = int(port_str) if port_str else 5432
|
||||
conn = psycopg2.connect(
|
||||
host=host or "localhost",
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
dbname=database,
|
||||
)
|
||||
# 设置自动提交,避免显式事务管理
|
||||
conn.autocommit = True
|
||||
# 设置会话时区为中国标准时间(Asia/Shanghai),便于直接以本地时区展示
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
|
||||
cur.close()
|
||||
except Exception:
|
||||
# 时区设置失败不影响连接,仅记录但不抛出
|
||||
pass
|
||||
return conn
|
||||
except Exception as e:
|
||||
try:
|
||||
print(f"[PostgreSQL] 连接失败: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
|
||||
global _CONN
|
||||
if _CONN is None:
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN
|
||||
|
||||
|
||||
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
|
||||
"""Close and recreate the global DB connection."""
|
||||
global _CONN
|
||||
try:
|
||||
if _CONN:
|
||||
try:
|
||||
_CONN.close()
|
||||
except Exception:
|
||||
pass
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN is not None
|
||||
except Exception:
|
||||
_CONN = None
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@@ -149,9 +85,9 @@ def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||
@@ -164,46 +100,125 @@ def create_config(
|
||||
svc = DataConfigService(db)
|
||||
result = svc.create(payload)
|
||||
return success(data=result, msg="创建成功")
|
||||
except ValueError as e:
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
|
||||
config_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {err_str}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||
except Exception as e:
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
|
||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||
|
||||
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: str,
|
||||
config_id: UUID|int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
) -> dict:
|
||||
"""删除记忆配置(带终端用户保护)
|
||||
|
||||
- 检查是否为默认配置,默认配置不允许删除
|
||||
- 检查是否有终端用户连接到该配置
|
||||
- 如果有连接且 force=False,返回警告
|
||||
- 如果 force=True,清除终端用户引用后删除配置
|
||||
|
||||
Query Parameters:
|
||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
|
||||
f"config_id={config_id}, force={force}"
|
||||
)
|
||||
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = svc.delete(ConfigParamsDelete(config_id=config_id))
|
||||
return success(data=result, msg="删除成功")
|
||||
# 使用带保护的删除服务
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
result = config_service.delete_config(config_id=config_id, force=force)
|
||||
|
||||
if result["status"] == "error":
|
||||
api_logger.warning(
|
||||
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.FORBIDDEN,
|
||||
msg=result["message"],
|
||||
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
|
||||
)
|
||||
|
||||
if result["status"] == "warning":
|
||||
api_logger.warning(
|
||||
f"记忆配置正在使用,无法删除: config_id={config_id}, "
|
||||
f"connected_count={result['connected_count']}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.RESOURCE_IN_USE,
|
||||
msg=result["message"],
|
||||
data={
|
||||
"connected_count": result["connected_count"],
|
||||
"force_required": result["force_required"]
|
||||
}
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"记忆配置删除成功: config_id={config_id}, "
|
||||
f"affected_users={result['affected_users']}"
|
||||
)
|
||||
return success(
|
||||
msg=result["message"],
|
||||
data={"affected_users": result["affected_users"]}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Delete config failed: {str(e)}")
|
||||
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 校验至少有一个字段需要更新
|
||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -219,9 +234,9 @@ def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||
@@ -238,37 +253,17 @@ def update_config_extracted(
|
||||
|
||||
|
||||
# --- Forget config params ---
|
||||
@router.post("/update_config_forget", response_model=ApiResponse) # 更新遗忘引擎配置参数(固定路径)
|
||||
def update_config_forget(
|
||||
payload: ConfigUpdateForget,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = svc.update_forget(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Update config forget failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
|
||||
|
||||
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
||||
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
config_id: str,
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||
@@ -283,33 +278,11 @@ def read_config_extracted(
|
||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||
|
||||
@router.get("/read_config_forget", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_forget(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = svc.get_forget(ConfigKey(config_id=config_id))
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read config forget failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -331,16 +304,22 @@ def read_all_config(
|
||||
@router.post("/pilot_run", response_model=None)
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}, "
|
||||
f"custom_text_length={len(payload.custom_text) if payload.custom_text else 0}"
|
||||
)
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
svc = DataConfigService(db)
|
||||
return StreamingResponse(
|
||||
svc.pilot_run_stream(payload),
|
||||
svc.pilot_run_stream(payload, language=language),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
@@ -349,9 +328,8 @@ async def pilot_run(
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
||||
"""
|
||||
|
||||
# ==================== Search & Analytics ====================
|
||||
|
||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||
async def get_kb_type_distribution(
|
||||
@@ -464,21 +442,7 @@ async def search_entity_edges(
|
||||
api_logger.error(f"Search edges failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||
|
||||
@router.get("/search/entity_graph", response_model=ApiResponse)
|
||||
async def search_for_entity_graph(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
搜索所有实体之间的关系网络
|
||||
"""
|
||||
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_entity_graph(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search entity graph failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||
@@ -487,39 +451,106 @@ async def get_hot_memory_tags_api(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}")
|
||||
"""
|
||||
获取热门记忆标签(带Redis缓存)
|
||||
|
||||
缓存策略:
|
||||
- 缓存键:workspace_id + limit
|
||||
- 过期时间:5分钟(300秒)
|
||||
- 缓存命中:~50ms
|
||||
- 缓存未命中:~600-800ms(取决于LLM速度)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 构建缓存键
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
|
||||
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
|
||||
|
||||
try:
|
||||
# 尝试从Redis缓存获取
|
||||
import json
|
||||
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
cached_result = await aio_redis_get(cache_key)
|
||||
if cached_result:
|
||||
api_logger.info(f"Cache hit for key: {cache_key}")
|
||||
try:
|
||||
data = json.loads(cached_result)
|
||||
return success(data=data, msg="查询成功(缓存)")
|
||||
except json.JSONDecodeError:
|
||||
api_logger.warning(f"Failed to parse cached data, will refresh")
|
||||
|
||||
# 缓存未命中,执行查询
|
||||
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
|
||||
result = await analytics_hot_memory_tags(db, current_user, limit)
|
||||
|
||||
# 写入缓存(过期时间:5分钟)
|
||||
# 注意:result是列表,需要转换为JSON字符串
|
||||
try:
|
||||
cache_data = json.dumps(result, ensure_ascii=False)
|
||||
await aio_redis_set(cache_key, cache_data, expire=300)
|
||||
api_logger.info(f"Cached result for key: {cache_key}")
|
||||
except Exception as cache_error:
|
||||
# 缓存写入失败不影响主流程
|
||||
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
|
||||
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||
|
||||
|
||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||
async def clear_hot_memory_tags_cache(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
清除热门标签缓存
|
||||
|
||||
用于:
|
||||
- 手动刷新数据
|
||||
- 调试和测试
|
||||
- 数据更新后立即生效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
|
||||
|
||||
try:
|
||||
from app.aioRedis import aio_redis_delete
|
||||
|
||||
# 清除所有limit的缓存(常见的limit值)
|
||||
cleared_count = 0
|
||||
for limit in [5, 10, 15, 20, 30, 50]:
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
result = await aio_redis_delete(cache_key)
|
||||
if result:
|
||||
cleared_count += 1
|
||||
api_logger.info(f"Cleared cache for key: {cache_key}")
|
||||
|
||||
return success(
|
||||
data={"cleared_count": cleared_count},
|
||||
msg=f"成功清除 {cleared_count} 个缓存"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Clear cache failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info("Recent activity stats requested")
|
||||
) -> dict:
|
||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await analytics_recent_activity_stats()
|
||||
result = await analytics_recent_activity_stats(workspace_id=workspace_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/self_reflexion")
|
||||
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
|
||||
"""
|
||||
自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。
|
||||
|
||||
Args:
|
||||
None
|
||||
Returns:
|
||||
自我反思结果。
|
||||
"""
|
||||
return await self_reflexion(host_id)
|
||||
|
||||
|
||||
134
api/app/controllers/memory_working_controller.py
Normal file
134
api/app/controllers/memory_working_controller.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/work",
|
||||
tags=["Working Memory System"],
|
||||
dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||
def get_conversations(
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve all conversations for the current user in a specific group.
|
||||
|
||||
Args:
|
||||
end_user_id (UUID): The group identifier.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains a list of conversation IDs.
|
||||
|
||||
Notes:
|
||||
- Initializes the ConversationService with the current DB session.
|
||||
- Returns only conversation IDs for lightweight response.
|
||||
- Logs can be added to trace requests in production.
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
end_user_id
|
||||
)
|
||||
return success(data=[
|
||||
{
|
||||
"id": conversation.id,
|
||||
"title": conversation.title
|
||||
} for conversation in conversations
|
||||
], msg="get conversations success")
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||
def get_messages(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve the message history for a specific conversation.
|
||||
|
||||
Args:
|
||||
conversation_id (UUID): The ID of the conversation to fetch messages from.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the list of messages in the conversation.
|
||||
|
||||
Notes:
|
||||
- Uses ConversationService to fetch messages.
|
||||
- Consider paginating results if message history is large.
|
||||
- Logging can be added for audit and debugging.
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
messages_obj = conversation_service.get_messages(
|
||||
conversation_id,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
"created_at": int(message.created_at.timestamp() * 1000),
|
||||
}
|
||||
for message in messages_obj
|
||||
]
|
||||
return success(data=messages, msg="get conversation history success")
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/detail", response_model=ApiResponse)
|
||||
async def get_conversation_detail(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve detailed information about a specific conversation.
|
||||
|
||||
This endpoint will fetch the conversation detail for the user. If the detail
|
||||
does not exist or is outdated, it will trigger the LLM to generate a new summary.
|
||||
|
||||
Args:
|
||||
conversation_id (UUID): The ID of the conversation.
|
||||
current_user (User, optional): The authenticated user making the request.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the conversation detail serialized as a dictionary.
|
||||
|
||||
Notes:
|
||||
- Uses async ConversationService to fetch or generate the conversation detail.
|
||||
- Handles workspace and user-specific context automatically.
|
||||
- Logging and exception handling should be implemented for production monitoring.
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
detail = await conversation_service.get_conversation_detail(
|
||||
user=current_user,
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=current_user.current_workspace_id
|
||||
)
|
||||
return success(data=detail.model_dump(), msg="get conversation detail success")
|
||||
@@ -3,15 +3,17 @@ from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
|
||||
from app.models.user_model import User
|
||||
from app.repositories.model_repository import ModelConfigRepository
|
||||
from app.schemas import model_schema
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -24,24 +26,83 @@ router = APIRouter(
|
||||
|
||||
@router.get("/type", response_model=ApiResponse)
|
||||
def get_model_types():
|
||||
|
||||
return success(msg="获取模型类型成功", data=list(ModelType))
|
||||
|
||||
|
||||
@router.get("/provider", response_model=ApiResponse)
|
||||
def get_model_providers():
|
||||
return success(msg="获取模型提供商成功", data=list(ModelProvider))
|
||||
providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||
return success(msg="获取模型提供商成功", data=providers)
|
||||
|
||||
@router.get("/strategy", response_model=ApiResponse)
|
||||
def get_model_strategies():
|
||||
return success(msg="获取模型策略成功", data=list(LoadBalanceStrategy))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取模型配置列表
|
||||
|
||||
支持多个 type 参数:
|
||||
- 单个:?type=LLM
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(
|
||||
f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = []
|
||||
if type is not None:
|
||||
flat_type = []
|
||||
for item in type:
|
||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||
flat_type.extend(split_items)
|
||||
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/new", response_model=ApiResponse)
|
||||
def get_model_list_new(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
is_composite: Optional[bool] = Query(None, description="组合模型筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -53,36 +114,127 @@ def get_model_list(
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = None
|
||||
if type:
|
||||
type_values = [t.strip() for t in type.split(',')]
|
||||
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
|
||||
type_list = []
|
||||
if type is not None:
|
||||
flat_type = []
|
||||
for item in type:
|
||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||
flat_type.extend(split_items)
|
||||
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
api_logger.info(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQueryNew(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
is_composite=is_composite,
|
||||
search=search
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.model_dump()}")
|
||||
result = ModelConfigService.get_model_list_new(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置列表获取成功: 分组数={len(result)}, 总模型数={sum(len(item['models']) for item in result)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/model_plaza", response_model=ApiResponse)
|
||||
def get_model_plaza_list(
|
||||
type: Optional[ModelType] = Query(None, description="模型类型"),
|
||||
provider: Optional[ModelProvider] = Query(None, description="供应商"),
|
||||
is_official: Optional[bool] = Query(None, description="是否官方模型"),
|
||||
is_deprecated: Optional[bool] = Query(None, description="是否弃用"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""模型广场查询接口(按供应商分组)"""
|
||||
|
||||
query = model_schema.ModelBaseQuery(
|
||||
type=type,
|
||||
provider=provider,
|
||||
is_official=is_official,
|
||||
is_deprecated=is_deprecated,
|
||||
search=search
|
||||
)
|
||||
result = ModelBaseService.get_model_base_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
return success(data=result, msg="模型广场列表获取成功")
|
||||
|
||||
|
||||
@router.get("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def get_model_base_by_id(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取基础模型详情"""
|
||||
|
||||
result = ModelBaseService.get_model_base_by_id(db=db, model_base_id=model_base_id)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型获取成功")
|
||||
|
||||
|
||||
@router.post("/model_plaza", response_model=ApiResponse)
|
||||
def create_model_base(
|
||||
data: model_schema.ModelBaseCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建基础模型"""
|
||||
|
||||
result = ModelBaseService.create_model_base(db=db, data=data)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型创建成功")
|
||||
|
||||
|
||||
@router.put("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def update_model_base(
|
||||
model_base_id: uuid.UUID,
|
||||
data: model_schema.ModelBaseUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新基础模型"""
|
||||
|
||||
# 不允许更改type类型
|
||||
if data.type is not None or data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
|
||||
|
||||
|
||||
@router.delete("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def delete_model_base(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除基础模型"""
|
||||
|
||||
ModelBaseService.delete_model_base(db=db, model_base_id=model_base_id)
|
||||
return success(msg="基础模型删除成功")
|
||||
|
||||
|
||||
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
|
||||
def add_model_from_plaza(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""从模型广场添加模型到模型列表"""
|
||||
|
||||
result = ModelBaseService.add_model_from_plaza(db=db, model_base_id=model_base_id, tenant_id=current_user.tenant_id)
|
||||
return success(data=model_schema.ModelConfig.model_validate(result), msg="模型添加成功")
|
||||
|
||||
|
||||
@router.get("/{model_id}", response_model=ApiResponse)
|
||||
def get_model_by_id(
|
||||
model_id: uuid.UUID,
|
||||
@@ -138,6 +290,73 @@ async def create_model(
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/composite", response_model=ApiResponse)
|
||||
async def create_composite_model(
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建组合模型
|
||||
|
||||
- 绑定一个或多个现有的 API Key
|
||||
- 所有 API Key 必须来自非组合模型
|
||||
- 所有 API Key 关联的模型类型必须与组合模型类型一致
|
||||
"""
|
||||
api_logger.info(f"创建组合模型请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
result_orm = await ModelConfigService.create_composite_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型创建成功: {result_orm.name} (ID: {result_orm.id})")
|
||||
|
||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return success(data=result, msg="组合模型创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建组合模型失败: {model_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||
async def update_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新组合模型"""
|
||||
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
if model_data.type is not None:
|
||||
raise BusinessException("不允许更改模型类型", BizCode.INVALID_PARAMETER)
|
||||
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
||||
|
||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return success(data=result, msg="组合模型更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新组合模型失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete("/composite/{model_id}", response_model=ApiResponse)
|
||||
def delete_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除组合模型"""
|
||||
api_logger.info(f"删除组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型删除成功: model_id={model_id}")
|
||||
return success(msg="组合模型删除成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"删除组合模型失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=ApiResponse)
|
||||
def update_model(
|
||||
model_id: uuid.UUID,
|
||||
@@ -149,6 +368,14 @@ def update_model(
|
||||
更新模型配置
|
||||
"""
|
||||
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
if model_data.type is not None or model_data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
if model_data.is_active:
|
||||
active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active)
|
||||
if not active_keys:
|
||||
raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||
@@ -214,6 +441,55 @@ def get_model_api_keys(
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/provider/apikeys", response_model=ApiResponse)
|
||||
async def create_model_api_key_by_provider(
|
||||
api_key_data: model_schema.ModelApiKeyCreateByProvider,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
根据供应商为所有匹配的模型创建API Key
|
||||
"""
|
||||
api_logger.info(f"创建API Key请求: provider={api_key_data.provider}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 根据tenant_id和provider筛选model_config_id列表
|
||||
model_config_ids = api_key_data.model_config_ids
|
||||
if not model_config_ids:
|
||||
model_config_ids = ModelConfigRepository.get_model_config_ids_by_provider(
|
||||
db=db,
|
||||
tenant_id=current_user.tenant_id,
|
||||
provider=api_key_data.provider
|
||||
)
|
||||
|
||||
if not model_config_ids:
|
||||
raise BusinessException(f"未找到供应商 {api_key_data.provider} 的模型配置", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
# 构造schema并调用service
|
||||
create_data = model_schema.ModelApiKeyCreateByProvider(
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
description=api_key_data.description,
|
||||
config=api_key_data.config,
|
||||
is_active=api_key_data.is_active,
|
||||
priority=api_key_data.priority,
|
||||
model_config_ids=model_config_ids,
|
||||
capability=api_key_data.capability,
|
||||
is_omni=api_key_data.is_omni
|
||||
)
|
||||
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||
|
||||
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
|
||||
# result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
|
||||
result = "API Key已存在" if len(created_keys) == 0 and len(failed_models) == 0 else \
|
||||
f"成功为 {len(created_keys)} 个模型创建API Key, 失败模型列表{failed_models}"
|
||||
return success(data=result, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建API Key失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_model_api_key(
|
||||
model_id: uuid.UUID,
|
||||
@@ -228,11 +504,12 @@ async def create_model_api_key(
|
||||
|
||||
try:
|
||||
# 设置模型配置ID
|
||||
api_key_data.model_config_id = model_id
|
||||
api_key_data.model_config_ids = [model_id]
|
||||
|
||||
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
|
||||
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
|
||||
result_orm = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||
api_logger.info(f"模型API Key创建成功: {result_orm.model_name} (ID: {result_orm.id})")
|
||||
result = model_schema.ModelApiKey.model_validate(result_orm)
|
||||
return success(data=result, msg="模型API Key创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
|
||||
@@ -334,5 +611,3 @@ async def validate_model_config(
|
||||
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ def get_multi_agent_configs(
|
||||
"app_id": str(app_id),
|
||||
"default_model_config_id": None,
|
||||
"model_parameters": None,
|
||||
"orchestration_mode": "conditional",
|
||||
"orchestration_mode": "supervisor",
|
||||
"sub_agents": [],
|
||||
"routing_rules": [],
|
||||
"execution_config": {
|
||||
|
||||
1184
api/app/controllers/ontology_controller.py
Normal file
1184
api/app/controllers/ontology_controller.py
Normal file
File diff suppressed because it is too large
Load Diff
663
api/app/controllers/ontology_secondary_routes.py
Normal file
663
api/app/controllers/ontology_secondary_routes.py
Normal file
@@ -0,0 +1,663 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体场景和类型路由(续)
|
||||
|
||||
由于主Controller文件较大,将剩余路由放在此文件中。
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, Header
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger, get_business_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.ontology_schemas import (
|
||||
SceneResponse,
|
||||
SceneListResponse,
|
||||
PaginationInfo,
|
||||
ClassCreateRequest,
|
||||
ClassUpdateRequest,
|
||||
ClassResponse,
|
||||
ClassListResponse,
|
||||
ClassBatchCreateResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.ontology_service import OntologyService
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
|
||||
api_logger = get_api_logger()
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
||||
"""获取OntologyService实例(不需要LLM)
|
||||
|
||||
场景和类型管理不需要LLM,创建一个dummy配置。
|
||||
"""
|
||||
dummy_config = RedBearModelConfig(
|
||||
model_name="dummy",
|
||||
provider="openai",
|
||||
api_key="dummy",
|
||||
base_url="https://api.openai.com/v1"
|
||||
)
|
||||
llm_client = OpenAIClient(model_config=dummy_config)
|
||||
return OntologyService(llm_client=llm_client, db=db)
|
||||
|
||||
|
||||
# 这些函数将被导入到主Controller中
|
||||
|
||||
async def scenes_handler(
|
||||
workspace_id: Optional[str] = None,
|
||||
scene_name: Optional[str] = None,
|
||||
page: Optional[int] = None,
|
||||
pagesize: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
|
||||
|
||||
当提供 scene_name 参数时,进行模糊搜索(不分页);
|
||||
当不提供 scene_name 参数时,返回所有场景(支持分页)。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||
pagesize: 每页数量(可选,仅在全量查询时有效)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if scene_name else "list"
|
||||
api_logger.info(
|
||||
f"Scene {operation} requested by user {current_user.id}, "
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 确定工作空间ID
|
||||
if workspace_id:
|
||||
try:
|
||||
ws_uuid = UUID(workspace_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
|
||||
else:
|
||||
ws_uuid = current_user.current_workspace_id
|
||||
if not ws_uuid:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 根据是否提供 scene_name 决定查询方式
|
||||
if scene_name and scene_name.strip():
|
||||
# 验证分页参数(模糊搜索也支持分页)
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
# 模糊搜索场景(支持分页)
|
||||
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
|
||||
total = len(scenes)
|
||||
|
||||
# 如果提供了分页参数,进行分页处理
|
||||
if page is not None and pagesize is not None:
|
||||
start_idx = (page - 1) * pagesize
|
||||
end_idx = start_idx + pagesize
|
||||
scenes = scenes[start_idx:end_idx]
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(
|
||||
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
|
||||
f"in workspace {ws_uuid}, total={total}"
|
||||
)
|
||||
else:
|
||||
# 获取所有场景(支持分页)
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
# ==================== 本体类型管理接口 ====================
|
||||
|
||||
async def create_class_handler(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = None
|
||||
):
|
||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||
|
||||
# 根据列表长度判断是单个还是批量
|
||||
count = len(request.classes)
|
||||
mode = "single" if count == 1 else "batch"
|
||||
|
||||
api_logger.info(
|
||||
f"Class creation ({mode}) requested by user {current_user.id}, "
|
||||
f"scene_id={request.scene_id}, count={count}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 准备类型数据
|
||||
classes_data = [
|
||||
{
|
||||
"class_name": item.class_name,
|
||||
"class_description": item.class_description
|
||||
}
|
||||
for item in request.classes
|
||||
]
|
||||
|
||||
if count == 1:
|
||||
# 单个创建 - 先检查重名
|
||||
class_data = classes_data[0]
|
||||
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
|
||||
if existing:
|
||||
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
|
||||
ontology_class = service.create_class(
|
||||
scene_id=request.scene_id,
|
||||
class_name=class_data["class_name"],
|
||||
class_description=class_data["class_description"],
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建单个响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
|
||||
|
||||
else:
|
||||
# 批量创建
|
||||
created_classes, errors = service.create_classes_batch(
|
||||
scene_id=request.scene_id,
|
||||
classes=classes_data,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建批量响应
|
||||
items = []
|
||||
for ontology_class in created_classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassBatchCreateResponse(
|
||||
total=len(classes_data),
|
||||
success_count=len(created_classes),
|
||||
failed_count=len(errors),
|
||||
items=items,
|
||||
errors=errors if errors else None
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Batch class creation completed: "
|
||||
f"success={len(created_classes)}, failed={len(errors)}"
|
||||
)
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||
|
||||
except ValueError as e:
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
|
||||
class_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.warning(f"Validation error in class creation: {err_str}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
|
||||
|
||||
except RuntimeError as e:
|
||||
err_str = str(e)
|
||||
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
|
||||
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
class_name = request.classes[0].class_name if request.classes else ""
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
|
||||
|
||||
async def update_class_handler(
|
||||
class_id: str,
|
||||
request: ClassUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新本体类型"""
|
||||
api_logger.info(
|
||||
f"Class update requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认类型
|
||||
class_repo = OntologyClassRepository(db)
|
||||
ontology_class = class_repo.get_by_id(class_uuid)
|
||||
if ontology_class and ontology_class.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试修改系统默认类型: user_id={current_user.id}, "
|
||||
f"class_id={class_id}, class_name={ontology_class.class_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认类型不可修改",
|
||||
"该类型为系统预设类型,不允许修改"
|
||||
)
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 更新类型
|
||||
ontology_class = service.update_class(
|
||||
class_id=class_uuid,
|
||||
class_name=request.class_name,
|
||||
class_description=request.class_description,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class updated successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class update: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
|
||||
async def delete_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除本体类型"""
|
||||
api_logger.info(
|
||||
f"Class deletion requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认类型
|
||||
class_repo = OntologyClassRepository(db)
|
||||
ontology_class = class_repo.get_by_id(class_uuid)
|
||||
if ontology_class and ontology_class.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试删除系统默认类型: user_id={current_user.id}, "
|
||||
f"class_id={class_id}, class_name={ontology_class.class_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认类型不可删除",
|
||||
"该类型为系统预设类型,不允许删除"
|
||||
)
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 删除类型
|
||||
success_flag = service.delete_class(
|
||||
class_id=class_uuid,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
api_logger.info(f"Class deleted successfully: {class_id}")
|
||||
|
||||
return success(data={"deleted": success_flag}, msg="类型删除成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class deletion: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
|
||||
async def get_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取单个本体类型"""
|
||||
api_logger.info(
|
||||
f"Get class requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取类型(会抛出ValueError如果不存在)
|
||||
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class retrieved successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 类型不存在或无权限访问
|
||||
api_logger.warning(f"Validation error in get class: {str(e)}")
|
||||
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
async def classes_handler(
|
||||
scene_id: str,
|
||||
class_name: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取类型列表(支持模糊搜索和全量查询)
|
||||
|
||||
当提供 class_name 参数时,进行模糊搜索;
|
||||
当不提供 class_name 参数时,返回场景下的所有类型。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID(必填)
|
||||
class_name: 类型名称关键词(可选,支持模糊匹配)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if class_name else "list"
|
||||
api_logger.info(
|
||||
f"Class {operation} requested by user {current_user.id}, "
|
||||
f"keyword={class_name}, scene_id={scene_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
scene_uuid = UUID(scene_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid scene_id format: {scene_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取场景信息
|
||||
scene = service.get_scene_by_id(scene_uuid, workspace_id)
|
||||
if not scene:
|
||||
api_logger.warning(f"Scene not found: {scene_id}")
|
||||
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
|
||||
|
||||
# 根据是否提供 class_name 决定查询方式
|
||||
if class_name and class_name.strip():
|
||||
# 模糊搜索类型
|
||||
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
|
||||
else:
|
||||
# 获取所有类型
|
||||
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for ontology_class in classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassListResponse(
|
||||
total=len(items),
|
||||
scene_id=scene_uuid,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
is_system_default=scene.is_system_default,
|
||||
items=items
|
||||
)
|
||||
|
||||
if class_name:
|
||||
api_logger.info(
|
||||
f"Class search completed: found {len(items)} classes matching '{class_name}' "
|
||||
f"in scene {scene_id}"
|
||||
)
|
||||
else:
|
||||
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
@@ -1,14 +1,20 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
|
||||
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
|
||||
from app.schemas.prompt_optimizer_schema import (
|
||||
PromptOptMessage,
|
||||
CreateSessionResponse,
|
||||
SessionHistoryResponse,
|
||||
SessionMessage,
|
||||
PromptSaveRequest
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.prompt_optimizer_service import PromptOptimizerService
|
||||
|
||||
@@ -70,12 +76,12 @@ def get_prompt_session(
|
||||
SessionMessage(role=role, content=content)
|
||||
for role, content in history
|
||||
]
|
||||
|
||||
|
||||
result = SessionHistoryResponse(
|
||||
session_id=session_id,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@@ -104,35 +110,139 @@ async def get_prompt_opt(
|
||||
ApiResponse: Contains the optimized prompt, description, and a list of variables.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
service.create_message(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
role=RoleType.USER,
|
||||
content=data.message
|
||||
|
||||
async def event_generator():
|
||||
yield "event:start\ndata: {}\n\n"
|
||||
try:
|
||||
async for chunk in service.optimize_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
model_id=data.model_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message,
|
||||
skill=data.skill
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"event:error\ndata: {json.dumps(
|
||||
{"error": str(e)}
|
||||
)}\n\n"
|
||||
yield "event:end\ndata: {}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
opt_result = await service.optimize_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
model_id=data.model_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
)
|
||||
service.create_message(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
role=RoleType.ASSISTANT,
|
||||
content=opt_result.desc
|
||||
)
|
||||
variables = service.parser_prompt_variables(opt_result.prompt)
|
||||
result = {
|
||||
"prompt": opt_result.prompt,
|
||||
"desc": opt_result.desc,
|
||||
"variables": variables
|
||||
}
|
||||
result_schema = OptimizePromptResponse.model_validate(result)
|
||||
return success(data=result_schema)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/releases",
|
||||
summary="Get prompt optimization",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def save_prompt(
|
||||
data: PromptSaveRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Save a prompt release for the current tenant.
|
||||
|
||||
Args:
|
||||
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
|
||||
db (Session): SQLAlchemy database session, injected via dependency.
|
||||
current_user: Currently authenticated user object, injected via dependency.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Standard API response containing the saved prompt release info:
|
||||
- id: UUID of the prompt release
|
||||
- session_id: associated session
|
||||
- title: prompt title
|
||||
- prompt: prompt content
|
||||
- created_at: timestamp of creation
|
||||
|
||||
Raises:
|
||||
Any database or service exceptions are propagated to the global exception handler.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
prompt_info = service.save_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=data.session_id,
|
||||
title=data.title,
|
||||
prompt=data.prompt
|
||||
)
|
||||
return success(data=prompt_info)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/releases/{prompt_id}",
|
||||
summary="Delete prompt (soft delete)",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def delete_prompt(
|
||||
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Soft delete a prompt release.
|
||||
|
||||
Args:
|
||||
prompt_id
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Success message confirming deletion
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
service.delete_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
prompt_id=prompt_id
|
||||
)
|
||||
return success(msg="Prompt deleted successfully")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/releases/list",
|
||||
summary="Get paginated list of released prompts with optional filter",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def get_release_list(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
keyword: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieve paginated list of released prompts for the current tenant.
|
||||
Optionally filter by keyword in title.
|
||||
|
||||
Args:
|
||||
page (int): Page number (starting from 1)
|
||||
page_size (int): Number of items per page (max 100)
|
||||
keyword (str | None): Optional keyword to filter prompt titles
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains paginated list of prompt releases with metadata
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
result = service.get_release_list(
|
||||
tenant_id=current_user.tenant_id,
|
||||
page=max(1, page),
|
||||
page_size=min(max(1, page_size), 100),
|
||||
filter_keyword=keyword
|
||||
)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
|
||||
@@ -1,22 +1,34 @@
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db, get_db_read
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||
from app.schemas import release_share_schema, conversation_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
logger = get_business_logger()
|
||||
@@ -62,10 +74,10 @@ def get_or_generate_user_id(payload_user_id: str, request: Request) -> str:
|
||||
summary="获取访问 token"
|
||||
)
|
||||
def get_access_token(
|
||||
share_token: str,
|
||||
payload: release_share_schema.TokenRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
share_token: str,
|
||||
payload: release_share_schema.TokenRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取访问 token
|
||||
|
||||
@@ -110,9 +122,9 @@ def get_access_token(
|
||||
response_model=None
|
||||
)
|
||||
def get_shared_release(
|
||||
password: str = Query(None, description="访问密码(如果需要)"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
password: str = Query(None, description="访问密码(如果需要)"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取公开分享的发布版本信息
|
||||
|
||||
@@ -134,9 +146,9 @@ def get_shared_release(
|
||||
summary="验证访问密码"
|
||||
)
|
||||
def verify_password(
|
||||
payload: release_share_schema.PasswordVerifyRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
payload: release_share_schema.PasswordVerifyRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""验证分享的访问密码
|
||||
|
||||
@@ -156,11 +168,11 @@ def verify_password(
|
||||
summary="获取嵌入代码"
|
||||
)
|
||||
def get_embed_code(
|
||||
width: str = Query("100%", description="iframe 宽度"),
|
||||
height: str = Query("600px", description="iframe 高度"),
|
||||
request: Request = None,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
width: str = Query("100%", description="iframe 宽度"),
|
||||
height: str = Query("600px", description="iframe 高度"),
|
||||
request: Request = None,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取嵌入代码
|
||||
|
||||
@@ -180,7 +192,6 @@ def get_embed_code(
|
||||
return success(data=embed_code)
|
||||
|
||||
|
||||
|
||||
# ---------- 会话管理接口 ----------
|
||||
|
||||
@router.get(
|
||||
@@ -188,11 +199,11 @@ def get_embed_code(
|
||||
summary="获取会话列表"
|
||||
)
|
||||
def list_conversations(
|
||||
password: str = Query(None, description="访问密码"),
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(20, ge=1, le=100),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
password: str = Query(None, description="访问密码"),
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(20, ge=1, le=100),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取分享应用的会话列表
|
||||
|
||||
@@ -202,15 +213,13 @@ def list_conversations(
|
||||
logger.debug(f"share_data:{share_data.user_id}")
|
||||
other_id = share_data.user_id
|
||||
service = SharedChatService(db)
|
||||
share, release = service._get_release_by_share_token(share_data.share_token, password)
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
other_id=other_id
|
||||
)
|
||||
app_id=share.app_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
service = SharedChatService(db)
|
||||
conversations, total = service.list_conversations(
|
||||
share_token=share_data.share_token,
|
||||
user_id=str(new_end_user.id),
|
||||
@@ -230,10 +239,10 @@ def list_conversations(
|
||||
summary="获取会话详情(含消息)"
|
||||
)
|
||||
def get_conversation(
|
||||
conversation_id: uuid.UUID,
|
||||
password: str = Query(None, description="访问密码"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
conversation_id: uuid.UUID,
|
||||
password: str = Query(None, description="访问密码"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取会话详情和消息历史"""
|
||||
chat_service = SharedChatService(db)
|
||||
@@ -263,9 +272,10 @@ def get_conversation(
|
||||
summary="发送消息(支持流式和非流式)"
|
||||
)
|
||||
async def chat(
|
||||
payload: conversation_schema.ChatRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db)
|
||||
payload: conversation_schema.ChatRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
):
|
||||
"""发送消息并获取回复
|
||||
|
||||
@@ -288,33 +298,31 @@ async def chat(
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
||||
from app.models.app_model import AppType
|
||||
|
||||
try:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.app_service import AppService
|
||||
# 验证分享链接和密码
|
||||
share, release = service._get_release_by_share_token(share_token, password)
|
||||
share, release = service.get_release_by_share_token(share_token, password)
|
||||
|
||||
# # Create end_user_id by concatenating app_id with user_id
|
||||
# end_user_id = f"{share.app_id}_{user_id}"
|
||||
|
||||
# Store end_user_id in database with original user_id
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
other_id=other_id,
|
||||
original_user_id=user_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
|
||||
appid=share.app_id
|
||||
appid = share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app
|
||||
from app.models.app_model import App
|
||||
app = db.query(App).filter(App.id == appid).first()
|
||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||
app = db.query(App).filter(
|
||||
App.id == appid,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
@@ -351,16 +359,19 @@ async def chat(
|
||||
app_type = release.app.type if release.app else None
|
||||
|
||||
# 根据应用类型验证配置
|
||||
if app_type == "agent":
|
||||
if app_type == AppType.AGENT:
|
||||
# Agent 类型:验证模型配置
|
||||
model_config_id = release.default_model_config_id
|
||||
if not model_config_id:
|
||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app_type == "multi_agent":
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# Multi-Agent 类型:验证多 Agent 配置
|
||||
config = release.config or {}
|
||||
if not config.get("sub_agents"):
|
||||
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
# Multi-Agent 类型:验证多 Agent 配置
|
||||
pass
|
||||
else:
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
@@ -389,19 +400,46 @@ async def chat(
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
# 流式返回
|
||||
agent_config = agent_config_4_app_release(release)
|
||||
|
||||
if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
async def event_generator():
|
||||
async for event in service.chat_stream(
|
||||
share_token=share_token,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=payload.web_search,
|
||||
config=agent_config,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -414,34 +452,48 @@ async def chat(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
result = await service.chat(
|
||||
share_token=share_token,
|
||||
# result = await service.chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
result = await app_chat_service.agnet_chat(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
config=agent_config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# 多 Agent 流式返回
|
||||
# config = workflow_config_4_app_release(release)
|
||||
config = multi_agent_config_4_app_release(release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in service.multi_agent_chat_stream(
|
||||
share_token=share_token,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
@@ -458,22 +510,162 @@ async def chat(
|
||||
)
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await service.multi_agent_chat(
|
||||
share_token=share_token,
|
||||
result = await app_chat_service.multi_agent_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
# 多 Agent 流式返回
|
||||
# if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.multi_agent_chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
|
||||
# # 多 Agent 非流式返回
|
||||
# result = await service.multi_agent_chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
config = workflow_config_4_app_release(release)
|
||||
if not config.id:
|
||||
with get_db_read() as db:
|
||||
source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id)
|
||||
config.id = source_config.id
|
||||
config.id = uuid.UUID(config.id)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data, default=str, ensure_ascii=False)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release.id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"has_response": "response" in result if isinstance(result, dict) else False
|
||||
}
|
||||
)
|
||||
return success(
|
||||
data=result,
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/config", summary="获取应用启动配置")
|
||||
async def config_query(
|
||||
password: str = Query(None, description="访问密码"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
share_service = SharedChatService(db)
|
||||
share_token = share_data.share_token
|
||||
share, release = share_service.get_release_by_share_token(share_token, password)
|
||||
if release.app.type == AppType.WORKFLOW:
|
||||
workflow_service = WorkflowService(db)
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": workflow_service.get_start_node_variables(release.config)
|
||||
}
|
||||
elif release.app.type == AppType.AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": release.config.get("variables")
|
||||
}
|
||||
elif release.app.type == AppType.MULTI_AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": []
|
||||
}
|
||||
else:
|
||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
return success(data=content)
|
||||
|
||||
@@ -4,14 +4,17 @@
|
||||
认证方式: API Key
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
from . import app_api_controller, rag_api_controller, memory_api_controller
|
||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
|
||||
|
||||
# 创建 V1 API 路由器
|
||||
service_router = APIRouter()
|
||||
|
||||
# 注册子路由
|
||||
service_router.include_router(app_api_controller.router)
|
||||
service_router.include_router(rag_api_controller.router)
|
||||
service_router.include_router(rag_api_knowledge_controller.router)
|
||||
service_router.include_router(rag_api_document_controller.router)
|
||||
service_router.include_router(rag_api_file_controller.router)
|
||||
service_router.include_router(rag_api_chunk_controller.router)
|
||||
service_router.include_router(memory_api_controller.router)
|
||||
|
||||
__all__ = ["service_router"]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""App 服务接口 - 基于 API Key 认证"""
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, Body
|
||||
@@ -11,7 +12,6 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_app_or_workspace
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
@@ -20,9 +20,10 @@ from app.schemas import AppChatRequest, conversation_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release
|
||||
from app.services.app_service import get_app_service, AppService
|
||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
||||
logger = get_business_logger()
|
||||
@@ -33,6 +34,7 @@ async def list_apps():
|
||||
"""列出可访问的应用(占位)"""
|
||||
return success(data=[], msg="App API - Coming Soon")
|
||||
|
||||
|
||||
# /v1/app/chat
|
||||
|
||||
# @router.post("/chat")
|
||||
@@ -72,21 +74,21 @@ def _checkAppConfig(app: App):
|
||||
else:
|
||||
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@require_api_key(scopes=["app"])
|
||||
async def chat(
|
||||
request:Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
):
|
||||
body = await request.json()
|
||||
payload = AppChatRequest(**body)
|
||||
|
||||
other_id = payload.user_id
|
||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||
other_id = payload.user_id
|
||||
workspace_id = app.workspace_id
|
||||
@@ -97,8 +99,8 @@ async def chat(
|
||||
original_user_id=other_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
web_search=True
|
||||
memory=True
|
||||
web_search = True
|
||||
memory = True
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
db=db,
|
||||
@@ -132,28 +134,31 @@ async def chat(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=end_user_id,
|
||||
is_draft=False
|
||||
is_draft=False,
|
||||
conversation_id=payload.conversation_id
|
||||
)
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
|
||||
print("="*50)
|
||||
print(app.current_release.default_model_config_id)
|
||||
# print("="*50)
|
||||
# print(app.current_release.default_model_config_id)
|
||||
agent_config = agent_config_4_app_release(app.current_release)
|
||||
print(agent_config.default_model_config_id)
|
||||
# print(agent_config.default_model_config_id)
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id= end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=web_search,
|
||||
config=agent_config,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=web_search,
|
||||
config=agent_config,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -173,27 +178,29 @@ async def chat(
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config= agent_config,
|
||||
config=agent_config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# 多 Agent 流式返回
|
||||
config = dict_to_multi_agent_config(app.current_release.config,app.id)
|
||||
config = multi_agent_config_4_app_release(app.current_release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
@@ -211,7 +218,6 @@ async def chat(
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.multi_agent_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
@@ -226,22 +232,31 @@ async def chat(
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
# 多 Agent 流式返回
|
||||
config = dict_to_workflow_config(app.current_release.config,app.id)
|
||||
config = workflow_config_4_app_release(app.current_release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
public=True
|
||||
):
|
||||
yield event
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -253,7 +268,7 @@ async def chat(
|
||||
}
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
message=payload.message,
|
||||
@@ -264,12 +279,23 @@ async def chat(
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"has_response": "response" in result if isinstance(result, dict) else False
|
||||
}
|
||||
)
|
||||
return success(
|
||||
data=result,
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
pass
|
||||
|
||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
|
||||
221
api/app/controllers/service/rag_api_chunk_controller.py
Normal file
221
api/app/controllers/service/rag_api_chunk_controller.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import chunk_controller
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.models.chunk import QAChunk
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas import chunk_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import api_key_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/chunks", tags=["V1 - RAG API"])
|
||||
api_logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/previewchunks", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_preview_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content")
|
||||
):
|
||||
"""
|
||||
Paged query document block preview list
|
||||
- Support filtering by document_id
|
||||
- Support keyword search for segmented content
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.get_preview_chunks(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keywords=keywords,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content")
|
||||
):
|
||||
"""
|
||||
Paged query document chunk list
|
||||
- Support filtering by document_id
|
||||
- Support keyword search for segmented content
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.get_chunks(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keywords=keywords,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def create_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
content: Union[str, QAChunk] = Body(..., description="Content can be either a string or a QAChunk object"),
|
||||
):
|
||||
"""
|
||||
create chunk
|
||||
"""
|
||||
body = await request.json()
|
||||
create_data = chunk_schema.ChunkCreate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.create_chunk(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
doc_id: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve document chunk information based on doc_id
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.get_chunk(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
doc_id=doc_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.put("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def update_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
doc_id: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
content: Union[str, QAChunk] = Body(..., description="Content can be either a string or a QAChunk object"),
|
||||
):
|
||||
"""
|
||||
Update document chunk content
|
||||
"""
|
||||
body = await request.json()
|
||||
update_data = chunk_schema.ChunkUpdate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.update_chunk(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
doc_id=doc_id,
|
||||
update_data=update_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.delete("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def delete_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
doc_id: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
delete document chunk
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.delete_chunk(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
doc_id=doc_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/retrieve_type", response_model=ApiResponse)
|
||||
def get_retrieve_types():
|
||||
return success(msg="Successfully obtained the retrieval type", data=list(chunk_schema.RetrieveType))
|
||||
|
||||
|
||||
@router.post("/retrieval", response_model=Any, status_code=status.HTTP_200_OK)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def retrieve_chunks(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
query: str = Body(..., description="question"),
|
||||
):
|
||||
"""
|
||||
retrieve chunk
|
||||
"""
|
||||
body = await request.json()
|
||||
retrieve_data = chunk_schema.ChunkRetrieve(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.retrieve_chunks(retrieve_data=retrieve_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
router = APIRouter(prefix="/knowledge", tags=["V1 - RAG API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_knowledge():
|
||||
"""列出可访问的知识库(占位)"""
|
||||
return success(data=[], msg="RAG API - Coming Soon")
|
||||
172
api/app/controllers/service/rag_api_document_controller.py
Normal file
172
api/app/controllers/service/rag_api_document_controller.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import document_controller
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db
|
||||
from app.schemas import document_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import api_key_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/documents", tags=["V1 - RAG API"])
|
||||
api_logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/{kb_id}/documents", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_documents(
|
||||
kb_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id when type is Folder"),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||
document_ids: Optional[str] = Query(None, description="document ids, separated by commas")
|
||||
):
|
||||
"""
|
||||
Paged query document list
|
||||
- Support filtering by kb_id and parent_id
|
||||
- Support keyword search for file names
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.get_documents(kb_id=kb_id,
|
||||
parent_id=parent_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
keywords=keywords,
|
||||
document_ids=document_ids,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/document", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def create_document(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
kb_id: uuid.UUID = Body(..., description="kb id"),
|
||||
file_name: str = Body(..., description="file name"),
|
||||
):
|
||||
"""
|
||||
create document
|
||||
"""
|
||||
body = await request.json()
|
||||
create_data = document_schema.DocumentCreate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.create_document(create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{document_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_document(
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve document information based on document_id
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.get_document(document_id=document_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.put("/{document_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def update_document(
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
file_name: str = Body(None, description="file name (optional)"),
|
||||
):
|
||||
"""
|
||||
Update document information
|
||||
"""
|
||||
body = await request.json()
|
||||
update_data = document_schema.DocumentUpdate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.update_document(document_id=document_id,
|
||||
update_data=update_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.delete("/{document_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def delete_document(
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Delete document
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.delete_document(document_id=document_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/{document_id}/chunks", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def parse_documents(
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
parse document
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.parse_documents(document_id=document_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
198
api/app/controllers/service/rag_api_file_controller.py
Normal file
198
api/app/controllers/service/rag_api_file_controller.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request, Query, File, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import file_controller
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db
|
||||
from app.schemas import file_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import api_key_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/files", tags=["V1 - RAG API"])
|
||||
api_logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{parent_id}/files", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_files(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||
):
|
||||
"""
|
||||
Paged query file list
|
||||
- Support filtering by kb_id and parent_id
|
||||
- Support keyword search for file names
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id=api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.get_files(kb_id=kb_id,
|
||||
parent_id=parent_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
keywords=keywords,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/folder", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def create_folder(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
folder_name: str = '/'
|
||||
):
|
||||
"""
|
||||
Create a new folder
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.create_folder(kb_id=kb_id,
|
||||
parent_id=parent_id,
|
||||
folder_name=folder_name,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def upload_file(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""
|
||||
upload file
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.upload_file(kb_id=kb_id,
|
||||
parent_id=parent_id,
|
||||
file=file,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/customtext", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def custom_text(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
title: str = Body(..., description="title"),
|
||||
content: str = Body(..., description="content"),
|
||||
):
|
||||
"""
|
||||
custom text
|
||||
"""
|
||||
body = await request.json()
|
||||
create_data = file_schema.CustomTextFileCreate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.custom_text(kb_id=kb_id,
|
||||
parent_id=parent_id,
|
||||
create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{file_id}", response_model=Any)
|
||||
async def get_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Download the file based on the file_id
|
||||
- Query file information from the database
|
||||
- Construct the file path and check if it exists
|
||||
- Return a FileResponse to download the file
|
||||
"""
|
||||
return await file_controller.get_file(file_id=file_id,
|
||||
db=db)
|
||||
|
||||
|
||||
@router.put("/{file_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def update_file(
|
||||
file_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
file_name: str = Body(None, description="file name (optional)"),
|
||||
):
|
||||
"""
|
||||
Update file information (such as file name)
|
||||
- Only specified fields such as file_name are allowed to be modified
|
||||
"""
|
||||
body = await request.json()
|
||||
update_data = file_schema.FileUpdate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.update_file(file_id=file_id,
|
||||
update_data=update_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.delete("/{file_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def delete_file(
|
||||
file_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Delete a file or folder
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.delete_file(file_id=file_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
318
api/app/controllers/service/rag_api_knowledge_controller.py
Normal file
318
api/app/controllers/service/rag_api_knowledge_controller.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from typing import Optional, Dict
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import knowledge_controller
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.models import knowledge_model
|
||||
from app.schemas import knowledge_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import api_key_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/knowledges", tags=["V1 - RAG API"])
|
||||
api_logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/knowledgetype", response_model=ApiResponse)
|
||||
def get_knowledge_types():
|
||||
return success(msg="Successfully obtained the knowledge type", data=list(knowledge_model.KnowledgeType))
|
||||
|
||||
|
||||
@router.get("/permissiontype", response_model=ApiResponse)
|
||||
def get_permission_types():
|
||||
return success(msg="Successfully obtained the knowledge permission type", data=list(knowledge_model.PermissionType))
|
||||
|
||||
|
||||
@router.get("/parsertype", response_model=ApiResponse)
|
||||
def get_parser_types():
|
||||
return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType))
|
||||
|
||||
|
||||
@router.get("/knowledge_graph_entity_types", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_knowledge_graph_entity_types(
|
||||
llm_id: uuid.UUID,
|
||||
scenario: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
get knowledge graph entity types based on llm_id
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.get_knowledge_graph_entity_types(llm_id=llm_id,
|
||||
scenario=scenario,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/knowledges", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_knowledges(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (knowledge base name)"),
|
||||
kb_ids: Optional[str] = Query(None, description="Knowledge base ids, separated by commas")
|
||||
):
|
||||
"""
|
||||
Query the knowledge base list in pages
|
||||
- Support filtering by parent_id
|
||||
- Support keyword search for knowledge base names
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.get_knowledges(parent_id=parent_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
keywords=keywords,
|
||||
kb_ids=kb_ids,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/knowledge", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def create_knowledge(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
name: str = Body(..., description="KB name"),
|
||||
):
|
||||
"""
|
||||
create knowledge
|
||||
"""
|
||||
body = await request.json()
|
||||
create_data = knowledge_schema.KnowledgeCreate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.create_knowledge(create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{knowledge_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve knowledge base information based on knowledge_id
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.get_knowledge(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.put("/{knowledge_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def update_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
name: str = Body(None, description="KB name (optional)"),
|
||||
):
|
||||
body = await request.json()
|
||||
update_data = knowledge_schema.KnowledgeUpdate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.update_knowledge(knowledge_id=knowledge_id,
|
||||
update_data=update_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.delete("/{knowledge_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def delete_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Soft-delete knowledge base
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.delete_knowledge(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_knowledge_graph(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve knowledge_graph base information based on knowledge_id
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.get_knowledge_graph(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.delete("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def delete_knowledge_graph(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
delete knowledge graph
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.delete_knowledge_graph(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def rebuild_knowledge_graph(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
rebuild knowledge graph
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.rebuild_knowledge_graph(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/check/yuque/auth", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def check_yuque_auth(
|
||||
yuque_user_id: str,
|
||||
yuque_token: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
check yuque auth info
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
api_logger.info(f"check yuque auth info, username: {current_user.username}")
|
||||
|
||||
return await knowledge_controller.check_yuque_auth(yuque_user_id=yuque_user_id,
|
||||
yuque_token=yuque_token,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/check/feishu/auth", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def check_feishu_auth(
|
||||
feishu_app_id: str,
|
||||
feishu_app_secret: str,
|
||||
feishu_folder_token: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
check feishu auth info
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
api_logger.info(f"check feishu auth info, username: {current_user.username}")
|
||||
|
||||
return await knowledge_controller.check_feishu_auth(feishu_app_id=feishu_app_id,
|
||||
feishu_app_secret=feishu_app_secret,
|
||||
feishu_folder_token=feishu_folder_token,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def sync_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
sync knowledge base information based on knowledge_id
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.sync_knowledge(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
85
api/app/controllers/skill_controller.py
Normal file
85
api/app/controllers/skill_controller.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Skill Controller - 技能市场管理"""
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas import skill_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.skill_service import SkillService
|
||||
from app.core.response_utils import success
|
||||
|
||||
router = APIRouter(prefix="/skills", tags=["Skills"])
|
||||
|
||||
|
||||
@router.post("", summary="创建技能")
|
||||
def create_skill(
|
||||
data: skill_schema.SkillCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建技能 - 可以关联现有工具(内置、MCP、自定义)"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.create_skill(db, data, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
|
||||
|
||||
|
||||
@router.get("", summary="技能列表")
|
||||
def list_skills(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_active: Optional[bool] = Query(None, description="是否激活"),
|
||||
is_public: Optional[bool] = Query(None, description="是否公开"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""技能市场列表 - 包含本工作空间和公开的技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skills, total = SkillService.list_skills(
|
||||
db, tenant_id, search, is_active, is_public, page, pagesize
|
||||
)
|
||||
|
||||
items = [skill_schema.Skill.model_validate(s) for s in skills]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
|
||||
|
||||
|
||||
@router.get("/{skill_id}", summary="获取技能详情")
|
||||
def get_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取技能详情"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.get_skill(db, skill_id, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
|
||||
|
||||
|
||||
@router.put("/{skill_id}", summary="更新技能")
|
||||
def update_skill(
|
||||
skill_id: uuid.UUID,
|
||||
data: skill_schema.SkillUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
|
||||
|
||||
|
||||
@router.delete("/{skill_id}", summary="删除技能")
|
||||
def delete_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
SkillService.delete_skill(db, skill_id, tenant_id)
|
||||
return success(msg="技能删除成功")
|
||||
@@ -1,23 +1,22 @@
|
||||
from fastapi import APIRouter, Depends, status, Query, HTTPException
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from fastapi import APIRouter, Depends, status, HTTPException, Body, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.models import RedBearLLM, RedBearRerank
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelApiKey, ModelProvider, ModelType
|
||||
from app.models.user_model import User
|
||||
from app.schemas import model_schema
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.app_schema import AppChatRequest
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.services.handoffs_service import get_handoffs_service_for_app, reset_handoffs_service_cache
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.dependencies import get_current_user
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -28,6 +27,8 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
# ==================== 原有测试接口 ====================
|
||||
|
||||
@router.get("/llm/{model_id}", response_model=ApiResponse)
|
||||
def test_llm(
|
||||
model_id: uuid.UUID,
|
||||
@@ -50,7 +51,6 @@ def test_llm(
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
# ChatPromptTemplate
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
chain = prompt | llm
|
||||
answer = chain.invoke({"question": "What is LangChain?"})
|
||||
@@ -80,13 +80,13 @@ def test_embedding(
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
embeddings = model.embed_documents(data)
|
||||
print(embeddings)
|
||||
query = "我想找一个适合学习的地方。"
|
||||
@@ -114,13 +114,123 @@ def test_rerank(
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
query = "最近哪家咖啡店评价最好?"
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
scores = model.rerank(query=query, documents=data, top_n=3)
|
||||
print(scores)
|
||||
return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores})
|
||||
|
||||
|
||||
# ==================== Handoffs 测试接口 ====================
|
||||
|
||||
@router.post("/handoffs/{app_id}")
|
||||
async def test_handoffs(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
request: AppChatRequest = Body(...),
|
||||
current_user=Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""测试 Agent Handoffs 功能
|
||||
|
||||
演示 LangGraph 实现的多 Agent 协作和动态切换
|
||||
|
||||
- 从数据库 multi_agent_config 获取 Agent 配置
|
||||
- 根据用户问题自动切换到合适的 Agent
|
||||
- 使用 conversation_id 保持会话状态
|
||||
- 通过 stream 参数控制是否流式输出
|
||||
|
||||
事件类型(流式):
|
||||
- start: 开始执行
|
||||
- agent: 当前 Agent 信息
|
||||
- message: 流式消息内容
|
||||
- handoff: Agent 切换事件
|
||||
- end: 执行结束
|
||||
- error: 错误信息
|
||||
"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 获取或创建会话
|
||||
conversation_service = ConversationService(db)
|
||||
|
||||
if request.conversation_id:
|
||||
# 验证会话存在
|
||||
conversation = conversation_service.get_conversation(uuid.UUID(request.conversation_id))
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
conversation_id = str(conversation.id)
|
||||
else:
|
||||
# 创建新会话
|
||||
conversation = conversation_service.create_or_get_conversation(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=request.user_id,
|
||||
is_draft=True
|
||||
)
|
||||
conversation_id = str(conversation.id)
|
||||
|
||||
# 根据 stream 参数决定返回方式
|
||||
if request.stream:
|
||||
# 流式返回
|
||||
service = get_handoffs_service_for_app(app_id, db, streaming=True)
|
||||
return StreamingResponse(
|
||||
service.chat_stream(
|
||||
message=request.message,
|
||||
conversation_id=conversation_id
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式返回
|
||||
service = get_handoffs_service_for_app(app_id, db, streaming=False)
|
||||
result = await service.chat(
|
||||
message=request.message,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
return success(data=result, msg="Handoffs 测试成功")
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Handoffs 测试失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/handoffs/{app_id}/agents", response_model=ApiResponse)
|
||||
def get_handoff_agents(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user)
|
||||
):
|
||||
"""获取应用的 Handoff Agent 列表"""
|
||||
try:
|
||||
service = get_handoffs_service_for_app(app_id, db, streaming=False)
|
||||
agents = service.get_agents()
|
||||
return success(data={"agents": agents}, msg="获取 Agent 列表成功")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取 Agent 列表失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/handoffs/{app_id}/reset")
|
||||
def reset_handoff_service(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
current_user=Depends(get_current_user)
|
||||
):
|
||||
"""重置指定应用的 Handoff 服务缓存"""
|
||||
reset_handoffs_service_cache(app_id)
|
||||
return success(msg="Handoff 服务已重置")
|
||||
|
||||
@@ -60,6 +60,22 @@ async def list_tools(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{tool_id}/methods", response_model=ApiResponse)
|
||||
async def get_tool_methods(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""获取工具的所有方法"""
|
||||
try:
|
||||
methods = await service.get_tool_methods(tool_id, current_user.tenant_id)
|
||||
if methods is None:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
return success(data=methods, msg="获取工具方法成功")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=ApiResponse)
|
||||
async def get_tool(
|
||||
tool_id: str,
|
||||
@@ -159,7 +175,8 @@ async def execute_tool(
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
timeout=request.timeout
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise HTTPException(status_code=400, detail=result["error"])
|
||||
return success(
|
||||
data={
|
||||
"success": result.success,
|
||||
@@ -198,8 +215,8 @@ async def sync_mcp_tools(
|
||||
"""同步MCP工具列表"""
|
||||
try:
|
||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||
if result["success"] is False:
|
||||
raise HTTPException(status_code=404, detail=result["message"])
|
||||
if not result.get("success", False):
|
||||
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
||||
return success(data=result, msg="MCP工具列表同步完成")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -2,15 +2,23 @@ from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.models.user_model import User
|
||||
from app.schemas import user_schema
|
||||
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
|
||||
from app.schemas.user_schema import (
|
||||
ChangePasswordRequest,
|
||||
AdminChangePasswordRequest,
|
||||
SendEmailCodeRequest,
|
||||
VerifyEmailCodeRequest,
|
||||
VerifyPasswordRequest)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import user_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.security import verify_password
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -92,7 +100,7 @@ def get_current_user_info(
|
||||
result_schema.current_workspace_name = current_workspace.name
|
||||
|
||||
for ws in result.workspaces:
|
||||
if ws.workspace_id == current_user.current_workspace_id:
|
||||
if ws.workspace_id == current_user.current_workspace_id and ws.is_active:
|
||||
result_schema.role = ws.role
|
||||
break
|
||||
|
||||
@@ -120,6 +128,7 @@ def get_tenant_superusers(
|
||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ApiResponse)
|
||||
def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
@@ -180,4 +189,54 @@ async def admin_change_password(
|
||||
return success(msg="密码修改成功")
|
||||
else:
|
||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
|
||||
|
||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||
def verify_pwd(
|
||||
request: VerifyPasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""验证当前用户密码"""
|
||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||
|
||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||
if not is_valid:
|
||||
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg="验证完成")
|
||||
|
||||
|
||||
@router.post("/send-email-code", response_model=ApiResponse)
|
||||
async def send_email_code(
|
||||
request: SendEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""发送邮箱验证码"""
|
||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||
|
||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||
|
||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||
return success(msg="验证码已发送到您的邮箱,请查收")
|
||||
|
||||
|
||||
@router.put("/change-email", response_model=ApiResponse)
|
||||
async def change_email(
|
||||
request: VerifyEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""验证验证码并修改邮箱"""
|
||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||
|
||||
await user_service.verify_and_change_email(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
new_email=request.new_email,
|
||||
code=request.code
|
||||
)
|
||||
|
||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||
return success(msg="邮箱修改成功")
|
||||
|
||||
@@ -5,20 +5,23 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends,Header
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.api_key_utils import timestamp_to_datetime
|
||||
from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_node_statistics,
|
||||
analytics_memory_types,
|
||||
analytics_graph_data,
|
||||
)
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
@@ -41,24 +44,27 @@ router = APIRouter(
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str, # 使用 end_user_id
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取缓存的记忆洞察报告"""
|
||||
api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的记忆洞察报告
|
||||
|
||||
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
||||
如需生成新的洞察报告,请使用专门的生成接口。
|
||||
"""
|
||||
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
||||
|
||||
|
||||
if result["is_cached"]:
|
||||
# 缓存存在,返回缓存数据
|
||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
else:
|
||||
# 缓存不存在,返回提示消息
|
||||
api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
return success(data=result, msg="数据尚未生成")
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e))
|
||||
@@ -66,24 +72,43 @@ async def get_memory_insight_report_api(
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str, # 使用 end_user_id
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取缓存的用户摘要"""
|
||||
api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的用户摘要
|
||||
|
||||
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
||||
如需生成新的用户摘要,请使用专门的生成接口。
|
||||
|
||||
语言控制:
|
||||
- 使用 X-Language-Type Header 指定语言
|
||||
- 如果未传 Header,默认使用中文 (zh)
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
|
||||
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
||||
|
||||
if result["is_cached"]:
|
||||
# 缓存存在,返回缓存数据
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
else:
|
||||
# 缓存不存在,返回提示消息
|
||||
api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
return success(data=result, msg="数据尚未生成")
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
|
||||
@@ -92,48 +117,56 @@ async def get_user_summary_api(
|
||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||
async def generate_cache_api(
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
手动触发缓存生成
|
||||
|
||||
|
||||
- 如果提供 end_user_id,只为该用户生成
|
||||
- 如果不提供,为当前工作空间的所有用户生成
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
语言控制:
|
||||
- 使用 X-Language-Type Header 指定语言 ("zh" 中文, "en" 英文)
|
||||
- 如果未传 Header,默认使用中文 (zh)
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
group_id = request.end_user_id
|
||||
|
||||
|
||||
end_user_id = request.end_user_id
|
||||
|
||||
api_logger.info(
|
||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||
f"end_user_id={group_id if group_id else '全部用户'}"
|
||||
f"end_user_id={end_user_id if end_user_id else '全部用户'}, language={language}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
# 为单个用户生成
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
|
||||
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
|
||||
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
|
||||
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
"end_user_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"insight_success": insight_result["success"],
|
||||
"summary_success": summary_result["success"],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
|
||||
# 收集错误信息
|
||||
if not insight_result["success"]:
|
||||
result["errors"].append({
|
||||
@@ -145,29 +178,29 @@ async def generate_cache_api(
|
||||
"type": "summary",
|
||||
"error": summary_result.get("error")
|
||||
})
|
||||
|
||||
|
||||
# 记录结果
|
||||
if result["insight_success"] and result["summary_success"]:
|
||||
api_logger.info(f"成功为用户 {group_id} 生成缓存")
|
||||
api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
|
||||
else:
|
||||
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
|
||||
|
||||
api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
|
||||
|
||||
return success(data=result, msg="生成完成")
|
||||
|
||||
|
||||
else:
|
||||
# 为整个工作空间生成
|
||||
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
|
||||
|
||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
|
||||
|
||||
|
||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id, language=language)
|
||||
|
||||
# 记录统计信息
|
||||
api_logger.info(
|
||||
f"工作空间 {workspace_id} 批量生成完成: "
|
||||
f"总数={result['total_users']}, 成功={result['successful']}, 失败={result['failed']}"
|
||||
)
|
||||
|
||||
|
||||
return success(data=result, msg="批量生成完成")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "缓存生成失败", str(e))
|
||||
@@ -180,18 +213,18 @@ async def get_node_statistics_api(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 调用新的记忆类型统计函数
|
||||
result = await analytics_memory_types(db, end_user_id)
|
||||
|
||||
|
||||
# 计算总数用于日志
|
||||
total_count = sum(item["count"] for item in result)
|
||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
@@ -211,31 +244,31 @@ async def get_graph_data_api(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询图数据但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
# 参数验证
|
||||
if limit > 1000:
|
||||
limit = 1000
|
||||
api_logger.warning("limit 参数超过最大值,已调整为 1000")
|
||||
|
||||
|
||||
if depth > 3:
|
||||
depth = 3
|
||||
api_logger.warning("depth 参数超过最大值,已调整为 3")
|
||||
|
||||
|
||||
# 解析 node_types 参数
|
||||
node_types_list = None
|
||||
if node_types:
|
||||
node_types_list = [t.strip() for t in node_types.split(",") if t.strip()]
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"图数据查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}, node_types={node_types_list}, limit={limit}, depth={depth}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
result = await analytics_graph_data(
|
||||
db=db,
|
||||
@@ -245,19 +278,18 @@ async def get_graph_data_api(
|
||||
depth=depth,
|
||||
center_node_id=center_node_id
|
||||
)
|
||||
|
||||
# 检查是否有错误消息
|
||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
|
||||
return success(data=result, msg=result.get("message", "查询成功"))
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取图数据: end_user_id={end_user_id}, "
|
||||
f"nodes={result['statistics']['total_nodes']}, "
|
||||
f"edges={result['statistics']['total_edges']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||
@@ -270,25 +302,30 @@ async def get_end_user_profile(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
@@ -300,10 +337,10 @@ async def get_end_user_profile(
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
|
||||
return success(data=profile_data.model_dump(), msg="查询成功")
|
||||
|
||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
|
||||
@@ -317,65 +354,90 @@ async def update_end_user_profile(
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户的基本信息
|
||||
|
||||
|
||||
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
||||
所有字段都是可选的,只更新提供的字段。
|
||||
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = profile_update.end_user_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
|
||||
# 验证工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
# 调用 Service 层处理业务逻辑
|
||||
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
|
||||
|
||||
if result["success"]:
|
||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
|
||||
return success(data=result["data"], msg="更新成功")
|
||||
else:
|
||||
error_msg = result["error"]
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
|
||||
# 根据错误类型映射到合适的业务错误码
|
||||
if error_msg == "终端用户不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
||||
elif error_msg == "无效的用户ID格式":
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
|
||||
else:
|
||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
MemoryEntity = MemoryEntityService(id, label)
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
async def memory_space_relationship_evolution(id: str, label: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
|
||||
# 更新字段(只更新提供的字段,排除 end_user_id)
|
||||
# 允许 None 值来重置字段(如 hire_date)
|
||||
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||
for field, value in update_data.items():
|
||||
setattr(end_user, field, value)
|
||||
|
||||
# 更新 updated_at 时间戳
|
||||
end_user.updated_at = datetime.datetime.now()
|
||||
|
||||
# 更新 updatetime_profile 为当前时间戳(毫秒)
|
||||
current_timestamp = int(datetime.datetime.now().timestamp() * 1000)
|
||||
end_user.updatetime_profile = current_timestamp
|
||||
|
||||
# 提交更改
|
||||
db.commit()
|
||||
db.refresh(end_user)
|
||||
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
other_name=end_user.other_name,
|
||||
position=end_user.position,
|
||||
department=end_user.department,
|
||||
contact=end_user.contact,
|
||||
phone=end_user.phone,
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
)
|
||||
|
||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}, updatetime_profile={current_timestamp}")
|
||||
return success(data=profile_data.model_dump(), msg="更新成功")
|
||||
|
||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||
|
||||
# 获取情绪数据
|
||||
emotion = MemoryEmotion(id, label)
|
||||
emotion_result = await emotion.get_emotion()
|
||||
|
||||
# 获取交互数据
|
||||
interaction = MemoryInteraction(id, label)
|
||||
interaction_result = await interaction.get_interaction_frequency()
|
||||
|
||||
# 关闭连接
|
||||
await emotion.close()
|
||||
await interaction.close()
|
||||
|
||||
result = {
|
||||
"emotion": emotion_result,
|
||||
"interaction": interaction_result
|
||||
}
|
||||
|
||||
api_logger.info(f"关系演变查询成功: id={id}, table={label}")
|
||||
return success(data=result, msg="关系演变")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
|
||||
api_logger.error(f"关系演变查询失败: id={id}, table={label}, error={str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "关系演变查询失败", str(e))
|
||||
|
||||
@@ -1,611 +0,0 @@
|
||||
"""
|
||||
工作流 API 控制器
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.schemas.workflow_schema import (
|
||||
WorkflowConfigCreate,
|
||||
WorkflowConfigUpdate,
|
||||
WorkflowConfig,
|
||||
WorkflowValidationResponse,
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowExecutionRequest,
|
||||
WorkflowExecutionResponse
|
||||
)
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["workflow"])
|
||||
|
||||
|
||||
# ==================== 工作流配置管理 ====================
|
||||
|
||||
@router.post("/{app_id}/workflow")
|
||||
@cur_workspace_access_guard()
|
||||
async def create_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""创建工作流配置
|
||||
|
||||
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 创建工作流配置
|
||||
workflow_config = service.create_workflow_config(
|
||||
app_id=app_id,
|
||||
nodes=[node.model_dump() for node in config.nodes],
|
||||
edges=[edge.model_dump() for edge in config.edges],
|
||||
variables=[var.model_dump() for var in config.variables],
|
||||
execution_config=config.execution_config.model_dump(),
|
||||
triggers=[trigger.model_dump() for trigger in config.triggers],
|
||||
validate=True # 进行基础验证
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowConfig.model_validate(workflow_config),
|
||||
msg="工作流配置创建成功"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"创建工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"创建工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
#
|
||||
# @router.get("/{app_id}/workflow")
|
||||
# async def get_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)]
|
||||
#
|
||||
# ):
|
||||
# """获取工作流配置
|
||||
#
|
||||
# 获取应用的工作流配置详情。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
#
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
#
|
||||
# # 获取工作流配置
|
||||
# service = WorkflowService(db)
|
||||
# workflow_config = service.get_workflow_config(app_id)
|
||||
#
|
||||
# if not workflow_config:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="工作流配置不存在"
|
||||
# )
|
||||
#
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config)
|
||||
# )
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"获取工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
# @router.put("/{app_id}/workflow")
|
||||
# async def update_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# config: WorkflowConfigUpdate,
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)],
|
||||
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
# ):
|
||||
# """更新工作流配置
|
||||
|
||||
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
|
||||
# # 更新工作流配置
|
||||
# workflow_config = service.update_workflow_config(
|
||||
# app_id=app_id,
|
||||
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
|
||||
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
|
||||
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
|
||||
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
|
||||
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
|
||||
# validate=True
|
||||
# )
|
||||
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config),
|
||||
# msg="工作流配置更新成功"
|
||||
# )
|
||||
|
||||
# except BusinessException as e:
|
||||
# logger.warning(f"更新工作流配置失败: {e.message}")
|
||||
# return fail(code=e.error_code, msg=e.message)
|
||||
# except Exception as e:
|
||||
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"更新工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.delete("/{app_id}/workflow")
|
||||
async def delete_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""删除工作流配置
|
||||
|
||||
删除应用的工作流配置。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 删除工作流配置
|
||||
deleted = service.delete_workflow_config(app_id)
|
||||
|
||||
if not deleted:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
return success(msg="工作流配置删除成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"删除工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{app_id}/workflow/validate")
|
||||
async def validate_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
):
|
||||
"""验证工作流配置
|
||||
|
||||
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证工作流配置
|
||||
|
||||
if for_publish:
|
||||
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
|
||||
else:
|
||||
workflow_config = service.get_workflow_config(app_id)
|
||||
if not workflow_config:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
from app.core.workflow.validator import validate_workflow_config as validate_config
|
||||
config_dict = {
|
||||
"nodes": workflow_config.nodes,
|
||||
"edges": workflow_config.edges,
|
||||
"variables": workflow_config.variables,
|
||||
"execution_config": workflow_config.execution_config,
|
||||
"triggers": workflow_config.triggers
|
||||
}
|
||||
is_valid, errors = validate_config(config_dict, for_publish=False)
|
||||
|
||||
return success(
|
||||
data=WorkflowValidationResponse(
|
||||
is_valid=is_valid,
|
||||
errors=errors,
|
||||
warnings=[]
|
||||
)
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"验证工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"验证工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 工作流执行管理 ====================
|
||||
|
||||
@router.get("/{app_id}/workflow/executions")
|
||||
async def get_workflow_executions(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
):
|
||||
"""获取工作流执行记录列表
|
||||
|
||||
获取应用的工作流执行历史记录。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 获取执行记录
|
||||
executions = service.get_executions_by_app(app_id, limit, offset)
|
||||
|
||||
# 获取统计信息
|
||||
statistics = service.get_execution_statistics(app_id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"executions": [WorkflowExecution.model_validate(e) for e in executions],
|
||||
"statistics": statistics,
|
||||
"pagination": {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"total": statistics["total"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行记录失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/workflow/executions/{execution_id}")
|
||||
async def get_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""获取工作流执行详情
|
||||
|
||||
获取单个工作流执行的详细信息,包括所有节点的执行记录。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 获取节点执行记录
|
||||
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"execution": WorkflowExecution.model_validate(execution),
|
||||
"node_executions": [
|
||||
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行详情失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
@router.post("/{app_id}/workflow/run")
|
||||
async def run_workflow(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""执行工作流
|
||||
|
||||
执行工作流并返回结果。支持流式和非流式两种模式。
|
||||
|
||||
**非流式模式**:等待工作流执行完成后返回完整结果。
|
||||
|
||||
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 准备输入数据
|
||||
input_data = {
|
||||
"message": request.message or "",
|
||||
"variables": request.variables
|
||||
}
|
||||
|
||||
# 执行工作流
|
||||
|
||||
if request.stream:
|
||||
# 流式执行
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
async def event_generator():
|
||||
"""生成 SSE 事件
|
||||
|
||||
SSE 格式:
|
||||
event: <event_type>
|
||||
data: <json_data>
|
||||
|
||||
支持的事件类型:
|
||||
- workflow_start: 工作流开始
|
||||
- workflow_end: 工作流结束
|
||||
- node_start: 节点开始执行
|
||||
- node_end: 节点执行完成
|
||||
- node_chunk: 中间节点的流式输出
|
||||
- message: 最终消息的流式输出(End 节点及其相邻节点)
|
||||
"""
|
||||
try:
|
||||
async for event in await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
# event: <type>
|
||||
# data: <json>
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
# 发送错误事件
|
||||
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
yield sse_error
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
result = await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=False
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowExecutionResponse(
|
||||
execution_id=result["execution_id"],
|
||||
status=result["status"],
|
||||
output=result.get("output"),
|
||||
output_data=result.get("output_data"),
|
||||
error_message=result.get("error_message"),
|
||||
elapsed_time=result.get("elapsed_time"),
|
||||
token_usage=result.get("token_usage")
|
||||
),
|
||||
msg="工作流执行完成"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"执行工作流失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"执行工作流失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||
async def cancel_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""取消工作流执行
|
||||
|
||||
取消正在运行的工作流执行。
|
||||
|
||||
**注意**:当前版本仅更新状态为 cancelled,实际的执行取消功能待实现。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 检查执行状态
|
||||
if execution.status not in ["pending", "running"]:
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"无法取消状态为 {execution.status} 的执行"
|
||||
)
|
||||
|
||||
# 更新状态为 cancelled
|
||||
service.update_execution_status(execution_id, "cancelled")
|
||||
|
||||
return success(msg="工作流执行已取消")
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"取消工作流执行失败: {str(e)}"
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
@@ -95,16 +95,29 @@ def get_workspaces(
|
||||
@router.post("", response_model=ApiResponse)
|
||||
def create_workspace(
|
||||
workspace: WorkspaceCreate,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
):
|
||||
"""创建新的工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
|
||||
# 验证并获取语言参数
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求创建工作空间: {workspace.name}, "
|
||||
f"language={language}"
|
||||
)
|
||||
|
||||
result = workspace_service.create_workspace(
|
||||
db=db, workspace=workspace, user=current_user)
|
||||
db=db, workspace=workspace, user=current_user, language=language
|
||||
)
|
||||
|
||||
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
|
||||
api_logger.info(
|
||||
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
||||
f"创建者: {current_user.username}, language={language}"
|
||||
)
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间创建成功")
|
||||
|
||||
|
||||
4
api/app/core/__init__.py
Normal file
4
api/app/core/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:24
|
||||
162
api/app/core/agent/agent_middleware.py
Normal file
162
api/app/core/agent/agent_middleware.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Agent Middleware - 动态技能过滤"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
|
||||
from app.services.skill_service import SkillService
|
||||
from app.repositories.skill_repository import SkillRepository
|
||||
|
||||
|
||||
class AgentMiddleware:
|
||||
"""Agent 中间件 - 用于动态过滤和加载技能"""
|
||||
|
||||
def __init__(self, skills: Optional[dict] = None):
|
||||
"""
|
||||
初始化中间件
|
||||
|
||||
Args:
|
||||
skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]}
|
||||
"""
|
||||
self.skills = skills or {}
|
||||
self.enabled = self.skills.get('enabled', False)
|
||||
self.all_skills = self.skills.get('all_skills', False)
|
||||
self.skill_ids = self.skills.get('skill_ids', [])
|
||||
|
||||
@staticmethod
|
||||
def filter_tools(
|
||||
tools: List,
|
||||
message: str = "",
|
||||
skill_configs: Dict[str, Any] = None,
|
||||
tool_to_skill_map: Dict[str, str] = None
|
||||
) -> tuple[List, List[str]]:
|
||||
"""
|
||||
根据消息内容和技能配置动态过滤工具
|
||||
|
||||
Args:
|
||||
tools: 所有可用工具列表
|
||||
message: 用户消息(可用于智能过滤)
|
||||
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
|
||||
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
|
||||
|
||||
Returns:
|
||||
(过滤后的工具列表, 激活的技能ID列表)
|
||||
"""
|
||||
if not tools:
|
||||
return [], []
|
||||
|
||||
# 如果没有技能配置,返回所有工具
|
||||
if not skill_configs:
|
||||
return tools, []
|
||||
|
||||
# 基于关键词匹配激活技能
|
||||
activated_skill_ids = []
|
||||
message_lower = message.lower()
|
||||
|
||||
for skill_id, config in skill_configs.items():
|
||||
if not config.get('enabled', True):
|
||||
continue
|
||||
|
||||
keywords = config.get('keywords', [])
|
||||
# 如果没有关键词限制,或消息包含关键词,则激活该技能
|
||||
if not keywords or any(kw.lower() in message_lower for kw in keywords):
|
||||
activated_skill_ids.append(skill_id)
|
||||
|
||||
# 如果没有工具映射关系,返回所有工具
|
||||
if not tool_to_skill_map:
|
||||
return tools, activated_skill_ids
|
||||
|
||||
# 根据激活的技能过滤工具
|
||||
filtered_tools = []
|
||||
for tool in tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
# 如果工具不属于任何skill(base_tools),或者工具所属的skill被激活,则保留
|
||||
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
|
||||
filtered_tools.append(tool)
|
||||
|
||||
return filtered_tools, activated_skill_ids
|
||||
|
||||
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
加载技能关联的工具
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
tenant_id: 租户id
|
||||
base_tools: 基础工具列表
|
||||
|
||||
Returns:
|
||||
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
|
||||
"""
|
||||
|
||||
tools_dict = {}
|
||||
tool_to_skill_map = {} # 工具名称到技能ID的映射
|
||||
|
||||
if base_tools:
|
||||
for tool in base_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
tools_dict[tool_name] = tool
|
||||
# base_tools 不属于任何 skill,不加入映射
|
||||
|
||||
skill_configs = {}
|
||||
skill_ids_to_load = []
|
||||
|
||||
# 如果启用技能且 all_skills 为 True,加载租户下所有激活的技能
|
||||
if self.enabled and self.all_skills:
|
||||
skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000)
|
||||
skill_ids_to_load = [str(skill.id) for skill in skills]
|
||||
elif self.enabled and self.skill_ids:
|
||||
skill_ids_to_load = self.skill_ids
|
||||
|
||||
if skill_ids_to_load:
|
||||
for skill_id in skill_ids_to_load:
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
||||
if skill and skill.is_active:
|
||||
# 保存技能配置(包含prompt)
|
||||
config = skill.config or {}
|
||||
config['prompt'] = skill.prompt
|
||||
config['name'] = skill.name
|
||||
skill_configs[skill_id] = config
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 加载技能工具并获取映射关系
|
||||
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id)
|
||||
|
||||
# 只添加不冲突的 skill_tools
|
||||
for tool in skill_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
if tool_name not in tools_dict:
|
||||
tools_dict[tool_name] = tool
|
||||
# 复制映射关系
|
||||
if tool_name in skill_tool_map:
|
||||
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
|
||||
|
||||
return list(tools_dict.values()), skill_configs, tool_to_skill_map
|
||||
|
||||
@staticmethod
|
||||
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
根据激活的技能ID获取对应的提示词
|
||||
|
||||
Args:
|
||||
activated_skill_ids: 被激活的技能ID列表
|
||||
skill_configs: 技能配置字典
|
||||
|
||||
Returns:
|
||||
合并后的提示词
|
||||
"""
|
||||
prompts = []
|
||||
for skill_id in activated_skill_ids:
|
||||
config = skill_configs.get(skill_id, {})
|
||||
prompt = config.get('prompt')
|
||||
name = config.get('name', 'Skill')
|
||||
if prompt:
|
||||
prompts.append(f"# {name}\n{prompt}")
|
||||
|
||||
return "\n\n".join(prompts) if prompts else ""
|
||||
|
||||
@staticmethod
|
||||
def create_runnable():
|
||||
"""创建可运行的中间件"""
|
||||
return RunnablePassthrough()
|
||||
@@ -7,17 +7,18 @@ LangChain Agent 封装
|
||||
- 支持流式输出
|
||||
- 使用 RedBearLLM 支持多提供商
|
||||
"""
|
||||
import os
|
||||
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.models.models_model import ModelType, ModelProvider
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
@@ -28,16 +29,19 @@ logger = get_business_logger()
|
||||
class LangChainAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
is_omni: bool = False,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -50,13 +54,37 @@ class LangChainAgent:
|
||||
max_tokens: 最大 token 数
|
||||
system_prompt: 系统提示词
|
||||
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
||||
streaming: 是否启用流式输出(默认 True)
|
||||
streaming: 是否启用流式输出
|
||||
max_iterations: 最大迭代次数(None 表示自动计算:基础 5 次 + 每个工具 2 次)
|
||||
max_tool_consecutive_calls: 单个工具最大连续调用次数(默认 3 次)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
self.tools = tools or []
|
||||
self.streaming = streaming
|
||||
self.is_omni = is_omni
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
self.last_tool_called: Optional[str] = None
|
||||
|
||||
# 根据工具数量动态调整最大迭代次数
|
||||
# 基础值 + 每个工具额外的调用机会
|
||||
if max_iterations is None:
|
||||
# 自动计算:基础 5 次 + 每个工具 2 次额外机会
|
||||
self.max_iterations = 5 + len(self.tools) * 2
|
||||
else:
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
f"max_tool_consecutive_calls={self.max_tool_consecutive_calls}, "
|
||||
f"auto_calculated={max_iterations is None}"
|
||||
)
|
||||
|
||||
# 创建 RedBearLLM(支持多提供商)
|
||||
model_config = RedBearModelConfig(
|
||||
@@ -64,6 +92,7 @@ class LangChainAgent:
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -80,11 +109,14 @@ class LangChainAgent:
|
||||
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
||||
self._underlying_llm.streaming = True
|
||||
|
||||
# 包装工具以跟踪连续调用次数
|
||||
wrapped_tools = self._wrap_tools_with_tracking(self.tools) if self.tools else None
|
||||
|
||||
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
||||
# 无论是否有工具,都使用 agent 统一处理
|
||||
self.agent = create_agent(
|
||||
model=self.llm,
|
||||
tools=self.tools if self.tools else None,
|
||||
tools=wrapped_tools,
|
||||
system_prompt=self.system_prompt
|
||||
)
|
||||
|
||||
@@ -96,17 +128,92 @@ class LangChainAgent:
|
||||
"has_api_base": bool(api_base),
|
||||
"temperature": temperature,
|
||||
"streaming": streaming,
|
||||
"max_iterations": self.max_iterations,
|
||||
"max_tool_consecutive_calls": self.max_tool_consecutive_calls,
|
||||
"tool_count": len(self.tools),
|
||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||
"tool_count": len(self.tools)
|
||||
# "tool_count": len(self.tools)
|
||||
}
|
||||
)
|
||||
|
||||
def _wrap_tools_with_tracking(self, tools: Sequence[BaseTool]) -> List[BaseTool]:
|
||||
"""包装工具以跟踪连续调用次数
|
||||
|
||||
Args:
|
||||
tools: 原始工具列表
|
||||
|
||||
Returns:
|
||||
List[BaseTool]: 包装后的工具列表
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from functools import wraps
|
||||
|
||||
wrapped_tools = []
|
||||
|
||||
for original_tool in tools:
|
||||
tool_name = original_tool.name
|
||||
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||
|
||||
if not original_func:
|
||||
# 如果无法获取原始函数,直接使用原工具
|
||||
wrapped_tools.append(original_tool)
|
||||
continue
|
||||
|
||||
# 创建包装函数
|
||||
def make_wrapped_func(tool_name, original_func):
|
||||
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||
|
||||
@wraps(original_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""包装后的工具函数,跟踪连续调用次数"""
|
||||
# 检查是否是连续调用同一个工具
|
||||
if self.last_tool_called == tool_name:
|
||||
self.tool_call_counter[tool_name] = self.tool_call_counter.get(tool_name, 0) + 1
|
||||
else:
|
||||
# 切换到新工具,重置计数器
|
||||
self.tool_call_counter[tool_name] = 1
|
||||
self.last_tool_called = tool_name
|
||||
|
||||
current_count = self.tool_call_counter[tool_name]
|
||||
|
||||
logger.debug(
|
||||
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||
)
|
||||
|
||||
# 检查是否超过最大连续调用次数
|
||||
if current_count > self.max_tool_consecutive_calls:
|
||||
logger.warning(
|
||||
f"工具 '{tool_name}' 连续调用次数已达上限 ({self.max_tool_consecutive_calls}),"
|
||||
f"返回提示信息"
|
||||
)
|
||||
return (
|
||||
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||
)
|
||||
|
||||
# 调用原始工具函数
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
# 使用 StructuredTool 创建新工具
|
||||
wrapped_tool = StructuredTool(
|
||||
name=original_tool.name,
|
||||
description=original_tool.description,
|
||||
func=make_wrapped_func(tool_name, original_func),
|
||||
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||
)
|
||||
|
||||
wrapped_tools.append(wrapped_tool)
|
||||
|
||||
return wrapped_tools
|
||||
|
||||
def _prepare_messages(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""准备消息列表
|
||||
|
||||
@@ -114,6 +221,7 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表
|
||||
context: 上下文信息
|
||||
files: 多模态文件内容列表(已处理)
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
@@ -136,45 +244,49 @@ class LangChainAgent:
|
||||
if context:
|
||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
# 构建用户消息(支持多模态)
|
||||
if files and len(files) > 0:
|
||||
content_parts = self._build_multimodal_content(user_content, files)
|
||||
messages.append(HumanMessage(content=content_parts))
|
||||
else:
|
||||
# 纯文本消息
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
'''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
end_user_end=f"Term_{end_user_end}"
|
||||
print(messages)
|
||||
print(aimessages)
|
||||
session_id = store.save_session(
|
||||
userid=end_user_end,
|
||||
messages=messages,
|
||||
apply_id=end_user_end,
|
||||
group_id=end_user_end,
|
||||
aimessages=aimessages
|
||||
)
|
||||
store.delete_duplicate_sessions()
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
return session_id
|
||||
async def term_memory_redis_read(self,end_user_end):
|
||||
end_user_end = f"Term_{end_user_end}"
|
||||
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
messagss_list=[]
|
||||
for messages in history:
|
||||
query = messages.get("Query")
|
||||
aimessages = messages.get("Answer")
|
||||
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
return messagss_list
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建多模态消息内容
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
files: 文件列表(已由 MultimodalService 处理为对应 provider 的格式)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 消息内容列表
|
||||
"""
|
||||
# 根据 provider 使用不同的文本格式
|
||||
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
|
||||
# ModelProvider.GPUSTACK] or (
|
||||
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
|
||||
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
# else:
|
||||
# # 通义千问等: {"text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
|
||||
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
|
||||
if storage_type == "rag":
|
||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||
# 添加文件内容
|
||||
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
||||
content_parts.extend(files)
|
||||
|
||||
logger.debug(
|
||||
f"构建多模态消息: provider={self.provider}, "
|
||||
f"parts={len(content_parts)}, "
|
||||
f"files={len(files)}"
|
||||
)
|
||||
|
||||
return content_parts
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@@ -185,7 +297,8 @@ class LangChainAgent:
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -197,13 +310,12 @@ class LangChainAgent:
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat= message
|
||||
message_chat = message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.db import get_db
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
@@ -218,19 +330,11 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
|
||||
history_term_memory=await self.term_memory_redis_read(end_user_id)
|
||||
if memory_flag:
|
||||
if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
history_term_memory=';'.join(history_term_memory)
|
||||
logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
|
||||
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
|
||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
|
||||
logger.debug(
|
||||
"准备调用 LangChain Agent",
|
||||
@@ -238,25 +342,86 @@ class LangChainAgent:
|
||||
"has_context": bool(context),
|
||||
"has_history": bool(history),
|
||||
"has_tools": bool(self.tools),
|
||||
"message_count": len(messages)
|
||||
"has_files": bool(files),
|
||||
"message_count": len(messages),
|
||||
"max_iterations": self.max_iterations
|
||||
}
|
||||
)
|
||||
|
||||
# 统一使用 agent.invoke 调用
|
||||
result = await self.agent.ainvoke({"messages": messages})
|
||||
# 通过 recursion_limit 限制最大迭代次数,防止工具调用死循环
|
||||
try:
|
||||
result = await self.agent.ainvoke(
|
||||
{"messages": messages},
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
)
|
||||
except RecursionError as e:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
||||
extra={"error": str(e)}
|
||||
)
|
||||
# 返回一个友好的错误提示
|
||||
return {
|
||||
"content": f"抱歉,我在处理您的请求时遇到了问题。已达到最大处理步骤限制({self.max_iterations}次)。请尝试简化您的问题或稍后再试。",
|
||||
"model": self.model_name,
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
content = msg.content
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
logger.debug(f"AI 消息内容: {msg.content}")
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
logger.debug(f"提取字符串内容,长度: {len(content)}")
|
||||
elif isinstance(msg.content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
logger.debug(f"多模态响应,列表长度: {len(msg.content)}")
|
||||
text_parts = []
|
||||
for item in msg.content:
|
||||
logger.debug(f"处理项: {item}")
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取文本: {text[:100]}...")
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取文本: {text[:100]}...")
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
logger.debug(f"提取字符串: {item[:100]}...")
|
||||
content = "".join(text_parts)
|
||||
logger.debug(f"合并后内容长度: {len(content)}")
|
||||
else:
|
||||
content = str(msg.content)
|
||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
break
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id)
|
||||
await self.term_memory_save(message_chat,end_user_id,content)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -264,7 +429,7 @@ class LangChainAgent:
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,15 +448,16 @@ class LangChainAgent:
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id:Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -314,10 +480,6 @@ class LangChainAgent:
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.db import get_db
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
@@ -329,22 +491,13 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
history_term_memory = await self.term_memory_redis_read(end_user_id)
|
||||
if memory_flag:
|
||||
if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
history_term_memory = ';'.join(history_term_memory)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
||||
history_term_memory, actual_config_id)
|
||||
|
||||
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
|
||||
logger.debug(
|
||||
f"准备流式调用,has_tools={bool(self.tools)}, message_count={len(messages)}"
|
||||
f"准备流式调用,has_tools={bool(self.tools)}, has_files={bool(files)}, message_count={len(messages)}"
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
@@ -352,47 +505,106 @@ class LangChainAgent:
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
full_content=''
|
||||
full_content = ''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2"
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
|
||||
# 处理所有可能的流式事件
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
full_content+=chunk.content
|
||||
if chunk and hasattr(chunk, "content") and chunk.content:
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk:
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
full_content+=chunk.content
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
if hasattr(chunk, "content"):
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
elif isinstance(chunk, str):
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
logger.debug(f"工具调用开始: {event.get('name')}")
|
||||
elif kind == "on_tool_end":
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||
0) if response_meta else 0
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id)
|
||||
await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -406,5 +618,3 @@ class LangChainAgent:
|
||||
logger.info("=" * 80)
|
||||
logger.info("chat_stream 方法执行结束")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import secrets
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from app.schemas.api_key_schema import ApiKeyType
|
||||
from app.models.api_key_model import ApiKeyType
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
@@ -1,23 +1,44 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Annotated, Any, Dict, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
# ========================================================================
|
||||
# Deployment Mode Configuration
|
||||
# ========================================================================
|
||||
# community: 社区版(开源,功能受限)
|
||||
# cloud: SaaS 云服务版(全功能,按量计费)
|
||||
# enterprise: 企业私有化版(License 控制)
|
||||
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
|
||||
|
||||
# License 配置(企业版)
|
||||
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
|
||||
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
|
||||
|
||||
# 计费服务配置(SaaS 版)
|
||||
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
|
||||
|
||||
# 基础 URL(用于 SSO 回调等)
|
||||
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
|
||||
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
|
||||
|
||||
|
||||
# Neo4j Configuration (记忆系统数据库)
|
||||
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
|
||||
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
|
||||
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
|
||||
|
||||
|
||||
# Database configuration (Postgres)
|
||||
DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1")
|
||||
DB_PORT: int = int(os.getenv("DB_PORT", "5432"))
|
||||
@@ -37,7 +58,7 @@ class Settings:
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
|
||||
@@ -48,7 +69,7 @@ class Settings:
|
||||
ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000"))
|
||||
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true"
|
||||
ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10"))
|
||||
|
||||
|
||||
# Xinference configuration
|
||||
XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1")
|
||||
|
||||
@@ -57,23 +78,43 @@ class Settings:
|
||||
LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true"
|
||||
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "")
|
||||
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "")
|
||||
|
||||
|
||||
# LLM Request Configuration
|
||||
LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0"))
|
||||
LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2"))
|
||||
|
||||
|
||||
# JWT Token Configuration
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random")
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
||||
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
# SSO 免登配置
|
||||
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
|
||||
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||
|
||||
# Storage Configuration
|
||||
STORAGE_TYPE: str = os.getenv("STORAGE_TYPE", "local")
|
||||
|
||||
# Aliyun OSS Configuration
|
||||
OSS_ENDPOINT: str = os.getenv("OSS_ENDPOINT", "")
|
||||
OSS_ACCESS_KEY_ID: str = os.getenv("OSS_ACCESS_KEY_ID", "")
|
||||
OSS_ACCESS_KEY_SECRET: str = os.getenv("OSS_ACCESS_KEY_SECRET", "")
|
||||
OSS_BUCKET_NAME: str = os.getenv("OSS_BUCKET_NAME", "")
|
||||
|
||||
# AWS S3 Configuration
|
||||
S3_REGION: str = os.getenv("S3_REGION", "")
|
||||
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
|
||||
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
|
||||
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
|
||||
|
||||
# VOLC ASR settings
|
||||
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
||||
@@ -86,19 +127,20 @@ class Settings:
|
||||
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
|
||||
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
|
||||
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
|
||||
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
# ========================================================================
|
||||
|
||||
|
||||
# Superuser settings (internal defaults)
|
||||
FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com")
|
||||
FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin")
|
||||
FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password")
|
||||
|
||||
|
||||
# Generic File Upload (internal)
|
||||
GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads")
|
||||
ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true"
|
||||
@@ -115,6 +157,11 @@ class Settings:
|
||||
if origin.strip()
|
||||
]
|
||||
|
||||
# Language Configuration
|
||||
# Supported values: "zh" (Chinese), "en" (English)
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
@@ -123,7 +170,7 @@ class Settings:
|
||||
LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5"))
|
||||
LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true"
|
||||
LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true"
|
||||
|
||||
|
||||
# Sensitive Data Filtering
|
||||
ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true"
|
||||
|
||||
@@ -142,29 +189,87 @@ class Settings:
|
||||
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
|
||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||
|
||||
|
||||
# Celery configuration (internal)
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||
# 详见 docs/celery-env-bug-report.md
|
||||
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
|
||||
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
|
||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
||||
|
||||
# SMTP Email Configuration
|
||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
||||
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
||||
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
||||
|
||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Celery Beat Schedule Configuration (定时任务执行频率)
|
||||
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
|
||||
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
|
||||
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
|
||||
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
|
||||
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
|
||||
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
|
||||
|
||||
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
|
||||
# implicit_emotions_update: 每天几分执行(分钟,0-59)
|
||||
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
|
||||
# Memory Module Configuration (internal)
|
||||
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
|
||||
# Tool Management Configuration
|
||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
|
||||
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
|
||||
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
||||
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
||||
|
||||
# model square loading
|
||||
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800))
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
# ========================================================================
|
||||
# General Ontology Type Configuration
|
||||
# ========================================================================
|
||||
# 通用本体文件路径列表(逗号分隔)
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
|
||||
|
||||
# 是否启用通用本体类型功能
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
||||
|
||||
# Prompt 中最大类型数量
|
||||
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
|
||||
|
||||
# 核心通用类型列表(逗号分隔)
|
||||
CORE_GENERAL_TYPES: str = os.getenv(
|
||||
"CORE_GENERAL_TYPES",
|
||||
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
|
||||
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
|
||||
)
|
||||
|
||||
# 实验模式开关(允许通过 API 动态切换本体配置)
|
||||
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module output files.
|
||||
@@ -179,7 +284,7 @@ class Settings:
|
||||
if filename:
|
||||
return str(base_path / filename)
|
||||
return str(base_path)
|
||||
|
||||
|
||||
def ensure_memory_output_dir(self) -> None:
|
||||
"""
|
||||
Ensure the memory output directory exists.
|
||||
|
||||
@@ -46,6 +46,7 @@ class BizCode(IntEnum):
|
||||
RESOURCE_ALREADY_EXISTS = 5002
|
||||
VERSION_ALREADY_EXISTS = 5003
|
||||
STATE_CONFLICT = 5004
|
||||
RESOURCE_IN_USE = 5005
|
||||
|
||||
# 应用发布(6xxx)
|
||||
PUBLISH_FAILED = 6001
|
||||
@@ -82,6 +83,13 @@ class BizCode(IntEnum):
|
||||
MEMORY_WRITE_FAILED = 9501
|
||||
MEMORY_READ_FAILED = 9502
|
||||
MEMORY_CONFIG_NOT_FOUND = 9503
|
||||
|
||||
# Implicit Memory API(96xx)
|
||||
INVALID_USER_ID = 9601
|
||||
INSUFFICIENT_DATA = 9602
|
||||
INVALID_FILTER_PARAMS = 9603
|
||||
ANALYSIS_FAILED = 9604
|
||||
PROFILE_STORAGE_ERROR = 9605
|
||||
|
||||
# 系统(100xx)
|
||||
INTERNAL_ERROR = 10001
|
||||
@@ -103,24 +111,25 @@ HTTP_MAPPING = {
|
||||
BizCode.TOKEN_EXPIRED: 401,
|
||||
BizCode.TOKEN_BLACKLISTED: 401,
|
||||
BizCode.FORBIDDEN: 403,
|
||||
BizCode.TENANT_NOT_FOUND: 404,
|
||||
BizCode.TENANT_NOT_FOUND: 400,
|
||||
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||
BizCode.NOT_FOUND: 404,
|
||||
BizCode.NOT_FOUND: 400,
|
||||
BizCode.USER_NOT_FOUND: 200,
|
||||
BizCode.WORKSPACE_NOT_FOUND: 404,
|
||||
BizCode.MODEL_NOT_FOUND: 404,
|
||||
BizCode.KNOWLEDGE_NOT_FOUND: 404,
|
||||
BizCode.DOCUMENT_NOT_FOUND: 404,
|
||||
BizCode.FILE_NOT_FOUND: 404,
|
||||
BizCode.APP_NOT_FOUND: 404,
|
||||
BizCode.RELEASE_NOT_FOUND: 404,
|
||||
BizCode.WORKSPACE_NOT_FOUND: 400,
|
||||
BizCode.MODEL_NOT_FOUND: 400,
|
||||
BizCode.KNOWLEDGE_NOT_FOUND: 400,
|
||||
BizCode.DOCUMENT_NOT_FOUND: 400,
|
||||
BizCode.FILE_NOT_FOUND: 400,
|
||||
BizCode.APP_NOT_FOUND: 400,
|
||||
BizCode.RELEASE_NOT_FOUND: 400,
|
||||
BizCode.DUPLICATE_NAME: 409,
|
||||
BizCode.RESOURCE_ALREADY_EXISTS: 409,
|
||||
BizCode.VERSION_ALREADY_EXISTS: 409,
|
||||
BizCode.STATE_CONFLICT: 409,
|
||||
BizCode.RESOURCE_IN_USE: 409,
|
||||
BizCode.PUBLISH_FAILED: 500,
|
||||
BizCode.NO_DRAFT_TO_PUBLISH: 400,
|
||||
BizCode.ROLLBACK_TARGET_NOT_FOUND: 404,
|
||||
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,
|
||||
BizCode.APP_TYPE_NOT_SUPPORTED: 400,
|
||||
BizCode.AGENT_CONFIG_MISSING: 400,
|
||||
BizCode.SHARE_DISABLED: 403,
|
||||
@@ -159,6 +168,13 @@ HTTP_MAPPING = {
|
||||
BizCode.MEMORY_READ_FAILED: 500,
|
||||
BizCode.MEMORY_CONFIG_NOT_FOUND: 400,
|
||||
|
||||
# Implicit Memory API 错误码映射
|
||||
BizCode.INVALID_USER_ID: 400,
|
||||
BizCode.INSUFFICIENT_DATA: 400,
|
||||
BizCode.INVALID_FILTER_PARAMS: 400,
|
||||
BizCode.ANALYSIS_FAILED: 500,
|
||||
BizCode.PROFILE_STORAGE_ERROR: 500,
|
||||
|
||||
BizCode.INTERNAL_ERROR: 500,
|
||||
BizCode.DB_ERROR: 500,
|
||||
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||
|
||||
82
api/app/core/language_utils.py
Normal file
82
api/app/core/language_utils.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""语言处理工具模块
|
||||
|
||||
本模块提供集中化的语言校验和处理功能,确保整个应用中语言参数的一致性。
|
||||
|
||||
Functions:
|
||||
validate_language: 校验语言参数,确保其为有效值
|
||||
get_language_from_header: 从请求头获取并校验语言参数
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 支持的语言列表
|
||||
SUPPORTED_LANGUAGES = {"zh", "en"}
|
||||
|
||||
# 默认回退语言
|
||||
DEFAULT_LANGUAGE = "zh"
|
||||
|
||||
|
||||
def validate_language(language: Optional[str]) -> str:
|
||||
"""
|
||||
校验语言参数,确保其为有效值。
|
||||
|
||||
Args:
|
||||
language: 待校验的语言代码,可以是 None、"zh"、"en" 或其他值
|
||||
|
||||
Returns:
|
||||
有效的语言代码("zh" 或 "en")
|
||||
|
||||
Examples:
|
||||
>>> validate_language("zh")
|
||||
'zh'
|
||||
>>> validate_language("en")
|
||||
'en'
|
||||
>>> validate_language("EN") # 大小写不敏感
|
||||
'en'
|
||||
>>> validate_language(None) # None 回退到默认值
|
||||
'zh'
|
||||
>>> validate_language("fr") # 不支持的语言回退到默认值
|
||||
'zh'
|
||||
"""
|
||||
if language is None:
|
||||
return DEFAULT_LANGUAGE
|
||||
|
||||
# 标准化:转小写并去除空白
|
||||
lang = str(language).lower().strip()
|
||||
|
||||
if lang in SUPPORTED_LANGUAGES:
|
||||
return lang
|
||||
|
||||
logger.warning(
|
||||
f"无效的语言参数 '{language}',已回退到默认值 '{DEFAULT_LANGUAGE}'。"
|
||||
f"支持的语言: {SUPPORTED_LANGUAGES}"
|
||||
)
|
||||
return DEFAULT_LANGUAGE
|
||||
|
||||
|
||||
def get_language_from_header(language_type: Optional[str]) -> str:
|
||||
"""
|
||||
从请求头获取并校验语言参数。
|
||||
|
||||
这是一个便捷函数,用于在 controller 层统一处理 X-Language-Type Header。
|
||||
|
||||
Args:
|
||||
language_type: 从 X-Language-Type Header 获取的语言值
|
||||
|
||||
Returns:
|
||||
有效的语言代码("zh" 或 "en")
|
||||
|
||||
Examples:
|
||||
>>> get_language_from_header(None) # Header 未传递
|
||||
'zh'
|
||||
>>> get_language_from_header("en")
|
||||
'en'
|
||||
>>> get_language_from_header("invalid") # 无效值回退
|
||||
'zh'
|
||||
"""
|
||||
return validate_language(language_type)
|
||||
@@ -38,6 +38,56 @@ class SensitiveDataLoggingFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
|
||||
class Neo4jSuccessNotificationFilter(logging.Filter):
|
||||
"""Neo4j 日志过滤器:过滤成功/信息性状态的通知,保留真正的警告和错误
|
||||
|
||||
Neo4j 驱动会以 WARNING 级别记录所有数据库通知,包括成功的操作。
|
||||
这个过滤器会过滤掉以下 GQL 状态码的通知,只保留真正的警告和错误:
|
||||
- 00000: 成功完成 (successful completion)
|
||||
- 00N00: 无数据 (no data)
|
||||
- 00NA0: 无数据,信息性通知 (no data, informational notification)
|
||||
|
||||
使用正则表达式进行更严格的匹配,避免误过滤无关的警告。
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
# 编译正则表达式以提高性能
|
||||
# 匹配所有"成功/信息性"的 GQL 状态码:
|
||||
# 00000 = 成功完成, 00N00 = 无数据, 00NA0 = 无数据信息性通知
|
||||
GQL_STATUS_PATTERN = re.compile(r"gql_status=['\"](00000|00N00|00NA0)['\"]")
|
||||
|
||||
# 匹配 status_description 中的成功完成或信息性通知消息
|
||||
SUCCESS_DESC_PATTERN = re.compile(r"status_description=['\"]note:\s*(successful\s+completion|no\s+data)['\"]", re.IGNORECASE)
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
"""
|
||||
过滤 Neo4j 成功通知
|
||||
|
||||
Args:
|
||||
record: 日志记录
|
||||
|
||||
Returns:
|
||||
True表示允许记录,False表示拒绝(过滤掉)
|
||||
"""
|
||||
# 只处理 INFO 和 WARNING 级别的日志
|
||||
# Neo4j 驱动对 severity='INFORMATION' 的通知使用 INFO 级别,
|
||||
# 对 severity='WARNING' 的通知使用 WARNING 级别
|
||||
if record.levelno not in (logging.INFO, logging.WARNING):
|
||||
return True
|
||||
|
||||
# 检查是否是 Neo4j 的成功通知
|
||||
message = str(record.msg)
|
||||
|
||||
# 使用正则表达式进行更严格的匹配
|
||||
# 这样可以避免误过滤包含这些子字符串但不是 Neo4j 通知的日志
|
||||
if self.GQL_STATUS_PATTERN.search(message) or self.SUCCESS_DESC_PATTERN.search(message):
|
||||
return False # 过滤掉这条日志
|
||||
|
||||
# 保留其他所有日志(包括真正的警告和错误)
|
||||
return True
|
||||
|
||||
|
||||
class LoggingConfig:
|
||||
"""全局日志配置类"""
|
||||
|
||||
@@ -65,6 +115,22 @@ class LoggingConfig:
|
||||
# 清除现有处理器
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Neo4j 通知过滤器 - 挂在 handler 上确保所有传播上来的日志都能被过滤
|
||||
neo4j_filter = Neo4jSuccessNotificationFilter()
|
||||
|
||||
# 抑制 Neo4j 通知日志
|
||||
# Neo4j 驱动内部会给 neo4j.notifications logger 配置自己的 handler,
|
||||
# 导致日志绕过根 logger 的 filter 直接输出。
|
||||
# 多管齐下确保过滤生效:
|
||||
# 1. 设置 neo4j.notifications 级别为 WARNING(过滤 INFO 级别的 00NA0 通知)
|
||||
# 2. 在所有 neo4j logger 上添加 filter(过滤 WARNING 级别的成功通知)
|
||||
# 3. 在根 handler 上也添加 filter(兜底)
|
||||
neo4j_notifications_logger = logging.getLogger("neo4j.notifications")
|
||||
neo4j_notifications_logger.setLevel(logging.WARNING)
|
||||
for neo4j_logger_name in ["neo4j", "neo4j.io", "neo4j.pool", "neo4j.notifications"]:
|
||||
neo4j_logger = logging.getLogger(neo4j_logger_name)
|
||||
neo4j_logger.addFilter(neo4j_filter)
|
||||
|
||||
# 创建格式化器
|
||||
formatter = logging.Formatter(
|
||||
fmt=settings.LOG_FORMAT,
|
||||
@@ -80,6 +146,7 @@ class LoggingConfig:
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||
console_handler.addFilter(sensitive_filter)
|
||||
console_handler.addFilter(neo4j_filter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器(带轮转)
|
||||
@@ -93,6 +160,7 @@ class LoggingConfig:
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||
file_handler.addFilter(sensitive_filter)
|
||||
file_handler.addFilter(neo4j_filter)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
cls._initialized = True
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
"""
|
||||
LangGraph Graph package for memory agent.
|
||||
|
||||
This package provides the LangGraph workflow orchestrator with modular
|
||||
node implementations, routing logic, and state management.
|
||||
|
||||
Package structure:
|
||||
- read_graph: Main graph factory for read operations
|
||||
- write_graph: Main graph factory for write operations
|
||||
- nodes: LangGraph node implementations
|
||||
- routing: State routing logic
|
||||
- state: State management utilities
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
|
||||
__all__ = ['make_read_graph']
|
||||
@@ -4,7 +4,7 @@ LangGraph node implementations.
|
||||
This module contains custom node implementations for the LangGraph workflow.
|
||||
"""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
|
||||
from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
|
||||
|
||||
__all__ = ["ToolExecutionNode", "create_input_message"]
|
||||
# from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
|
||||
# from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
|
||||
#
|
||||
# __all__ = ["ToolExecutionNode", "create_input_message"]
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||
|
||||
|
||||
def content_input_node(state: ReadState) -> ReadState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
return {"data": content}
|
||||
|
||||
def content_input_write(state: WriteState) -> WriteState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
return {"data": content}
|
||||
@@ -1,150 +0,0 @@
|
||||
"""
|
||||
Input node for LangGraph workflow entry point.
|
||||
|
||||
This module provides the create_input_message function which processes initial
|
||||
user input with multimodal support and creates the first tool call message.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_input_message(
|
||||
state: Dict[str, Any],
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
multimodal_processor: MultimodalProcessor,
|
||||
memory_config: MemoryConfig,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create initial tool call message from user input.
|
||||
|
||||
This function:
|
||||
1. Extracts the last message content from state
|
||||
2. Processes multimodal inputs (images/audio) using the multimodal processor
|
||||
3. Generates a unique message ID
|
||||
4. Extracts namespace from session_id
|
||||
5. Handles verified_data extraction for backward compatibility
|
||||
6. Returns AIMessage with complete tool_calls structure
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary containing messages
|
||||
tool_name: Name of the tool to invoke (typically "Split_The_Problem")
|
||||
session_id: Session identifier (format: "call_id_{namespace}")
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
multimodal_processor: Processor for handling image/audio inputs
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
State update with AIMessage containing tool_call
|
||||
|
||||
Examples:
|
||||
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
|
||||
>>> result = await create_input_message(
|
||||
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config
|
||||
... )
|
||||
>>> result["messages"][0].tool_calls[0]["name"]
|
||||
'Split_The_Problem'
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# Extract last message content
|
||||
if messages:
|
||||
last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1])
|
||||
else:
|
||||
logger.warning("[create_input_message] No messages in state, using empty string")
|
||||
last_message = ""
|
||||
|
||||
logger.debug(f"[create_input_message] Original input: {last_message[:100]}...")
|
||||
|
||||
# Process multimodal input (images/audio)
|
||||
try:
|
||||
processed_content = await multimodal_processor.process_input(last_message)
|
||||
if processed_content != last_message:
|
||||
logger.info(
|
||||
f"[create_input_message] Multimodal processing converted input "
|
||||
f"from {len(last_message)} to {len(processed_content)} chars"
|
||||
)
|
||||
last_message = processed_content
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[create_input_message] Multimodal processing failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Continue with original content
|
||||
|
||||
# Generate unique message ID
|
||||
uuid_str = uuid.uuid4()
|
||||
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Extract namespace from session_id
|
||||
# Expected format: "call_id_{namespace}" or similar
|
||||
try:
|
||||
namespace = str(session_id).split('_id_')[1]
|
||||
except (IndexError, AttributeError):
|
||||
logger.warning(
|
||||
f"[create_input_message] Could not extract namespace from session_id: {session_id}"
|
||||
)
|
||||
namespace = "unknown"
|
||||
|
||||
# Handle verified_data extraction (backward compatibility)
|
||||
# This regex-based extraction is kept for compatibility with existing data formats
|
||||
if 'verified_data' in str(last_message):
|
||||
try:
|
||||
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
|
||||
query_match = re.findall(r'"query": "(.*?)",', messages_last)
|
||||
if query_match:
|
||||
last_message = query_match[0]
|
||||
logger.debug(
|
||||
f"[create_input_message] Extracted query from verified_data: {last_message}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[create_input_message] Failed to extract query from verified_data: {e}"
|
||||
)
|
||||
|
||||
# Construct tool call message
|
||||
tool_call_id = f"{session_id}_{uuid_str}"
|
||||
|
||||
logger.info(
|
||||
f"[create_input_message] Creating tool call for '{tool_name}' "
|
||||
f"with ID: {tool_call_id}"
|
||||
)
|
||||
|
||||
# Build tool arguments
|
||||
tool_args = {
|
||||
"sentence": last_message,
|
||||
"sessionid": session_id,
|
||||
"messages_id": str(uuid_str),
|
||||
"search_switch": search_switch,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"memory_config": memory_config,
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": tool_name,
|
||||
"args": tool_args,
|
||||
"id": tool_call_id
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
250
api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py
Normal file
250
api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
)
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ProblemNodeService(LLMServiceMixin):
|
||||
"""问题处理节点服务类"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
problem_service = ProblemNodeService()
|
||||
|
||||
|
||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
system_prompt = await problem_service.template_service.render_template(
|
||||
template_name='problem_breakdown_prompt.jinja2',
|
||||
operation_name='split_the_problem',
|
||||
history=history,
|
||||
sentence=content,
|
||||
json_schema=json_schema
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
with get_db_context() as db_session:
|
||||
structured = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
# 添加更详细的日志记录
|
||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||
|
||||
# 验证结构化响应
|
||||
if not structured or not hasattr(structured, 'root'):
|
||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
elif not structured.root:
|
||||
logger.warning("Split_The_Problem: 结构化响应的root为空")
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
else:
|
||||
split_result = json.dumps(
|
||||
[item.model_dump() for item in structured.root],
|
||||
ensure_ascii=False
|
||||
)
|
||||
|
||||
split_result_dict = []
|
||||
for index, item in enumerate(json.loads(split_result)):
|
||||
split_data = {
|
||||
"id": f"Q{index + 1}",
|
||||
"question": item['extended_question'],
|
||||
"type": item['type'],
|
||||
"reason": item['reason']
|
||||
}
|
||||
split_result_dict.append(split_data)
|
||||
|
||||
logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项")
|
||||
|
||||
result = {
|
||||
"context": split_result,
|
||||
"original": content,
|
||||
"_intermediate": {
|
||||
"type": "problem_split",
|
||||
"title": "问题拆分",
|
||||
"data": split_result_dict,
|
||||
"original_query": content
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Split_The_Problem failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"content_length": len(content),
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||
|
||||
# 创建默认的空结果
|
||||
result = {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": content,
|
||||
"error": str(e),
|
||||
"_intermediate": {
|
||||
"type": "problem_split",
|
||||
"title": "问题拆分",
|
||||
"data": [],
|
||||
"original_query": content,
|
||||
"error": error_details
|
||||
}
|
||||
}
|
||||
|
||||
# 返回更新后的状态,包含spit_context字段
|
||||
return {"spit_data": result}
|
||||
|
||||
|
||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
"""问题扩展节点"""
|
||||
# 获取原始数据和分解结果
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
databasets = {}
|
||||
try:
|
||||
data = json.loads(data)
|
||||
for i in data:
|
||||
databasets[i['extended_question']] = i['type']
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(f"Problem_Extension: 数据解析失败: {e}")
|
||||
# 使用空字典作为fallback
|
||||
databasets = {}
|
||||
data = []
|
||||
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
system_prompt = await problem_service.template_service.render_template(
|
||||
template_name='Problem_Extension_prompt.jinja2',
|
||||
operation_name='problem_extension',
|
||||
history=history,
|
||||
questions=databasets,
|
||||
json_schema=json_schema
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
with get_db_context() as db_session:
|
||||
response_content = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||
|
||||
# 验证结构化响应
|
||||
if not response_content or not hasattr(response_content, 'root'):
|
||||
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
||||
aggregated_dict = {}
|
||||
elif not response_content.root:
|
||||
logger.warning("Problem_Extension: 结构化响应的root为空")
|
||||
aggregated_dict = {}
|
||||
else:
|
||||
# Aggregate results by original question
|
||||
aggregated_dict = {}
|
||||
for item in response_content.root:
|
||||
try:
|
||||
key = getattr(item, "original_question", None) or (
|
||||
item.get("original_question") if isinstance(item, dict) else None
|
||||
)
|
||||
value = getattr(item, "extended_question", None) or (
|
||||
item.get("extended_question") if isinstance(item, dict) else None
|
||||
)
|
||||
if not key or not value:
|
||||
logger.warning(f"Problem_Extension: 跳过无效项: key={key}, value={value}")
|
||||
continue
|
||||
aggregated_dict.setdefault(key, []).append(value)
|
||||
except Exception as item_error:
|
||||
logger.warning(f"Problem_Extension: 处理项目时出错: {item_error}")
|
||||
continue
|
||||
|
||||
logger.info(f"Problem_Extension: 成功生成 {len(aggregated_dict)} 个扩展问题组")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Problem_Extension: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"questions_count": len(databasets),
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Problem_Extension error details: {error_details}")
|
||||
aggregated_dict = {}
|
||||
|
||||
logger.info("Problem extension")
|
||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
print(time.time() - start)
|
||||
result = {
|
||||
"context": aggregated_dict,
|
||||
"original": data,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "problem_extension",
|
||||
"title": "问题扩展",
|
||||
"data": aggregated_dict,
|
||||
"original_query": content,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
return {"problem_extension": result}
|
||||
@@ -0,0 +1,411 @@
|
||||
# ===== 标准库 =====
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
|
||||
# ===== 第三方库 =====
|
||||
from langchain.agents import create_agent
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
||||
create_hybrid_retrieval_tool_sync,
|
||||
create_time_retrieval_tool,
|
||||
extract_tool_message_content,
|
||||
)
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id": os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query = question
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except Exception:
|
||||
retrieval_knowledge = []
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||
|
||||
|
||||
async def llm_infomation(state: ReadState) -> ReadState:
|
||||
memory_config = state.get('memory_config', None)
|
||||
model_id = memory_config.llm_model_id
|
||||
tenant_id = memory_config.tenant_id
|
||||
|
||||
# 使用现有的 memory_config 而不是重新查询数据库
|
||||
# 或者使用线程安全的数据库访问
|
||||
with get_db_context() as db:
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return result_pydantic
|
||||
|
||||
|
||||
async def clean_databases(data) -> str:
|
||||
"""
|
||||
简化的数据库搜索结果清理函数
|
||||
|
||||
Args:
|
||||
data: 搜索结果数据
|
||||
|
||||
Returns:
|
||||
清理后的内容字符串
|
||||
"""
|
||||
try:
|
||||
# 解析JSON字符串
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
return data
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return str(data)
|
||||
|
||||
# 获取结果数据
|
||||
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
||||
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
||||
results = data.get('results', data)
|
||||
if not isinstance(results, dict):
|
||||
return str(results)
|
||||
|
||||
# 收集所有内容
|
||||
content_list = []
|
||||
|
||||
# 处理重排序结果
|
||||
reranked = results.get('reranked_results', {})
|
||||
if reranked:
|
||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
||||
items = reranked.get(category, [])
|
||||
if isinstance(items, list):
|
||||
content_list.extend(items)
|
||||
# 处理时间搜索结果
|
||||
time_search = results.get('time_search', {})
|
||||
if time_search:
|
||||
if isinstance(time_search, dict):
|
||||
statements = time_search.get('statements', time_search.get('time_search', []))
|
||||
if isinstance(statements, list):
|
||||
content_list.extend(statements)
|
||||
elif isinstance(time_search, list):
|
||||
content_list.extend(time_search)
|
||||
|
||||
# 提取文本内容
|
||||
text_parts = []
|
||||
for item in content_list:
|
||||
if isinstance(item, dict):
|
||||
text = item.get('statement') or item.get('content', '')
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
|
||||
return '\n'.join(text_parts).strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"clean_databases failed: {e}", exc_info=True)
|
||||
return str(data)
|
||||
|
||||
|
||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
'''
|
||||
|
||||
模型信息
|
||||
'''
|
||||
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original = state.get('data', '')
|
||||
problem_list = []
|
||||
for key, values in problem_extension.items():
|
||||
for data in values:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
async def process_question_nodes(idx, question):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"question": question,
|
||||
"return_raw_results": True
|
||||
}
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await SearchService().execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
|
||||
return {
|
||||
"Query_small": cleaned_query,
|
||||
"Result_small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": idx + 1,
|
||||
"total": len(problem_list)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for question '{question}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty result for this question
|
||||
return {
|
||||
"Query_small": question,
|
||||
"Result_small": "",
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": question,
|
||||
"raw_results": [],
|
||||
"index": idx + 1,
|
||||
"total": len(problem_list)
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
databases_data = {
|
||||
"Query": original,
|
||||
"Expansion_issue": databases_anser
|
||||
}
|
||||
|
||||
# Collect intermediate outputs before deduplication
|
||||
intermediate_outputs = []
|
||||
for item in databases_anser:
|
||||
if '_intermediate' in item:
|
||||
intermediate_outputs.append(item['_intermediate'])
|
||||
|
||||
# Deduplicate and merge results
|
||||
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
|
||||
deduplicated_data_merged = merge_to_key_value_pairs(
|
||||
deduplicated_data,
|
||||
'Query_small',
|
||||
'Result_small'
|
||||
)
|
||||
|
||||
# Restructure for Verify/Retrieve_Summary compatibility
|
||||
keys, val = [], []
|
||||
for item in deduplicated_data_merged:
|
||||
for items_key, items_value in item.items():
|
||||
keys.append(items_key)
|
||||
val.append(items_value)
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val, strict=False):
|
||||
if j != ['']:
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
})
|
||||
|
||||
dup_databases = {
|
||||
"Query": original,
|
||||
"Expansion_issue": send_verify,
|
||||
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
|
||||
}
|
||||
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取end_user_id
|
||||
import time
|
||||
start = time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original = state.get('data', '')
|
||||
problem_list = []
|
||||
for key, values in problem_extension.items():
|
||||
for data in values:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
databases_anser = []
|
||||
|
||||
async def get_llm_info():
|
||||
with get_db_context() as db: # 使用同步数据库上下文管理器
|
||||
config_service = MemoryConfigService(db)
|
||||
return await llm_infomation(state)
|
||||
|
||||
llm_config = await get_llm_info()
|
||||
api_key_obj = llm_config.api_keys[0]
|
||||
api_key = api_key_obj.api_key
|
||||
api_base = api_key_obj.api_base
|
||||
model_name = api_key_obj.model_name
|
||||
llm = ChatOpenAI(
|
||||
model=model_name,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
|
||||
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
tools=[time_retrieval_tool, hybrid_retrieval],
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
import asyncio
|
||||
|
||||
# 在模块级别定义信号量,限制最大并发数
|
||||
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
|
||||
|
||||
async def process_question(idx, question):
|
||||
async with SEMAPHORE: # 限制并发
|
||||
try:
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||
question)
|
||||
else:
|
||||
cleaned_query = question
|
||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||
import asyncio
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: agent.invoke({"messages": question})
|
||||
)
|
||||
tool_results = extract_tool_message_content(response)
|
||||
if tool_results == None:
|
||||
raw_results = []
|
||||
clean_content = ''
|
||||
else:
|
||||
raw_results = tool_results['content']
|
||||
clean_content = await clean_databases(raw_results)
|
||||
|
||||
try:
|
||||
raw_results = raw_results['results']
|
||||
except Exception:
|
||||
raw_results = []
|
||||
|
||||
return {
|
||||
"Query_small": cleaned_query,
|
||||
"Result_small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": idx + 1,
|
||||
"total": len(problem_list)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for question '{question}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty result for this question
|
||||
return {
|
||||
"Query_small": question,
|
||||
"Result_small": "",
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": question,
|
||||
"raw_results": [],
|
||||
"index": idx + 1,
|
||||
"total": len(problem_list)
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
import asyncio
|
||||
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
databases_data = {
|
||||
"Query": original,
|
||||
"Expansion_issue": databases_anser
|
||||
}
|
||||
|
||||
# Collect intermediate outputs before deduplication
|
||||
intermediate_outputs = []
|
||||
for item in databases_anser:
|
||||
if '_intermediate' in item:
|
||||
intermediate_outputs.append(item['_intermediate'])
|
||||
|
||||
# Deduplicate and merge results
|
||||
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
|
||||
deduplicated_data_merged = merge_to_key_value_pairs(
|
||||
deduplicated_data,
|
||||
'Query_small',
|
||||
'Result_small'
|
||||
)
|
||||
|
||||
# Restructure for Verify/Retrieve_Summary compatibility
|
||||
keys, val = [], []
|
||||
for item in deduplicated_data_merged:
|
||||
for items_key, items_value in item.items():
|
||||
keys.append(items_key)
|
||||
val.append(items_value)
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val, strict=False):
|
||||
if j != ['']:
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
})
|
||||
|
||||
dup_databases = {
|
||||
"Query": original,
|
||||
"Expansion_issue": send_verify,
|
||||
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
|
||||
}
|
||||
# with open('retrieve_text.json', 'w') as f:
|
||||
# json.dump(dup_databases, f, indent=4)
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve': dup_databases}
|
||||
371
api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py
Normal file
371
api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
)
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
)
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SummaryNodeService(LLMServiceMixin):
|
||||
"""总结节点服务类"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id": os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query = question
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except Exception:
|
||||
retrieval_knowledge = []
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||
search_mode) -> str:
|
||||
"""
|
||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
|
||||
# 构建系统提示词
|
||||
if str(search_mode) == "0":
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
template_name=template_name,
|
||||
operation_name=operation_name,
|
||||
data=retrieve_info,
|
||||
query=data
|
||||
)
|
||||
else:
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
template_name=template_name,
|
||||
operation_name=operation_name,
|
||||
query=data,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
try:
|
||||
# 使用优化的LLM服务进行结构化输出
|
||||
with get_db_context() as db_session:
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# 验证结构化响应
|
||||
if structured is None:
|
||||
logger.warning("LLM返回None,使用默认回答")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
# 根据操作类型提取答案
|
||||
if operation_name == "summary":
|
||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
# 处理RetrieveSummaryResponse
|
||||
if hasattr(structured, 'data') and structured.data:
|
||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
logger.warning("结构化响应缺少data字段")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
# 验证答案不为空
|
||||
if not aimessages or aimessages.strip() == "":
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
return aimessages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||
|
||||
# 尝试非结构化输出作为fallback
|
||||
try:
|
||||
logger.info("尝试非结构化输出作为fallback")
|
||||
response = await summary_service.call_llm_simple(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
fallback_message="信息不足,无法回答"
|
||||
)
|
||||
|
||||
if response and response.strip():
|
||||
# 简单清理响应
|
||||
cleaned_response = response.strip()
|
||||
# 移除可能的JSON标记
|
||||
if cleaned_response.startswith('```'):
|
||||
lines = cleaned_response.split('\n')
|
||||
cleaned_response = '\n'.join(lines[1:-1])
|
||||
|
||||
return cleaned_response
|
||||
else:
|
||||
return "信息不足,无法回答"
|
||||
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback也失败: {fallback_error}")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
|
||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
user_id=end_user_id,
|
||||
query=data,
|
||||
apply_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
await SessionService(store).cleanup_duplicates()
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
|
||||
|
||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
input_summary = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "input_summary",
|
||||
"title": "快速答案",
|
||||
"summary": aimessages,
|
||||
"query": data,
|
||||
"raw_results": raw_results,
|
||||
"search_mode": "quick_search",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
retrieve = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "retrieval_summary",
|
||||
"title": "快速检索",
|
||||
"summary": aimessages,
|
||||
"query": data,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
return input_summary, retrieve
|
||||
|
||||
|
||||
async def Input_Summary(state: ReadState) -> ReadState:
|
||||
start = time.time()
|
||||
storage_type = state.get("storage_type", '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
history = await summary_history(state)
|
||||
search_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
}
|
||||
|
||||
try:
|
||||
if storage_type != "rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
||||
memory_config=memory_config)
|
||||
else:
|
||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||
except Exception as e:
|
||||
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
|
||||
retrieve_info, question, raw_results = "", data, []
|
||||
try:
|
||||
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
||||
# 'input_summary',RetrieveSummaryResponse)
|
||||
# logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
||||
summary = summary_result[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Input_Summary failed: {e}", exc_info=True)
|
||||
summary = {
|
||||
"status": "fail",
|
||||
"summary_result": "信息不足,无法回答",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
retrieve = state.get("retrieve", '')
|
||||
history = await summary_history(state)
|
||||
import json
|
||||
with open("检索.json", "w", encoding='utf-8') as f:
|
||||
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
||||
retrieve = retrieve.get("Expansion_issue", [])
|
||||
start = time.time()
|
||||
retrieve_info_str = []
|
||||
for data in retrieve:
|
||||
if data == '':
|
||||
retrieve_info_str = ''
|
||||
else:
|
||||
for key, value in data.items():
|
||||
if key == 'Answer_Small':
|
||||
for i in value:
|
||||
retrieve_info_str.append(i)
|
||||
retrieve_info_str = list(set(retrieve_info_str))
|
||||
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages = await summary_llm(state, history, retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
logger.info(f"Summary after retrieval: {aimessages}")
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary(state: ReadState) -> ReadState:
|
||||
start = time.time()
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
verify_expansion_issue = verify.get("verified_data", '')
|
||||
retrieve_info_str = ''
|
||||
for data in verify_expansion_issue:
|
||||
for key, value in data.items():
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str += i + '\n'
|
||||
history = await summary_history(state)
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
try:
|
||||
duration = time.time() - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary_fails(state: ReadState) -> ReadState:
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
verify_expansion_issue = verify.get("verified_data", '')
|
||||
retrieve_info_str = ''
|
||||
for data in verify_expansion_issue:
|
||||
for key, value in data.items():
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str += i + '\n'
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
return {"summary": result}
|
||||
@@ -1,234 +0,0 @@
|
||||
"""
|
||||
Tool execution node for LangGraph workflow.
|
||||
|
||||
This module provides the ToolExecutionNode class which wraps tool execution
|
||||
with parameter transformation logic using the ParameterBuilder service.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||
extract_content_payload,
|
||||
extract_tool_call_id,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolExecutionNode:
|
||||
"""
|
||||
Custom LangGraph node that wraps tool execution with parameter transformation.
|
||||
|
||||
This node extracts content from previous tool results, transforms parameters
|
||||
based on tool type using ParameterBuilder, and invokes the tool with the
|
||||
correct argument structure.
|
||||
|
||||
Attributes:
|
||||
tool_node: LangGraph ToolNode wrapping the actual tool
|
||||
id: Node identifier for message IDs
|
||||
tool_name: Name of the tool being executed
|
||||
namespace: Namespace for session management
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool: Callable,
|
||||
node_id: str,
|
||||
namespace: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
parameter_builder: ParameterBuilder,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
):
|
||||
"""
|
||||
Initialize the tool execution node.
|
||||
|
||||
Args:
|
||||
tool: The tool function to execute
|
||||
node_id: Identifier for this node (used in message IDs)
|
||||
namespace: Namespace for session management
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
storage_type: Storage type for the workspace
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
self.tool_node = ToolNode([tool])
|
||||
self.id = node_id
|
||||
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
self.namespace = namespace
|
||||
self.search_switch = search_switch
|
||||
self.apply_id = apply_id
|
||||
self.group_id = group_id
|
||||
self.parameter_builder = parameter_builder
|
||||
self.storage_type = storage_type
|
||||
self.user_rag_memory_id = user_rag_memory_id
|
||||
self.memory_config = memory_config
|
||||
|
||||
logger.info(
|
||||
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
|
||||
)
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the tool with transformed parameters.
|
||||
|
||||
This method:
|
||||
1. Extracts the last message from state
|
||||
2. Extracts tool call ID using state extractors
|
||||
3. Extracts content payload using state extractors
|
||||
4. Builds tool arguments using parameter builder
|
||||
5. Constructs AIMessage with tool_calls
|
||||
6. Invokes the tool and returns the result
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Updated state with tool result in messages
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
logger.debug( self.tool_name)
|
||||
|
||||
if not messages:
|
||||
logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state")
|
||||
return {"messages": [AIMessage(content="Error: No messages in state")]}
|
||||
|
||||
last_message = messages[-1]
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract tool call ID using state extractors
|
||||
tool_call_id = extract_tool_call_id(last_message)
|
||||
logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}"
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
|
||||
|
||||
try:
|
||||
# Extract content payload using state extractors
|
||||
content = extract_content_payload(last_message)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}"
|
||||
)
|
||||
# Log raw message content for debugging
|
||||
if hasattr(last_message, 'content'):
|
||||
raw = last_message.content
|
||||
logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
content = {}
|
||||
|
||||
try:
|
||||
# Build tool arguments using parameter builder
|
||||
tool_args = self.parameter_builder.build_tool_args(
|
||||
tool_name=self.tool_name,
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
search_switch=self.search_switch,
|
||||
apply_id=self.apply_id,
|
||||
group_id=self.group_id,
|
||||
memory_config=self.memory_config,
|
||||
storage_type=self.storage_type,
|
||||
user_rag_memory_id=self.user_rag_memory_id,
|
||||
)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]}
|
||||
|
||||
# Construct tool input message
|
||||
tool_input = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": self.tool_name,
|
||||
"args": tool_args,
|
||||
"id": f"{self.id}_{tool_call_id}",
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
try:
|
||||
# Invoke the tool
|
||||
result = await self.tool_node.ainvoke(tool_input)
|
||||
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution completed"
|
||||
)
|
||||
|
||||
# Check for error in tool response
|
||||
error_entry = None
|
||||
if result and "messages" in result:
|
||||
for msg in result["messages"]:
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
import json
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and "error" in parsed:
|
||||
error_msg = parsed["error"]
|
||||
logger.warning(
|
||||
f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}"
|
||||
)
|
||||
error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Return result with error tracking if error was found
|
||||
if error_entry:
|
||||
result["errors"] = [error_entry]
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Track error in state and return error message
|
||||
from langchain_core.messages import ToolMessage
|
||||
error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id}
|
||||
return {
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Error executing tool: {str(e)}",
|
||||
tool_call_id=f"{self.id}_{tool_call_id}"
|
||||
)
|
||||
],
|
||||
"errors": [error_entry]
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.verification_models import VerificationResult
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
)
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class VerificationNodeService(LLMServiceMixin):
|
||||
"""验证节点服务类"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
data = state.get('data', '')
|
||||
|
||||
# 将 VerificationItem 对象转换为字典列表
|
||||
verified_data = []
|
||||
if messages_deal.expansion_issue:
|
||||
for item in messages_deal.expansion_issue:
|
||||
if hasattr(item, 'model_dump'):
|
||||
verified_data.append(item.model_dump())
|
||||
elif isinstance(item, dict):
|
||||
verified_data.append(item)
|
||||
|
||||
Verify_result = {
|
||||
"status": messages_deal.split_result,
|
||||
"verified_data": verified_data,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "Data Verification",
|
||||
"result": messages_deal.split_result,
|
||||
"reason": messages_deal.reason or "验证完成",
|
||||
"query": messages_deal.query,
|
||||
"verified_count": len(verified_data),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
return Verify_result
|
||||
|
||||
|
||||
async def Verify(state: ReadState):
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", {})
|
||||
logger.info(
|
||||
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||
|
||||
messages = {
|
||||
"Query": content,
|
||||
"Expansion_issue": retrieve_expansion
|
||||
}
|
||||
|
||||
logger.info("Verify: 开始渲染模板")
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = VerificationResult.model_json_schema()
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
template_name='split_verify_prompt.jinja2',
|
||||
operation_name='split_verify_prompt',
|
||||
history=history,
|
||||
sentence=messages,
|
||||
json_schema=json_schema
|
||||
)
|
||||
logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}")
|
||||
|
||||
# 使用优化的LLM服务,添加超时保护
|
||||
logger.info("Verify: 开始调用 LLM")
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
|
||||
with get_db_context() as db_session:
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"query": content,
|
||||
"history": history if isinstance(history, list) else [],
|
||||
"expansion_issue": [],
|
||||
"split_result": "failed",
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
)
|
||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
||||
structured = VerificationResult(
|
||||
query=content,
|
||||
history=history if isinstance(history, list) else [],
|
||||
expansion_issue=[],
|
||||
split_result="failed",
|
||||
reason="LLM调用超时"
|
||||
)
|
||||
|
||||
result = await Verify_prompt(state, structured)
|
||||
logger.info("=== Verify 节点执行完成 ===")
|
||||
return {"verify": result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
|
||||
# 返回失败的验证结果
|
||||
return {
|
||||
"verify": {
|
||||
"status": "failed",
|
||||
"verified_data": [],
|
||||
"storage_type": state.get('storage_type', ''),
|
||||
"user_rag_memory_id": state.get('user_rag_memory_id', ''),
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "Data Verification",
|
||||
"result": "failed",
|
||||
"reason": f"验证过程出错: {str(e)}",
|
||||
"query": state.get('data', ''),
|
||||
"verified_count": 0,
|
||||
"storage_type": state.get('storage_type', ''),
|
||||
"user_rag_memory_id": state.get('user_rag_memory_id', '')
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write_node(state: WriteState) -> WriteState:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages, end_user_id, memory_config, and language
|
||||
|
||||
Returns:
|
||||
dict: Contains 'write_result' with status and data fields
|
||||
"""
|
||||
messages = state.get('messages', [])
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', '')
|
||||
language = state.get('language', 'zh') # 默认中文
|
||||
|
||||
# Convert LangChain messages to structured format expected by write()
|
||||
structured_messages = []
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
||||
# Map LangChain message types to role names
|
||||
role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type
|
||||
structured_messages.append({
|
||||
"role": role,
|
||||
"content": msg.content # content is now guaranteed to be a string
|
||||
})
|
||||
|
||||
try:
|
||||
result = await write(
|
||||
messages=structured_messages,
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
language=language,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
# 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成
|
||||
for lang in ["zh", "en"]:
|
||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=lang,
|
||||
)
|
||||
if deleted:
|
||||
logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||
|
||||
write_result = {
|
||||
"status": "success",
|
||||
"data": structured_messages,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name,
|
||||
}
|
||||
return {"write_result": write_result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data_write failed: {e}", exc_info=True)
|
||||
write_result = {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
}
|
||||
return {"write_result": write_result}
|
||||
@@ -1,469 +1,179 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
#!/usr/bin/env python3
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Literal
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.nodes import (
|
||||
ToolExecutionNode,
|
||||
create_input_message,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import END, START
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
from app.db import get_db
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
load_dotenv()
|
||||
redishost=os.getenv("REDISHOST")
|
||||
redisport=os.getenv('REDISPORT')
|
||||
redisdb=os.getenv('REDISDB')
|
||||
redispassword=os.getenv('REDISPASSWORD')
|
||||
counter = COUNTState(limit=3)
|
||||
|
||||
# Update loop count in workflow
|
||||
async def update_loop_count(state):
|
||||
"""Update loop counter"""
|
||||
current_count = state.get("loop_count", 0)
|
||||
return {"loop_count": current_count + 1}
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
messages = state["messages"]
|
||||
|
||||
# Add boundary check
|
||||
if not messages:
|
||||
return END
|
||||
counter.add(1) # Increment by 1
|
||||
|
||||
loop_count = counter.get_total()
|
||||
logger.debug(f"[should_continue] Current loop count: {loop_count}")
|
||||
|
||||
last_message = messages[-1]
|
||||
last_message_str = str(last_message).replace('\\', '')
|
||||
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
|
||||
logger.debug(f"Status tools: {status_tools}")
|
||||
|
||||
if "success" in status_tools:
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
elif "failed" in status_tools:
|
||||
if loop_count < 2: # Maximum loop count is 3
|
||||
return "content_input"
|
||||
else:
|
||||
counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Add default return value to avoid returning None
|
||||
counter.reset()
|
||||
return "Summary" # Default based on business requirements
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
# Direct dictionary access instead of regex parsing
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
# Handle case where search_switch might be in messages
|
||||
if search_switch is None and "messages" in state:
|
||||
messages = state.get("messages", [])
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
# Try to extract from tool_calls args
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
break
|
||||
|
||||
# Convert to string for comparison if needed
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '0':
|
||||
return 'Verify'
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
# Add default return value to avoid returning None
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
|
||||
|
||||
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
logger.debug(f"Split_continue state: {state}")
|
||||
|
||||
# Direct dictionary access instead of regex parsing
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
# Handle case where search_switch might be in messages
|
||||
if search_switch is None and "messages" in state:
|
||||
messages = state.get("messages", [])
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
# Try to extract from tool_calls args
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
break
|
||||
|
||||
# Convert to string for comparison if needed
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # Default case
|
||||
|
||||
|
||||
class ProblemExtensionNode:
|
||||
def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""):
|
||||
self.tool_node = ToolNode([tool])
|
||||
self.id = id
|
||||
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
self.namespace = namespace
|
||||
self.search_switch = search_switch
|
||||
self.apply_id = apply_id
|
||||
self.group_id = group_id
|
||||
self.storage_type = storage_type
|
||||
self.user_rag_memory_id = user_rag_memory_id
|
||||
|
||||
async def __call__(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1] if messages else ""
|
||||
logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}")
|
||||
if self.tool_name == 'Input_Summary':
|
||||
tool_call = re.findall("'id': '(.*?)'", str(last_message))[0]
|
||||
else:
|
||||
tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
|
||||
|
||||
# Try to extract actual content payload from previous tool result
|
||||
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||
extracted_payload = None
|
||||
# Capture ToolMessage content field (supports single/double quotes), avoid greedy matching
|
||||
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
|
||||
if m:
|
||||
extracted_payload = m.group(1)
|
||||
else:
|
||||
# Fallback: use raw string directly
|
||||
extracted_payload = raw_msg
|
||||
|
||||
# Try to parse content as JSON first
|
||||
try:
|
||||
content = json.loads(extracted_payload)
|
||||
except Exception:
|
||||
# Try to extract JSON fragment from text and parse
|
||||
parsed = None
|
||||
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
|
||||
for cand in candidates:
|
||||
try:
|
||||
parsed = json.loads(cand)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
# If still fails, use raw string as content
|
||||
content = parsed if parsed is not None else extracted_payload
|
||||
|
||||
# Build correct parameters based on tool name
|
||||
tool_args = {}
|
||||
|
||||
if self.tool_name == "Verify":
|
||||
# Verify tool requires context and usermessages parameters
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Retrieve":
|
||||
# Retrieve tool requires context and usermessages parameters
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["search_switch"] = str(self.search_switch)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary":
|
||||
# Summary tool requires string type context parameter
|
||||
if isinstance(content, dict):
|
||||
# Convert dict to JSON string
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary_fails":
|
||||
# Summary_fails tool requires string type context parameter
|
||||
if isinstance(content, dict):
|
||||
# Convert dict to JSON string
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == 'Input_Summary':
|
||||
tool_args["context"] = str(last_message)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["search_switch"] = str(self.search_switch)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
tool_args["storage_type"] = getattr(self, 'storage_type', "")
|
||||
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
|
||||
elif self.tool_name == 'Retrieve_Summary':
|
||||
# Retrieve_Summary expects dict directly, not JSON string
|
||||
# content might be a JSON string, try to parse it
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
# Check if it has a "context" key
|
||||
if isinstance(parsed_content, dict) and "context" in parsed_content:
|
||||
tool_args["context"] = parsed_content["context"]
|
||||
else:
|
||||
tool_args["context"] = parsed_content
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, wrap the string
|
||||
tool_args["context"] = {"content": content}
|
||||
elif isinstance(content, dict):
|
||||
# Check if content has a "context" key that needs unwrapping
|
||||
if "context" in content:
|
||||
tool_args["context"] = content["context"]
|
||||
else:
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": str(content)}
|
||||
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
else:
|
||||
# Other tools use context parameter
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
|
||||
|
||||
tool_input = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": self.tool_name,
|
||||
"args": tool_args,
|
||||
"id": self.id + f"{tool_call}",
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
result = await self.tool_node.ainvoke(tool_input)
|
||||
result_text = str(result)
|
||||
|
||||
return {"messages": [AIMessage(content=result_text)]}
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
|
||||
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
||||
Split_The_Problem,
|
||||
Problem_Extension,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
||||
retrieve,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||
Input_Summary,
|
||||
Retrieve_Summary,
|
||||
Summary_fails,
|
||||
Summary,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.verification_nodes import Verify
|
||||
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
Split_continue,
|
||||
Retrieve_continue,
|
||||
Verify_continue,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None):
|
||||
"""
|
||||
Create a read graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
namespace: Namespace identifier
|
||||
tools: MCP tools loaded from session
|
||||
search_switch: Search mode switch ("0", "1", or "2")
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type (optional)
|
||||
user_rag_memory_id: User RAG memory ID (optional)
|
||||
"""
|
||||
memory = InMemorySaver()
|
||||
tool = [i.name for i in tools]
|
||||
logger.info(f"Initializing read graph with tools: {tool}")
|
||||
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
|
||||
|
||||
# Extract tool functions
|
||||
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
|
||||
Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None)
|
||||
Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None)
|
||||
Verify_ = next((t for t in tools if t.name == "Verify"), None)
|
||||
Summary_ = next((t for t in tools if t.name == "Summary"), None)
|
||||
Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None)
|
||||
Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None)
|
||||
Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None)
|
||||
|
||||
# Instantiate services
|
||||
parameter_builder = ParameterBuilder()
|
||||
multimodal_processor = MultimodalProcessor()
|
||||
|
||||
# Create nodes using new modular components
|
||||
Split_The_Problem_node = ToolNode([Split_The_Problem_])
|
||||
|
||||
Problem_Extension_node = ToolExecutionNode(
|
||||
tool=Problem_Extension_,
|
||||
node_id="Problem_Extension_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
async def make_read_graph():
|
||||
"""创建并返回 LangGraph 工作流"""
|
||||
try:
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow.add_node("content_input", content_input_node)
|
||||
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
||||
workflow.add_node("Problem_Extension", Problem_Extension)
|
||||
workflow.add_node("Input_Summary", Input_Summary)
|
||||
# workflow.add_node("Retrieve", retrieve_nodes)
|
||||
workflow.add_node("Retrieve", retrieve)
|
||||
workflow.add_node("Verify", Verify)
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||
workflow.add_node("Summary", Summary)
|
||||
workflow.add_node("Summary_fails", Summary_fails)
|
||||
|
||||
# 添加边
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
# 编译工作流
|
||||
graph = workflow.compile()
|
||||
yield graph
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建工作流失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
end_user_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
import time
|
||||
start = time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
|
||||
Retrieve_node = ToolExecutionNode(
|
||||
tool=Retrieve_,
|
||||
node_id="Retrieve_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
print(f"处理节点: {node_name}")
|
||||
|
||||
Verify_node = ToolExecutionNode(
|
||||
tool=Verify_,
|
||||
node_id="Verify_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
Summary_node = ToolExecutionNode(
|
||||
tool=Summary_,
|
||||
node_id="Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||
summary = node_data['InputSummary']['summary_result']
|
||||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||||
summary = node_data['RetrieveSummary']['summary_result']
|
||||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||||
summary = node_data['summary']['summary_result']
|
||||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||||
summary = node_data['SummaryFails']['summary_result']
|
||||
|
||||
Summary_fails_node = ToolExecutionNode(
|
||||
tool=Summary_fails_,
|
||||
node_id="Summary_fails_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||
if spit_data and spit_data != [] and spit_data != {}:
|
||||
_intermediate_outputs.append(spit_data)
|
||||
|
||||
Retrieve_Summary_node = ToolExecutionNode(
|
||||
tool=Retrieve_Summary_,
|
||||
node_id="Retrieve_Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
# Problem_Extension 节点
|
||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||
_intermediate_outputs.append(problem_extension)
|
||||
|
||||
Input_Summary_node = ToolExecutionNode(
|
||||
tool=Input_Summary_,
|
||||
node_id="Input_Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
# Retrieve 节点
|
||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
async def content_input_node(state):
|
||||
state_search_switch = state.get("search_switch", search_switch)
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
|
||||
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
_intermediate_outputs.append(summary_n)
|
||||
|
||||
return await create_input_message(
|
||||
state=state,
|
||||
tool_name=tool_name,
|
||||
session_id=f"{session_prefix}_{namespace}",
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
multimodal_processor=multimodal_processor,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
# # 过滤掉空值
|
||||
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||
#
|
||||
# # 优化搜索结果
|
||||
# print("=== 开始优化搜索结果 ===")
|
||||
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
# result=reorder_output_results(optimized_outputs)
|
||||
# # 保存优化后的结果到文件
|
||||
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
|
||||
# import json
|
||||
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
|
||||
#
|
||||
print(f"=== 最终摘要 ===")
|
||||
print(summary)
|
||||
|
||||
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow.add_node("content_input", content_input_node)
|
||||
workflow.add_node("Split_The_Problem", Split_The_Problem_node)
|
||||
workflow.add_node("Problem_Extension", Problem_Extension_node)
|
||||
workflow.add_node("Retrieve", Retrieve_node)
|
||||
workflow.add_node("Verify", Verify_node)
|
||||
workflow.add_node("Summary", Summary_node)
|
||||
workflow.add_node("Summary_fails", Summary_fails_node)
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary_node)
|
||||
workflow.add_node("Input_Summary", Input_Summary_node)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
# Add edges using imported routers
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
end = time.time()
|
||||
print(100 * 'y')
|
||||
print(f"总耗时: {end - start}s")
|
||||
print(100 * 'y')
|
||||
|
||||
graph = workflow.compile(checkpointer=memory)
|
||||
yield graph
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""LangGraph routing logic."""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
Verify_continue,
|
||||
Retrieve_continue,
|
||||
Split_continue,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Verify_continue",
|
||||
"Retrieve_continue",
|
||||
"Split_continue",
|
||||
]
|
||||
@@ -1,123 +1,61 @@
|
||||
"""
|
||||
Routing functions for LangGraph conditional edges.
|
||||
|
||||
This module provides routing functions that determine the next node to execute
|
||||
based on state values. All functions return Literal types for type safety.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global counter for Verify routing
|
||||
logger = get_agent_logger(__name__)
|
||||
counter = COUNTState(limit=3)
|
||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
logger.debug(f"Split_continue state: {state}")
|
||||
search_switch = state.get('search_switch', '')
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
search_switch = state.get('search_switch', '')
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '0':
|
||||
return 'Verify'
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
"""
|
||||
Determine routing after Verify node based on verification result.
|
||||
|
||||
This function checks the verification result in the last message and routes to:
|
||||
- Summary: if verification succeeded
|
||||
- content_input: if verification failed and retry limit not reached
|
||||
- Summary_fails: if verification failed and retry limit reached
|
||||
|
||||
Args:
|
||||
state: LangGraph state containing messages
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# Boundary check
|
||||
if not messages:
|
||||
logger.warning("[Verify_continue] No messages in state, defaulting to Summary")
|
||||
counter.reset()
|
||||
status=state.get('verify', '')['status']
|
||||
# loop_count = counter.get_total()
|
||||
if "success" in status:
|
||||
# counter.reset()
|
||||
return "Summary"
|
||||
|
||||
# Increment counter
|
||||
counter.add(1)
|
||||
loop_count = counter.get_total()
|
||||
logger.debug(f"[Verify_continue] Current loop count: {loop_count}")
|
||||
|
||||
# Extract verification result from last message
|
||||
last_message = messages[-1]
|
||||
last_message_str = str(last_message).replace('\\', '')
|
||||
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
|
||||
logger.debug(f"[Verify_continue] Status tools: {status_tools}")
|
||||
|
||||
# Route based on verification result
|
||||
if "success" in status_tools:
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
elif "failed" in status_tools:
|
||||
if loop_count < 2: # Max retry count is 2
|
||||
return "content_input"
|
||||
else:
|
||||
counter.reset()
|
||||
return "Summary_fails"
|
||||
elif "failed" in status:
|
||||
# if loop_count < 2: # Maximum loop count is 3
|
||||
# return "content_input"
|
||||
# else:
|
||||
# counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Default to Summary if status is unclear
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
|
||||
|
||||
def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing after Retrieve node based on search_switch value.
|
||||
|
||||
This function routes based on the search_switch parameter:
|
||||
- search_switch == '0': Route to Verify (verification needed)
|
||||
- search_switch == '1': Route to Retrieve_Summary (direct summary)
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
search_switch = extract_search_switch(state)
|
||||
|
||||
logger.debug(f"[Retrieve_continue] search_switch: {search_switch}")
|
||||
|
||||
if search_switch == '0':
|
||||
return 'Verify'
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
# Default to Retrieve_Summary
|
||||
logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary")
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
|
||||
def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing after content_input node based on search_switch value.
|
||||
|
||||
This function routes based on the search_switch parameter:
|
||||
- search_switch == '2': Route to Input_Summary (direct input summary)
|
||||
- Otherwise: Route to Split_The_Problem (problem decomposition)
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
logger.debug(f"[Split_continue] state keys: {state.keys()}")
|
||||
|
||||
search_switch = extract_search_switch(state)
|
||||
|
||||
logger.debug(f"[Split_continue] search_switch: {search_switch}")
|
||||
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
|
||||
# Default to Split_The_Problem
|
||||
return 'Split_The_Problem'
|
||||
# Add default return value to avoid returning None
|
||||
# counter.reset()
|
||||
return "Summary" # Default based on business requirements
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# 如果是 JSON 字符串,先解析
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: JSON 字符串格式的消息列表
|
||||
str(actual_config_id), # config_id: 配置ID字符串
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||
long_messages = await messages_parse(long_time_data)
|
||||
repo.upsert(end_user_id, long_messages)
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
|
||||
'''根据窗口'''
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
'''
|
||||
根据窗口获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
langchain_messages:原始数据LIST
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||
is_end_user_id += 1
|
||||
langchain_messages += redis_messages
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""根据时间"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
'''
|
||||
根据时间获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
'''聚合判断'''
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
history = []
|
||||
else:
|
||||
history = await format_parsing(result)
|
||||
json_schema = WriteAggregateModel.model_json_schema()
|
||||
template_service = TemplateService(template_root)
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='write_aggregate_judgment.jinja2',
|
||||
operation_name='aggregate_judgment',
|
||||
history=history,
|
||||
sentence=ori_messages,
|
||||
json_schema=json_schema
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
factory = MemoryClientFactory(db_session)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": system_prompt
|
||||
}
|
||||
]
|
||||
structured = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=WriteAggregateModel
|
||||
)
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
result_dict = {
|
||||
"is_same_event": structured.is_same_event,
|
||||
"output": output_value
|
||||
}
|
||||
if not structured.is_same_event:
|
||||
logger.info(result_dict)
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
"""LangGraph state management utilities."""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||
extract_search_switch,
|
||||
extract_tool_call_id,
|
||||
extract_content_payload,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"extract_search_switch",
|
||||
"extract_tool_call_id",
|
||||
"extract_content_payload",
|
||||
]
|
||||
@@ -1,179 +0,0 @@
|
||||
"""
|
||||
State extraction utilities for type-safe access to LangGraph state values.
|
||||
|
||||
This module provides utility functions for extracting values from LangGraph state
|
||||
dictionaries with proper error handling and sensible defaults.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_search_switch(state: dict) -> Optional[str]:
|
||||
"""
|
||||
Extract search_switch from state or messages.
|
||||
"""
|
||||
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
|
||||
# Try to extract from messages
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 从最新的消息开始查找
|
||||
for message in reversed(messages):
|
||||
# 尝试从 tool_calls 中提取
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
# 从 tool_call 的 args 中提取
|
||||
if "args" in tool_call and isinstance(tool_call["args"], dict):
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
# 直接从 tool_call 中提取
|
||||
search_switch = tool_call.get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
|
||||
# 尝试从 content 中提取(如果是 JSON 格式)
|
||||
if hasattr(message, "content"):
|
||||
try:
|
||||
import json
|
||||
if isinstance(message.content, str):
|
||||
content_data = json.loads(message.content)
|
||||
if isinstance(content_data, dict):
|
||||
search_switch = content_data.get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_tool_call_id(message: Any) -> str:
|
||||
"""
|
||||
Extract tool call ID from message using structured attributes.
|
||||
|
||||
This function extracts the tool call ID from a message object, handling both
|
||||
direct attribute access and tool_calls list structures.
|
||||
|
||||
Args:
|
||||
message: Message object (typically ToolMessage or AIMessage)
|
||||
|
||||
Returns:
|
||||
Tool call ID as string
|
||||
|
||||
Raises:
|
||||
ValueError: If tool call ID cannot be extracted
|
||||
|
||||
Examples:
|
||||
>>> message = ToolMessage(content="...", tool_call_id="call_123")
|
||||
>>> extract_tool_call_id(message)
|
||||
'call_123'
|
||||
"""
|
||||
# Try direct attribute access for ToolMessage
|
||||
if hasattr(message, "tool_call_id"):
|
||||
tool_call_id = message.tool_call_id
|
||||
if tool_call_id:
|
||||
return str(tool_call_id)
|
||||
|
||||
# Try extracting from tool_calls list for AIMessage
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
tool_call = message.tool_calls[0]
|
||||
if isinstance(tool_call, dict) and "id" in tool_call:
|
||||
return str(tool_call["id"])
|
||||
|
||||
# Try extracting from id attribute
|
||||
if hasattr(message, "id"):
|
||||
message_id = message.id
|
||||
if message_id:
|
||||
return str(message_id)
|
||||
|
||||
# If all else fails, raise an error
|
||||
raise ValueError(f"Could not extract tool call ID from message: {type(message)}")
|
||||
|
||||
|
||||
def extract_content_payload(message: Any) -> Any:
|
||||
"""
|
||||
Extract content payload from ToolMessage, parsing JSON if needed.
|
||||
|
||||
This function extracts the content from a message and attempts to parse it as JSON
|
||||
if it appears to be a JSON string. It handles various message formats and provides
|
||||
sensible fallbacks.
|
||||
|
||||
Args:
|
||||
message: Message object (typically ToolMessage)
|
||||
|
||||
Returns:
|
||||
Parsed content (dict, list, or str)
|
||||
|
||||
Examples:
|
||||
>>> message = ToolMessage(content='{"key": "value"}')
|
||||
>>> extract_content_payload(message)
|
||||
{'key': 'value'}
|
||||
|
||||
>>> message = ToolMessage(content='plain text')
|
||||
>>> extract_content_payload(message)
|
||||
'plain text'
|
||||
"""
|
||||
# Extract raw content
|
||||
# For ToolMessages (responses from tools), extract from content
|
||||
if hasattr(message, "content"):
|
||||
raw_content = message.content
|
||||
logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}")
|
||||
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(raw_content, list):
|
||||
for block in raw_content:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
raw_content = block.get('text', '')
|
||||
logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}")
|
||||
break
|
||||
|
||||
# If content is empty and this is an AIMessage with tool_calls,
|
||||
# extract from args (this handles the initial tool call from content_input)
|
||||
if not raw_content and hasattr(message, "tool_calls") and message.tool_calls:
|
||||
tool_call = message.tool_calls[0]
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
return tool_call["args"]
|
||||
else:
|
||||
raw_content = str(message)
|
||||
|
||||
# If content is already a dict or list, return it directly
|
||||
if isinstance(raw_content, (dict, list)):
|
||||
logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}")
|
||||
return raw_content
|
||||
|
||||
# Try to parse as JSON
|
||||
if isinstance(raw_content, str):
|
||||
# First, try direct JSON parsing
|
||||
try:
|
||||
parsed = json.loads(raw_content)
|
||||
logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# If that fails, try to extract JSON from the string
|
||||
# This handles cases where the content is embedded in a larger string
|
||||
import re
|
||||
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
|
||||
for candidate in json_candidates:
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
|
||||
# If all parsing attempts fail, return the raw content
|
||||
logger.info(f"extract_content_payload: returning raw content (parsing failed)")
|
||||
return raw_content
|
||||
321
api/app/core/memory/agent/langgraph_graph/tools/tool.py
Normal file
321
api/app/core/memory/agent/langgraph_graph/tools/tool.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from app.core.memory.src.search import (
|
||||
search_by_temporal,
|
||||
search_by_keyword_temporal,
|
||||
)
|
||||
|
||||
def extract_tool_message_content(response):
|
||||
"""从agent响应中提取ToolMessage内容和工具名称"""
|
||||
messages = response.get('messages', [])
|
||||
|
||||
for message in messages:
|
||||
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
||||
# 这是一个ToolMessage
|
||||
tool_content = message.content
|
||||
tool_name = None
|
||||
|
||||
# 尝试获取工具名称
|
||||
if hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
elif hasattr(message, 'tool_name'):
|
||||
tool_name = message.tool_name
|
||||
|
||||
try:
|
||||
# 解析JSON内容
|
||||
parsed_content = json.loads(tool_content)
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': parsed_content
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# 如果不是JSON格式,直接返回内容
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': tool_content
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
"""
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
"""
|
||||
清理时间搜索结果中不需要的字段,并修改结构
|
||||
|
||||
Args:
|
||||
data: 要清理的数据
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
fields_to_remove = {
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'valid_at', 'invalid_at', 'statement_ids'
|
||||
}
|
||||
|
||||
if isinstance(data, dict):
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
||||
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
|
||||
cleaned_value = clean_temporal_result_fields(value)
|
||||
# 进一步将内部的 statements 改为 time_search
|
||||
if 'statements' in cleaned_value:
|
||||
cleaned['results'] = {
|
||||
'time_search': cleaned_value['statements']
|
||||
}
|
||||
else:
|
||||
cleaned['results'] = cleaned_value
|
||||
elif key not in fields_to_remove:
|
||||
cleaned[key] = clean_temporal_result_fields(value)
|
||||
return cleaned
|
||||
elif isinstance(data, list):
|
||||
return [clean_temporal_result_fields(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
"""
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
results = await search_by_temporal(
|
||||
end_user_id=actual_end_user_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=10
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
cleaned_results = results
|
||||
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
@tool
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询内容
|
||||
- days_back: 向前搜索的天数,默认7天
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
- end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
"""
|
||||
async def _async_search():
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||
|
||||
# 关键词时间搜索
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
end_user_id=end_user_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=15
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
cleaned_results = results
|
||||
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
return TimeRetrievalWithGroupId
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"""
|
||||
创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||
"""
|
||||
|
||||
def clean_result_fields(data):
|
||||
"""
|
||||
递归清理结果中不需要的字段
|
||||
|
||||
Args:
|
||||
data: 要清理的数据(可能是字典、列表或其他类型)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
}
|
||||
|
||||
if isinstance(data, dict):
|
||||
# 对字典进行清理
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key not in fields_to_remove:
|
||||
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
|
||||
return cleaned
|
||||
elif isinstance(data, list):
|
||||
# 对列表中的每个元素进行清理
|
||||
return [clean_result_fields(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
|
||||
@tool
|
||||
async def HybridSearch(
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
clean_output: bool = True # 新增:是否清理输出字段
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
"""
|
||||
try:
|
||||
# 导入run_hybrid_search函数
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
# 合并参数,优先使用传入的参数
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
"memory_config": memory_config,
|
||||
"rerank_alpha": rerank_alpha,
|
||||
"use_forgetting_rerank": use_forgetting_rerank,
|
||||
"use_llm_rerank": use_llm_rerank
|
||||
}
|
||||
|
||||
# 执行混合检索
|
||||
raw_results = await run_hybrid_search(**final_params)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
if clean_output:
|
||||
cleaned_results = clean_result_fields(raw_results)
|
||||
else:
|
||||
cleaned_results = raw_results
|
||||
|
||||
# 格式化返回结果
|
||||
formatted_results = {
|
||||
"search_query": context,
|
||||
"search_type": search_type,
|
||||
"results": cleaned_results
|
||||
}
|
||||
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"error": f"混合检索失败: {str(e)}",
|
||||
"search_query": context,
|
||||
"search_type": search_type,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||
|
||||
return HybridSearch
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"""
|
||||
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数
|
||||
"""
|
||||
@tool
|
||||
def HybridSearchSync(
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
"""
|
||||
async def _async_search():
|
||||
# 创建异步工具并执行
|
||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||
return await async_tool.ainvoke({
|
||||
"context": context,
|
||||
"search_type": search_type,
|
||||
"limit": limit,
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
return HybridSearchSync
|
||||
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
"""
|
||||
格式化解析消息列表
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
type: 返回类型 ('string' 或 'dict')
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
for history_messag in hstory_messages.strip().splitlines():
|
||||
history_messag = json.loads(history_messag)
|
||||
for content in history_messag:
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
return result
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
for data in Query:
|
||||
role = data['role']
|
||||
if role == "human":
|
||||
user.append(data['content'])
|
||||
if role == "ai":
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{user_content}"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"{ai_content}"
|
||||
}
|
||||
|
||||
]
|
||||
return messages
|
||||
@@ -3,17 +3,18 @@ import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from app.db import get_db, get_db_context
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
@@ -21,60 +22,83 @@ if sys.platform.startswith("win"):
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig):
|
||||
async def make_write_graph():
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
logger.info("Loading MCP tools: %s", [t.name for t in tools])
|
||||
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
|
||||
|
||||
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
|
||||
|
||||
if not data_write_tool:
|
||||
logger.error("Data_write tool not found", exc_info=True)
|
||||
raise ValueError("Data_write tool not found")
|
||||
|
||||
write_node = ToolNode([data_write_tool])
|
||||
|
||||
async def call_model(state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
content = last_message[1] if isinstance(last_message, tuple) else last_message.content
|
||||
|
||||
# Call Data_write directly with memory_config
|
||||
write_params = {
|
||||
"content": content,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"memory_config": memory_config,
|
||||
}
|
||||
logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}")
|
||||
|
||||
write_result = await data_write_tool.ainvoke(write_params)
|
||||
|
||||
if isinstance(write_result, dict):
|
||||
result_content = write_result.get("data", str(write_result))
|
||||
else:
|
||||
result_content = str(write_result)
|
||||
logger.info("Write content: %s", result_content)
|
||||
return {"messages": [AIMessage(content=result_content)]}
|
||||
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("content_input", call_model)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_edge("content_input", "save_neo4j")
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
workflow.add_edge("save_neo4j", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
|
||||
yield graph
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
# 获取数据库会话
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
"""方案三:聚合判断"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
else:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
# langchain_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "今天周五去爬山"
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "好耶"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
@@ -1,28 +0,0 @@
|
||||
"""
|
||||
MCP Server package for memory agent.
|
||||
|
||||
This package provides the FastMCP server implementation with context-based
|
||||
dependency injection for tool functions.
|
||||
|
||||
Package structure:
|
||||
- server: FastMCP server initialization and context setup
|
||||
- tools: MCP tool implementations
|
||||
- models: Pydantic response models
|
||||
- services: Business logic services
|
||||
"""
|
||||
# from app.core.memory.agent.mcp_server.server import (
|
||||
# mcp,
|
||||
# initialize_context,
|
||||
# main,
|
||||
# get_context_resource
|
||||
# )
|
||||
|
||||
# # Import tools to register them (but don't export them)
|
||||
# from app.core.memory.agent.mcp_server import tools
|
||||
|
||||
# __all__ = [
|
||||
# 'mcp',
|
||||
# 'initialize_context',
|
||||
# 'main',
|
||||
# 'get_context_resource',
|
||||
# ]
|
||||
@@ -1,11 +0,0 @@
|
||||
"""
|
||||
MCP Server Instance
|
||||
|
||||
This module contains the FastMCP server instance that is shared across all modules.
|
||||
It's in a separate file to avoid circular import issues.
|
||||
"""
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Initialize FastMCP server instance
|
||||
# This instance is shared across all tool modules
|
||||
mcp = FastMCP('data_flow')
|
||||
@@ -1,14 +0,0 @@
|
||||
"""Pydantic models for verification operations."""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VerificationResult(BaseModel):
|
||||
"""Result model for verification operation."""
|
||||
|
||||
query: str
|
||||
expansion_issue: List[Dict[str, Any]]
|
||||
split_result: str
|
||||
reason: Optional[str] = None
|
||||
history: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
@@ -1,159 +0,0 @@
|
||||
"""
|
||||
MCP Server initialization with FastMCP context setup.
|
||||
|
||||
This module initializes the FastMCP server and registers shared resources
|
||||
in the context for dependency injection into tool functions.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.services.search_service import SearchService
|
||||
from app.core.memory.agent.mcp_server.services.session_service import SessionService
|
||||
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
def get_context_resource(ctx, resource_name: str):
|
||||
"""
|
||||
Helper function to retrieve a resource from the FastMCP context.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP Context object (passed to tool functions)
|
||||
resource_name: Name of the resource to retrieve
|
||||
|
||||
Returns:
|
||||
The requested resource
|
||||
|
||||
Raises:
|
||||
AttributeError: If the resource doesn't exist
|
||||
|
||||
Example:
|
||||
@mcp.tool()
|
||||
async def my_tool(ctx: Context):
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
"""
|
||||
if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None:
|
||||
raise RuntimeError("Context does not have fastmcp attribute")
|
||||
|
||||
if not hasattr(ctx.fastmcp, resource_name):
|
||||
raise AttributeError(
|
||||
f"Resource '{resource_name}' not found in context. "
|
||||
f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}"
|
||||
)
|
||||
|
||||
return getattr(ctx.fastmcp, resource_name)
|
||||
|
||||
|
||||
def initialize_context():
|
||||
"""
|
||||
Initialize and register shared resources in FastMCP context.
|
||||
|
||||
This function sets up all shared resources that will be available
|
||||
to tool functions via dependency injection through the context parameter.
|
||||
|
||||
Resources are stored as attributes on the FastMCP instance and can be
|
||||
accessed via ctx.fastmcp in tool functions.
|
||||
|
||||
Resources registered:
|
||||
- session_store: RedisSessionStore for session management
|
||||
- llm_client: LLM client for structured API calls
|
||||
- app_settings: Application settings (renamed to avoid conflict with FastMCP settings)
|
||||
- template_service: Service for template rendering
|
||||
- search_service: Service for hybrid search
|
||||
- session_service: Service for session operations
|
||||
"""
|
||||
try:
|
||||
# Register Redis session store
|
||||
logger.info("Registering session_store in context")
|
||||
mcp.session_store = store
|
||||
|
||||
# Note: LLM client is NOT loaded at server startup
|
||||
# It should be loaded dynamically when needed, with config_id passed explicitly
|
||||
# to make_write_graph or make_read_graph functions
|
||||
logger.info("LLM client will be loaded dynamically with config_id when needed")
|
||||
mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id
|
||||
|
||||
# Register application settings (renamed to avoid conflict with FastMCP's settings)
|
||||
logger.info("Registering app_settings in context")
|
||||
mcp.app_settings = settings
|
||||
|
||||
# Register template service
|
||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
||||
# logger.info(f"Registering template_service in context with root: {template_root}")
|
||||
template_service = TemplateService(template_root)
|
||||
mcp.template_service = template_service
|
||||
|
||||
# Register search service
|
||||
# logger.info("Registering search_service in context")
|
||||
search_service = SearchService()
|
||||
mcp.search_service = search_service
|
||||
|
||||
# Register session service
|
||||
# logger.info("Registering session_service in context")
|
||||
session_service = SessionService(store)
|
||||
mcp.session_service = session_service
|
||||
|
||||
# logger.info("All context resources registered successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize context: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the MCP server.
|
||||
|
||||
Initializes context and starts the server with SSE transport.
|
||||
"""
|
||||
try:
|
||||
logger.info("Starting MCP server initialization")
|
||||
# Initialize context resources
|
||||
initialize_context()
|
||||
|
||||
# Import and register tools (imports trigger tool registration)
|
||||
from app.core.memory.agent.mcp_server.tools import ( # noqa: F401
|
||||
data_tools,
|
||||
problem_tools,
|
||||
retrieval_tools,
|
||||
summary_tools,
|
||||
verification_tools,
|
||||
)
|
||||
|
||||
# Tools are registered via imports above
|
||||
|
||||
# Get MCP port from environment (default: 8081)
|
||||
mcp_port = int(os.getenv("MCP_PORT", "8081"))
|
||||
logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
|
||||
|
||||
# Configure DNS rebinding protection for Docker container compatibility
|
||||
from mcp.server.fastmcp.server import TransportSecuritySettings
|
||||
|
||||
# Disable DNS rebinding protection to allow Docker container hostnames
|
||||
# This allows containers to connect using service names like 'mcp-server'
|
||||
mcp.settings.transport_security = TransportSecuritySettings(
|
||||
enable_dns_rebinding_protection=False,
|
||||
)
|
||||
logger.info("DNS rebinding protection: disabled for Docker container compatibility")
|
||||
|
||||
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
|
||||
|
||||
# Run the server with SSE transport for HTTP connections
|
||||
import uvicorn
|
||||
app = mcp.sse_app()
|
||||
uvicorn.run(app, host=settings.SERVER_IP, port=mcp_port, log_level="info")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,27 +0,0 @@
|
||||
"""
|
||||
MCP Tools module.
|
||||
|
||||
This module contains all MCP tool implementations organized by functionality.
|
||||
|
||||
Tools are organized into the following modules:
|
||||
- problem_tools: Question segmentation and extension
|
||||
- retrieval_tools: Database and context retrieval
|
||||
- verification_tools: Data verification
|
||||
- summary_tools: Summarization and summary retrieval
|
||||
- data_tools: Data type differentiation and writing
|
||||
"""
|
||||
|
||||
# Import all tool modules to register them with the MCP server
|
||||
from . import problem_tools
|
||||
from . import retrieval_tools
|
||||
from . import verification_tools
|
||||
from . import summary_tools
|
||||
from . import data_tools
|
||||
|
||||
__all__ = [
|
||||
'problem_tools',
|
||||
'retrieval_tools',
|
||||
'verification_tools',
|
||||
'summary_tools',
|
||||
'data_tools',
|
||||
]
|
||||
@@ -1,155 +0,0 @@
|
||||
"""
|
||||
Data Tools for data type differentiation and writing.
|
||||
|
||||
This module contains MCP tools for distinguishing data types and writing data.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.models.retrieval_models import (
|
||||
DistinguishTypeResponse,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Data_type_differentiation(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Distinguish the type of data (read or write).
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Text to analyze for type differentiation
|
||||
memory_config: MemoryConfig object containing LLM configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' with the original text and 'type' field
|
||||
"""
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
|
||||
# Get LLM client from memory_config using factory pattern
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='distinguish_types_prompt.jinja2',
|
||||
operation_name='status_typle',
|
||||
user_query=context
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Data_type_differentiation: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=DistinguishTypeResponse
|
||||
)
|
||||
|
||||
result = structured.model_dump()
|
||||
|
||||
# Add context to result
|
||||
result["context"] = context
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Data_type_differentiation: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": context,
|
||||
"type": "error",
|
||||
"message": f"LLM call failed: {str(e)}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Data_type_differentiation failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": context,
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Data_write(
|
||||
ctx: Context,
|
||||
content: str,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
content: Data content to write
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status', 'saved_to', and 'data' fields
|
||||
"""
|
||||
try:
|
||||
# Ensure output directory exists
|
||||
os.makedirs("data_output", exist_ok=True)
|
||||
file_path = os.path.join("data_output", "user_data.csv")
|
||||
|
||||
# Write data - clients are constructed inside write() from memory_config
|
||||
await write(
|
||||
content=content,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"saved_to": file_path,
|
||||
"data": content,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data_write failed: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
"""
|
||||
Problem Tools for question segmentation and extension.
|
||||
|
||||
This module contains MCP tools for breaking down and extending user questions.
|
||||
LLM clients are constructed from MemoryConfig when needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.models.problem_models import (
|
||||
ProblemBreakdownResponse,
|
||||
ProblemExtensionResponse,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Split_The_Problem(
|
||||
ctx: Context,
|
||||
sentence: str,
|
||||
sessionid: str,
|
||||
messages_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Segment the dialogue or sentence into sub-problems.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
sentence: Original sentence to split
|
||||
sessionid: Session identifier
|
||||
messages_id: Message identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' (JSON string of split results) and 'original' sentence
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Extract user ID from session
|
||||
user_id = session_service.resolve_user_id(sessionid)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(user_id, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
history = []
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='problem_breakdown_prompt.jinja2',
|
||||
operation_name='split_the_problem',
|
||||
history=history,
|
||||
sentence=sentence
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Split_The_Problem: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": sentence,
|
||||
"error": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=ProblemBreakdownResponse
|
||||
)
|
||||
|
||||
# Handle RootModel response with .root attribute access
|
||||
if structured is None:
|
||||
# LLM returned None, use empty list as fallback
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
elif hasattr(structured, 'root') and structured.root is not None:
|
||||
split_result = json.dumps(
|
||||
[item.model_dump() for item in structured.root],
|
||||
ensure_ascii=False
|
||||
)
|
||||
elif isinstance(structured, list):
|
||||
# Fallback: treat structured itself as the list
|
||||
split_result = json.dumps(
|
||||
[item.model_dump() for item in structured],
|
||||
ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
# Last resort: use empty list
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Split_The_Problem: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
|
||||
logger.info("Problem splitting")
|
||||
logger.info(f"Problem split result: {split_result}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
"context": split_result,
|
||||
"original": sentence,
|
||||
"_intermediate": {
|
||||
"type": "problem_split",
|
||||
"data": json.loads(split_result) if split_result else [],
|
||||
"original_query": sentence
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Split_The_Problem failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": sentence,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Problem splitting', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Problem_Extension(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Extend the problem with additional sub-questions.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing split problem results
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' (aggregated questions) and 'original' question
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID from usermessages
|
||||
from app.core.memory.agent.utils.messages_tool import Resolve_username
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
history = []
|
||||
|
||||
# Process context to extract questions
|
||||
extent_quest, original = await Problem_Extension_messages_deal(context)
|
||||
|
||||
# Format questions for template rendering
|
||||
questions_formatted = []
|
||||
for msg in extent_quest:
|
||||
if msg.get("role") == "user":
|
||||
questions_formatted.append(msg.get("content", ""))
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Problem_Extension_prompt.jinja2',
|
||||
operation_name='problem_extension',
|
||||
history=history,
|
||||
questions=questions_formatted
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Problem_Extension: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {},
|
||||
"original": original,
|
||||
"error": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
response_content = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=ProblemExtensionResponse
|
||||
)
|
||||
|
||||
# Aggregate results by original question
|
||||
aggregated_dict = {}
|
||||
for item in response_content.root:
|
||||
key = getattr(item, "original_question", None) or (
|
||||
item.get("original_question") if isinstance(item, dict) else None
|
||||
)
|
||||
value = getattr(item, "extended_question", None) or (
|
||||
item.get("extended_question") if isinstance(item, dict) else None
|
||||
)
|
||||
if not key or not value:
|
||||
continue
|
||||
aggregated_dict.setdefault(key, []).append(value)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Problem_Extension: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aggregated_dict = {}
|
||||
|
||||
logger.info("Problem extension")
|
||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
"context": aggregated_dict,
|
||||
"original": original,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "problem_extension",
|
||||
"data": aggregated_dict,
|
||||
"original_query": original,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Problem_Extension failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {},
|
||||
"original": context.get("original", ""),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Problem extension', duration)
|
||||
@@ -1,294 +0,0 @@
|
||||
"""
|
||||
Retrieval Tools for database and context retrieval.
|
||||
|
||||
This module contains MCP tools for retrieving data using hybrid search.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
load_dotenv()
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Retrieve(
|
||||
ctx: Context,
|
||||
context,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Retrieve data from the database using hybrid search.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary or string containing query information
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' with Query and Expansion_issue results
|
||||
"""
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id": os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
start = time.time()
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}")
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
search_service = get_context_resource(ctx, 'search_service')
|
||||
|
||||
databases_anser = []
|
||||
|
||||
# Handle both dict and string context
|
||||
if isinstance(context, dict):
|
||||
# Process dict context with extended questions
|
||||
all_items = []
|
||||
logger.info(f"Retrieve: context keys={list(context.keys())}")
|
||||
content, original = await Retriev_messages_deal(context)
|
||||
logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}")
|
||||
logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'")
|
||||
|
||||
if not original:
|
||||
logger.warning(f"Retrieve: original query is empty! context={context}")
|
||||
|
||||
# Extract all query items from content
|
||||
# content is like {original_question: [extended_questions...], ...}
|
||||
for key, values in content.items():
|
||||
if isinstance(values, list):
|
||||
all_items.extend(values)
|
||||
elif isinstance(values, str):
|
||||
all_items.append(values)
|
||||
elif values is not None:
|
||||
# Fallback: convert non-empty non-list values to string
|
||||
all_items.append(str(values))
|
||||
|
||||
# Execute search for each question
|
||||
for idx, question in enumerate(all_items):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": question,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query=question
|
||||
raw_results=clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
clean_content = ''
|
||||
raw_results=''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
|
||||
databases_anser.append({
|
||||
"Query_small": cleaned_query,
|
||||
"Result_small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": idx + 1,
|
||||
"total": len(all_items)
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for question '{question}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Continue with empty result for this question
|
||||
databases_anser.append({
|
||||
"Query_small": question,
|
||||
"Result_small": ""
|
||||
})
|
||||
|
||||
# Build initial database data structure
|
||||
databases_data = {
|
||||
"Query": original,
|
||||
"Expansion_issue": databases_anser
|
||||
}
|
||||
|
||||
# Collect intermediate outputs before deduplication
|
||||
intermediate_outputs = []
|
||||
for item in databases_anser:
|
||||
if '_intermediate' in item:
|
||||
intermediate_outputs.append(item['_intermediate'])
|
||||
|
||||
# Deduplicate and merge results
|
||||
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
|
||||
deduplicated_data_merged = merge_to_key_value_pairs(
|
||||
deduplicated_data,
|
||||
'Query_small',
|
||||
'Result_small'
|
||||
)
|
||||
|
||||
# Restructure for Verify/Retrieve_Summary compatibility
|
||||
keys, val = [], []
|
||||
for item in deduplicated_data_merged:
|
||||
for items_key, items_value in item.items():
|
||||
keys.append(items_key)
|
||||
val.append(items_value)
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val, strict=False):
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
})
|
||||
|
||||
dup_databases = {
|
||||
"Query": original,
|
||||
"Expansion_issue": send_verify,
|
||||
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
|
||||
}
|
||||
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
|
||||
else:
|
||||
# Handle string context (simple query)
|
||||
query = str(context).strip()
|
||||
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": query,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query = query
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = query
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
# Keep structure for Verify/Retrieve_Summary compatibility
|
||||
dup_databases = {
|
||||
"Query": cleaned_query,
|
||||
"Expansion_issue": [{
|
||||
"Query_small": cleaned_query,
|
||||
"Answer_Small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": 1,
|
||||
"total": 1
|
||||
}
|
||||
}]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for query '{query}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty results on failure
|
||||
dup_databases = {
|
||||
"Query": query,
|
||||
"Expansion_issue": []
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
|
||||
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
|
||||
)
|
||||
|
||||
# Build result with intermediate outputs
|
||||
result = {
|
||||
"context": dup_databases,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
# Add intermediate outputs list if they exist
|
||||
intermediate_outputs = dup_databases.get('_intermediate_outputs', [])
|
||||
if intermediate_outputs:
|
||||
result['_intermediates'] = intermediate_outputs
|
||||
logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result")
|
||||
else:
|
||||
logger.warning("No intermediate outputs found in dup_databases")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {
|
||||
"Query": "",
|
||||
"Expansion_issue": []
|
||||
},
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval', duration)
|
||||
@@ -1,666 +0,0 @@
|
||||
"""
|
||||
Summary Tools for data summarization.
|
||||
|
||||
This module contains MCP tools for summarizing retrieved data and generating responses.
|
||||
LLM clients are constructed from MemoryConfig when needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Resolve_username,
|
||||
Summary_messages_deal,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
load_dotenv()
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Summary(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize the verified data.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: JSON string containing verified data
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'summary_result'
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Process context to extract answer and query
|
||||
answer_small, query = await Summary_messages_deal(context)
|
||||
|
||||
|
||||
start_time= time.time()
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
end_time=time.time()
|
||||
logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}")
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": answer_small
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Summary: initialization failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"summary_result": "信息不足,无法回答"
|
||||
}
|
||||
|
||||
try:
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='summary_prompt.jinja2',
|
||||
operation_name='summary',
|
||||
data=data,
|
||||
query=query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# Call LLM with structured response
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=SummaryResponse
|
||||
)
|
||||
|
||||
aimessages = structured.query_answer or ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = ""
|
||||
|
||||
try:
|
||||
# Save session
|
||||
if aimessages != "":
|
||||
await session_service.save_session(
|
||||
user_id=sessionid,
|
||||
query=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
# Use fallback if empty
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"Summary after verification: {aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Summary', duration)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Retrieve_Summary(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize data directly from retrieval results.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing Query and Expansion_issue from Retrieve
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'summary_result'
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
|
||||
|
||||
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
|
||||
logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}")
|
||||
|
||||
if isinstance(context, dict):
|
||||
if "content" in context:
|
||||
inner = context["content"]
|
||||
# If it's a JSON string, parse it
|
||||
if isinstance(inner, str):
|
||||
try:
|
||||
parsed = json.loads(inner)
|
||||
logger.info("Retrieve_Summary: successfully parsed JSON")
|
||||
except json.JSONDecodeError:
|
||||
# Try unescaping first
|
||||
try:
|
||||
unescaped = inner.encode('utf-8').decode('unicode_escape')
|
||||
parsed = json.loads(unescaped)
|
||||
logger.info("Retrieve_Summary: parsed after unescaping")
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: parsing failed even after unescape: {e}"
|
||||
)
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
parsed = None
|
||||
|
||||
if parsed:
|
||||
# Check if parsed has 'context' wrapper
|
||||
if isinstance(parsed, dict) and "context" in parsed:
|
||||
context_dict = parsed["context"]
|
||||
else:
|
||||
context_dict = parsed
|
||||
elif isinstance(inner, dict):
|
||||
context_dict = inner
|
||||
else:
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
elif "context" in context:
|
||||
context_dict = context["context"] if isinstance(context["context"], dict) else context
|
||||
else:
|
||||
context_dict = context
|
||||
else:
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
|
||||
query = context_dict.get("Query", "")
|
||||
expansion_issue = context_dict.get("Expansion_issue", [])
|
||||
|
||||
logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}")
|
||||
logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}")
|
||||
|
||||
# Extract retrieve_info from expansion_issue
|
||||
retrieve_info = []
|
||||
for item in expansion_issue:
|
||||
# Check for both Answer_Small and Answer_Small (typo) for backward compatibility
|
||||
answer = None
|
||||
if isinstance(item, dict):
|
||||
if "Answer_Small" in item:
|
||||
answer = item["Answer_Small"]
|
||||
|
||||
|
||||
if answer is not None:
|
||||
# Handle both string and list formats
|
||||
if isinstance(answer, list):
|
||||
# Join list of characters/strings into a single string
|
||||
retrieve_info.append(''.join(str(x) for x in answer))
|
||||
elif isinstance(answer, str):
|
||||
retrieve_info.append(answer)
|
||||
else:
|
||||
retrieve_info.append(str(answer))
|
||||
|
||||
# Join all retrieve_info into a single string
|
||||
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
|
||||
|
||||
start_time=time.time()
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
end_time=time.time()
|
||||
logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: initialization failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"summary_result": "信息不足,无法回答"
|
||||
}
|
||||
|
||||
try:
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='retrieve_summary',
|
||||
query=query,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info_str
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Retrieve_Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# Call LLM with structured response
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
|
||||
# Handle case where structured response might be None or incomplete
|
||||
if structured and hasattr(structured, 'data') and structured.data:
|
||||
aimessages = structured.data.query_answer or ""
|
||||
else:
|
||||
logger.warning("Structured response is None or incomplete, using default message")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
# Check for insufficient information response
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
|
||||
# Save session
|
||||
await session_service.save_session(
|
||||
user_id=sessionid,
|
||||
query=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: LLM call failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = ""
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
# Use fallback if empty
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"Summary after retrieval: {aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "retrieval_summary",
|
||||
"summary": aimessages,
|
||||
"query": query,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Input_Summary(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a quick summary for direct input without verification.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: String containing the input sentence
|
||||
usermessages: User messages identifier
|
||||
search_switch: Search switch value for routing ('2' for summaries only)
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
Returns:
|
||||
dict: Contains 'query_answer' with the summary result
|
||||
"""
|
||||
start = time.time()
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
search_service = get_context_resource(ctx, "search_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
sessionid = sessionid.replace('call_id_', '')
|
||||
|
||||
start_time=time.time()
|
||||
history = await session_service.get_history(
|
||||
str(sessionid),
|
||||
str(apply_id),
|
||||
str(group_id)
|
||||
)
|
||||
end_time=time.time()
|
||||
logger.info(f"Input_Summary-REDIS搜索:{end_time - start_time}")
|
||||
# Override with empty list for now (as in original)
|
||||
|
||||
# Log the raw context for debugging
|
||||
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
|
||||
|
||||
# Extract sentence from context
|
||||
# Context can be a string or might contain the sentence in various formats
|
||||
try:
|
||||
# Try to parse as JSON first
|
||||
if isinstance(context, str) and (context.startswith('{') or context.startswith('[')):
|
||||
try:
|
||||
import json
|
||||
context_dict = json.loads(context)
|
||||
if isinstance(context_dict, dict):
|
||||
query = context_dict.get('sentence', context_dict.get('content', context))
|
||||
else:
|
||||
query = context
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON, try regex
|
||||
match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context)
|
||||
query = match.group(1) if match else context
|
||||
else:
|
||||
query = context
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract query from context: {e}")
|
||||
query = context
|
||||
|
||||
# Clean query
|
||||
query = str(query).strip().strip("\"'")
|
||||
|
||||
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
|
||||
|
||||
# Execute search based on search_switch and storage_type
|
||||
try:
|
||||
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
|
||||
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": query,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
|
||||
# Retrieval
|
||||
if search_switch == '2':
|
||||
search_params["include"] = ["summaries"]
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
raw_results = []
|
||||
retrieve_info = ""
|
||||
kb_config={
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id":os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
|
||||
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
retrieve_info = '\n\n'.join(retrieval_knowledge)
|
||||
raw_results=[retrieve_info]
|
||||
logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
retrieve_info=''
|
||||
raw_results=['']
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
logger.info("Input_Summary: Using summary for retrieval")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: hybrid_search failed, using empty results: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
retrieve_info, question, raw_results = "", query, []
|
||||
|
||||
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='input_summary',
|
||||
query=query,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
aimessages = structured.data.query_answer or "信息不足,无法回答"
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: response_structured failed, using default answer: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "input_summary",
|
||||
"title": "快速答案",
|
||||
"summary": aimessages,
|
||||
"query": query,
|
||||
"raw_results": raw_results,
|
||||
"search_mode": "quick_search",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "fail",
|
||||
"summary_result": "信息不足,无法回答",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Summary_fails(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Handle workflow failure when summary cannot be generated.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Failure context string
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'query_answer' with failure message
|
||||
"""
|
||||
try:
|
||||
# Extract services from context
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
|
||||
# Parse session ID from usermessages
|
||||
usermessages_parts = usermessages.split('_')[1:]
|
||||
sessionid = '_'.join(usermessages_parts[:-1])
|
||||
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
logger.info("没有相关数据")
|
||||
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": "没有相关数据",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Summary_fails failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "fail",
|
||||
"summary_result": "没有相关数据",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -1,174 +0,0 @@
|
||||
"""
|
||||
Verification Tools for data verification.
|
||||
|
||||
This module contains MCP tools for verifying retrieved data.
|
||||
"""
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Resolve_username,
|
||||
Retrieve_verify_tool_messages_deal,
|
||||
Verify_messages_deal,
|
||||
)
|
||||
from app.core.memory.agent.utils.verify_tool import VerifyTool
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from jinja2 import Template
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Verify(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Verify the retrieved data.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing query and expansion issues
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'verified_data' with verification results
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
|
||||
# Load verification prompt template
|
||||
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2'
|
||||
|
||||
# Read template file directly (VerifyTool expects raw template content)
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
system_prompt = await read_template_file(file_path)
|
||||
|
||||
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
|
||||
template = Template(system_prompt)
|
||||
system_prompt = template.render(history=history, sentence=context)
|
||||
|
||||
# Process context to extract query and results
|
||||
Query_small, Result_small, query = await Verify_messages_deal(context)
|
||||
|
||||
# Build query list for verification
|
||||
query_list = []
|
||||
for query_small, anser in zip(Query_small, Result_small, strict=False):
|
||||
query_list.append({
|
||||
'Query_small': query_small,
|
||||
'Answer_Small': anser
|
||||
})
|
||||
|
||||
messages = {
|
||||
"Query": query,
|
||||
"Expansion_issue": query_list
|
||||
}
|
||||
|
||||
|
||||
|
||||
# Call verification workflow with LLM model ID from memory_config
|
||||
verify_tool = VerifyTool(
|
||||
system_prompt=system_prompt,
|
||||
verify_data=messages,
|
||||
llm_model_id=str(memory_config.llm_model_id)
|
||||
)
|
||||
verify_result = await verify_tool.verify()
|
||||
|
||||
# Parse LLM verification result with error handling
|
||||
try:
|
||||
messages_deal = await Retrieve_verify_tool_messages_deal(
|
||||
verify_result,
|
||||
history,
|
||||
query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_verify_tool_messages_deal parsing failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Fallback to avoid 500 errors
|
||||
messages_deal = {
|
||||
"data": {
|
||||
"query": query,
|
||||
"expansion_issue": []
|
||||
},
|
||||
"split_result": "failed",
|
||||
"reason": str(e),
|
||||
"history": history,
|
||||
}
|
||||
|
||||
logger.info(f"Verification result: {messages_deal}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"verified_data": messages_deal,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "Data Verification",
|
||||
"result": messages_deal.get("split_result", "unknown"),
|
||||
"reason": messages_deal.get("reason", ""),
|
||||
"query": query,
|
||||
"verified_count": len(query_list),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Verify failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"verified_data": {
|
||||
"data": {
|
||||
"query": "",
|
||||
"expansion_issue": []
|
||||
},
|
||||
"split_result": "failed",
|
||||
"reason": str(e),
|
||||
"history": [],
|
||||
}
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Verification', duration)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user